diff --git a/yolo/tools/data_loader.py b/yolo/tools/data_loader.py index 785235a9..48288d69 100644 --- a/yolo/tools/data_loader.py +++ b/yolo/tools/data_loader.py @@ -34,7 +34,7 @@ def __init__(self, data_cfg: DataConfig, dataset_cfg: DatasetConfig, phase: str self.transform.get_more_data = self.get_more_data self.data = self.load_data(Path(dataset_cfg.path), phase_name) - def load_data(self, dataset_path: Path, phase_name: str): + def load_data(self, dataset_path: Path, phase_name: str) -> list: """ Loads data from a cache or generates a new cache for a specific dataset phase. @@ -43,7 +43,7 @@ def load_data(self, dataset_path: Path, phase_name: str): phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for. Returns: - dict: The loaded data from the cache for the specified phase. + list: The loaded data from the cache for the specified phase. """ cache_path = dataset_path / f"{phase_name}.cache" @@ -58,38 +58,48 @@ def load_data(self, dataset_path: Path, phase_name: str): def filter_data(self, dataset_path: Path, phase_name: str) -> list: """ - Filters and collects dataset information by pairing images with their corresponding labels. + Filters and collects dataset information by pairing images with + their corresponding labels. Parameters: - images_path (Path): Path to the directory containing image files. - labels_path (str): Path to the directory containing label files. + dataset_path (Path): The root path to the dataset directory. + phase_name (str): The specific phase of the dataset + (e.g., 'train', 'test') to load or generate data for. Returns: - list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor. + list: A list of tuples, each containing image id, path to an image file + and its associated segmentation as a tensor. For COCO formatted .json + files, image id is the `int` `image_id` attribute for each annotation + in the json file. + For YOLO formatted .txt files, image id is the image file name without + the extension. """ images_path = dataset_path / "images" / phase_name labels_path, data_type = locate_label_paths(dataset_path, phase_name) images_list = sorted([p.name for p in Path(images_path).iterdir() if p.is_file()]) if data_type == "json": - annotations_index, image_info_dict = create_image_metadata(labels_path) - + ( + annotations_dict, + image_info_dict, + image_name_to_id_dict + ) = create_image_metadata(labels_path) data = [] valid_inputs = 0 for image_name in track(images_list, description="Filtering data"): if not image_name.lower().endswith((".jpg", ".jpeg", ".png")): continue - image_id = Path(image_name).stem - if data_type == "json": + image_id = image_name_to_id_dict[image_name] image_info = image_info_dict.get(image_id, None) if image_info is None: continue - annotations = annotations_index.get(image_info["id"], []) + annotations = annotations_dict.get(image_id, []) image_seg_annotations = scale_segmentation(annotations, image_info) if not image_seg_annotations: continue elif data_type == "txt": + image_id = Path(image_name).stem label_path = labels_path / f"{image_id}.txt" if not label_path.is_file(): continue @@ -99,19 +109,24 @@ def filter_data(self, dataset_path: Path, phase_name: str) -> list: image_seg_annotations = [] labels = self.load_valid_labels(image_id, image_seg_annotations) - - img_path = images_path / image_name - data.append((img_path, labels)) + image_path = images_path / image_name + data.append((image_id, image_path, labels)) valid_inputs += 1 logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list)) return data - def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]: + def load_valid_labels( + self, + image_id: Union[int, str], + seg_data_one_img: list + ) -> Union[Tensor, None]: """ Loads and validates bounding box data is [0, 1] from a label file. Parameters: - label_path (str): The filepath to the label file containing bounding box data. + image_id (int | str): Image id. + If COCO .json file is used, image id is a `int`. + If YOLO .txt file is used, image id is a string. Returns: Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None. @@ -128,22 +143,22 @@ def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Te if bboxes: return torch.stack(bboxes) else: - logger.warning("No valid BBox in {}", label_path) + logger.warning("No valid BBox in image id:{}", image_id) return torch.zeros((0, 5)) def get_data(self, idx): - img_path, bboxes = self.data[idx] + image_id, img_path, bboxes = self.data[idx] img = Image.open(img_path).convert("RGB") - return img, bboxes, img_path + return img, bboxes, image_id def get_more_data(self, num: int = 1): indices = torch.randint(0, len(self), (num,)) return [self.get_data(idx)[:2] for idx in indices] def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]: - img, bboxes, img_path = self.get_data(idx) + img, bboxes, image_id = self.get_data(idx) img, bboxes, rev_tensor = self.transform(img, bboxes) - return img, bboxes, rev_tensor, img_path + return img, bboxes, rev_tensor, image_id def __len__(self) -> int: return len(self.data) @@ -189,11 +204,11 @@ def collate_fn(self, batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[T batch_targets[idx, : min(target_size, 100)] = batch[idx][1][:100] batch_targets[:, :, 1:] *= self.image_size - batch_images, _, batch_reverse, batch_path = zip(*batch) + batch_images, _, batch_reverse, batch_image_ids = zip(*batch) batch_images = torch.stack(batch_images) batch_reverse = torch.stack(batch_reverse) - return batch_size, batch_images, batch_targets, batch_reverse, batch_path + return batch_size, batch_images, batch_targets, batch_reverse, batch_image_ids def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: str = "train", use_ddp: bool = False): diff --git a/yolo/tools/solver.py b/yolo/tools/solver.py index 51ceffc0..d0a69f9a 100644 --- a/yolo/tools/solver.py +++ b/yolo/tools/solver.py @@ -237,7 +237,7 @@ def solve(self, dataloader, epoch_idx=1): self.model.eval() predict_json, mAPs = [], defaultdict(list) self.progress.start_one_epoch(len(dataloader), task="Validate") - for batch_size, images, targets, rev_tensor, img_paths in dataloader: + for batch_size, images, targets, rev_tensor, image_ids in dataloader: images, targets, rev_tensor = images.to(self.device), targets.to(self.device), rev_tensor.to(self.device) with torch.no_grad(): predicts = self.model(images) @@ -250,7 +250,7 @@ def solve(self, dataloader, epoch_idx=1): avg_mAPs = {key: 100 * torch.mean(torch.stack(val)) for key, val in mAPs.items()} self.progress.one_batch(avg_mAPs) - predict_json.extend(predicts_to_json(img_paths, predicts, rev_tensor)) + predict_json.extend(predicts_to_json(image_ids, predicts, rev_tensor)) self.progress.finish_one_epoch(avg_mAPs, epoch_idx=epoch_idx) self.progress.visualize_image(images, targets, predicts, epoch_idx=epoch_idx) diff --git a/yolo/utils/dataset_utils.py b/yolo/utils/dataset_utils.py index a6c6e1fd..0e02ef7f 100644 --- a/yolo/utils/dataset_utils.py +++ b/yolo/utils/dataset_utils.py @@ -37,47 +37,67 @@ def locate_label_paths(dataset_path: Path, phase_name: Path) -> Tuple[Path, Path return [], None -def create_image_metadata(labels_path: str) -> Tuple[Dict[str, List], Dict[str, Dict]]: +def create_image_metadata( + labels_path: str +) -> Tuple[Dict[int, List], Dict[int, Dict], Dict[str, int]]: """ - Create a dictionary containing image information and annotations indexed by image ID. + Returns three dictionaries mapping image id to list of annotations, + image id to image information, and image name to image id. + Image id is the `int` `id` assigned to a image in the COCO formatted .json file. Args: labels_path (str): The path to the annotation json file. Returns: - - annotations_index: A dictionary where keys are image IDs and values are lists of annotations. - - image_info_dict: A dictionary where keys are image file names without extension and values are image information dictionaries. + (annotations_dict, image_info_dict, image_name_to_id_dict): + annotations_dict is a dictionary where keys are image ids and values + are lists of annotation dictionaries. + image_info_dict is a dictionary where keys are image file id and + values are image information dictionaries. + image_name_to_id_dict is a dictionary with image name without + extension as key and int image id as value. """ with open(labels_path, "r") as file: - labels_data = json.load(file) - id_to_idx = discretize_categories(labels_data.get("categories", [])) if "categories" in labels_data else None - annotations_index = organize_annotations_by_image(labels_data, id_to_idx) # check lookup is a good name? - image_info_dict = {Path(img["file_name"]).stem: img for img in labels_data["images"]} - return annotations_index, image_info_dict + json_data = json.load(file) + image_name_to_id_dict = { + Path(img["file_name"]).name: img['id'] for img in json_data["images"] + } + id_to_idx = discretize_categories(json_data.get("categories", [])) if "categories" in json_data else None + annotations_dict = organize_annotations_by_image(json_data, id_to_idx) # check lookup is a good name? + image_info_dict = {img['id']: img for img in json_data["images"]} + return annotations_dict, image_info_dict, image_name_to_id_dict + + +def organize_annotations_by_image( + json_data: Dict[str, Any], + category_id_to_idx: Optional[Dict[int, int]], +) -> dict[int, list[dict]]: + """ + Returns a dict mapping image id to a list of all corresponding annotations. + Annotations with "iscrowd" set to True, are excluded. Image id is the `int` + `image_id` in the corresponding annotation dict stored in the + COCO formatted .json file. -def organize_annotations_by_image(data: Dict[str, Any], id_to_idx: Optional[Dict[int, int]]): - """ - Use image index to lookup every annotations Args: - data (Dict[str, Any]): A dictionary containing annotation data. - + json_data: Data read from a COCO json file. + category_id_to_idx: For COCO dataset, a dict mapping from category_id + to (category_id - 1). Returns: - Dict[int, List[Dict[str, Any]]]: A dictionary where keys are image IDs and values are lists of annotations. - Annotations with "iscrowd" set to True are excluded from the index. - + image_name_to_annotation_dict_list: A dictionary where keys are image ids + and values are lists of annotation dictionaries. """ - annotation_lookup = {} - for anno in data["annotations"]: - if anno["iscrowd"]: + image_id_to_annotation_dict_list = {} + for annotation_dict in json_data["annotations"]: + if annotation_dict["iscrowd"]: continue - image_id = anno["image_id"] - if id_to_idx: - anno["category_id"] = id_to_idx[anno["category_id"]] - if image_id not in annotation_lookup: - annotation_lookup[image_id] = [] - annotation_lookup[image_id].append(anno) - return annotation_lookup + image_id = annotation_dict["image_id"] + if category_id_to_idx: + annotation_dict["category_id"] = category_id_to_idx[annotation_dict["category_id"]] + if image_id not in image_id_to_annotation_dict_list: + image_id_to_annotation_dict_list[image_id] = [] + image_id_to_annotation_dict_list[image_id].append(annotation_dict) + return image_id_to_annotation_dict_list def scale_segmentation( diff --git a/yolo/utils/model_utils.py b/yolo/utils/model_utils.py index c35b6009..834ab7b0 100644 --- a/yolo/utils/model_utils.py +++ b/yolo/utils/model_utils.py @@ -160,19 +160,32 @@ def collect_prediction(predict_json: List, local_rank: int) -> List: return predict_json -def predicts_to_json(img_paths, predicts, rev_tensor): +def predicts_to_json( + image_ids:Union[tuple[int], tuple[str]], + predicts:list[Tensor], + rev_tensor:Tensor +) -> list[dict[str, any]]: """ - TODO: function document - turn a batch of imagepath and predicts(n x 6 for each image) to a List of diction(Detection output) + Returns a list of prediction dictionaries. Each dict contains, image_id, + category_id, bbox and score. + + Args: + image_ids: Tuple of image ids. + When using a COCO .json annotation file, image ids are int. + When using YOLO .txt annotation files, image ids are string. + predicts: For each iamge, contains a tensor of shape (n, 6), + where n is the number of detected bbox in the corresponding image. + rev_tensor: A tensor of shape (m,5), where m is the number of images. + TODO: add docstring of what this is. """ batch_json = [] - for img_path, bboxes, box_reverse in zip(img_paths, predicts, rev_tensor): + for image_id, bboxes, box_reverse in zip(image_ids, predicts, rev_tensor): scale, shift = box_reverse.split([1, 4]) bboxes[:, 1:5] = (bboxes[:, 1:5] - shift[None]) / scale[None] bboxes[:, 1:5] = transform_bbox(bboxes[:, 1:5], "xyxy -> xywh") for cls, *pos, conf in bboxes: bbox = { - "image_id": int(Path(img_path).stem), + "image_id": image_id, "category_id": IDX_TO_ID[int(cls)], "bbox": [float(p) for p in pos], "score": float(conf),