diff --git a/docs/source/User-guide/Classify/Train.rst b/docs/source/User-guide/Classify/Train.rst index cd35a86c..0d1b43f0 100644 --- a/docs/source/User-guide/Classify/Train.rst +++ b/docs/source/User-guide/Classify/Train.rst @@ -30,7 +30,7 @@ For example, if you have set up your directory as recommended in our `Input Guid .. admonition:: Advanced usage :class: dropdown - Other arguments you may want to specify when adding metadata to your images include: + Other arguments you may want to specify when loading your annotations include: - ``delimiter`` - By default, this is set to "\t" so will assume your ``csv`` file is tab delimited. You will need to specify the ``delimiter`` argument if your file is saved in another format. - ``id_col``, ``patch_paths_col``, ``label_col`` - These are used to indicate the column headings for the columns which contain image IDs, patch file paths and labels respectively. By default, these are set to "image_id", "image_path" and "label". @@ -184,31 +184,41 @@ To split your annotated images and create your dataloaders, use: By default, this will split your annotated images using the :ref:`default train:val:test ratios` and apply the :ref:`default image transforms` to each by calling the ``.create_datasets()`` method. It will then create a dataloader for each dataset, using a batch size of 16 and the :ref:`default sampler`. -To change the ratios used to split your annotations, you can specify ``frac_train``, ``frac_val`` and ``frac_test``: +To change the batch size used when creating your dataloaders, use the ``batch_size`` argument: .. code-block:: python #EXAMPLE - dataloaders = annotated_images.create_dataloaders(frac_train=0.6, frac_val=0.3, frac_test=0.1) + dataloaders = annotated_images.create_dataloaders(batch_size=24) -This will result in a split of 60% (train), 30% (val) and 10% (test). +.. admonition:: Advanced usage + :class: dropdown -To change the batch size used when creating your dataloaders, use the ``batch_size`` argument: + Other arguments you may want to specify when creating your dataloaders include: + + - ``sampler`` - By default, this is set to ``default`` and so the :ref:`default sampler` will be used when creating your dataloaders and batches. You can choose not to use a sampler by specifying ``sampler=None`` or, you can define a custom sampler using `pytorch's sampler class `__. + - ``shuffle`` - If your datasets are ordered (e.g. ``"a","a","a","a","b","c"``), you can use ``shuffle=True`` to create dataloaders which contain shuffled batches of data. This cannot be used in conjunction with a sampler and so, by default, ``shuffle=False``. + + +If you would like to use custom settings when creating your datasets, you should call the ``create_datasets()`` method directly instead of via the ``create_dataloaders()`` method. +You should then run the ``create_dataloaders()`` method afterwards to create your dataloaders as before. + +For example, to change the ratios used to split your annotations, you can specify ``frac_train``, ``frac_val`` and ``frac_test``: .. code-block:: python #EXAMPLE - dataloaders = annotated_images.create_dataloaders(batch_size=24) + annotated_images.create_datasets(frac_train=0.6, frac_val=0.3, frac_test=0.1) + dataloaders = annotated_images.create_dataloaders() + +This will result in a split of 60% (train), 30% (val) and 10% (test). .. admonition:: Advanced usage - :class: dropdown - Other arguments you may want to specify when adding metadata to your images include: + Other arguments you may want to specify when creating your datasets include: - - ``sampler`` - By default, this is set to ``default`` and so the :ref:`default sampler` will be used when creating your dataloaders and batches. You can choose not to use a sampler by specifying ``sampler=None`` or, you can define a custom sampler using `pytorch's sampler class `__. - - ``shuffle`` - If your datasets are ordered (e.g. ``"a","a","a","a","b","c"``), you can use ``shuffle=True`` to create dataloaders which contain shuffled batches of data. This cannot be used in conjunction with a sampler and so, by default, ``shuffle=False``. - ``train_transform``, ``val_transform`` and ``test_transform`` - By default, these are set to "train", "val" and "test" respectively and so the :ref:`default image transforms` for each of these sets are applied to the images. You can define your own transforms, using `torchvision's transforms module `__, and apply these to your datasets by specifying the ``train_transform``, ``val_transform`` and ``test_transform`` arguments. - + - ``context_dataset`` - By default, this is set to ``False`` and so only the patches themselves are used as inputs to the model. Setting ``context_dataset=True`` will result in datasets which return both the patches and their context as inputs for the model. Train ------ @@ -336,6 +346,30 @@ There are a number of options for the ``model`` argument: .. note:: You will need to install the `timm `__ library to do this (``pip install timm``). +.. admonition:: Context models + :class: dropdown + + If you have created context datasets, you will need to load two models (one for processing patches and one for processing patches plus context) using the methods above. + You should then pass these models to MapReaders ``twoParrallelModels`` class which combines their outputs through one fully connected layer: + + .. code:: python + + # define fc layer inputs and output + import torch + + fc_layer = torch.nn.Linear(1004, len(annotated_images.labels_map)) + + The number of inputs to your fully connected layer should be the sum of the number of outputs from your two models and the number of outputs should be the number of classes (labels) you are using. + + Your models and ``fc_layer`` should then be used to set up your custom model: + + .. code:: python + + from mapreader.classify.custom_models import twoParrallelModels + + my_model = twoParrallelModels(patch_model, context_model, fc_layer) + + Define criterion, optimizer and scheduler ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -381,7 +415,7 @@ In order to train/fine-tune your model, will need to define: You should change this to suit your needs. The ``params2optimize`` argument can be used to select which parameters to optimize during training. - By default, this is set to ``"infer"``, meaning that all trainable parameters will be optimized. + By default, this is set to ``"default"``, meaning that all trainable parameters will be optimized. When training/fine-tuning your model, you can either use one learning rate for all layers in your neural network or define layerwise learning rates (i.e. different learning rates for each layer in your neural network). Normally, when fine-tuning pre-trained models, layerwise learning rates are favoured, with smaller learning rates assigned to the first layers and larger learning rates assigned to later layers. @@ -401,6 +435,8 @@ In order to train/fine-tune your model, will need to define: #EXAMPLE params2optimize = my_classifier.generate_layerwise_lrs(min_lr=1e-4, max_lr=1e-3, spacing="geomspace") + .. note:: If you are using a context model, you should also set ``parameter_groups=True`` when running the ``generate_layerwise_lrs()`` method. This will ensure the two branches of your models are optimized properly. + You should then pass your ``params2optimize`` list to the ``.initialize_optimizer()`` method: .. code-block:: python diff --git a/mapreader/__init__.py b/mapreader/__init__.py index 25d63b7a..4534161c 100644 --- a/mapreader/__init__.py +++ b/mapreader/__init__.py @@ -10,7 +10,6 @@ from mapreader.classify.datasets import PatchDataset from mapreader.classify.datasets import PatchContextDataset from mapreader.classify.classifier import ClassifierContainer -from mapreader.classify.classifier_context import ClassifierContextContainer from mapreader.classify import custom_models from mapreader.process import process diff --git a/mapreader/annotate/annotator.py b/mapreader/annotate/annotator.py index 87abb9f2..2d660a1a 100644 --- a/mapreader/annotate/annotator.py +++ b/mapreader/annotate/annotator.py @@ -134,7 +134,7 @@ def __init__( raise ValueError( "[ERROR] ``patch_df`` must be a path to a csv or a pandas DataFrame." ) - self._eval_df(patch_df) # eval tuples/lists in df + patch_df = self._eval_df(patch_df) # eval tuples/lists in df if parent_df is not None: if isinstance(parent_df, str): @@ -150,7 +150,7 @@ def __init__( raise ValueError( "[ERROR] ``parent_df`` must be a path to a csv or a pandas DataFrame." ) - self._eval_df(parent_df) # eval tuples/lists in df + parent_df = self._eval_df(parent_df) # eval tuples/lists in df if patch_df is None: # If we don't get patch data provided, we'll use the patches and parents to create the dataframes @@ -183,7 +183,6 @@ def __init__( # Add label column if not present if label_col not in patch_df.columns: patch_df[label_col] = None - patch_df["changed"] = False # Check for image paths column if patch_paths_col not in patch_df.columns: @@ -214,47 +213,26 @@ def __init__( # Ensure unique values in list labels = sorted(set(labels), key=labels.index) - # Test for existing file + # Test for existing patch annotation file if os.path.exists(annotations_file): - print(f"[INFO] Loading existing annotations for {username}.") - existing_annotations = pd.read_csv( - annotations_file, index_col=0, sep=delimiter + print("[INFO] Loading existing patch annotations.") + patch_df = self._load_annotations( + patch_df=patch_df, + annotations_file=annotations_file, + labels=labels, + col=label_col, + delimiter=delimiter, ) - if label_col not in existing_annotations.columns: - raise ValueError( - f"[ERROR] Your existing annotations do not have the label column: {label_col}." - ) - - print(existing_annotations[label_col].dtype) - - if existing_annotations[label_col].dtype == int: - # convert label indices (ints) to labels (strings) - # this is to convert old annotations format to new annotations format - existing_annotations[label_col] = existing_annotations[label_col].apply( - lambda x: labels[x] - ) - - patch_df = patch_df.join( - existing_annotations, how="left", lsuffix="_x", rsuffix="_y" - ) - patch_df[label_col] = patch_df["label_y"].fillna(patch_df[f"{label_col}_x"]) - patch_df = patch_df.drop( - columns=[ - f"{label_col}_x", - f"{label_col}_y", - ] - ) - patch_df["changed"] = patch_df[label_col].apply( - lambda x: True if x else False - ) - - patch_df[patch_paths_col] = patch_df[f"{patch_paths_col}_x"] - patch_df = patch_df.drop( - columns=[ - f"{patch_paths_col}_x", - f"{patch_paths_col}_y", - ] + # Test for existing context annotation file + if os.path.exists(f"{annotations_file[:-4]}_context.csv"): + print("[INFO] Loading existing context annotations.") + patch_df = self._load_annotations( + patch_df=patch_df, + annotations_file=f"{annotations_file[:-4]}_context.csv", + labels=labels, + col="context_label", + delimiter=delimiter, ) # initiate as a DataFrame @@ -283,13 +261,12 @@ def __init__( self.auto_save = auto_save self.username = username self.task_name = task_name + self._annotate_context = False # set up for the annotator self._min_values = min_values or {} self._max_values = max_values or {} - self.patch_width, self.patch_height = self.get_patch_size() - # Create annotations_dir Path(annotations_dir).mkdir(parents=True, exist_ok=True) @@ -324,7 +301,7 @@ def __init__( self._setup_box() # Setup queue - self._queue = self.get_queue() + self._queue = [] @staticmethod def _load_dataframes( @@ -373,32 +350,61 @@ def _load_dataframes( return parent_df, patch_df - def _eval_df(self, df): + @staticmethod + def _eval_df(df): for col in df.columns: try: df[col] = df[col].apply(literal_eval) except (ValueError, TypeError, SyntaxError): pass + return df - def get_patch_size(self): - """ - Calculate and return the width and height of the patches based on the - first patch of the DataFrame, assuming the same shape of patches - across the frame. + @staticmethod + def _load_annotations( + patch_df: pd.DataFrame, + annotations_file: str, + labels: list, + col: str, + delimiter: str, + ): + """Load existing annotations from file. + + Parameters + ---------- + patch_df : pd.DataFrame + Current patch dataframe. + annotations_file : str + Name of the annotations file + labels : list + List of labels for annotation. + col : str + Name of the column in which labels are stored in annotations file + delimiter : str + Delimiter used in CSV files - Returns - ------- - Tuple[int, int] - Width and height of the patches. """ - patch_width = ( - self.sort_values("min_x").max_x[0] - self.sort_values("min_x").min_x[0] - ) - patch_height = ( - self.sort_values("min_y").max_y[0] - self.sort_values("min_y").min_y[0] + existing_annotations = pd.read_csv(annotations_file, index_col=0, sep=delimiter) + + if col not in existing_annotations.columns: + raise ValueError( + f"[ERROR] Your existing annotations do not have the label column: {col}." + ) + + if existing_annotations[col].dtype == int: + # convert label indices (ints) to labels (strings) + # this is to convert old annotations format to new annotations format + existing_annotations[col] = existing_annotations[col].apply( + lambda x: labels[x] + ) + + patch_df = patch_df.join( + existing_annotations[col], how="left", rsuffix="_existing" ) + if f"{col}_existing" in patch_df.columns: + patch_df[col].fillna(patch_df[f"{col}_existing"], inplace=True) + patch_df.drop(columns=f"{col}_existing", inplace=True) - return patch_width, patch_height + return patch_df def _setup_buttons(self) -> None: """ @@ -450,7 +456,7 @@ def get_queue( self, as_type: str | None = "list" ) -> list[int] | (pd.Index | pd.Series): """ - Gets the indices of rows which are legible for annotation. + Gets the indices of rows which are eligible for annotation. Parameters ---------- @@ -466,8 +472,8 @@ def get_queue( pd.Index object, or a pd.Series of legible rows. """ - def check_legibility(row): - if row.label is not None: + def check_eligibility(row): + if row.label not in [np.NaN, None]: return False test = [ @@ -479,18 +485,48 @@ def check_legibility(row): return True - test = self.copy() - test["eligible"] = test.apply(check_legibility, axis=1) - test = test[ - ["eligible"] + [col for col in test.columns if not col == "eligible"] - ] + queue_df = self.copy(deep=True) + queue_df = queue_df[queue_df[self.label_col].isna()] # only unlabelled + queue_df["eligible"] = queue_df.apply(check_eligibility, axis=1) + queue_df = queue_df[queue_df.eligible].sample(frac=1) # shuffle - indices = test[test.eligible].index + indices = queue_df.index if as_type == "list": return list(indices) if as_type == "index": return indices - return test[test.eligible] + return queue_df + + def get_context_queue( + self, as_type: str | None = "list" + ) -> list[int] | (pd.Index | pd.Series): + """ + Gets the indices of rows which are eligible for annotation at the context-level. + + Parameters + ---------- + as_type : str, optional + The format in which to return the indices. Options: "list", + "index". Default is "list". If any other value is provided, it + returns a pandas.Series. + + Returns + ------- + List[int] or pandas.Index or pandas.Series + Depending on "as_type", returns either a list of indices, a + pd.Index object, or a pd.Series of legible rows. + """ + + queue_df = self.copy(deep=True) + queue_df = queue_df[queue_df["context_label"].isna()] # only unlabelled + queue_df = queue_df[queue_df[self.label_col].notna()].sample(frac=1) # shuffle + + indices = queue_df.index + if as_type == "list": + return list(indices) + if as_type == "index": + return indices + return queue_df def get_context(self): """ @@ -507,16 +543,27 @@ def get_path(image_path, dim=True): # Resize the image im = Image.open(image_path) + # Never dim when annotating context + if self._annotate_context: + dim = False + # Dim the image - if dim is True or dim == "True": + if dim in [True, "True"]: im_array = np.array(im) im_array = 256 - (256 - im_array) * 0.4 # lighten image im = Image.fromarray(im_array.astype(np.uint8)) return im - def get_empty_square(): + def get_empty_square(patch_size: tuple[int, int]): + """Generates an empty square image. + + Parameters + ---------- + patch_size : tuple[int, int] + Patch size in pixels as tuple of `(width, height)`. + """ im = Image.new( - size=(self.patch_width, self.patch_height), + size=patch_size, mode="RGB", color="white", ) @@ -531,17 +578,26 @@ def get_empty_square(): ix = self._queue[self.current_index] - x = self.at[ix, "min_x"] - y = self.at[ix, "min_y"] - current_parent = self.at[ix, "parent_id"] + min_x = self.at[ix, "min_x"] + min_y = self.at[ix, "min_y"] + + # cannot assume all patches are same size + try: + height, width, _ = self.at[ix, "shape"] + except KeyError: + im_path = self.at[ix, self.patch_paths_col] + im = Image.open(im_path) + height = im.height + width = im.width + current_parent = self.at[ix, "parent_id"] parent_frame = self.query(f"parent_id=='{current_parent}'") deltas = list(range(-self.surrounding, self.surrounding + 1)) y_and_x = list( product( - [y + y_delta * self.patch_height for y_delta in deltas], - [x + x_delta * self.patch_width for x_delta in deltas], + [min_y + y_delta * height for y_delta in deltas], + [min_x + x_delta * width for x_delta in deltas], ) ) queries = [f"min_x == {x} & min_y == {y}" for y, x in y_and_x] @@ -562,12 +618,15 @@ def get_empty_square(): # split them into rows per_row = len(deltas) images = [ - [get_path(x[0], dim=x[1]) if x[0] else get_empty_square() for x in lst] + [ + get_path(x[0], dim=x[1]) if x[0] else get_empty_square((width, height)) + for x in lst + ] for lst in array_split(image_list, per_row) ] - total_width = (2 * self.surrounding + 1) * self.patch_width - total_height = (2 * self.surrounding + 1) * self.patch_height + total_width = (2 * self.surrounding + 1) * width + total_height = (2 * self.surrounding + 1) * height context_image = Image.new("RGB", (total_width, total_height)) @@ -576,8 +635,8 @@ def get_empty_square(): x_offset = 0 for image in row: context_image.paste(image, (x_offset, y_offset)) - x_offset += self.patch_width - y_offset += self.patch_height + x_offset += width + y_offset += height if self.resize_to is not None: context_image = ImageOps.contain( @@ -600,7 +659,7 @@ def annotate( resize_to: int | None = None, max_size: int | None = None, ) -> None: - """ + """Annotate at the patch-level of the current patch. Renders the annotation interface for the first image. Parameters @@ -624,15 +683,99 @@ def annotate( The size in pixels for the longest side to which constrain each patch image. Default: 100. - Returns - ------- - None + Notes + ----- + This method is a wrapper for the ``_annotate`` method. """ + + self._annotate_context = False + if min_values is not None: self._min_values = min_values if max_values is not None: self._max_values = max_values + # re-set up queue using new min/max values + self._queue = self.get_queue() + + self._annotate( + show_context=show_context, + surrounding=surrounding, + resize_to=resize_to, + max_size=max_size, + ) + + def annotate_context( + self, + resize_to: int | None = None, + max_size: int | None = None, + ) -> None: + """Annotate at the context-level of the current patch. + Renders the annotation interface for the first image plus surrounding context. + + Parameters + ---------- + min_values : dict or None, optional + Minimum values for each property to filter images for annotation. + It should be provided as a dictionary consisting of column names + (keys) and minimum values as floating point values (values). + Default is None. + max_values : dict or None, optional + Maximum values for each property to filter images for annotation. + It should be provided as a dictionary consisting of column names + (keys) and minimum values as floating point values (values). + Default is None + surrounding : int or None, optional + The number of surrounding images to show for context. Default: 1. + max_size : int or None, optional + The size in pixels for the longest side to which constrain each + patch image. Default: 100. + + Notes + ----- + This method is a wrapper for the ``_annotate`` method. + """ + self._annotate_context = True + + if "context_label" not in self.columns: + self["context_label"] = None + + # re-set up queue for context images + self._queue = self.get_context_queue() + + self._annotate( + show_context=True, + surrounding=1, + resize_to=resize_to, + max_size=max_size, + ) + + def _annotate( + self, + show_context: bool | None = None, + surrounding: int | None = None, + resize_to: int | None = None, + max_size: int | None = None, + ): + """ + Renders the annotation interface for the first image. + + Parameters + ---------- + show_context : bool or None, optional + Whether or not to display the surrounding context for each image. + Default is None. + surrounding : int or None, optional + The number of surrounding images to show for context. Default: 1. + max_size : int or None, optional + The size in pixels for the longest side to which constrain each + patch image. Default: 100. + + Returns + ------- + None + """ + self.current_index = -1 for button in self._buttons: button.disabled = False @@ -646,9 +789,6 @@ def annotate( if max_size is not None: self.max_size = max_size - # re-set up queue - self._queue = self.get_queue() - self.out = widgets.Output(layout=_CENTER_LAYOUT) display(self.box) display(self.navbox) @@ -667,21 +807,12 @@ def _next_example(self, *_) -> tuple[int, int, str]: Tuple[int, int, str] Previous index, current index, and path of the current image. """ - if not len(self._queue): + if self.current_index == len(self._queue): self.render_complete() return - if isinstance(self.current_index, type(None)) or self.current_index == -1: - self.current_index = 0 - else: - current_index = self.current_index + 1 - - try: - self._queue[current_index] - self.previous_index = self.current_index - self.current_index = current_index - except IndexError: - pass + self.previous_index = self.current_index + self.current_index += 1 ix = self._queue[self.current_index] @@ -699,21 +830,13 @@ def _prev_example(self, *_) -> tuple[int, int, str]: Tuple[int, int, str] Previous index, current index, and path of the current image. """ - if not len(self._queue): + if self.current_index == len(self._queue): self.render_complete() return - current_index = self.current_index - 1 - - if current_index < 0: - current_index = 0 - - try: - self._queue[current_index] - self.previous_index = current_index - 1 - self.current_index = current_index - except IndexError: - pass + if self.current_index > 0: + self.previous_index = self.current_index + self.current_index -= 1 ix = self._queue[self.current_index] @@ -738,7 +861,6 @@ def render(self) -> None: self.render_complete() return - # ix = self.iloc[self.current_index].name ix = self._queue[self.current_index] # render buttons @@ -750,7 +872,8 @@ def render(self) -> None: # disable skip button when at last example button.disabled = self.current_index >= len(self) - 1 elif button.description != "submit": - if self.at[ix, self.label_col] == button.description: + col = "context_label" if self._annotate_context else self.label_col + if self.at[ix, col] == button.description: button.icon = "check" else: button.icon = "" @@ -791,13 +914,13 @@ def render(self) -> None: ) ) - def get_patch_image(self, ix: int) -> Image: + def get_patch_image(self, ix) -> Image: """ Returns the image at the given index. Parameters ---------- - ix : int + ix : int | str The index of the image in the dataframe. Returns @@ -831,8 +954,10 @@ def _add_annotation(self, annotation: str) -> None: """ # ix = self.iloc[self.current_index].name ix = self._queue[self.current_index] - self.at[ix, self.label_col] = annotation - self.at[ix, "changed"] = True + if self._annotate_context: + self.at[ix, "context_label"] = annotation + else: + self.at[ix, self.label_col] = annotation if self.auto_save: self._auto_save() self._next_example() @@ -845,11 +970,16 @@ def _auto_save(self): ------- None """ - self.get_labelled_data(sort=True).to_csv(self.annotations_file) + if self._annotate_context: + annotations_file = f"{self.annotations_file[:-4]}_context.csv" + self.get_labelled_data(sort=True, context=True).to_csv(annotations_file) + else: + self.get_labelled_data(sort=True).to_csv(self.annotations_file) def get_labelled_data( self, sort: bool = True, + context: bool = False, index_labels: bool = False, include_paths: bool = True, ) -> pd.DataFrame: @@ -861,6 +991,8 @@ def get_labelled_data( sort : bool, optional Whether to sort the dataframe by the order of the images in the input data, by default True + context : bool, optional + Whether to save the context annotations or not, by default False index_labels : bool, optional Whether to return the label's index number (in the labels list provided in setting up the instance) or the human-readable label @@ -875,27 +1007,29 @@ def get_labelled_data( A dataframe containing the labelled images and their associated label index. """ - if index_labels: - col1 = self.filtered[self.label_col].apply(lambda x: self._labels.index(x)) + if context: + filtered_df = self[self["context_label"].notna()].copy(deep=True) else: - col1 = self.filtered[self.label_col] + filtered_df = self[self[self.label_col].notna()].copy(deep=True) + + # force image_id to be index (incase of integer index) + # TODO: Force all indices to be integers so this is not needed + if "image_id" in filtered_df.columns: + filtered_df.set_index("image_id", drop=True, inplace=True) + + if sort: + filtered_df.sort_values(by=["parent_id", "min_x", "min_y"], inplace=True) - if include_paths: - col2 = self.filtered[self.patch_paths_col] - df = pd.DataFrame( - {self.patch_paths_col: col2, self.label_col: col1}, - index=pd.Index(col1.index, name="image_id"), + if index_labels: + filtered_df[self.label_col] = filtered_df[self.label_col].apply( + lambda x: self._labels.index(x) ) - else: - df = pd.DataFrame(col1, index=pd.Index(col1.index, name="image_id")) - if not sort: - return df + if context: + filtered_df["context_label"] = filtered_df["context_label"].apply( + lambda x: self._labels.index(x) + ) - df["sort_value"] = df.index.to_list() - df["sort_value"] = df["sort_value"].apply( - lambda x: f"{x.split('#')[1]}-{x.split('#')[0]}" - ) - return df.sort_values("sort_value").drop(columns=["sort_value"]) + return filtered_df @property def filtered(self) -> pd.DataFrame: diff --git a/mapreader/classify/classifier.py b/mapreader/classify/classifier.py index 6b736a93..89ec957c 100644 --- a/mapreader/classify/classifier.py +++ b/mapreader/classify/classifier.py @@ -32,12 +32,13 @@ class ClassifierContainer: def __init__( self, - model: str | (nn.Module | None), + model: str | nn.Module | None, labels_map: dict[int, str] | None, dataloaders: dict[str, DataLoader] | None = None, device: str | None = "default", input_size: int | None = (224, 224), - is_inception: bool | None = False, + is_inception: bool = False, + context: bool = False, load_path: str | None = None, force_device: bool | None = False, **kwargs, @@ -65,6 +66,9 @@ def __init__( is_inception : bool, optional Whether the model is an Inception-style model. Default is ``False``. + context : bool, optional + Whether the model is uses patch and context inputs. + Default is `False`. load_path : str, optional The path to an ``.obj`` file containing a force_device : bool, optional @@ -88,8 +92,10 @@ def __init__( The model. input_size : None or tuple of int The size of the input to the model. - is_inception : None or bool + is_inception : bool A flag indicating if the model is an Inception model. + context : bool + A flag indicating if the model uses patch and context as inputs. optimizer : None or torch.optim.Optimizer The optimizer being used for training the model. scheduler : None or torch.optim.lr_scheduler._LRScheduler @@ -127,15 +133,15 @@ def __init__( raise ValueError( "[ERROR] ``labels_map`` and ``load_path`` cannot be used together - please set one to ``None``." ) - + # load object self.load(load_path=load_path, force_device=force_device) - + # add any extra dataloaders if dataloaders: for set_name, dataloader in dataloaders.items(): - self.dataloaders[set_name]=dataloader - + self.dataloaders[set_name] = dataloader + else: if model is None or labels_map is None: raise ValueError( @@ -144,12 +150,13 @@ def __init__( self.labels_map = labels_map - # set up model and move to device + # set up model and move to device print("[INFO] Initializing model.") if isinstance(model, nn.Module): self.model = model.to(self.device) self.input_size = input_size self.is_inception = is_inception + self.context = context elif isinstance(model, str): self._initialize_model(model, **kwargs) @@ -170,15 +177,13 @@ def __init__( ) # add colors for printing/logging - self._print_colors() + self._set_up_print_colors() # add dataloaders and labels_map self.dataloaders = dataloaders if dataloaders else {} - + for set_name, dataloader in self.dataloaders.items(): - print( - f'[INFO] Loaded "{set_name}" with {len(dataloader.dataset)} items.' - ) + print(f'[INFO] Loaded "{set_name}" with {len(dataloader.dataset)} items.') def generate_layerwise_lrs( self, @@ -208,27 +213,65 @@ def generate_layerwise_lrs( list of dicts A list of dictionaries containing the parameters and learning rates for each layer. + + Notes + ----- + parameter_groups : bool, optional + When using context mode, whether to consider parameters belonging to the patch model and context model as separate groups. + If True, layers belonging to each group will be assigned the same learning rate. + Defaults to ``False``. """ - if spacing.lower() == "linspace": - lrs = np.linspace(min_lr, max_lr, len(list(self.model.named_parameters()))) - elif spacing.lower() in ["log", "geomspace"]: - lrs = np.geomspace(min_lr, max_lr, len(list(self.model.named_parameters()))) - else: + + if spacing.lower() not in ["linspace", "geomspace"]: raise NotImplementedError( '[ERROR] ``spacing`` must be one of "linspace" or "geomspace"' ) - params2optimize = [ - {"params": params, "learning rate": lrs[i]} - for i, (_, params) in enumerate(self.model.named_parameters()) - ] + if self.context: + params2optimize = [] + + for group in set( + tuple[0].split(".")[0] for tuple in [*self.model.named_parameters()] + ): + group_params = [ + params + for (name, params) in self.model.named_parameters() + if group in name + ] + + if spacing.lower() == "linspace": + lrs = np.linspace(min_lr, max_lr, len(group_params)) + elif spacing.lower() in ["log", "geomspace"]: + lrs = np.geomspace(min_lr, max_lr, len(group_params)) + + params2optimize.extend( + [ + {"params": params, "learning rate": lr} + for params, lr in zip(group_params, lrs) + ] + ) + + else: + if spacing.lower() == "linspace": + lrs = np.linspace( + min_lr, max_lr, len(list(self.model.named_parameters())) + ) + elif spacing.lower() in ["log", "geomspace"]: + lrs = np.geomspace( + min_lr, max_lr, len(list(self.model.named_parameters())) + ) + + params2optimize = [ + {"params": params, "learning rate": lr} + for (_, params), lr in zip(self.model.named_parameters(), lrs) + ] return params2optimize def initialize_optimizer( self, optim_type: str | None = "adam", - params2optimize: str | Iterable | None = "infer", + params2optimize: str | Iterable | None = "default", optim_param_dict: dict | None = None, add_optim: bool | None = True, ) -> torch.optim.Optimizer | None: @@ -242,9 +285,9 @@ def initialize_optimizer( The type of optimizer to use. Can be set to ``"adam"`` (default), ``"adamw"``, or ``"sgd"``. params2optimize : str or iterable, optional - The parameters to optimize. If set to ``"infer"``, all model - parameters that require gradients will be optimized, by default - ``"infer"``. + The parameters to optimize. If set to ``"default"``, all model + parameters that require gradients will be optimized. + Default is ``"default"``. optim_param_dict : dict, optional The parameters to pass to the optimizer constructor as a dictionary, by default ``{"lr": 1e-3}``. @@ -276,7 +319,11 @@ def initialize_optimizer( """ if optim_param_dict is None: optim_param_dict = {"lr": 0.001} - if params2optimize == "infer": + if params2optimize == "default": + if self.context: + raise ValueError( + "[ERROR] When using context model, first call `params2optimize` cannot be set to `default`." + ) params2optimize = filter(lambda p: p.requires_grad, self.model.parameters()) if optim_type.lower() in ["adam"]: @@ -612,7 +659,7 @@ def inference( The name of the dataset to run inference on, by default ``"infer"``. verbose : bool, optional - Whether to print verbose outputs, by default False. + Whether to print verbose outputs, by default False. print_info_batch_freq : int, optional The frequency of printouts, by default ``5``. @@ -640,7 +687,7 @@ def inference( def train_component_summary(self) -> None: """ - Print a summary of the optimizer, criterion and trainable model + Print a summary of the optimizer, criterion, and trainable model components. Returns: @@ -656,17 +703,17 @@ def train_component_summary(self) -> None: print(str(self.criterion)) print(divider) print("* Model:") - self.model_summary(only_trainable=True) + self.model_summary(trainable_col=True) def train( self, phases: list[str] | None = None, num_epochs: int | None = 25, save_model_dir: str | None | None = "models", - verbose: bool | None = False, + verbose: bool = False, tensorboard_path: str | None | None = None, tmp_file_save_freq: int | None | None = 2, - remove_after_load: bool | None = True, + remove_after_load: bool = True, print_info_batch_freq: int | None | None = 5, ) -> None: """ @@ -744,7 +791,7 @@ def train_core( phases: list[str] | None = None, num_epochs: int | None = 25, save_model_dir: str | None | None = "models", - verbose: bool | None = False, + verbose: bool = False, tensorboard_path: str | None | None = None, tmp_file_save_freq: int | None | None = 2, print_info_batch_freq: int | None | None = 5, @@ -866,8 +913,10 @@ def train_core( for batch_idx, (inputs, _labels, label_indices) in enumerate( self.dataloaders[phase] ): - inputs = inputs.to(self.device) - label_indices = label_indices.to(self.device) + inputs = tuple(input.to(self.device) for input in inputs) + label_indices = tuple( + label_index.to(self.device) for label_index in label_indices + ) if self.optimizer is None: if phase.lower() in train_phase_names: @@ -892,37 +941,49 @@ def train_core( raise ValueError( "[ERROR] Criterion is not yet defined.\n\n\ Use ``add_criterion`` to define one." - ) + ) if self.is_inception and ( phase.lower() in train_phase_names ): - outputs, aux_outputs = self.model(inputs) - - if not all( - isinstance(out, torch.Tensor) - for out in [outputs, aux_outputs] - ): - try: - outputs = outputs.logits - aux_outputs = aux_outputs.logits - except AttributeError as err: - raise AttributeError(err.message) - - loss1 = self.criterion(outputs, label_indices) - loss2 = self.criterion(aux_outputs, label_indices) - # XXX From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958 # noqa + outputs, aux_outputs = self.model(*inputs) + + if not isinstance(outputs, torch.Tensor): + outputs = self._get_logits(outputs) + if not isinstance(aux_outputs, torch.Tensor): + aux_outputs = self._get_logits(aux_outputs) + + loss1 = self.criterion(outputs, *label_indices) + loss2 = self.criterion(aux_outputs, *label_indices) + # https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958 loss = loss1 + 0.4 * loss2 + elif self.context: + (patch_outputs, context_outputs), outputs = self.model( + *inputs + ) + + if not isinstance(outputs, torch.Tensor): + outputs = self._get_logits(outputs) + if not isinstance(patch_outputs, torch.Tensor): + patch_outputs = self._get_logits(patch_outputs) + if not isinstance(context_outputs, torch.Tensor): + context_outputs = self._get_logits(context_outputs) + + loss1 = self.criterion(outputs, label_indices[0]) + loss2 = self.criterion(patch_outputs, label_indices[0]) + loss3 = self.criterion(outputs, label_indices[1]) + + loss = loss1 + 0.4 * loss2 + 0.4 * loss3 + else: - outputs = self.model(inputs) + outputs = self.model(*inputs) if not isinstance(outputs, torch.Tensor): - try: - outputs = outputs.logits - except AttributeError as err: - raise AttributeError(err.message) - loss = self.criterion(outputs, label_indices) + outputs = self._get_logits(outputs) + + loss = self.criterion(outputs, *label_indices) + print(loss, type(loss)) _, pred_label_indices = torch.max(outputs, dim=1) @@ -932,19 +993,21 @@ def train_core( self.optimizer.step() # XXX (why multiply?) - running_loss += loss.item() * inputs.size(0) + running_loss += loss.item() * inputs[0].size(0) # TQDM # batch_loop.set_postfix(loss=loss.data) # batch_loop.refresh() else: - outputs = self.model(inputs) + if self.context: + (patch_outputs, context_outputs), outputs = self.model( + *inputs + ) + else: + outputs = self.model(*inputs) if not isinstance(outputs, torch.Tensor): - try: - outputs = outputs.logits - except AttributeError as err: - raise AttributeError(err.message) + self._get_logits(outputs) _, pred_label_indices = torch.max(outputs, dim=1) @@ -952,7 +1015,7 @@ def train_core( torch.nn.functional.softmax(outputs, dim=1).cpu().tolist() ) running_pred_label_indices.extend(pred_label_indices.cpu().tolist()) - running_orig_label_indices.extend(label_indices.cpu().tolist()) + running_orig_label_indices.extend(label_indices[0].cpu().tolist()) if batch_idx % print_info_batch_freq == 0: curr_inp_counts = min( @@ -967,12 +1030,12 @@ def train_core( if phase.lower() in valid_phase_names: epoch_msg += f"Loss: {loss.data:.3f}" - self.cprint("[INFO]", self.__color_dred, epoch_msg) + self.cprint("[INFO]", "dred", epoch_msg) elif phase.lower() in train_phase_names: epoch_msg += f"Loss: {loss.data:.3f}" - self.cprint("[INFO]", self.__color_dgreen, epoch_msg) + self.cprint("[INFO]", "dgreen", epoch_msg) else: - self.cprint("[INFO]", self.__color_dgreen, epoch_msg) + self.cprint("[INFO]", "dgreen", epoch_msg) # --- END: one batch # scheduler @@ -1005,9 +1068,9 @@ def train_core( epoch_msg = self._gen_epoch_msg(phase, epoch_msg) if phase.lower() in valid_phase_names: - self.cprint("[INFO]", self.__color_dred, epoch_msg + "\n") + self.cprint("[INFO]", "dred", epoch_msg + "\n") else: - self.cprint("[INFO]", self.__color_dgreen, epoch_msg) + self.cprint("[INFO]", "dgreen", epoch_msg) # labels/confidence self.pred_conf.extend(running_pred_conf) @@ -1023,7 +1086,11 @@ def train_core( if phase.lower() in valid_phase_names: if epoch % tmp_file_save_freq == 0: tmp_str = f'[INFO] Checkpoint file saved to "{self.tmp_save_filename}".' # noqa - print(self.__color_lgrey + tmp_str + self.__color_reset) + print( + self._print_colors["lgrey"] + + tmp_str + + self._print_colors["reset"] + ) self.last_epoch = epoch self.save(self.tmp_save_filename, force=True) @@ -1053,7 +1120,15 @@ def train_core( print( f"[INFO] Model at epoch {self.best_epoch} has least valid loss ({self.best_loss:.4f}) so will be saved.\n\ [INFO] Path: {save_model_path}" - ) # noqa + ) + + @staticmethod + def _get_logits(out): + try: + out = out.logits + except AttributeError as err: + raise AttributeError(err.message) + return out def calculate_add_metrics( self, @@ -1459,6 +1534,7 @@ def _initialize_model( self.model = model_dw.to(self.device) self.input_size = input_size self.is_inception = is_inception + self.context = False def show_sample( self, @@ -1527,12 +1603,13 @@ def show_sample( inputs, labels, label_indices = next(dl_iter) # Make a grid from batch - out = torchvision.utils.make_grid(inputs) - self._imshow( - out, - title=f"{labels}\n{label_indices.tolist()}", - figsize=figsize, - ) + for input in inputs: + out = torchvision.utils.make_grid(input) + self._imshow( + out, + title=f"{labels[0]}\n{label_indices[0].tolist()}", + figsize=figsize, + ) def print_batch_info(self, set_name: str | None = "train") -> None: """ @@ -1655,16 +1732,18 @@ def show_inference_sample_results( plt.figure(figsize=figsize) with torch.no_grad(): for inputs, _labels, label_indices in iter(self.dataloaders[set_name]): - inputs = inputs.to(self.device) - label_indices = label_indices.to(self.device) + inputs = tuple(input.to(self.device) for input in inputs) + label_indices = tuple( + label_index.to(self.device) for label_index in label_indices + ) - outputs = self.model(inputs) + if self.context: + _, outputs = self.model(*inputs) + else: + outputs = self.model(*inputs) if not isinstance(outputs, torch.Tensor): - try: - outputs = outputs.logits - except AttributeError as err: - raise AttributeError(err.message) + self._get_logits(outputs) pred_conf = torch.nn.functional.softmax(outputs, dim=1) * 100.0 _, preds = torch.max(outputs, 1) @@ -1693,7 +1772,7 @@ def show_inference_sample_results( ax.axis("off") ax.set_title(f"{label} | {conf_score:.3f}") - inp = inputs.cpu().data[j].numpy().transpose((1, 2, 0)) + inp = inputs[0].cpu().data[j].numpy().transpose((1, 2, 0)) inp = np.clip(inp, 0, 1) plt.imshow(inp) @@ -1863,38 +1942,39 @@ def load( except: pass - def _print_colors(self): + def _set_up_print_colors(self): """Private function, setting color attributes on the object.""" - # color - self.__color_lgrey = "\033[1;90m" - self.__color_grey = "\033[90m" # boring information - self.__color_yellow = "\033[93m" # FYI - self.__color_orange = "\033[0;33m" # Warning + self._print_colors = {} - self.__color_lred = "\033[1;31m" # there is smoke - self.__color_red = "\033[91m" # fire! - self.__color_dred = "\033[2;31m" # Everything is on fire + # color + self._print_colors["lgrey"] = "\033[1;90m" + self._print_colors["grey"] = "\033[90m" # boring information + self._print_colors["yellow"] = "\033[93m" # FYI + self._print_colors["orange"] = "\033[0;33m" # Warning - self.__color_lblue = "\033[1;34m" - self.__color_blue = "\033[94m" - self.__color_dblue = "\033[2;34m" + self._print_colors["lred"] = "\033[1;31m" # there is smoke + self._print_colors["red"] = "\033[91m" # fire! + self._print_colors["dred"] = "\033[2;31m" # Everything is on fire - self.__color_lgreen = "\033[1;32m" # all is normal - self.__color_green = "\033[92m" # something else - self.__color_dgreen = "\033[2;32m" # even more interesting + self._print_colors["lblue"] = "\033[1;34m" + self._print_colors["blue"] = "\033[94m" + self._print_colors["dblue"] = "\033[2;34m" - self.__color_lmagenta = "\033[1;35m" - self.__color_magenta = "\033[95m" # for title - self.__color_dmagenta = "\033[2;35m" + self._print_colors["lgreen"] = "\033[1;32m" # all is normal + self._print_colors["green"] = "\033[92m" # something else + self._print_colors["dgreen"] = "\033[2;32m" # even more interesting - self.__color_cyan = "\033[96m" # system time - self.__color_white = "\033[97m" # final time + self._print_colors["lmagenta"] = "\033[1;35m" + self._print_colors["magenta"] = "\033[95m" # for title + self._print_colors["dmagenta"] = "\033[2;35m" - self.__color_black = "\033[0;30m" + self._print_colors["cyan"] = "\033[96m" # system time + self._print_colors["white"] = "\033[97m" # final time + self._print_colors["black"] = "\033[0;30m" - self.__color_reset = "\033[0m" - self.__color_bold = "\033[1m" - self.__color_under = "\033[4m" + self._print_colors["reset"] = "\033[0m" + self._print_colors["bold"] = "\033[1m" + self._print_colors["under"] = "\033[4m" def _get_dtime(self) -> str: """ @@ -1929,10 +2009,15 @@ def cprint(self, type_info: str, bc_color: str, text: str) -> None: host_name = socket.gethostname().split(".")[0][:10] print( - self.__color_green + self._get_dtime() + self.__color_reset, - self.__color_magenta + host_name + self.__color_reset, - self.__color_bold + self.__color_grey + type_info + self.__color_reset, - bc_color + text + self.__color_reset, + self._print_colors["green"] + + self._get_dtime() + + self._print_colors["reset"], + self._print_colors["magenta"] + host_name + self._print_colors["reset"], + self._print_colors["bold"] + + self._print_colors["grey"] + + type_info + + self._print_colors["reset"], + self._print_colors[bc_color] + text + self._print_colors["reset"], ) def update_progress( diff --git a/mapreader/classify/classifier_context.py b/mapreader/classify/classifier_context.py deleted file mode 100644 index 83ab0ac4..00000000 --- a/mapreader/classify/classifier_context.py +++ /dev/null @@ -1,665 +0,0 @@ -#!/usr/bin/env python -from __future__ import annotations - -import copy -import os -import time - -# from tqdm.autonotebook import tqdm -import matplotlib.pyplot as plt -import numpy as np -import torch -import torchvision - -from .classifier import ClassifierContainer - - -class ClassifierContextContainer(ClassifierContainer): - def train( - self, - phases: list[str] | None = None, - num_epochs: int | None = 25, - save_model_dir: str | None | None = "models", - verbosity_level: int | None = 1, - tensorboard_path: str | None | None = None, - tmp_file_save_freq: int | None | None = 2, - remove_after_load: bool | None = True, - print_info_batch_freq: int | None | None = 5, - ) -> None: - """ - Train the model on the specified phases for a given number of epochs. - Wrapper function for ``train_core`` method to capture exceptions (with - supported exceptions so far: ``KeyboardInterrupt``). Refer to - ``train_core`` for more information. - - Parameters - ---------- - phases : list of str, optional - The phases to train the model on for each epoch. Default is - ``["train", "val"]``. - num_epochs : int, optional - The number of epochs to train the model for. Default is ``25``. - save_model_dir : str or None, optional - The directory to save the model in. Default is ``"models"``. If - set to ``None``, the model is not saved. - verbosity_level : int, optional - The level of verbosity during training: - - - ``0`` is silent, - - ``1`` is progress bar and metrics, - - ``2`` is detailed information. - - Default is ``1``. - tensorboard_path : str or None, optional - The path to the directory to save TensorBoard logs in. If set to - ``None``, no TensorBoard logs are saved. Default is ``None``. - tmp_file_save_freq : int, optional - The frequency (in epochs) to save a temporary file of the model. - Default is ``2``. If set to ``0`` or ``None``, no temporary file - is saved. - remove_after_load : bool, optional - Whether to remove the temporary file after loading it. Default is - ``True``. - print_info_batch_freq : int, optional - The frequency (in batches) to print training information. Default - is ``5``. If set to ``0`` or ``None``, no training information is - printed. - - Returns - ------- - None - The function saves the model to the ``save_model_dir`` directory, - and optionally to a temporary file. If interrupted with a - ``KeyboardInterrupt``, the function tries to load the temporary - file. If no temporary file is found, it continues without loading. - """ - - if phases is None: - phases = ["train", "val"] - try: - self.train_core( - phases, - num_epochs, - save_model_dir, - verbosity_level, - tensorboard_path, - tmp_file_save_freq, - print_info_batch_freq=print_info_batch_freq, - ) - except KeyboardInterrupt: - print("[INFO] Exiting...") - if os.path.isfile(self.tmp_save_filename): - print(f'[INFO] Loading "{self.tmp_save_filename}" as model.') - self.load(self.tmp_save_filename, remove_after_load=remove_after_load) - else: - print("[INFO] No checkpoint file found - model has not been updated.") - - def train_core( - self, - phases: list[str] | None = None, - num_epochs: int | None = 25, - save_model_dir: str | None | None = "models", - verbosity_level: int | None = 1, - tensorboard_path: str | None | None = None, - tmp_file_save_freq: int | None | None = 2, - print_info_batch_freq: int | None | None = 5, - ) -> None: - """ - Trains/fine-tunes a classifier for the specified number of epochs on - the given phases using the specified hyperparameters. - - Parameters - ---------- - phases : list of str, optional - The phases to train the model on for each epoch. Default is - ``["train", "val"]``. - num_epochs : int, optional - The number of epochs to train the model for. Default is ``25``. - save_model_dir : str or None, optional - The directory to save the model in. Default is ``"models"``. If - set to ``None``, the model is not saved. - verbosity_level : int, optional - The level of verbosity during training: - - - ``0`` is silent, - - ``1`` is progress bar and metrics, - - ``2`` is detailed information. - - Default is ``1``. - tensorboard_path : str or None, optional - The path to the directory to save TensorBoard logs in. If set to - ``None``, no TensorBoard logs are saved. Default is ``None``. - tmp_file_save_freq : int, optional - The frequency (in epochs) to save a temporary file of the model. - Default is ``2``. If set to ``0`` or ``None``, no temporary file - is saved. - print_info_batch_freq : int, optional - The frequency (in batches) to print training information. Default - is ``5``. If set to ``0`` or ``None``, no training information is - printed. - - Raises - ------ - ValueError - If the criterion is not set. Use the ``add_criterion`` method to - set the criterion. - - If the optimizer is not set and the phase is "train". Use the - ``initialize_optimizer`` or ``add_optimizer`` method to set the - optimizer. - - KeyError - If the specified phase cannot be found in the object's dataloader - with keys. - - Returns - ------- - None - """ - - if phases is None: - phases = ["train", "val"] - if self.criterion is None: - raise ValueError( - "[ERROR] Criterion is not yet defined.\n\n\ -Use ``add_criterion`` to define one." - ) - - print(f"[INFO] Each epoch will pass: {phases}.") - - for phase in phases: - if phase not in self.dataloaders.keys(): - raise KeyError( - f'[ERROR] "{phase}" dataloader cannot be found in dataloaders.\n\ - Valid options for ``phases`` argument are: {self.dataloaders.keys()}' # noqa - ) - - if verbosity_level >= 1: - self.train_component_summary() - - since = time.time() - - # initialize variables - train_phase_names = ["train", "training"] - valid_phase_names = ["val", "validation", "eval", "evaluation"] - best_model_wts = copy.deepcopy(self.model.state_dict()) - self.pred_conf = [] - self.pred_label = [] - self.orig_label = [] - if save_model_dir is not None: - save_model_dir = os.path.abspath(save_model_dir) - - # Check if SummaryWriter (for tensorboard) can be imported - tboard_writer = None - if tensorboard_path is not None: - try: - from torch.utils.tensorboard import SummaryWriter - - tboard_writer = SummaryWriter(tensorboard_path) - except ImportError: - print( - "[WARNING] could not import SummaryWriter from torch.utils.tensorboard" # noqa - ) - print("[WARNING] continue without tensorboard.") - tensorboard_path = None - - start_epoch = self.last_epoch + 1 - end_epoch = self.last_epoch + num_epochs - - # --- Main train loop - for epoch in range(start_epoch, end_epoch + 1): - # --- loop, phases - for phase in phases: - if phase.lower() in train_phase_names: - self.model.train() - else: - self.model.eval() - - # initialize vars with one epoch lifetime - running_loss = 0.0 - running_pred_conf = [] - running_pred_label_indices = [] - running_orig_label_indices = [] - - # TQDM - # batch_loop = tqdm(iter(self.dataloaders[phase]), total=len(self.dataloaders[phase]), leave=False) # noqa - # if phase.lower() in train_phase_names+valid_phase_names: - # batch_loop.set_description(f"Epoch {epoch}/{end_epoch}") - - phase_batch_size = self.dataloaders[phase].batch_size - total_inp_counts = len(self.dataloaders[phase].dataset) - - # --- loop, batches - for batch_idx, (inputs1, inputs2, _labels, label_indices) in enumerate( - self.dataloaders[phase] - ): - inputs1 = inputs1.to(self.device) - inputs2 = inputs2.to(self.device) - label_indices = label_indices.to(self.device) - - if self.optimizer is None: - if phase.lower() in train_phase_names: - raise ValueError( - f"[ERROR] An optimizer should be defined for {phase} phase.\n\ -Use ``initialize_optimizer`` or ``add_optimizer`` to add one." # noqa - ) - else: - self.optimizer.zero_grad() - - if phase.lower() in train_phase_names + valid_phase_names: - # forward, track history if only in train - with torch.set_grad_enabled(phase.lower() in train_phase_names): - # Get model outputs and calculate loss - # Special case for inception because in training - # it has an auxiliary output. - # In train mode we calculate the loss by - # summing the final output and the auxiliary - # output but in testing we only consider the - # final output. - if self.is_inception and ( - phase.lower() in train_phase_names - ): - outputs, aux_outputs = self.model(inputs1, inputs2) - - if not all( - isinstance(out, torch.Tensor) - for out in [outputs, aux_outputs] - ): - try: - outputs = outputs.logits - aux_outputs = aux_outputs.logits - except AttributeError as err: - raise AttributeError(err.message) - - loss1 = self.criterion(outputs, label_indices) - loss2 = self.criterion(aux_outputs, label_indices) - # XXX From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958 # noqa - loss = loss1 + 0.4 * loss2 - else: - outputs = self.model(inputs1, inputs2) - # labels = labels.long().squeeze_() - if not isinstance(outputs, torch.Tensor): - try: - outputs = outputs.logits - except AttributeError as err: - raise AttributeError(err.message) - - loss = self.criterion(outputs, label_indices) - - _, pred_label_indices = torch.max(outputs, dim=1) - - # backward + optimize only if in training phase - if phase.lower() in train_phase_names: - loss.backward() - self.optimizer.step() - - # XXX (why multiply?) - running_loss += loss.item() * inputs1.size(0) - - # TQDM - # batch_loop.set_postfix(loss=loss.data) - # batch_loop.refresh() - else: - outputs = self.model(inputs1, inputs2) - - if not isinstance(outputs, torch.Tensor): - try: - outputs = outputs.logits - except AttributeError as err: - raise AttributeError(err.message) - - _, pred_label_indices = torch.max(outputs, dim=1) - - running_pred_conf.extend( - torch.nn.functional.softmax(outputs, dim=1).cpu().tolist() - ) - running_pred_label_indices.extend(pred_label_indices.cpu().tolist()) - running_orig_label_indices.extend(label_indices.cpu().tolist()) - - if batch_idx % print_info_batch_freq == 0: - curr_inp_counts = min( - total_inp_counts, - (batch_idx + 1) * phase_batch_size, - ) - progress_perc = curr_inp_counts / total_inp_counts * 100.0 - tmp_str = f"{curr_inp_counts}/{total_inp_counts} ({progress_perc:5.1f}%)" # noqa - - epoch_msg = f"{phase: <8} -- {epoch}/{end_epoch} -- " - epoch_msg += f"{tmp_str: >20} -- " - - if phase.lower() in valid_phase_names: - epoch_msg += f"Loss: {loss.data:.3f}" - self.cprint("[INFO]", self.color_dred, epoch_msg) - elif phase.lower() in train_phase_names: - epoch_msg += f"Loss: {loss.data:.3f}" - self.cprint("[INFO]", self.color_dgreen, epoch_msg) - else: - self.cprint("[INFO]", self.color_dgreen, epoch_msg) - # --- END: one batch - - # scheduler - if phase.lower() in train_phase_names and (self.scheduler is not None): - self.scheduler.step() - - if phase.lower() in train_phase_names + valid_phase_names: - # --- collect statistics - epoch_loss = running_loss / len(self.dataloaders[phase].dataset) - self._add_metrics(f"epoch_loss_{phase}", epoch_loss) - - if tboard_writer is not None: - tboard_writer.add_scalar( - f"loss/{phase}", - self.metrics[f"epoch_loss_{phase}"][-1], - epoch, - ) - - # other metrics (precision/recall/F1) - self.calculate_add_metrics( - running_orig_label_indices, - running_pred_label_indices, - running_pred_conf, - phase, - epoch, - tboard_writer, - ) - - epoch_msg = f"{phase: <8} -- {epoch}/{end_epoch} -- " - epoch_msg = self.gen_epoch_msg(phase, epoch_msg) - - if phase.lower() in valid_phase_names: - self.cprint("[INFO]", self.color_dred, epoch_msg + "\n") - else: - self.cprint("[INFO]", self.color_dgreen, epoch_msg) - - # labels/confidence - self.pred_conf.extend(running_pred_conf) - self.pred_label_indices.extend(running_pred_label_indices) - self.orig_label_indices.extend(running_orig_label_indices) - - # Update best_loss and _epoch? - if phase.lower() in valid_phase_names and epoch_loss < self.best_loss: - self.best_loss = epoch_loss - self.best_epoch = epoch - best_model_wts = copy.deepcopy(self.model.state_dict()) - - if phase.lower() in valid_phase_names: - if epoch % tmp_file_save_freq == 0: - tmp_str = f'[INFO] Checkpoint file saved to "{self.tmp_save_filename}".' # noqa - print(self.color_lgrey + tmp_str + self.color_reset) - self.last_epoch = epoch - self.save(self.tmp_save_filename, force=True) - - self.pred_label = [ - self.labels_map.get(i, None) for i in self.pred_label_indices - ] - self.orig_label = [ - self.labels_map.get(i, None) for i in self.orig_label_indices - ] - - time_elapsed = time.time() - since - print(f"[INFO] Total time: {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s") - - # load best model weights - self.model.load_state_dict(best_model_wts) - - # --- SAVE model/object - if phase.lower() in train_phase_names + valid_phase_names: - self.last_epoch = epoch - if save_model_dir is not None: - save_filename = f"checkpoint_{self.best_epoch}.pkl" - save_model_path = os.path.join(save_model_dir, save_filename) - self.save(save_model_path, force=True) - info_path = os.path.join(save_model_dir, "info.txt") - with open(info_path, "a+") as f: - f.writelines(f"{save_filename},{self.best_loss:.5f}\n") - - print( - f"[INFO] Model at epoch {self.best_epoch} has least valid loss ({self.best_loss:.4f}) so will be saved.\n\ -[INFO] Path: {save_model_path}" - ) # noqa - - def show_sample( - self, - set_name: str | None = "train", - batch_number: int | None = 1, - print_batch_info: bool | None = True, - figsize: tuple[int, int] | None = (15, 10), - ) -> None: - """ - Displays a sample of training or validation data in a grid format with - their corresponding class labels. - - Parameters - ---------- - set_name : str, optional - Name of the dataset (``train``/``validation``) to display the - sample from, by default ``"train"``. - batch_number : int, optional - Number of batches to display, by default ``1``. - print_batch_info : bool, optional - Whether to print information about the batch size, by default - ``True``. - figsize : tuple, optional - Figure size (width, height) in inches, by default ``(15, 10)``. - - Returns - ------- - None - Displays the sample images with their corresponding class labels. - - Raises - ------ - StopIteration - If the specified number of batches to display exceeds the total - number of batches in the dataset. - - Notes - ----- - This method uses the dataloader of the ``ImageClassifierData`` class - and the ``torchvision.utils.make_grid`` function to display the sample - data in a grid format. It also calls the ``_imshow`` method of the - ``ImageClassifierData`` class to show the sample data. - """ - if set_name not in self.dataloaders.keys(): - raise ValueError( - f"[ERROR] ``set_name`` must be one of {list(self.dataloaders.keys())}." - ) - - if print_batch_info: - # print info about batch size - self.batch_info() - - dataloader = self.dataloaders[set_name] - - num_batches = int(np.ceil(len(dataloader.dataset) / dataloader.batch_size)) - if min(num_batches, batch_number) != batch_number: - print( - f'[INFO] "{set_name}" only contains {num_batches}.\n\ -Output will show batch number {num_batches}.' - ) - batch_number = num_batches - - dl_iter = iter(dataloader) - for _ in range(batch_number): - # Get a batch of training data - inputs1, inputs2, labels, label_indices = next(dl_iter) - - # Make a grid from batch - out = torchvision.utils.make_grid(inputs1) - self._imshow( - out, - title=f"{labels}\n{label_indices.tolist()}", - figsize=figsize, - ) - - out = torchvision.utils.make_grid(inputs2) - self._imshow( - out, - title=f"{labels}\n{label_indices.tolist()}", - figsize=figsize, - ) - - def generate_layerwise_lrs( - self, - min_lr: float, - max_lr: float, - spacing: str | None = "linspace", - sep_group_names: list[str] = None, - ) -> list[dict]: - """ - Calculates layer-wise learning rates for a given set of model - parameters. - - Parameters - ---------- - min_lr : float - The minimum learning rate to be used. - max_lr : float - The maximum learning rate to be used. - spacing : str, optional - The type of sequence to use for spacing the specified interval - learning rates. Can be either ``"linspace"`` or ``"geomspace"``, - where `"linspace"` uses evenly spaced learning rates over a - specified interval and `"geomspace"` uses learning rates spaced - evenly on a log scale (a geometric progression). By default ``"linspace"``. - sep_group_names : list, optional - A list of strings containing the names of parameter groups. Layers - belonging to each group will be assigned the same learning rate. - Defaults to ``["features1", "features2"]``. - - Returns - ------- - list of dicts - A list of dictionaries containing the parameters and learning - rates for each layer. - """ - if sep_group_names is None: - sep_group_names = ["features1", "features2"] - params2optimize = [] - - for group in range(len(sep_group_names)): - # count number of layers in this group - num_grp_layers = 0 - for _i, (name, _) in enumerate(self.model.named_parameters()): - if sep_group_names[group] in name: - num_grp_layers += 1 - - # define layer-wise learning rates - if spacing.lower() == "linspace": - list_lrs = np.linspace(min_lr, max_lr, num_grp_layers) - elif spacing.lower() in ["log", "geomspace"]: - list_lrs = np.geomspace(min_lr, max_lr, num_grp_layers) - else: - raise NotImplementedError( - '[ERROR] ``spacing`` must be one of "linspace" or "geomspace"' - ) - - # assign learning rates - i_count = 0 - for _, (name, params) in enumerate(self.model.named_parameters()): - if sep_group_names[group] not in name: - continue - params2optimize.append({"params": params, "lr": list_lrs[i_count]}) - i_count += 1 - - return params2optimize - - def show_inference_sample_results( - self, - label: str, - num_samples: int | None = 6, - set_name: str | None = "train", - min_conf: None | float | None = None, - max_conf: None | float | None = None, - figsize: tuple[int, int] | None = (15, 15), - ) -> None: - """ - Shows a sample of the results of the inference. - - Parameters - ---------- - label : str, optional - The label for which to display results. - num_samples : int, optional - The number of sample results to display. Defaults to ``6``. - set_name : str, optional - The name of the dataset split to use for inference. Defaults to - ``"train"``. - min_conf : float, optional - The minimum confidence score for a sample result to be displayed. - Samples with lower confidence scores will be skipped. Defaults to - ``None``. - max_conf : float, optional - The maximum confidence score for a sample result to be displayed. - Samples with higher confidence scores will be skipped. Defaults to - ``None``. - figsize : tuple[int, int], optional - Figure size (width, height) in inches, displaying the sample - results. Defaults to ``(15, 15)``. - - Returns - ------- - None - """ - - # eval mode, keep track of the current mode - was_training = self.model.training - self.model.eval() - - counter = 0 - plt.figure(figsize=figsize) - with torch.no_grad(): - for inputs1, inputs2, labels, label_indices in iter( - self.dataloaders[set_name] - ): - inputs1 = inputs1.to(self.device) - inputs2 = inputs2.to(self.device) - label_indices = label_indices.to(self.device) - - outputs = self.model(inputs1, inputs2) - - if not isinstance(outputs, torch.Tensor): - try: - outputs = outputs.logits - except AttributeError as err: - raise AttributeError(err.message) - - pred_conf = torch.nn.functional.softmax(outputs, dim=1) * 100.0 - _, preds = torch.max(outputs, 1) - - label_index_dict = { - label: index for label, index in zip(labels, label_indices) - } - - # go through images in batch - for j in range(len(preds)): - predicted_index = int(preds[j]) - if predicted_index != label_index_dict[label]: - continue - if (min_conf is not None) and ( - pred_conf[j][predicted_index] < min_conf - ): - continue - if (max_conf is not None) and ( - pred_conf[j][predicted_index] > max_conf - ): - continue - - counter += 1 - - conf_score = pred_conf[j][predicted_index] - ax = plt.subplot(int(num_samples / 2.0), 3, counter) - ax.axis("off") - ax.set_title(f"{label} | {conf_score:.3f}") - - inp = inputs1.cpu().data[j].numpy().transpose((1, 2, 0)) - inp = np.clip(inp, 0, 1) - plt.imshow(inp) - - if counter == num_samples: - self.model.train(mode=was_training) - plt.show() - return - - self.model.train(mode=was_training) - plt.show() diff --git a/mapreader/classify/custom_models.py b/mapreader/classify/custom_models.py index 69072594..44e68cbe 100644 --- a/mapreader/classify/custom_models.py +++ b/mapreader/classify/custom_models.py @@ -1,52 +1,58 @@ #!/usr/bin/env python from __future__ import annotations +import copy + import torch -class twoParallelModels(torch.nn.Module): +class PatchContextModel(torch.nn.Module): """ - A class for building a model that contains two parallel branches, with - separate input pipelines, but shares a fully connected layer at the end. + Model that contains two parallel branches, with separate input pipelines, but one shared fully connected layer at the end. This class inherits from PyTorch's nn.Module. """ def __init__( self, - feature1: torch.nn.Module, - feature2: torch.nn.Module, + patch_model: torch.nn.Module, + context_model: torch.nn.Module, fc_layer: torch.nn.Linear, ): """ - Initializes a new instance of the twoParallelModels class. + Initializes a new instance of the PatchContextModel class. Parameters: ----------- - feature1 : nn.Module - The feature extractor module for the first input pipeline. - feature2 : nn.Module - The feature extractor module for the second input pipeline. + patch_model : nn.Module + The feature extractor module for the first patch only pipeline. + context_model : nn.Module + The feature extractor module for the second context pipeline. fc_layer : nn.Linear The fully connected layer at the end of the model. + Input size should be output size of patch_model + output size of context_model. + Output size should be number of classes (labels) at the patch level. """ super().__init__() - self.features1 = feature1 - self.features2 = feature2 + + if patch_model is context_model: + context_model = copy.deepcopy(context_model) + + self.patch_model = patch_model + self.context_model = context_model self.fc_layer = fc_layer - def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + def forward(self, patch: torch.Tensor, context: torch.Tensor) -> torch.Tensor: """ - Defines the computation performed at every forward pass. Receives two - inputs, x1 and x2, and feeds them through the respective feature - extractor modules, then concatenates the output and passes it through + Defines the computation performed at every forward pass. + Receives two inputs, patch and context, and feeds them through the respective feature extractor modules, then concatenates the output and passes it through the fully connected layer. Parameters: ----------- - x1 : torch.Tensor - The input tensor for the first input pipeline. - x2 : torch.Tensor - The input tensor for the second input pipeline. + patch : torch.Tensor + The input tensor for the patch pipeline. + context : torch.Tensor + The input tensor for the context pipeline. Returns: -------- @@ -54,13 +60,10 @@ def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: The output tensor of the model. """ - x1 = self.features1(x1) - x1 = x1.view(x1.size(0), -1) - - x2 = self.features2(x2) - x2 = x2.view(x2.size(0), -1) + patch_output = self.patch_model(patch) + context_output = self.context_model(context) # Concatenate in dim1 (feature dimension) - x = torch.cat((x1, x2), 1) - x = self.fc_layer(x) - return x + out = torch.cat((patch_output, context_output), 1) + out = self.fc_layer(out) + return (patch_output, context_output), out diff --git a/mapreader/classify/datasets.py b/mapreader/classify/datasets.py index ceef28cc..3ce57c2c 100644 --- a/mapreader/classify/datasets.py +++ b/mapreader/classify/datasets.py @@ -2,14 +2,15 @@ from __future__ import annotations import os +from ast import literal_eval +from itertools import product from typing import Callable import matplotlib.pyplot as plt -import numpy as np import pandas as pd import torch -from PIL import Image, ImageOps -from torch.utils.data import Dataset, DataLoader +from PIL import Image +from torch.utils.data import DataLoader, Dataset from torchvision import transforms # Import parhugin @@ -105,6 +106,8 @@ def __init__( if os.path.isfile(patch_df): print(f'[INFO] Reading "{patch_df}".') patch_df = pd.read_csv(patch_df, sep=delimiter) + # ensure tuple/list columns are read as such + patch_df = self._eval_df(patch_df) self.patch_df = patch_df else: raise ValueError(f'[ERROR] "{patch_df}" cannot be found.') @@ -114,6 +117,12 @@ def __init__( "[ERROR] Please pass ``patch_df`` as a string (path to csv file) or pd.DataFrame." ) + # force index to be integer + if self.patch_df.index.name == "image_id": + if "image_id" in self.patch_df.columns: + self.patch_df.drop(columns=["image_id"], inplace=True) + self.patch_df.reset_index(drop=False, names="image_id", inplace=True) + self.label_col = label_col self.label_index_col = label_index_col self.image_mode = image_mode @@ -152,6 +161,15 @@ def __init__( else: self.transform = transform + @staticmethod + def _eval_df(df): + for col in df.columns: + try: + df[col] = df[col].apply(literal_eval) + except (ValueError, TypeError, SyntaxError): + pass + return df + def __len__(self) -> int: """ Return the length of the dataset. @@ -163,7 +181,9 @@ def __len__(self) -> int: """ return len(self.patch_df) - def __getitem__(self, idx: int | torch.Tensor) -> tuple[torch.Tensor, str, int]: + def __getitem__( + self, idx: int | torch.Tensor + ) -> tuple[tuple[torch.Tensor], str, int]: """ Return the image, its label and the index of that label at the given index in the dataset. @@ -206,7 +226,7 @@ def __getitem__(self, idx: int | torch.Tensor) -> tuple[torch.Tensor, str, int]: else: image_label_index = -1 - return img, image_label, image_label_index + return (img,), (image_label,), (image_label_index,) def return_orig_image(self, idx: int | torch.Tensor) -> Image: """ @@ -325,9 +345,9 @@ def _get_label_index(self, label: str) -> int: def create_dataloaders( self, set_name: str = "infer", - batch_size: Optional[int] = 16, - shuffle: Optional[bool] = False, - num_workers: Optional[int] = 0, + batch_size: int = 16, + shuffle: bool = False, + num_workers: int = 0, **kwargs, ) -> None: """Creates a dictionary containing a PyTorch dataloader. @@ -338,7 +358,7 @@ def create_dataloaders( The name to use for the dataloader. batch_size : int, optional The batch size to use for the dataloader. By default ``16``. - shuffle : Optional[bool], optional + shuffle : bool, optional Whether to shuffle the PatchDataset, by default False num_workers : int, optional The number of worker threads to use for loading data. By default ``0``. @@ -351,34 +371,36 @@ def create_dataloaders( Dictionary containing dataloaders. """ - dataloaders = {set_name: DataLoader( - self, - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - **kwargs, - )} + dataloaders = { + set_name: DataLoader( + self, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + **kwargs, + ) + } return dataloaders + # --- Dataset that returns an image, its context and its label class PatchContextDataset(PatchDataset): def __init__( self, patch_df: pd.DataFrame | str, - transform1: str, - transform2: str, + patch_transform: str, + context_transform: str, delimiter: str = ",", patch_paths_col: str | None = "image_path", label_col: str | None = None, label_index_col: str | None = None, + context_label_col: str | None = None, + context_label_index_col: str | None = None, image_mode: str | None = "RGB", - context_save_path: str | None = "./maps/maps_context", - create_context: bool | None = False, + context_dir: str | None = "./maps/maps_context", + create_context: bool = False, parent_path: str | None = "./maps", - x_offset: float | None = 1.0, - y_offset: float | None = 1.0, - slice_method: str | None = "scale", ): """ A PyTorch Dataset class for loading contextual information about image @@ -388,10 +410,10 @@ def __init__( ---------- patch_df : pandas.DataFrame or str DataFrame or path to csv file containing the paths to image patches and their labels. - transform1 : str + patch_transform : str Torchvision transform to be applied to input images. Either "train" or "val". - transform2 : str + context_transform : str Torchvision transform to be applied to target images. Either "train" or "val". delimiter : str @@ -402,24 +424,20 @@ def __init__( The name of the column containing the image labels. Default is None. label_index_col : str, optional The name of the column containing the indices of the image labels. Default is None. + context_label_col : str, optional + The name of the column containing the context labels. Default is None. + context_label_index_col : str, optional + The name of the column containing the indices of the context labels. Default is None. image_mode : str, optional The color space of the images. Default is "RGB". - context_save_path : str, optional - The path to save context maps to. Default is "./maps/maps_context". + context_dir : str, optional + The path to context maps (or, where to save context if not created yet). + Default is "./maps/maps_context". create_context : bool, optional Whether or not to create context maps. Default is False. parent_path : str, optional The path to the directory containing parent images. Default is "./maps". - x_offset : float, optional - The size of the horizontal offset around objects, as a fraction of - the image width. Default is 1.0. - y_offset : float, optional - The size of the vertical offset around objects, as a fraction of - the image height. Default is 1.0. - slice_method : str, optional - The method used to slice images. Either "scale" or "absolute". - Default is "scale". Attributes ---------- @@ -430,6 +448,10 @@ def __init__( The name of the column containing the image labels. label_index_col : str The name of the column containing the labels indices. + context_label_col : str + The name of the column containing the context labels. + context_label_index_col : str + The name of the column containing the context labels indices. patch_paths_col : str The name of the column in the DataFrame containing the image paths. @@ -437,33 +459,12 @@ def __init__( The color space of the images. parent_path : str The path to the directory containing parent images. - x_offset : float - The size of the horizontal offset around objects, as a fraction of - the image width. - y_offset : float - The size of the vertical offset around objects, as a fraction of - the image height. - slice_method : str - The method used to slice images. create_context : bool Whether or not to create context maps. - context_save_path : str - The path to save context maps to. + context_dir : str + The path to context maps. unique_labels : list or str - The unique labels in ``label_col``, or "NS" if ``label_col`` not in - ``patch_df``. - - Methods - ---------- - __getitem__(idx) - Retrieves the patch image, the context image and the label at the - given index in the dataset. - save_parents() - Saves parent images. - save_parents_idx(idx) - Saves parent image at index ``idx``. - return_orig_image(idx) - Return the original image associated with the given index. + The unique labels in ``label_col``. """ if isinstance(patch_df, pd.DataFrame): @@ -482,72 +483,87 @@ def __init__( "[ERROR] Please pass ``patch_df`` as a string (path to csv file) or pd.DataFrame." ) + # force index to be integer + if self.patch_df.index.name in ["image_id", "name"]: + if "image_id" in self.patch_df.columns: + self.patch_df.drop(columns=["image_id"], inplace=True) + self.patch_df.reset_index(drop=False, names="image_id", inplace=True) + self.label_col = label_col self.label_index_col = label_index_col + self.context_label_col = context_label_col + self.context_label_index_col = context_label_index_col self.image_mode = image_mode self.patch_paths_col = patch_paths_col self.parent_path = parent_path - self.x_offset = x_offset - self.y_offset = y_offset - self.slice_method = slice_method self.create_context = create_context - self.context_save_path = os.path.abspath( - context_save_path - ) # we need this either way I think? + self.context_dir = os.path.abspath(context_dir) if self.label_col: if self.label_col not in self.patch_df.columns: raise ValueError( - f"[ERROR] Label column ({label_col}) not in dataframe." + f"[ERROR] Label column ({self.label_col}) not in dataframe." ) + if self.context_label_col: + if self.context_label_col not in self.patch_df.columns: + raise ValueError( + f"[ERROR] Context label column ({self.context_label_col}) not in dataframe." + ) + else: + unique_labels = ( + self.patch_df[self.label_col].unique().tolist() + + self.patch_df[self.context_label_col].unique().tolist() + ) + self.unique_labels = list(set(unique_labels)) else: self.unique_labels = self.patch_df[self.label_col].unique().tolist() if self.label_index_col: if self.label_index_col not in self.patch_df.columns: - if self.label_col: + print( + f"[INFO] Label index column ({label_index_col}) not in dataframe. Creating column." + ) + self.patch_df[self.label_index_col] = self.patch_df[ + self.label_col + ].apply(self._get_label_index) + if self.context_label_index_col: + if self.context_label_index_col not in self.patch_df.columns: print( - f"[INFO] Label index column ({label_index_col}) not in dataframe. Creating column." + f"[INFO] Context label index column ({context_label_index_col}) not in dataframe. Creating column." ) - self.patch_df[self.label_index_col] = self.patch_df[ - self.label_col + self.patch_df[self.context_label_index_col] = self.patch_df[ + self.context_label_col ].apply(self._get_label_index) - else: - raise ValueError( - f"[ERROR] Label index column ({label_index_col}) not in dataframe." - ) - if isinstance(transform1, str): - if transform1 in ["train", "val", "test"]: - self.transform1 = self._default_transform(transform1) + if isinstance(patch_transform, str): + if patch_transform in ["train", "val", "test"]: + self.patch_transform = self._default_transform(patch_transform) else: raise ValueError( '[ERROR] ``transform`` can only be "train", "val" or "test" or, a transform.' ) else: - self.transform1 = transform1 + self.patch_transform = patch_transform - if isinstance(transform2, str): - if transform2 in ["train", "val", "test"]: - self.transform2 = self._default_transform(transform2) + if isinstance(context_transform, str): + if context_transform in ["train", "val", "test"]: + self.context_transform = self._default_transform(context_transform) else: raise ValueError( '[ERROR] ``transform`` can only be "train", "val" or "test" or, a transform.' ) else: - self.transform2 = transform2 + self.context_transform = context_transform - def save_parents( + def save_context( self, - processors: int | None = 10, - sleep_time: float | None = 0.001, - use_parhugin: bool | None = True, - parent_delimiter: str | None = "#", - loc_delimiter: str | None = "-", - overwrite: bool | None = False, + processors: int = 10, + sleep_time: float = 0.001, + use_parhugin: bool = True, + overwrite: bool = False, ) -> None: """ - Save parent patches for all patches in the patch_df. + Save context images for all patches in the patch_df. Parameters ---------- @@ -556,17 +572,9 @@ def save_parents( sleep_time : float, optional The time to wait between jobs, by default 0.001. use_parhugin : bool, optional - Flag indicating whether to use Parhugin to parallelize the job, by - default True. - parent_delimiter : str, optional - The delimiter used to separate parent IDs in the patch filename, by - default "#". - loc_delimiter : str, optional - The delimiter used to separate patch pixel bounds in the patch - filename, by default "-". + Whether to use Parhugin to parallelize the job, by default True. overwrite : bool, optional - Flag indicating whether to overwrite existing parent files, by - default False. + Whether to overwrite existing parent files, by default False. Returns ------- @@ -578,37 +586,57 @@ def save_parents( multiple CPU cores. The method uses Parhugin to parallelize the computation of saving parent patches to disk. When Parhugin is installed and ``use_parhugin`` is set to True, the method parallelizes - the calling of the ``save_parents_idx`` method and its corresponding + the calling of the ``get_context_id`` method and its corresponding arguments. If Parhugin is not installed or ``use_parhugin`` is set to False, the method executes the loop over patch indices sequentially instead. """ if parhugin_installed and use_parhugin: - myproc = multiFunc(processors=processors, sleep_time=sleep_time) + my_proc = multiFunc(processors=processors, sleep_time=sleep_time) list_jobs = [] - for idx in range(len(self.patch_df)): + for idx in self.patch_df.index: list_jobs.append( [ - self.save_parents_idx, - (idx, parent_delimiter, loc_delimiter, overwrite), + self.save_context_id( + idx, + overwrite=overwrite, + save_context=True, + return_image=False, + ), ] ) print(f"Total number of jobs: {len(list_jobs)}") - # and then adding them to myproc - myproc.add_list_jobs(list_jobs) - myproc.run_jobs() + # and then adding them to my_proc + my_proc.add_list_jobs(list_jobs) + my_proc.run_jobs() else: - for idx in range(len(self.patch_df)): - self.save_parents_idx(idx) + for idx in self.patch_df.index: + self.get_context_id( + idx, + overwrite=overwrite, + save_context=True, + return_image=False, + ) - def save_parents_idx( + @staticmethod + def _get_empty_square( + patch_size: tuple[int, int], + ): + """Get an empty square image with size (width, height) equal to `patch_size`.""" + im = Image.new( + size=patch_size, + mode="RGB", + color=None, + ) + return im + + def get_context_id( self, idx: int, - parent_delimiter: str | None = "#", - loc_delimiter: str | None = "-", - overwrite: bool | None = False, - return_image: bool | None = False, + overwrite: bool = False, + save_context: bool = False, + return_image: bool = True, ) -> None: """ Save the parents of a specific patch to the specified location. @@ -617,15 +645,13 @@ def save_parents_idx( ---------- idx : int Index of the patch in the dataset. - parent_delimiter : str, optional - Delimiter to split the parent names in the file path. Default - is "#". - loc_delimiter : str, optional - Delimiter to split the location of the patch in the file path. - Default is "-". overwrite : bool, optional Whether to overwrite the existing parent files. Default is False. + save_context : bool, optional + Whether to save the context image. Default is False. + return_image : bool, optional + Whether to return the context image. Default is True. Raises ------ @@ -636,84 +662,107 @@ def save_parents_idx( ------- None """ - img_path = self.patch_df.iloc[idx][self.patch_paths_col] - - if os.path.exists(img_path): - img = Image.open(img_path).convert(self.image_mode) - else: - raise ValueError( - f'[ERROR] "{img_path} cannot be found.\n\n\ -Please check the image exists, your file paths are correct and that ``.patch_paths_col`` is set to the correct column.' - ) - - if not return_image: - os.makedirs(self.context_save_path, exist_ok=True) - - path2save_context = os.path.join( - self.context_save_path, os.path.basename(img_path) - ) + patch_df = self.patch_df.copy(deep=True) - if os.path.isfile(path2save_context) and (not overwrite): - return + if not all( + [col in patch_df.columns for col in ["min_x", "min_y", "max_x", "max_y"]] + ): + patch_df[["min_x", "min_y", "max_x", "max_y"]] = [*patch_df.pixel_bounds] - if self.slice_method in ["scale"]: - # size: (width, height) - tar_y_offset = int(img.size[1] * self.y_offset) - tar_x_offset = int(img.size[0] * self.x_offset) - else: - tar_y_offset = self.y_offset - tar_x_offset = self.x_offset - - par_name = os.path.basename(img_path).split(parent_delimiter)[1] - split_path = os.path.basename(img_path).split(loc_delimiter) - min_x, min_y, max_x, max_y = ( - int(split_path[1]), - int(split_path[2]), - int(split_path[3]), - int(split_path[4]), + patch_image = Image.open(patch_df.iloc[idx][self.patch_paths_col]).convert( + self.image_mode ) - - if self.parent_path in ["dynamic"]: - parent_path2read = os.path.join( - os.path.dirname(os.path.dirname(os.path.abspath(img_path))), - par_name, + patch_width, patch_height = (patch_image.width, patch_image.height) + parent_id = patch_df.iloc[idx]["parent_id"] + min_x = patch_df.iloc[idx]["min_x"] + min_y = patch_df.iloc[idx]["min_y"] + max_x = patch_df.iloc[idx]["max_x"] + max_y = patch_df.iloc[idx]["max_y"] + + # get a pixel bounds of context images + context_grid = [ + *product( + [ + (patch_df["min_y"], min_y), + (min_y, max_y), + (max_y, patch_df["max_y"]), + ], + [ + (patch_df["min_x"], min_x), + (min_x, max_x), + (max_x, patch_df["max_x"]), + ], ) - else: - parent_path2read = os.path.join(os.path.abspath(self.parent_path), par_name) - - par_img = Image.open(parent_path2read).convert(self.image_mode) - - min_y_par = max(0, min_y - tar_y_offset) - min_x_par = max(0, min_x - tar_x_offset) - max_x_par = min(max_x + tar_x_offset, np.shape(par_img)[1]) - max_y_par = min(max_y + tar_y_offset, np.shape(par_img)[0]) - - pad_activate = False - top_pad = left_pad = right_pad = bottom_pad = 0 - if (min_y - tar_y_offset) < 0: - top_pad = abs(min_y - tar_y_offset) - pad_activate = True - if (min_x - tar_x_offset) < 0: - left_pad = abs(min_x - tar_x_offset) - pad_activate = True - if (max_x + tar_x_offset) > np.shape(par_img)[1]: - right_pad = max_x + tar_x_offset - np.shape(par_img)[1] - pad_activate = True - if (max_y + tar_y_offset) > np.shape(par_img)[0]: - bottom_pad = max_y + tar_y_offset - np.shape(par_img)[0] - pad_activate = True - - # par_img = par_img[min_y_par:max_y_par, min_x_par:max_x_par] - par_img = par_img.crop((min_x_par, min_y_par, max_x_par, max_y_par)) - - if pad_activate: - padding = (left_pad, top_pad, right_pad, bottom_pad) - par_img = ImageOps.expand(par_img, padding) + ] + # reshape to min_x, min_y, max_x, max_y + context_grid = [ + (coord[1][0], coord[0][0], coord[1][1], coord[0][1]) + for coord in context_grid + ] + + # get a list of context images + context_list = [ + patch_df[ + (patch_df["min_x"] == context_loc[0]) + & (patch_df["min_y"] == context_loc[1]) + & (patch_df["max_x"] == context_loc[2]) + & (patch_df["max_y"] == context_loc[3]) + & (patch_df["parent_id"] == parent_id) + ] + for context_loc in context_grid + ] + if any([len(context_patch) > 1 for context_patch in context_list]): + raise ValueError(f"[ERROR] Multiple context patches found for patch {idx}.") + if len(context_list) != 9: + raise ValueError(f"[ERROR] Missing context images for patch {idx}.") + + context_paths = [ + ( + context_patch[self.patch_paths_col].values[0] + if len(context_patch) + else None + ) + for context_patch in context_list + ] + context_images = [ + ( + Image.open(context_path).convert(self.image_mode) + if context_path is not None + else self._get_empty_square((patch_width, patch_height)) + ) + for context_path in context_paths + ] + + # split into rows (3x3 grid) + context_images = [ + context_images[i : i + 3] for i in range(0, len(context_images), 3) + ] + + total_width = 3 * patch_width + total_height = 3 * patch_height + context_image = Image.new(self.image_mode, (total_width, total_height)) + + y_offset = 0 + for row in context_images: + x_offset = 0 + for image in row: + context_image.paste(image, (x_offset, y_offset)) + x_offset += patch_width + y_offset += patch_height + + if save_context: + os.makedirs(self.context_dir, exist_ok=True) + context_path = os.path.join( + self.context_dir, + os.path.basename(patch_df.iloc[idx][self.patch_paths_col]), + ) + if overwrite or not os.path.exists(context_path): + context_image.save(context_path) if return_image: - return par_img - elif not os.path.isfile(path2save_context): - par_img.save(path2save_context) + return context_image + else: + return def plot_sample(self, idx: int) -> None: """ @@ -740,13 +789,13 @@ def plot_sample(self, idx: int) -> None: """ plt.figure(figsize=(10, 5)) plt.subplot(1, 2, 1) - plt.imshow(transforms.ToPILImage()(self.__getitem__(idx)[0])) + plt.imshow(transforms.ToPILImage()(self.__getitem__(idx)[0][0])) plt.title("Patch", size=18) plt.xticks([]) plt.yticks([]) plt.subplot(1, 2, 2) - plt.imshow(transforms.ToPILImage()(self.__getitem__(idx)[1])) + plt.imshow(transforms.ToPILImage()(self.__getitem__(idx)[0][1])) plt.title("Context", size=18) plt.xticks([]) plt.yticks([]) @@ -755,7 +804,7 @@ def plot_sample(self, idx: int) -> None: def __getitem__( self, idx: int | torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor, str, int]: + ) -> tuple[tuple[torch.Tensor, torch.Tensor], str, int]: """ Retrieves the patch image, the context image and the label at the given index in the dataset (``idx``). @@ -789,23 +838,37 @@ def __getitem__( ) if self.create_context: - context_img = self.save_parents_idx(idx, return_image=True) + context_img = self.get_context_id(idx, return_image=True) else: context_img = Image.open( - os.path.join(self.context_save_path, os.path.basename(img_path)) + os.path.join(self.context_dir, os.path.basename(img_path)) ).convert(self.image_mode) - img = self.transform1(img) - context_img = self.transform2(context_img) + img = self.patch_transform(img) + context_img = self.context_transform(context_img) if self.label_col in self.patch_df.iloc[idx].keys(): image_label = self.patch_df.iloc[idx][self.label_col] else: image_label = "" + if self.context_label_col in self.patch_df.iloc[idx].keys(): + context_label = self.patch_df.iloc[idx][self.context_label_col] + else: + context_label = "" + if self.label_index_col in self.patch_df.iloc[idx].keys(): image_label_index = self.patch_df.iloc[idx][self.label_index_col] else: image_label_index = -1 - return img, context_img, image_label, image_label_index + if self.context_label_index_col in self.patch_df.iloc[idx].keys(): + context_label_index = self.patch_df.iloc[idx][self.context_label_index_col] + else: + context_label_index = -1 + + return ( + (img, context_img), + (image_label, context_label), + (image_label_index, context_label_index), + ) diff --git a/mapreader/classify/load_annotations.py b/mapreader/classify/load_annotations.py index aae7a97d..899a6ba7 100644 --- a/mapreader/classify/load_annotations.py +++ b/mapreader/classify/load_annotations.py @@ -14,7 +14,7 @@ from torch.utils.data import DataLoader, Sampler, WeightedRandomSampler from torchvision.transforms import Compose -from .datasets import PatchDataset +from .datasets import PatchContextDataset, PatchDataset class AnnotationsLoader: @@ -92,6 +92,7 @@ def load( print( f'[WARNING] ID column was previously "{self.id_col}, but will now be set to {id_col}.' ) + self.id_col = id_col if not self.patch_paths_col: self.patch_paths_col = patch_paths_col @@ -99,6 +100,7 @@ def load( print( f'[WARNING] Patch paths column was previously "{self.patch_paths_col}, but will now be set to {patch_paths_col}.' ) + self.patch_paths_col = patch_paths_col if not self.label_col: self.label_col = label_col @@ -106,6 +108,7 @@ def load( print( f'[WARNING] Label column was previously "{self.label_col}, but will now be set to {label_col}.' ) + self.label_col = label_col if not isinstance(annotations, (str, pd.DataFrame)): raise ValueError( @@ -115,6 +118,7 @@ def load( annotations = self._load_annotations_csv( annotations, delimiter, scramble_frame, reset_index ) + context_labels = True if "context_label" in annotations.columns else False if images_dir: abs_images_dir = os.path.abspath(images_dir) @@ -122,8 +126,9 @@ def load( lambda x: os.path.join(abs_images_dir, x) ) + cols = [self.label_col, "context_label"] if context_labels else [self.label_col] annotations = annotations.astype( - {self.label_col: str} + {col: str for col in cols} ) # ensure labels are interpreted as strings if append: @@ -136,14 +141,22 @@ def load( ) unique_labels = self.annotations[self.label_col].unique().tolist() + if context_labels: + unique_labels.extend(self.annotations["context_label"].unique().tolist()) + unique_labels = list(set(unique_labels)) self.unique_labels = unique_labels - self.annotations["label_index"] = self.annotations[self.label_col].apply( - self._get_label_index - ) labels_map = {i: label for i, label in enumerate(unique_labels)} self.labels_map = labels_map + self.annotations["label_index"] = self.annotations[self.label_col].apply( + self._get_label_index + ) + if context_labels: + self.annotations["context_label_index"] = self.annotations[ + "context_label" + ].apply(self._get_label_index) + print(self) def _load_annotations_csv( @@ -506,6 +519,7 @@ def create_datasets( train_transform: str | (Compose | Callable) | None = "train", val_transform: str | (Compose | Callable) | None = "val", test_transform: str | (Compose | Callable) | None = "test", + context_datasets: bool = False, ) -> None: """ Splits the dataset into three subsets: training, validation, and test sets (DataFrames) and saves them as a dictionary in ``self.datasets``. @@ -535,6 +549,8 @@ def create_datasets( The transform to use on the test dataset images. Options are "train", "test" or "val" or, a callable object (e.g. a torchvision transform or torchvision.transforms.Compose). By default "test". + context_datasets: bool, optional + Whether to create context datasets or not. By default False. Raises @@ -599,6 +615,39 @@ def create_datasets( df_test = None assert len(self.annotations) == len(df_train) + len(df_val) + if context_datasets: + datasets = self.create_patch_context_datasets( + train_transform, + val_transform, + test_transform, + df_train, + df_val, + df_test, + ) + else: + datasets = self.create_patch_datasets( + train_transform, + val_transform, + test_transform, + df_train, + df_val, + df_test, + ) + + dataset_sizes = { + set_name: len(datasets[set_name]) for set_name in datasets.keys() + } + + self.datasets = datasets + self.dataset_sizes = dataset_sizes + + print("[INFO] Number of annotations in each set:") + for set_name in datasets.keys(): + print(f" - {set_name}: {dataset_sizes[set_name]}") + + def create_patch_datasets( + self, train_transform, val_transform, test_transform, df_train, df_val, df_test + ): train_dataset = PatchDataset( df_train, train_transform, @@ -630,16 +679,55 @@ def create_datasets( else: datasets = {"train": train_dataset, "val": val_dataset} - dataset_sizes = { - set_name: len(datasets[set_name]) for set_name in datasets.keys() - } + return datasets - self.datasets = datasets - self.dataset_sizes = dataset_sizes + def create_patch_context_datasets( + self, train_transform, val_transform, test_transform, df_train, df_val, df_test + ): + train_dataset = PatchContextDataset( + df_train, + train_transform, + train_transform, + patch_paths_col=self.patch_paths_col, + label_col=self.label_col, + label_index_col="label_index", + context_label_col="context_label", + context_label_index_col="context_label_index", + create_context=True, + ) + val_dataset = PatchContextDataset( + df_val, + val_transform, + val_transform, + patch_paths_col=self.patch_paths_col, + label_col=self.label_col, + label_index_col="label_index", + context_label_col="context_label", + context_label_index_col="context_label_index", + create_context=True, + ) + if df_test is not None: + test_dataset = PatchContextDataset( + df_test, + test_transform, + test_transform, + patch_paths_col=self.patch_paths_col, + label_col=self.label_col, + label_index_col="label_index", + context_label_col="context_label", + context_label_index_col="context_label_index", + create_context=True, + ) + datasets = { + "train": train_dataset, + "val": val_dataset, + "test": test_dataset, + } - print("[INFO] Number of annotations in each set:") - for set_name in datasets.keys(): - print(f" - {set_name}: {dataset_sizes[set_name]}") + else: + datasets = {"train": train_dataset, "val": val_dataset} + + return datasets def create_dataloaders( self, diff --git a/mapreader/load/images.py b/mapreader/load/images.py index 545239af..ac4ad991 100644 --- a/mapreader/load/images.py +++ b/mapreader/load/images.py @@ -19,7 +19,7 @@ import pandas as pd import PIL import rasterio -from PIL import Image, ImageStat +from PIL import Image, ImageOps, ImageStat from pyproj import Transformer from rasterio.plot import reshape_as_raster from shapely import wkt @@ -985,7 +985,6 @@ def patchify_all( tree_level: str | None = "parent", path_save: str | None = None, add_to_parents: bool | None = True, - square_cuts: bool | None = False, resize_factor: bool | None = False, output_format: str | None = "png", rewrite: bool | None = False, @@ -1012,9 +1011,6 @@ def patchify_all( add_to_parents : bool, optional If True, patches will be added to the MapImages instance's ``images`` dictionary, by default ``True``. - square_cuts : bool, optional - If True, all patches will have the same number of pixels in - x and y, by default ``False``. resize_factor : bool, optional If True, resize the images before patchifying, by default ``False``. output_format : str, optional @@ -1067,7 +1063,6 @@ def patchify_all( patch_size=patch_size, path_save=path_save, add_to_parents=add_to_parents, - square_cuts=square_cuts, resize_factor=resize_factor, output_format=output_format, rewrite=rewrite, @@ -1080,7 +1075,6 @@ def _patchify_by_pixel( patch_size: int, path_save: str, add_to_parents: bool | None = True, - square_cuts: bool | None = False, resize_factor: bool | None = False, output_format: str | None = "png", rewrite: bool | None = False, @@ -1099,9 +1093,6 @@ def _patchify_by_pixel( add_to_parents : bool, optional If True, patches will be added to the MapImages instance's ``images`` dictionary, by default ``True``. - square_cuts : bool, optional - If True, all patches will have the same number of pixels in - x and y, by default ``False``. resize_factor : bool, optional If True, resize the images before patchifying, by default ``False``. output_format : str, optional @@ -1133,15 +1124,8 @@ def _patchify_by_pixel( max_x = min(x + patch_size, width) max_y = min(y + patch_size, height) - if ( - square_cuts - ): # move min_x and min_y back a bit so the patch is square - min_x = x - (patch_size - (max_x - x)) - min_y = y - (patch_size - (max_y - y)) - - else: - min_x = x - min_y = y + min_x = x + min_y = y patch_id = f"patch-{min_x}-{min_y}-{max_x}-{max_y}-#{image_id}#.{output_format}" patch_path = os.path.join(path_save, patch_id) @@ -1153,12 +1137,22 @@ def _patchify_by_pixel( ) else: - self._print_if_verbose( - f'[INFO] Creating "{patch_id}". Number of pixels in x,y: {max_x - min_x},{max_y - min_y}.', - verbose, - ) - patch = img.crop((min_x, min_y, max_x, max_y)) + if max_x == width: + patch = ImageOps.pad( + patch, (patch_size, patch.height), centering=(0, 0) + ) + if max_y == height: + patch = ImageOps.pad( + patch, (patch.width, patch_size), centering=(0, 0) + ) + + # check patch size + if patch.height != patch_size or patch.width != patch_size: + raise ValueError( + f"[ERROR] Patch size is {patch.height}x{patch.width} instead of {patch_size}x{patch_size}." + ) + patch.save(patch_path, output_format) if add_to_parents: @@ -2300,6 +2294,7 @@ def _save_patch_as_geotiff( patch_path = self.patches[patch_id]["image_path"] patch_dir = os.path.dirname(patch_path) + patch = Image.open(patch_path) if not os.path.exists(patch_dir): raise ValueError(f'[ERROR] Patch directory "{patch_dir}" does not exist.') @@ -2334,8 +2329,16 @@ def _save_patch_as_geotiff( if not crs: crs = self.patches[patch_id].get("crs", "EPSG:4326") + # for edge patches, crop the patch to the correct size first + min_x, min_y, max_x, max_y = self.patches[patch_id]["pixel_bounds"] + if width != max_x - min_x: + width = max_x - min_x + patch = patch.crop((0, 0, width, height)) + if height != max_y - min_y: + height = max_y - min_y + patch = patch.crop((0, 0, width, height)) + patch_affine = rasterio.transform.from_bounds(*coords, width, height) - patch = Image.open(patch_path) with rasterio.open( f"{geotiff_path}", diff --git a/tests/test_geo_pipeline.py b/tests/test_geo_pipeline.py index 297caf7b..c9d5c656 100644 --- a/tests/test_geo_pipeline.py +++ b/tests/test_geo_pipeline.py @@ -37,7 +37,6 @@ def test_pipeline(tmp_path, sample_dir): my_files.patchify_all( patch_size=300, # in pixels - square_cuts=True, path_save=f"{tmp_path}/patches_300_pixel", )