diff --git a/yolo/model/yolo.py b/yolo/model/yolo.py index 42d722087..7e3b1c044 100644 --- a/yolo/model/yolo.py +++ b/yolo/model/yolo.py @@ -135,27 +135,35 @@ def save_load_weights(self, weights: Union[Path, OrderedDict]): weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False) if "model_state_dict" in weights: weights = weights["model_state_dict"] + if "state_dict" in weights: + weights = weights["state_dict"] - model_state_dict = self.model.state_dict() + model_state_dict = self.state_dict() # TODO1: autoload old version weight # TODO2: weight transform if num_class difference error_dict = {"Mismatch": set(), "Not Found": set()} for model_key, model_weight in model_state_dict.items(): - if model_key not in weights: + + weights_key = model_key + if weights_key not in weights: # .ckpt + weights_key = "model." + model_key + if weights_key not in weights: # .pt old + weights_key = model_key[6:] + if weights_key not in weights: error_dict["Not Found"].add(tuple(model_key.split(".")[:-2])) continue - if model_weight.shape != weights[model_key].shape: + if model_weight.shape != weights[weights_key].shape: error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2])) continue - model_state_dict[model_key] = weights[model_key] + model_state_dict[model_key] = weights[weights_key] for error_name, error_set in error_dict.items(): for weight_name in error_set: logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}") - self.model.load_state_dict(model_state_dict) + self.load_state_dict(model_state_dict) def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO: diff --git a/yolo/tools/data_loader.py b/yolo/tools/data_loader.py index c44f00c68..bf2903f17 100644 --- a/yolo/tools/data_loader.py +++ b/yolo/tools/data_loader.py @@ -234,7 +234,7 @@ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: st class StreamDataLoader: - def __init__(self, data_cfg: DataConfig): + def __init__(self, data_cfg: DataConfig, asynchronous: bool = True): self.source = data_cfg.source self.running = True self.is_stream = isinstance(self.source, int) or str(self.source).lower().startswith("rtmp://") @@ -249,8 +249,12 @@ def __init__(self, data_cfg: DataConfig): else: self.source = Path(self.source) self.queue = Queue() - self.thread = Thread(target=self.load_source) - self.thread.start() + + if asynchronous: + self.thread = Thread(target=self.load_source) + self.thread.start() + else: + self.load_source() def load_source(self): if self.source.is_dir(): # image folder @@ -272,20 +276,22 @@ def process_image(self, image_path): image = Image.open(image_path).convert("RGB") if image is None: raise ValueError(f"Error loading image: {image_path}") - self.process_frame(image) + self.process_frame(image, image_path) def load_video_file(self, video_path): import cv2 cap = cv2.VideoCapture(str(video_path)) + frame_idx = 0 while self.running: ret, frame = cap.read() if not ret: break - self.process_frame(frame) + self.process_frame(frame, f"{video_path.stem}_frame{frame_idx:04d}.png") + frame_idx += 1 cap.release() - def process_frame(self, frame): + def process_frame(self, frame, image_path): if isinstance(frame, np.ndarray): # TODO: we don't need cv2 import cv2 @@ -297,9 +303,9 @@ def process_frame(self, frame): frame = frame[None] rev_tensor = rev_tensor[None] if not self.is_stream: - self.queue.put((frame, rev_tensor, origin_frame)) + self.queue.put((frame, rev_tensor, origin_frame, image_path)) else: - self.current_frame = (frame, rev_tensor, origin_frame) + self.current_frame = (frame, rev_tensor, origin_frame, image_path) def __iter__(self) -> Generator[Tensor, None, None]: return self @@ -310,7 +316,7 @@ def __next__(self) -> Tensor: if not ret: self.stop() raise StopIteration - self.process_frame(frame) + self.process_frame(frame, "stream_frame.png") return self.current_frame else: try: diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index c20b1ab3d..f86f82d73 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -6,7 +6,7 @@ from yolo.config.config import Config from yolo.model.yolo import create_model -from yolo.tools.data_loader import create_dataloader +from yolo.tools.data_loader import StreamDataLoader, create_dataloader from yolo.tools.drawer import draw_bboxes from yolo.tools.loss_functions import create_loss_function from yolo.utils.bounding_box_utils import create_converter, to_metrics_format @@ -112,7 +112,9 @@ def __init__(self, cfg: Config): super().__init__(cfg) self.cfg = cfg # TODO: Add FastModel - self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task) + # StreamDataLoader has to be synchronous, otherwise not all images are loaded + # TODO: Make this load in parallel + self.predict_loader = StreamDataLoader(cfg.task.data, asynchronous=False) def setup(self, stage): self.vec2box = create_converter( @@ -124,15 +126,29 @@ def predict_dataloader(self): return self.predict_loader def predict_step(self, batch, batch_idx): - images, rev_tensor, origin_frame = batch + images, rev_tensor, origin_frame, image_path = batch predicts = self.post_process(self(images), rev_tensor=rev_tensor) img = draw_bboxes(origin_frame, predicts, idx2label=self.cfg.dataset.class_list) + if getattr(self.predict_loader, "is_stream", None): fps = self._display_stream(img) else: fps = None + if getattr(self.cfg.task, "save_predict", None): self._save_image(img, batch_idx) + + output_txt_file = Path(getattr(self.cfg, "out_path")) / f"results.txt" + + # save predics to file img.name .txt, space separated + with open(output_txt_file, "a") as f: + for bboxes in predicts: + for bbox in bboxes: + class_id, x_min, y_min, x_max, y_max, *conf = [float(val) for val in bbox] + f.write(f"{image_path.name} {int(class_id)} {x_min} {y_min} {x_max} {y_max} {conf[0]}\n") + + print(f"💾 Saved predictions at {output_txt_file}") + return img, fps def _save_image(self, img, batch_idx): diff --git a/yolo/utils/logging_utils.py b/yolo/utils/logging_utils.py index f60410d4f..9f577253b 100644 --- a/yolo/utils/logging_utils.py +++ b/yolo/utils/logging_utils.py @@ -107,7 +107,8 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch: Any, batch_idx: epoch_descript = "[cyan]Train [white]|" batch_descript = "[green]Train [white]|" metrics = self.get_metrics(trainer, pl_module) - metrics.pop("v_num") + if "v_num" in metrics: + metrics.pop("v_num") for metrics_name, metrics_val in metrics.items(): if "Loss_step" in metrics_name: epoch_descript += f"{metrics_name.removesuffix('_step').split('/')[1]: ^9}|"