From db08aa82d8282bf5f9982b2b77861a6233d13ca5 Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Thu, 6 Nov 2025 13:01:38 +1000 Subject: [PATCH 01/18] Added initial problem directory structure --- recognition/{ => ImprovedUNet3D}/README.md | 0 recognition/ImprovedUNet3D/dataset.py | 0 recognition/ImprovedUNet3D/module.py | 0 recognition/ImprovedUNet3D/predict.py | 0 recognition/ImprovedUNet3D/train.py | 0 5 files changed, 0 insertions(+), 0 deletions(-) rename recognition/{ => ImprovedUNet3D}/README.md (100%) create mode 100644 recognition/ImprovedUNet3D/dataset.py create mode 100644 recognition/ImprovedUNet3D/module.py create mode 100644 recognition/ImprovedUNet3D/predict.py create mode 100644 recognition/ImprovedUNet3D/train.py diff --git a/recognition/README.md b/recognition/ImprovedUNet3D/README.md similarity index 100% rename from recognition/README.md rename to recognition/ImprovedUNet3D/README.md diff --git a/recognition/ImprovedUNet3D/dataset.py b/recognition/ImprovedUNet3D/dataset.py new file mode 100644 index 000000000..e69de29bb diff --git a/recognition/ImprovedUNet3D/module.py b/recognition/ImprovedUNet3D/module.py new file mode 100644 index 000000000..e69de29bb diff --git a/recognition/ImprovedUNet3D/predict.py b/recognition/ImprovedUNet3D/predict.py new file mode 100644 index 000000000..e69de29bb diff --git a/recognition/ImprovedUNet3D/train.py b/recognition/ImprovedUNet3D/train.py new file mode 100644 index 000000000..e69de29bb From 0e4cda98439045ea3e49127484c0185e546c4c06 Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Thu, 6 Nov 2025 15:41:30 +1000 Subject: [PATCH 02/18] Added initial dataset implementation --- recognition/ImprovedUNet3D/dataset.py | 158 ++++++++++++++++++++++++++ 1 file changed, 158 insertions(+) diff --git a/recognition/ImprovedUNet3D/dataset.py b/recognition/ImprovedUNet3D/dataset.py index e69de29bb..94cc2077b 100644 --- a/recognition/ImprovedUNet3D/dataset.py +++ b/recognition/ImprovedUNet3D/dataset.py @@ -0,0 +1,158 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, Dataset +import torchvision.transforms as transforms +from torchvision.datasets import OxfordIIITPet +import torchvision.transforms.functional as TF + +import numpy as np +import matplotlib.pyplot as plt +import os +from PIL import Image +from tqdm import tqdm +import random +import nibabel as nib +import utils + +# Check if CUDA is available +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f'Using device: {device}') + +# Set random seeds for reproducibility +torch.manual_seed(42) +np.random.seed(42) +random.seed(42) +if torch.cuda.is_available(): + torch.cuda.manual_seed(42) + +def to_channels(arr: np.ndarray, dtype=np.uint8) -> np.ndarray: + channels = np.unique(arr) + res = np.zeros(arr.shape + (len(channels),), dtype=dtype) + for c in channels: + c = int(c) + res[..., c:c+1][arr == c] = 1 + return res + +def load_data_3D(imageNames, normImage=False, categorical=False, dtype=np.float32, + getAffines=False, orient=False, early_stop=False): + ''' + Load medical image data from names, cases list provided into a list for each. + + This function pre-allocates 5D arrays for conv3d to avoid excessive memory usage. + + normImage: bool (normalise the image 0.0-1.0) + orient: Apply orientation and resample image? Good for images with large slice + thickness or anisotropic resolution + dtype: Type of the data. If dtype=np.uint8, it is assumed that the data is labels + early_stop: Stop loading pre-maturely? Leaves arrays mostly empty, for quick + loading and testing scripts. + ''' + affines = [] + + # interp = 'continuous' + interp = 'linear' + if dtype == np.uint8: # assume labels + interp = 'nearest' + + # get fixed size + num = len(imageNames) + niftiImage = nib.load(imageNames[0]) + if orient: + niftiImage = im.applyOrientation(niftiImage, interpolation=interp, scale=1) + # testResultName = "oriented.nii.gz" + # niftiImage.to_filename(testResultName) + first_case = niftiImage.get_fdata(caching='unchanged') + + if len(first_case.shape) == 4: + first_case = first_case[:, :, :, 0] # sometimes extra dims, remove + + if categorical: + first_case = to_channels(first_case, dtype=dtype) + rows, cols, depth, channels = first_case.shape + images = np.zeros((num, rows, cols, depth, channels), dtype=dtype) + else: + rows, cols, depth = first_case.shape + images = np.zeros((num, rows, cols, depth), dtype=dtype) + + for i, inName in enumerate(tqdm(imageNames)): + niftiImage = nib.load(inName) + if orient: + niftiImage = im.applyOrientation(niftiImage, interpolation=interp, scale=1) + inImage = niftiImage.get_fdata(caching='unchanged') # read disk only + affine = niftiImage.affine + if len(inImage.shape) == 4: + inImage = inImage[:, :, :, 0] # sometimes extra dims in HipMRI_study data + inImage = inImage[:, :, :depth] # clip slices + inImage = inImage.astype(dtype) + + if normImage: + # inImage = inImage / np.linalg.norm(inImage) + # inImage = 255. * inImage / inImage.max() + inImage = (inImage - inImage.mean()) / inImage.std() + + if categorical: + inImage = utils.to_channels(inImage, dtype=dtype) + # images[i, :, :, :, :] = inImage + images[i, :inImage.shape[0], :inImage.shape[1], :inImage.shape[2], :inImage.shape[3]] = inImage # with pad + else: + # images[i, :, :, :] = inImage + images[i, :inImage.shape[0], :inImage.shape[1], :inImage.shape[2]] = inImage # with pad + + affines.append(affine) + + if i > 20 and early_stop: + break + + if getAffines: + return images, affines + else: + return images + +class HipMriDataset3D(Dataset): + """Dataset for prostate cancer 3D Images.""" + + def __init__(self, image_path, mask_path, transform=None): + # Download and load the dataset + self.image_dataset_path = image_path + self.mask_dataset_path = mask_path + self.transform = transform + self.dataset = [] + + image_paths = [os.path.join(self.image_dataset_path, img_name) + for img_name in sorted(os.listdir(self.image_dataset_path))] + mask_paths = [os.path.join(self.mask_dataset_path, mask_name) + for mask_name in sorted(os.listdir(self.image_dataset_path))] + + for case in range(len(image_paths)): + case_image = nib.load(image_paths[case]) + case_mask = nib.load(mask_paths[case]) + self.dataset.append((case_image, case_mask)) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + # Get image and mask + image, mask = self.dataset[idx] + + image_np = image.get_fdata().astype(np.float32) + + mask_np = mask.get_fdata().as_type(np.uint8) # convert + binary_mask = np.zeros_like(mask_np, dtype=np.uint8) + + affine = mask.affine + + # Apply transforms to image + if self.transform: + image_np = self.transform(image_np) + + binary_mask[mask_np != 5] = 0 # prostate_voxels + binary_mask[mask_np == 5] = 1 # prostate voxels + + # Convert to tensor + binary_mask = torch.from_numpy(binary_mask).unsqueeze(0).int() # add channel dim + image_np = torch.from_numpy(image_np).unsqueeze(0).float() # add channel dim + + return image_np, binary_mask, affine \ No newline at end of file From f430d152d090f3c6399f380a42e39c5ae27712aa Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Thu, 6 Nov 2025 17:29:13 +1000 Subject: [PATCH 03/18] Added normalisation to dataset --- recognition/ImprovedUNet3D/dataset.py | 165 ++++++++------------------ 1 file changed, 49 insertions(+), 116 deletions(-) diff --git a/recognition/ImprovedUNet3D/dataset.py b/recognition/ImprovedUNet3D/dataset.py index 94cc2077b..932a9ebb2 100644 --- a/recognition/ImprovedUNet3D/dataset.py +++ b/recognition/ImprovedUNet3D/dataset.py @@ -1,114 +1,48 @@ import torch -import torch.nn as nn +from torch.utils.data import Dataset import torch.nn.functional as F -import torch.optim as optim -from torch.utils.data import DataLoader, Dataset -import torchvision.transforms as transforms -from torchvision.datasets import OxfordIIITPet -import torchvision.transforms.functional as TF import numpy as np -import matplotlib.pyplot as plt import os -from PIL import Image -from tqdm import tqdm -import random import nibabel as nib -import utils - -# Check if CUDA is available -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -print(f'Using device: {device}') - -# Set random seeds for reproducibility -torch.manual_seed(42) -np.random.seed(42) -random.seed(42) -if torch.cuda.is_available(): - torch.cuda.manual_seed(42) - -def to_channels(arr: np.ndarray, dtype=np.uint8) -> np.ndarray: - channels = np.unique(arr) - res = np.zeros(arr.shape + (len(channels),), dtype=dtype) - for c in channels: - c = int(c) - res[..., c:c+1][arr == c] = 1 - return res - -def load_data_3D(imageNames, normImage=False, categorical=False, dtype=np.float32, - getAffines=False, orient=False, early_stop=False): - ''' - Load medical image data from names, cases list provided into a list for each. - - This function pre-allocates 5D arrays for conv3d to avoid excessive memory usage. - - normImage: bool (normalise the image 0.0-1.0) - orient: Apply orientation and resample image? Good for images with large slice - thickness or anisotropic resolution - dtype: Type of the data. If dtype=np.uint8, it is assumed that the data is labels - early_stop: Stop loading pre-maturely? Leaves arrays mostly empty, for quick - loading and testing scripts. - ''' - affines = [] - - # interp = 'continuous' - interp = 'linear' - if dtype == np.uint8: # assume labels - interp = 'nearest' - - # get fixed size - num = len(imageNames) - niftiImage = nib.load(imageNames[0]) - if orient: - niftiImage = im.applyOrientation(niftiImage, interpolation=interp, scale=1) - # testResultName = "oriented.nii.gz" - # niftiImage.to_filename(testResultName) - first_case = niftiImage.get_fdata(caching='unchanged') - - if len(first_case.shape) == 4: - first_case = first_case[:, :, :, 0] # sometimes extra dims, remove - - if categorical: - first_case = to_channels(first_case, dtype=dtype) - rows, cols, depth, channels = first_case.shape - images = np.zeros((num, rows, cols, depth, channels), dtype=dtype) - else: - rows, cols, depth = first_case.shape - images = np.zeros((num, rows, cols, depth), dtype=dtype) - - for i, inName in enumerate(tqdm(imageNames)): - niftiImage = nib.load(inName) - if orient: - niftiImage = im.applyOrientation(niftiImage, interpolation=interp, scale=1) - inImage = niftiImage.get_fdata(caching='unchanged') # read disk only - affine = niftiImage.affine - if len(inImage.shape) == 4: - inImage = inImage[:, :, :, 0] # sometimes extra dims in HipMRI_study data - inImage = inImage[:, :, :depth] # clip slices - inImage = inImage.astype(dtype) - - if normImage: - # inImage = inImage / np.linalg.norm(inImage) - # inImage = 255. * inImage / inImage.max() - inImage = (inImage - inImage.mean()) / inImage.std() - - if categorical: - inImage = utils.to_channels(inImage, dtype=dtype) - # images[i, :, :, :, :] = inImage - images[i, :inImage.shape[0], :inImage.shape[1], :inImage.shape[2], :inImage.shape[3]] = inImage # with pad - else: - # images[i, :, :, :] = inImage - images[i, :inImage.shape[0], :inImage.shape[1], :inImage.shape[2]] = inImage # with pad - - affines.append(affine) - - if i > 20 and early_stop: - break - if getAffines: - return images, affines - else: - return images +def zScoreNormalize(image): + mean = image.mean() + std = image.std() + if std > 0: + image = (image - mean) / std + else: + image = image - mean + return image + +def RandomFlip(image, mask): + axes = [0, 1, 2] # D, H, W axes + for axis in axes: + if np.random.rand() > 0.5: + image = np.flip(image, axis=axis) + mask = np.flip(mask, axis=axis) + return image, mask + +def RandomRotate_90(image, mask): + k = np.random.randint(0, 4) # 0, 90, 180, 270 degrees + axes = (1, 2) # rotate in-plane (H, W) + image = np.rot90(image, k, axes) + mask = np.rot90(mask, k, axes) + return image, mask + +def TrainingTransform(image, mask): + image, mask = RandomFlip(image, mask) + image, mask = RandomRotate_90(image, mask) + image = zScoreNormalize(image) + return image, mask + +def Resize3dTensor(img_tensor, target_shape=(128,128,128), mode_type='trilinear'): + """ + img_tensor: torch tensor of shape (C, D, H, W) + """ + img_tensor = img_tensor.unsqueeze(0) # add batch dim + img_resized = F.interpolate(img_tensor, size=target_shape, mode=mode_type, align_corners=False) + return img_resized.squeeze(0) class HipMriDataset3D(Dataset): """Dataset for prostate cancer 3D Images.""" @@ -123,31 +57,27 @@ def __init__(self, image_path, mask_path, transform=None): image_paths = [os.path.join(self.image_dataset_path, img_name) for img_name in sorted(os.listdir(self.image_dataset_path))] mask_paths = [os.path.join(self.mask_dataset_path, mask_name) - for mask_name in sorted(os.listdir(self.image_dataset_path))] + for mask_name in sorted(os.listdir(self.mask_dataset_path))] for case in range(len(image_paths)): - case_image = nib.load(image_paths[case]) - case_mask = nib.load(mask_paths[case]) - self.dataset.append((case_image, case_mask)) + self.dataset.append((image_paths[case], mask_paths[case])) def __len__(self): return len(self.dataset) def __getitem__(self, idx): # Get image and mask - image, mask = self.dataset[idx] + image = nib.load(self.dataset[idx][0]) + mask = nib.load(self.dataset[idx][1]) image_np = image.get_fdata().astype(np.float32) - - mask_np = mask.get_fdata().as_type(np.uint8) # convert - binary_mask = np.zeros_like(mask_np, dtype=np.uint8) - - affine = mask.affine + mask_np = mask.get_fdata().astype(np.uint8) # convert # Apply transforms to image if self.transform: - image_np = self.transform(image_np) + image_np, mask_np = self.transform(image_np, mask_np) + binary_mask = np.zeros_like(mask_np, dtype=np.uint8) binary_mask[mask_np != 5] = 0 # prostate_voxels binary_mask[mask_np == 5] = 1 # prostate voxels @@ -155,4 +85,7 @@ def __getitem__(self, idx): binary_mask = torch.from_numpy(binary_mask).unsqueeze(0).int() # add channel dim image_np = torch.from_numpy(image_np).unsqueeze(0).float() # add channel dim - return image_np, binary_mask, affine \ No newline at end of file + binary_mask = Resize3dTensor(binary_mask, target_shape=(128,128,128), mode_type='nearest') + image_np = Resize3dTensor(image_np, target_shape=(128,128,128)) + + return image_np, binary_mask \ No newline at end of file From 46879d2565be1da835aefc9f283af0effaad090a Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Thu, 6 Nov 2025 17:31:55 +1000 Subject: [PATCH 04/18] Added first implementation of module --- recognition/ImprovedUNet3D/module.py | 108 +++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/recognition/ImprovedUNet3D/module.py b/recognition/ImprovedUNet3D/module.py index e69de29bb..daa7b509c 100644 --- a/recognition/ImprovedUNet3D/module.py +++ b/recognition/ImprovedUNet3D/module.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# ----------------------------- +# Pre-activation Residual Block (Context Module) +# ----------------------------- +class ResidualBlock3D(nn.Module): + def __init__(self, in_channels, out_channels, dropout=0.3, stride=1): + super().__init__() + self.stride = stride + self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride) + self.norm1 = nn.InstanceNorm3d(out_channels) + self.act1 = nn.LeakyReLU(0.01) + + self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1) + self.norm2 = nn.InstanceNorm3d(out_channels) + self.act2 = nn.LeakyReLU(0.01) + + self.dropout = nn.Dropout3d(dropout) + + if in_channels != out_channels or stride > 1: + self.skip = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride) + else: + self.skip = nn.Identity() + + def forward(self, x): + residual = self.skip(x) + x = self.conv1(x) + x = self.dropout(x) + x = self.norm2(x) + x = self.act2(x) + x = self.conv2(x) + return x + residual + +# ----------------------------- +# Localization Module +# -----------------------------x +class LocalizationModule(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv3x3 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1) + self.act = nn.LeakyReLU(0.01) + self.conv1x1 = nn.Conv3d(out_channels, out_channels // 2, kernel_size=1) + + def forward(self, x): + x = self.act(self.conv3x3(x)) + x = self.conv1x1(x) + return x + +# ----------------------------- +# Upsample by voxel repetition +# ----------------------------- +def upsample_repeat(x): + # Double each spatial dimension by repeating voxels + return x.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) + +# ----------------------------- +# Full 3D U-Net with deep supervision +# ----------------------------- +class ImprovedUNet3D(nn.Module): + def __init__(self, in_channels=1, out_channels=3, base_filters=16, dropout=0.3): + super().__init__() + self.conv1 = nn.Conv3d(in_channels, base_filters, kernel_size=3, padding=1) + + # Encoder / context pathway + self.enc1 = ResidualBlock3D(base_filters, base_filters, dropout) + self.enc2 = ResidualBlock3D(base_filters*2, base_filters*2, dropout, stride=2) + self.enc3 = ResidualBlock3D(base_filters*4, base_filters*4, dropout, stride=2) + self.enc4 = ResidualBlock3D(base_filters*8, base_filters*8, dropout, stride=2) + self.enc5 = ResidualBlock3D(base_filters*16, base_filters*16, dropout, stride=2) + + # Decoder / localization pathway + self.loc3 = LocalizationModule(base_filters*8 + base_filters*4, base_filters*4) + self.loc2 = LocalizationModule(base_filters*4 + base_filters*2, base_filters*2) + self.loc1 = LocalizationModule(base_filters*2 + base_filters, base_filters) + + # Deep supervision outputs + self.ds3 = nn.Conv3d(base_filters*4, out_channels, kernel_size=1) + self.ds2 = nn.Conv3d(base_filters*2, out_channels, kernel_size=1) + self.ds1 = nn.Conv3d(base_filters, out_channels, kernel_size=1) + + def forward(self, x): + # Encoder + e1 = self.enc1(x) + e2 = self.enc2(e1) + e3 = self.enc3(e2) + e4 = self.enc4(e3) + + # Decoder with upsample + concat + localization + u3 = upsample_repeat(e4) + u3 = torch.cat([u3, e3], dim=1) + l3 = self.loc3(u3) + + u2 = upsample_repeat(l3) + u2 = torch.cat([u2, e2], dim=1) + l2 = self.loc2(u2) + + u1 = upsample_repeat(l2) + u1 = torch.cat([u1, e1], dim=1) + l1 = self.loc1(u1) + + # Deep supervision: sum outputs from different levels + out = F.interpolate(self.ds3(l3), size=l1.shape[2:], mode='trilinear', align_corners=False) + out += F.interpolate(self.ds2(l2), size=l1.shape[2:], mode='trilinear', align_corners=False) + out += self.ds1(l1) + + return out \ No newline at end of file From 317684282cb2b89e73bbd13857ef1e0cfd83859f Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Thu, 6 Nov 2025 22:38:13 +1000 Subject: [PATCH 05/18] Finished implementation of models --- recognition/ImprovedUNet3D/module.py | 133 ++++++++++++++++++++------- 1 file changed, 102 insertions(+), 31 deletions(-) diff --git a/recognition/ImprovedUNet3D/module.py b/recognition/ImprovedUNet3D/module.py index daa7b509c..e3be9be2f 100644 --- a/recognition/ImprovedUNet3D/module.py +++ b/recognition/ImprovedUNet3D/module.py @@ -2,6 +2,17 @@ import torch.nn as nn import torch.nn.functional as F +class Upsample3D(nn.Module): + def __init__(self, in_channels, scale_factor=2): + super().__init__() + self.scale_factor = scale_factor + self.conv1 = nn.Conv3d(in_channels, in_channels/2, kernel_size=3, padding=1) + + def forward(self, x): + x.repeat_interleave(self.scale_factor, dim=2).repeat_interleave(self.scale_factor, dim=3).repeat_interleave(self.scale_factor, dim=4) + x = self.conv1(x) + return x + # ----------------------------- # Pre-activation Residual Block (Context Module) # ----------------------------- @@ -61,7 +72,18 @@ def upsample_repeat(x): class ImprovedUNet3D(nn.Module): def __init__(self, in_channels=1, out_channels=3, base_filters=16, dropout=0.3): super().__init__() - self.conv1 = nn.Conv3d(in_channels, base_filters, kernel_size=3, padding=1) + self.upsample1 = Upsample3D(base_filters*16) + self.upsample2 = Upsample3D(base_filters*8) + self.upsample3 = Upsample3D(base_filters*4) + self.upsample4 = Upsample3D(base_filters*2) + + self.convInput = nn.Conv3d(in_channels, base_filters, kernel_size=3, padding=1) + self.convOutput = nn.Conv3d(base_filters*2, base_filters*2, kernel_size=3, padding=1) + + self.StrideConv1 = nn.Conv3d(base_filters*2, base_filters*2, kernel_size=3, padding=1, stride=2) + self.StrideConv2 = nn.Conv3d(base_filters*4, base_filters*4, kernel_size=3, padding=1, stride=2) + self.StrideConv3 = nn.Conv3d(base_filters*8, base_filters*8, kernel_size=3, padding=1, stride=2) + self.StrideConv4 = nn.Conv3d(base_filters*16, base_filters*16, kernel_size=3, padding=1, stride=2) # Encoder / context pathway self.enc1 = ResidualBlock3D(base_filters, base_filters, dropout) @@ -71,38 +93,87 @@ def __init__(self, in_channels=1, out_channels=3, base_filters=16, dropout=0.3): self.enc5 = ResidualBlock3D(base_filters*16, base_filters*16, dropout, stride=2) # Decoder / localization pathway - self.loc3 = LocalizationModule(base_filters*8 + base_filters*4, base_filters*4) - self.loc2 = LocalizationModule(base_filters*4 + base_filters*2, base_filters*2) - self.loc1 = LocalizationModule(base_filters*2 + base_filters, base_filters) - - # Deep supervision outputs - self.ds3 = nn.Conv3d(base_filters*4, out_channels, kernel_size=1) - self.ds2 = nn.Conv3d(base_filters*2, out_channels, kernel_size=1) - self.ds1 = nn.Conv3d(base_filters, out_channels, kernel_size=1) + self.loc3 = LocalizationModule(base_filters*8 + base_filters*8, base_filters*4) + self.loc2 = LocalizationModule(base_filters*4 + base_filters*4, base_filters*2) + self.loc1 = LocalizationModule(base_filters*2 + base_filters*2, base_filters) + + # Segmentation layers + self.seg1 = nn.Conv3d(base_filters*4, 1, kernel_size=1) + self.seg2= nn.Conv3d(base_filters*2, 1, kernel_size=1) + self.seg3 = nn.Conv3d(base_filters*2, 1, kernel_size=1) def forward(self, x): + x = self.convInput(x) # Encoder e1 = self.enc1(x) - e2 = self.enc2(e1) - e3 = self.enc3(e2) - e4 = self.enc4(e3) - - # Decoder with upsample + concat + localization - u3 = upsample_repeat(e4) - u3 = torch.cat([u3, e3], dim=1) - l3 = self.loc3(u3) - - u2 = upsample_repeat(l3) - u2 = torch.cat([u2, e2], dim=1) - l2 = self.loc2(u2) - - u1 = upsample_repeat(l2) - u1 = torch.cat([u1, e1], dim=1) - l1 = self.loc1(u1) - - # Deep supervision: sum outputs from different levels - out = F.interpolate(self.ds3(l3), size=l1.shape[2:], mode='trilinear', align_corners=False) - out += F.interpolate(self.ds2(l2), size=l1.shape[2:], mode='trilinear', align_corners=False) - out += self.ds1(l1) + + e2 = self.StrideConv1(e1) + e2 = self.enc2(e2) + + e3 = self.StrideConv2(e2) + e3 = self.enc3(e3) + + e4 = self.StrideConv3(e3) + e4 = self.enc4(e4) + + e5 = self.StrideConv4(e4) + e5 = self.enc5(e5) + + u1 = self.upsample1(e5) + u1 = torch.concat((u1, e4)) + u1 = self.loc3(u1) + + u2 = self.upsample2(u1) + u2 = torch.concat(u2, e3) + u2 = self.loc2(u2) + + res1 = self.seg1(u1) + res1 = F.interpolate(res1, size=e1.shape[2:], mode='nearest') + + u3 = self.upsample3(u2) + u3 = torch.concat(u3, e2) + u3 = self.loc1(u3) + + res2 = self.seg2(u3) + res2 = res1 + res2 + res2 = F.interpolate(res2, size=e1.shape[2:], mode='nearest') + + u4 = self.upsample4(u3) + u4 = torch.concat(u4, e1) + + out = self.convOutput(u4) + out = self.seg3(out) + out = res2 + out + out = F.softmax(out, dim=1) - return out \ No newline at end of file + return out + +class DiceLoss(nn.Module): + """Dice Loss for binary segmentation. + + Dice Loss = 1 - Dice Coefficient + Dice Coefficient = (2 * |X โˆฉ Y|) / (|X| + |Y|) + + Args: + smooth (float): Smoothing factor to avoid division by zero (default: 1e-6) + """ + def __init__(self, smooth=1e-6): + super(DiceLoss, self).__init__() + self.smooth = smooth + + def forward(self, predictions, targets): + """ + Args: + predictions: Sigmoid output from model [B, H, W] (values between 0-1) + targets: Binary ground truth [B, H, W] (values 0 or 1) + """ + # Flatten tensors using reshape to handle non-contiguous memory layout + predictions = predictions.reshape(-1) + targets = targets.reshape(-1).float() + + # Calculate intersection and union + intersection = (predictions * targets).sum() + dice_coeff = (2.0 * intersection + self.smooth) / (predictions.sum() + targets.sum() + self.smooth) + + # Return Dice Loss (1 - Dice Coefficient) + return 1 - dice_coeff \ No newline at end of file From 734b7d93f410781cb3a4d4cf359c6b855ada9c32 Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Fri, 7 Nov 2025 00:28:17 +1000 Subject: [PATCH 06/18] Added training loop, and debugged modules for shape mismatch --- recognition/ImprovedUNet3D/dataset.py | 13 +++- recognition/ImprovedUNet3D/module.py | 61 +++++++++--------- recognition/ImprovedUNet3D/train.py | 93 +++++++++++++++++++++++++++ 3 files changed, 133 insertions(+), 34 deletions(-) diff --git a/recognition/ImprovedUNet3D/dataset.py b/recognition/ImprovedUNet3D/dataset.py index 932a9ebb2..08846a839 100644 --- a/recognition/ImprovedUNet3D/dataset.py +++ b/recognition/ImprovedUNet3D/dataset.py @@ -41,13 +41,13 @@ def Resize3dTensor(img_tensor, target_shape=(128,128,128), mode_type='trilinear' img_tensor: torch tensor of shape (C, D, H, W) """ img_tensor = img_tensor.unsqueeze(0) # add batch dim - img_resized = F.interpolate(img_tensor, size=target_shape, mode=mode_type, align_corners=False) + img_resized = F.interpolate(img_tensor, size=target_shape, mode=mode_type) return img_resized.squeeze(0) class HipMriDataset3D(Dataset): """Dataset for prostate cancer 3D Images.""" - def __init__(self, image_path, mask_path, transform=None): + def __init__(self, image_path, mask_path, transform=None, train=True): # Download and load the dataset self.image_dataset_path = image_path self.mask_dataset_path = mask_path @@ -59,6 +59,13 @@ def __init__(self, image_path, mask_path, transform=None): mask_paths = [os.path.join(self.mask_dataset_path, mask_name) for mask_name in sorted(os.listdir(self.mask_dataset_path))] + if train: + image_paths = image_paths[:int(0.8*len(image_paths))] + mask_paths = mask_paths[:int(0.8*len(mask_paths))] + else: + image_paths = image_paths[int(0.8*len(image_paths)):] + mask_paths = mask_paths[int(0.8*len(mask_paths)):] + for case in range(len(image_paths)): self.dataset.append((image_paths[case], mask_paths[case])) @@ -82,7 +89,7 @@ def __getitem__(self, idx): binary_mask[mask_np == 5] = 1 # prostate voxels # Convert to tensor - binary_mask = torch.from_numpy(binary_mask).unsqueeze(0).int() # add channel dim + binary_mask = torch.from_numpy(binary_mask).unsqueeze(0).float() # add channel dim image_np = torch.from_numpy(image_np).unsqueeze(0).float() # add channel dim binary_mask = Resize3dTensor(binary_mask, target_shape=(128,128,128), mode_type='nearest') diff --git a/recognition/ImprovedUNet3D/module.py b/recognition/ImprovedUNet3D/module.py index e3be9be2f..2102eea82 100644 --- a/recognition/ImprovedUNet3D/module.py +++ b/recognition/ImprovedUNet3D/module.py @@ -6,10 +6,10 @@ class Upsample3D(nn.Module): def __init__(self, in_channels, scale_factor=2): super().__init__() self.scale_factor = scale_factor - self.conv1 = nn.Conv3d(in_channels, in_channels/2, kernel_size=3, padding=1) + self.conv1 = nn.Conv3d(in_channels, in_channels // 2, kernel_size=3, padding=1) def forward(self, x): - x.repeat_interleave(self.scale_factor, dim=2).repeat_interleave(self.scale_factor, dim=3).repeat_interleave(self.scale_factor, dim=4) + x = x.repeat_interleave(self.scale_factor, dim=2).repeat_interleave(self.scale_factor, dim=3).repeat_interleave(self.scale_factor, dim=4) x = self.conv1(x) return x @@ -21,11 +21,17 @@ def __init__(self, in_channels, out_channels, dropout=0.3, stride=1): super().__init__() self.stride = stride self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride) - self.norm1 = nn.InstanceNorm3d(out_channels) + if (in_channels == 1): + self.norm1 = nn.InstanceNorm3d(out_channels) + else: + self.norm1 = nn.Identity() self.act1 = nn.LeakyReLU(0.01) self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1) - self.norm2 = nn.InstanceNorm3d(out_channels) + if (in_channels == 1): + self.norm2 = nn.InstanceNorm3d(out_channels) + else: + self.norm2 = nn.Identity() self.act2 = nn.LeakyReLU(0.01) self.dropout = nn.Dropout3d(dropout) @@ -50,27 +56,20 @@ def forward(self, x): class LocalizationModule(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() - self.conv3x3 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1) + self.conv3x3 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) self.act = nn.LeakyReLU(0.01) - self.conv1x1 = nn.Conv3d(out_channels, out_channels // 2, kernel_size=1) + self.conv1x1 = nn.Conv3d(in_channels, out_channels, kernel_size=1) def forward(self, x): x = self.act(self.conv3x3(x)) x = self.conv1x1(x) return x -# ----------------------------- -# Upsample by voxel repetition -# ----------------------------- -def upsample_repeat(x): - # Double each spatial dimension by repeating voxels - return x.repeat_interleave(2, dim=2).repeat_interleave(2, dim=3).repeat_interleave(2, dim=4) - # ----------------------------- # Full 3D U-Net with deep supervision # ----------------------------- class ImprovedUNet3D(nn.Module): - def __init__(self, in_channels=1, out_channels=3, base_filters=16, dropout=0.3): + def __init__(self, in_channels=1, base_filters=16, dropout=0.3): super().__init__() self.upsample1 = Upsample3D(base_filters*16) self.upsample2 = Upsample3D(base_filters*8) @@ -80,22 +79,22 @@ def __init__(self, in_channels=1, out_channels=3, base_filters=16, dropout=0.3): self.convInput = nn.Conv3d(in_channels, base_filters, kernel_size=3, padding=1) self.convOutput = nn.Conv3d(base_filters*2, base_filters*2, kernel_size=3, padding=1) - self.StrideConv1 = nn.Conv3d(base_filters*2, base_filters*2, kernel_size=3, padding=1, stride=2) - self.StrideConv2 = nn.Conv3d(base_filters*4, base_filters*4, kernel_size=3, padding=1, stride=2) - self.StrideConv3 = nn.Conv3d(base_filters*8, base_filters*8, kernel_size=3, padding=1, stride=2) - self.StrideConv4 = nn.Conv3d(base_filters*16, base_filters*16, kernel_size=3, padding=1, stride=2) + self.StrideConv1 = nn.Conv3d(base_filters, base_filters*2, kernel_size=3, padding=1, stride=2) + self.StrideConv2 = nn.Conv3d(base_filters*2, base_filters*4, kernel_size=3, padding=1, stride=2) + self.StrideConv3 = nn.Conv3d(base_filters*4, base_filters*8, kernel_size=3, padding=1, stride=2) + self.StrideConv4 = nn.Conv3d(base_filters*8, base_filters*16, kernel_size=3, padding=1, stride=2) # Encoder / context pathway self.enc1 = ResidualBlock3D(base_filters, base_filters, dropout) - self.enc2 = ResidualBlock3D(base_filters*2, base_filters*2, dropout, stride=2) - self.enc3 = ResidualBlock3D(base_filters*4, base_filters*4, dropout, stride=2) - self.enc4 = ResidualBlock3D(base_filters*8, base_filters*8, dropout, stride=2) - self.enc5 = ResidualBlock3D(base_filters*16, base_filters*16, dropout, stride=2) + self.enc2 = ResidualBlock3D(base_filters*2, base_filters*2, dropout) + self.enc3 = ResidualBlock3D(base_filters*4, base_filters*4, dropout) + self.enc4 = ResidualBlock3D(base_filters*8, base_filters*8, dropout) + self.enc5 = ResidualBlock3D(base_filters*16, base_filters*16, dropout) # Decoder / localization pathway - self.loc3 = LocalizationModule(base_filters*8 + base_filters*8, base_filters*4) - self.loc2 = LocalizationModule(base_filters*4 + base_filters*4, base_filters*2) - self.loc1 = LocalizationModule(base_filters*2 + base_filters*2, base_filters) + self.loc3 = LocalizationModule(base_filters*8 + base_filters*8, base_filters*8) + self.loc2 = LocalizationModule(base_filters*4 + base_filters*4, base_filters*4) + self.loc1 = LocalizationModule(base_filters*2 + base_filters*2, base_filters*2) # Segmentation layers self.seg1 = nn.Conv3d(base_filters*4, 1, kernel_size=1) @@ -120,18 +119,18 @@ def forward(self, x): e5 = self.enc5(e5) u1 = self.upsample1(e5) - u1 = torch.concat((u1, e4)) + u1 = torch.cat((u1, e4), dim=1) u1 = self.loc3(u1) u2 = self.upsample2(u1) - u2 = torch.concat(u2, e3) + u2 = torch.cat((u2, e3), dim=1) u2 = self.loc2(u2) - res1 = self.seg1(u1) - res1 = F.interpolate(res1, size=e1.shape[2:], mode='nearest') + res1 = self.seg1(u2) + res1 = F.interpolate(res1, size=e2.shape[2:], mode='nearest') u3 = self.upsample3(u2) - u3 = torch.concat(u3, e2) + u3 = torch.cat((u3, e2), dim=1) u3 = self.loc1(u3) res2 = self.seg2(u3) @@ -139,7 +138,7 @@ def forward(self, x): res2 = F.interpolate(res2, size=e1.shape[2:], mode='nearest') u4 = self.upsample4(u3) - u4 = torch.concat(u4, e1) + u4 = torch.cat((u4, e1), dim=1) out = self.convOutput(u4) out = self.seg3(out) diff --git a/recognition/ImprovedUNet3D/train.py b/recognition/ImprovedUNet3D/train.py index e69de29bb..bc4e87c39 100644 --- a/recognition/ImprovedUNet3D/train.py +++ b/recognition/ImprovedUNet3D/train.py @@ -0,0 +1,93 @@ +import dataset +import matplotlib as plt +import torch +import torch.optim as optim +import numpy as np +import torch.nn.functional as F +from module import ImprovedUNet3D, DiceLoss +import dataset +import random + +# Check if CUDA is available +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f'Using device: {device}') + +# Set random seeds for reproducibility +torch.manual_seed(42) +np.random.seed(42) +random.seed(42) +if torch.cuda.is_available(): + torch.cuda.manual_seed(42) + +# Quick visualization of loss +def plot_loss(losses, loss_type='dice'): + plt.figure(figsize=(8, 4)) + plt.plot(losses, 'bo-', linewidth=2, markersize=8) + + title_map = { + 'bce': '๐Ÿ”ฅ Training Loss (BCE)', + 'dice': '๐Ÿ”ฅ Training Loss (Dice)', + 'combined': '๐Ÿ”ฅ Training Loss (Combined BCE + Dice)' + } + plt.title(title_map.get(loss_type, '๐Ÿ”ฅ Training Loss'), fontsize=14, fontweight='bold') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.grid(True, alpha=0.3) + plt.show() + +def train(model, train_loader, test_dataset, epochs=100, lr=0.001, visualize_every=1): + model.to(device) + criterion = DiceLoss() + optimizer = optim.Adam(model.parameters(), lr=lr) + + losses = [] + + print(" Starting training with Instance Norm, LeakyReLU, and Softmax activation...") + for epoch in range(epochs): + model.train() + epoch_loss = 0 + + # Training loop with progress + for batch_idx, (images, masks) in enumerate(train_loader): + images, masks = images.to(device), masks.to(device) + + optimizer.zero_grad() + outputs = model(images) + + print(f"pred_pet shape: {outputs.shape}, masks shape: {masks.shape}") + loss = criterion(outputs, masks) + + # Backward pass + loss.backward() + optimizer.step() + + epoch_loss += loss.item() + + avg_loss = epoch_loss / len(train_loader) + losses.append(avg_loss) + print(f"๐Ÿ“ˆ Epoch {epoch+1}/{epochs} Complete: Avg Loss = {avg_loss:.4f}") + + # Visualize predictions after each epoch (or every few epochs) + #if (epoch) % visualize_every == 0: + # show_epoch_predictions(model, test_dataset, epoch + 1, n=3) + + print(" Training complete with enhanced U-Net!") + plot_loss(losses, loss_type='dice') + return losses + +if __name__ == "__main__": + # Example usage + train_dataset = dataset.HipMriDataset3D(image_path='semantic_MRs_anon', + mask_path='semantic_labels_anon', + transform=dataset.TrainingTransform, + train=True) + test_dataset = dataset.HipMriDataset3D(image_path='semantic_MRs_anon', + mask_path='semantic_labels_anon', + train=False) + + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) + + model = ImprovedUNet3D(in_channels=1, base_filters=16, dropout=0.3) + + train(model, train_loader, test_dataset, epochs=50, lr=0.001, visualize_every=5) From 4e86805bee4e4d189f82a95dd2f8b6f7a9fd9f68 Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Fri, 7 Nov 2025 11:00:15 +1000 Subject: [PATCH 07/18] Seeing wrong dice coefficient, attempting fix --- recognition/ImprovedUNet3D/train.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/recognition/ImprovedUNet3D/train.py b/recognition/ImprovedUNet3D/train.py index bc4e87c39..3626e9099 100644 --- a/recognition/ImprovedUNet3D/train.py +++ b/recognition/ImprovedUNet3D/train.py @@ -1,5 +1,5 @@ import dataset -import matplotlib as plt +import matplotlib.pyplot as plt import torch import torch.optim as optim import numpy as np @@ -35,7 +35,7 @@ def plot_loss(losses, loss_type='dice'): plt.grid(True, alpha=0.3) plt.show() -def train(model, train_loader, test_dataset, epochs=100, lr=0.001, visualize_every=1): +def train(model, train_loader, test_dataset, epochs=100, lr=0.001): model.to(device) criterion = DiceLoss() optimizer = optim.Adam(model.parameters(), lr=lr) @@ -71,6 +71,24 @@ def train(model, train_loader, test_dataset, epochs=100, lr=0.001, visualize_eve #if (epoch) % visualize_every == 0: # show_epoch_predictions(model, test_dataset, epoch + 1, n=3) + model.eval() + with torch.no_grad(): + total_dice = 0 + for i in range(len(test_dataset)): + image, mask = test_dataset[i] + image = image.unsqueeze(0).to(device) + mask = mask.to(device) + + output = model(image) + pred_mask = (output > 0.5).float() + + intersection = (pred_mask * mask).sum() + dice_score = (2.0 * intersection) / (pred_mask.sum() + mask.sum() + 1e-6) + total_dice += dice_score.item() + + avg_dice = total_dice / len(test_dataset) + print(f"๐Ÿงช Validation Dice Score after Epoch {epoch+1}: {avg_dice:.4f}") + print(" Training complete with enhanced U-Net!") plot_loss(losses, loss_type='dice') return losses From 8b4c3faf22829c841424125bbd15a197906b4852 Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Fri, 7 Nov 2025 12:38:59 +1000 Subject: [PATCH 08/18] Realised that I need to segment all 6 different types. --- recognition/ImprovedUNet3D/dataset.py | 26 ++++++++++++++++++-------- recognition/ImprovedUNet3D/module.py | 22 +++++++++------------- recognition/ImprovedUNet3D/train.py | 16 ++++++---------- 3 files changed, 33 insertions(+), 31 deletions(-) diff --git a/recognition/ImprovedUNet3D/dataset.py b/recognition/ImprovedUNet3D/dataset.py index 08846a839..e7ea77bac 100644 --- a/recognition/ImprovedUNet3D/dataset.py +++ b/recognition/ImprovedUNet3D/dataset.py @@ -6,6 +6,14 @@ import os import nibabel as nib +def to_channels(label_volume, dtype=np.float32): + """Convert label map (3D) to one-hot channels (4D).""" + num_classes = int(label_volume.max()) + 1 + out = np.zeros((num_classes,) + label_volume.shape, dtype=dtype) + for c in range(num_classes): + out[c] = (label_volume == c) + return out + def zScoreNormalize(image): mean = image.mean() std = image.std() @@ -34,6 +42,10 @@ def TrainingTransform(image, mask): image, mask = RandomFlip(image, mask) image, mask = RandomRotate_90(image, mask) image = zScoreNormalize(image) + + image[mask == 0] = 0 + image = np.clip(image, -5, 5) + image = (image + 5) / 10.0 return image, mask def Resize3dTensor(img_tensor, target_shape=(128,128,128), mode_type='trilinear'): @@ -79,20 +91,18 @@ def __getitem__(self, idx): image_np = image.get_fdata().astype(np.float32) mask_np = mask.get_fdata().astype(np.uint8) # convert - + # Apply transforms to image if self.transform: image_np, mask_np = self.transform(image_np, mask_np) - binary_mask = np.zeros_like(mask_np, dtype=np.uint8) - binary_mask[mask_np != 5] = 0 # prostate_voxels - binary_mask[mask_np == 5] = 1 # prostate voxels + mask_np = to_channels(mask_np, dtype=np.uint8) # convert to one-hot channels # Convert to tensor - binary_mask = torch.from_numpy(binary_mask).unsqueeze(0).float() # add channel dim - image_np = torch.from_numpy(image_np).unsqueeze(0).float() # add channel dim + image_np = torch.from_numpy(image_np).unsqueeze(0).float() + mask_np = torch.from_numpy(mask_np).float() - binary_mask = Resize3dTensor(binary_mask, target_shape=(128,128,128), mode_type='nearest') + mask_np = Resize3dTensor(mask_np, target_shape=(128,128,128), mode_type='nearest') image_np = Resize3dTensor(image_np, target_shape=(128,128,128)) - return image_np, binary_mask \ No newline at end of file + return image_np, mask_np \ No newline at end of file diff --git a/recognition/ImprovedUNet3D/module.py b/recognition/ImprovedUNet3D/module.py index 2102eea82..af50c1e96 100644 --- a/recognition/ImprovedUNet3D/module.py +++ b/recognition/ImprovedUNet3D/module.py @@ -21,17 +21,11 @@ def __init__(self, in_channels, out_channels, dropout=0.3, stride=1): super().__init__() self.stride = stride self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride) - if (in_channels == 1): - self.norm1 = nn.InstanceNorm3d(out_channels) - else: - self.norm1 = nn.Identity() + self.norm1 = nn.InstanceNorm3d(out_channels) self.act1 = nn.LeakyReLU(0.01) self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1) - if (in_channels == 1): - self.norm2 = nn.InstanceNorm3d(out_channels) - else: - self.norm2 = nn.Identity() + self.norm2 = nn.InstanceNorm3d(out_channels) self.act2 = nn.LeakyReLU(0.01) self.dropout = nn.Dropout3d(dropout) @@ -44,10 +38,12 @@ def __init__(self, in_channels, out_channels, dropout=0.3, stride=1): def forward(self, x): residual = self.skip(x) x = self.conv1(x) + x = self.norm1(x) + x = self.act1(x) x = self.dropout(x) + x = self.conv2(x) x = self.norm2(x) x = self.act2(x) - x = self.conv2(x) return x + residual # ----------------------------- @@ -69,7 +65,7 @@ def forward(self, x): # Full 3D U-Net with deep supervision # ----------------------------- class ImprovedUNet3D(nn.Module): - def __init__(self, in_channels=1, base_filters=16, dropout=0.3): + def __init__(self, in_channels=6, base_filters=16, dropout=0.3): super().__init__() self.upsample1 = Upsample3D(base_filters*16) self.upsample2 = Upsample3D(base_filters*8) @@ -97,9 +93,9 @@ def __init__(self, in_channels=1, base_filters=16, dropout=0.3): self.loc1 = LocalizationModule(base_filters*2 + base_filters*2, base_filters*2) # Segmentation layers - self.seg1 = nn.Conv3d(base_filters*4, 1, kernel_size=1) - self.seg2= nn.Conv3d(base_filters*2, 1, kernel_size=1) - self.seg3 = nn.Conv3d(base_filters*2, 1, kernel_size=1) + self.seg1 = nn.Conv3d(base_filters*4, 6, kernel_size=1) + self.seg2= nn.Conv3d(base_filters*2, 6, kernel_size=1) + self.seg3 = nn.Conv3d(base_filters*2, 6, kernel_size=1) def forward(self, x): x = self.convInput(x) diff --git a/recognition/ImprovedUNet3D/train.py b/recognition/ImprovedUNet3D/train.py index 3626e9099..91958a9e8 100644 --- a/recognition/ImprovedUNet3D/train.py +++ b/recognition/ImprovedUNet3D/train.py @@ -48,13 +48,14 @@ def train(model, train_loader, test_dataset, epochs=100, lr=0.001): epoch_loss = 0 # Training loop with progress - for batch_idx, (images, masks) in enumerate(train_loader): + for images, masks in train_loader: images, masks = images.to(device), masks.to(device) optimizer.zero_grad() outputs = model(images) print(f"pred_pet shape: {outputs.shape}, masks shape: {masks.shape}") + loss = criterion(outputs, masks) # Backward pass @@ -74,17 +75,12 @@ def train(model, train_loader, test_dataset, epochs=100, lr=0.001): model.eval() with torch.no_grad(): total_dice = 0 - for i in range(len(test_dataset)): - image, mask = test_dataset[i] - image = image.unsqueeze(0).to(device) + for image, mask in test_dataset: + image = image.to(device) mask = mask.to(device) output = model(image) - pred_mask = (output > 0.5).float() - - intersection = (pred_mask * mask).sum() - dice_score = (2.0 * intersection) / (pred_mask.sum() + mask.sum() + 1e-6) - total_dice += dice_score.item() + total_dice += 1 - criterion(output, mask).item() avg_dice = total_dice / len(test_dataset) print(f"๐Ÿงช Validation Dice Score after Epoch {epoch+1}: {avg_dice:.4f}") @@ -108,4 +104,4 @@ def train(model, train_loader, test_dataset, epochs=100, lr=0.001): model = ImprovedUNet3D(in_channels=1, base_filters=16, dropout=0.3) - train(model, train_loader, test_dataset, epochs=50, lr=0.001, visualize_every=5) + train(model, train_loader, test_dataset, epochs=50, lr=0.001) From 794e8537c9a5a85f5b2d82d92fe60c729fe7980c Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Fri, 7 Nov 2025 12:59:48 +1000 Subject: [PATCH 09/18] Minor fix to use dataloader and not dataset --- recognition/ImprovedUNet3D/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/recognition/ImprovedUNet3D/train.py b/recognition/ImprovedUNet3D/train.py index 91958a9e8..9a118fae1 100644 --- a/recognition/ImprovedUNet3D/train.py +++ b/recognition/ImprovedUNet3D/train.py @@ -75,7 +75,7 @@ def train(model, train_loader, test_dataset, epochs=100, lr=0.001): model.eval() with torch.no_grad(): total_dice = 0 - for image, mask in test_dataset: + for image, mask in test_loader: image = image.to(device) mask = mask.to(device) @@ -95,6 +95,7 @@ def train(model, train_loader, test_dataset, epochs=100, lr=0.001): mask_path='semantic_labels_anon', transform=dataset.TrainingTransform, train=True) + test_dataset = dataset.HipMriDataset3D(image_path='semantic_MRs_anon', mask_path='semantic_labels_anon', train=False) From 39efbf5b530878bba9a7c1c6b4f52e2074ef829e Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Fri, 7 Nov 2025 13:41:24 +1000 Subject: [PATCH 10/18] Added scheduler and fixed the DiceLoss coefficient metric. --- recognition/ImprovedUNet3D/module.py | 39 +++++++++++++++++----------- recognition/ImprovedUNet3D/train.py | 16 ++++++++++++ 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/recognition/ImprovedUNet3D/module.py b/recognition/ImprovedUNet3D/module.py index af50c1e96..cf4f72295 100644 --- a/recognition/ImprovedUNet3D/module.py +++ b/recognition/ImprovedUNet3D/module.py @@ -144,13 +144,14 @@ def forward(self, x): return out class DiceLoss(nn.Module): - """Dice Loss for binary segmentation. + """ + Multi-class Dice Loss supporting one-hot encoded targets. - Dice Loss = 1 - Dice Coefficient - Dice Coefficient = (2 * |X โˆฉ Y|) / (|X| + |Y|) + Dice Loss = 1 - (2 * |X โˆฉ Y| + smooth) / (|X| + |Y| + smooth) - Args: - smooth (float): Smoothing factor to avoid division by zero (default: 1e-6) + Works for both 2D and 3D tensors: + predictions: [B, C, H, W] or [B, C, D, H, W] + targets: same shape (one-hot encoded) """ def __init__(self, smooth=1e-6): super(DiceLoss, self).__init__() @@ -159,16 +160,24 @@ def __init__(self, smooth=1e-6): def forward(self, predictions, targets): """ Args: - predictions: Sigmoid output from model [B, H, W] (values between 0-1) - targets: Binary ground truth [B, H, W] (values 0 or 1) + predictions (torch.Tensor): Model outputs after sigmoid or softmax [B, C, ...] + targets (torch.Tensor): One-hot encoded ground truth [B, C, ...] """ - # Flatten tensors using reshape to handle non-contiguous memory layout - predictions = predictions.reshape(-1) - targets = targets.reshape(-1).float() + # Ensure floating point + predictions = predictions.float() + targets = targets.float() + + # Flatten across spatial dimensions but keep class and batch + dims = tuple(range(2, predictions.ndim)) # e.g. (2,3,4) for 3D data + + # Compute intersection and union per class + intersection = torch.sum(predictions * targets, dims) + pred_sum = torch.sum(predictions, dims) + target_sum = torch.sum(targets, dims) + + dice_per_class = (2.0 * intersection + self.smooth) / (pred_sum + target_sum + self.smooth) - # Calculate intersection and union - intersection = (predictions * targets).sum() - dice_coeff = (2.0 * intersection + self.smooth) / (predictions.sum() + targets.sum() + self.smooth) + # Average over classes and batch + dice_loss = 1.0 - dice_per_class.mean() - # Return Dice Loss (1 - Dice Coefficient) - return 1 - dice_coeff \ No newline at end of file + return dice_loss diff --git a/recognition/ImprovedUNet3D/train.py b/recognition/ImprovedUNet3D/train.py index 9a118fae1..eb1a58c68 100644 --- a/recognition/ImprovedUNet3D/train.py +++ b/recognition/ImprovedUNet3D/train.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt import torch import torch.optim as optim +from torch.optim.lr_scheduler import ReduceLROnPlateau import numpy as np import torch.nn.functional as F from module import ImprovedUNet3D, DiceLoss @@ -39,6 +40,14 @@ def train(model, train_loader, test_dataset, epochs=100, lr=0.001): model.to(device) criterion = DiceLoss() optimizer = optim.Adam(model.parameters(), lr=lr) + scheduler = ReduceLROnPlateau( + optimizer, + mode='min', + factor=0.5, + patience=5, + min_lr=1e-6, + verbose=True + ) losses = [] @@ -61,6 +70,7 @@ def train(model, train_loader, test_dataset, epochs=100, lr=0.001): # Backward pass loss.backward() optimizer.step() + scheduler.step(loss) epoch_loss += loss.item() @@ -73,6 +83,7 @@ def train(model, train_loader, test_dataset, epochs=100, lr=0.001): # show_epoch_predictions(model, test_dataset, epoch + 1, n=3) model.eval() + atThreshold = True with torch.no_grad(): total_dice = 0 for image, mask in test_loader: @@ -81,10 +92,15 @@ def train(model, train_loader, test_dataset, epochs=100, lr=0.001): output = model(image) total_dice += 1 - criterion(output, mask).item() + if (atThreshold and 1 - criterion(output, mask).item() < 0.7): + atThreshold = False avg_dice = total_dice / len(test_dataset) print(f"๐Ÿงช Validation Dice Score after Epoch {epoch+1}: {avg_dice:.4f}") + if (atThreshold): + break + print(" Training complete with enhanced U-Net!") plot_loss(losses, loss_type='dice') return losses From f6315c0d07c6412561e6dc19bda37e30db1c424a Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Fri, 7 Nov 2025 13:43:46 +1000 Subject: [PATCH 11/18] Removed random seed --- recognition/ImprovedUNet3D/train.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/recognition/ImprovedUNet3D/train.py b/recognition/ImprovedUNet3D/train.py index eb1a58c68..550bf965c 100644 --- a/recognition/ImprovedUNet3D/train.py +++ b/recognition/ImprovedUNet3D/train.py @@ -13,13 +13,6 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f'Using device: {device}') -# Set random seeds for reproducibility -torch.manual_seed(42) -np.random.seed(42) -random.seed(42) -if torch.cuda.is_available(): - torch.cuda.manual_seed(42) - # Quick visualization of loss def plot_loss(losses, loss_type='dice'): plt.figure(figsize=(8, 4)) @@ -78,10 +71,6 @@ def train(model, train_loader, test_dataset, epochs=100, lr=0.001): losses.append(avg_loss) print(f"๐Ÿ“ˆ Epoch {epoch+1}/{epochs} Complete: Avg Loss = {avg_loss:.4f}") - # Visualize predictions after each epoch (or every few epochs) - #if (epoch) % visualize_every == 0: - # show_epoch_predictions(model, test_dataset, epoch + 1, n=3) - model.eval() atThreshold = True with torch.no_grad(): From 645e8df200a4f46b6666d5486b2df96f57f4e603 Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Fri, 7 Nov 2025 13:45:05 +1000 Subject: [PATCH 12/18] Fixed deprecated arg for scheduler --- recognition/ImprovedUNet3D/train.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/recognition/ImprovedUNet3D/train.py b/recognition/ImprovedUNet3D/train.py index 550bf965c..69567408d 100644 --- a/recognition/ImprovedUNet3D/train.py +++ b/recognition/ImprovedUNet3D/train.py @@ -38,8 +38,7 @@ def train(model, train_loader, test_dataset, epochs=100, lr=0.001): mode='min', factor=0.5, patience=5, - min_lr=1e-6, - verbose=True + min_lr=1e-6 ) losses = [] From 2153b5892f6e0feb69b7bf105bb28799cf388892 Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Fri, 7 Nov 2025 13:49:37 +1000 Subject: [PATCH 13/18] Fixed value passed to scheduler. --- recognition/ImprovedUNet3D/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recognition/ImprovedUNet3D/train.py b/recognition/ImprovedUNet3D/train.py index 69567408d..a710c62ba 100644 --- a/recognition/ImprovedUNet3D/train.py +++ b/recognition/ImprovedUNet3D/train.py @@ -62,7 +62,7 @@ def train(model, train_loader, test_dataset, epochs=100, lr=0.001): # Backward pass loss.backward() optimizer.step() - scheduler.step(loss) + scheduler.step(loss.detach().item()) epoch_loss += loss.item() From 9aac7acbd85a2eafe3858eaec21cfe030283735d Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Fri, 7 Nov 2025 14:29:45 +1000 Subject: [PATCH 14/18] Optimised model to run in less time --- recognition/ImprovedUNet3D/module.py | 13 +++++++++---- recognition/ImprovedUNet3D/train.py | 28 +++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/recognition/ImprovedUNet3D/module.py b/recognition/ImprovedUNet3D/module.py index cf4f72295..a617c85ef 100644 --- a/recognition/ImprovedUNet3D/module.py +++ b/recognition/ImprovedUNet3D/module.py @@ -5,11 +5,11 @@ class Upsample3D(nn.Module): def __init__(self, in_channels, scale_factor=2): super().__init__() - self.scale_factor = scale_factor + self.up = nn.Upsample(scale_factor=scale_factor, mode='trilinear', align_corners=True) self.conv1 = nn.Conv3d(in_channels, in_channels // 2, kernel_size=3, padding=1) def forward(self, x): - x = x.repeat_interleave(self.scale_factor, dim=2).repeat_interleave(self.scale_factor, dim=3).repeat_interleave(self.scale_factor, dim=4) + x = self.up(x) x = self.conv1(x) return x @@ -43,8 +43,8 @@ def forward(self, x): x = self.dropout(x) x = self.conv2(x) x = self.norm2(x) - x = self.act2(x) - return x + residual + x = x + residual + return self.act2(x) # ----------------------------- # Localization Module @@ -80,6 +80,11 @@ def __init__(self, in_channels=6, base_filters=16, dropout=0.3): self.StrideConv3 = nn.Conv3d(base_filters*4, base_filters*8, kernel_size=3, padding=1, stride=2) self.StrideConv4 = nn.Conv3d(base_filters*8, base_filters*16, kernel_size=3, padding=1, stride=2) + #self.maxPool1 = nn.MaxPool3d(2) + #self.maxPool2 = nn.MaxPool3d(2) + #self.maxPool3 = nn.MaxPool3d(2) + #self.maxPool4 = nn.MaxPool3d(2) + # Encoder / context pathway self.enc1 = ResidualBlock3D(base_filters, base_filters, dropout) self.enc2 = ResidualBlock3D(base_filters*2, base_filters*2, dropout) diff --git a/recognition/ImprovedUNet3D/train.py b/recognition/ImprovedUNet3D/train.py index a710c62ba..97f75ff3c 100644 --- a/recognition/ImprovedUNet3D/train.py +++ b/recognition/ImprovedUNet3D/train.py @@ -29,7 +29,32 @@ def plot_loss(losses, loss_type='dice'): plt.grid(True, alpha=0.3) plt.show() -def train(model, train_loader, test_dataset, epochs=100, lr=0.001): +import numpy as np +import matplotlib.pyplot as plt + +def visualize_cnn_slices(cnn_output, slice_indices=None): + """ + Visualize CNN segmentation slice-by-slice using matplotlib. + + cnn_output: C x H x W x D (numpy array) + slice_indices: list of slice indices along depth (D) + """ + # Convert to label indices + label_volume = np.argmax(cnn_output, axis=0) + + D = label_volume.shape[2] + if slice_indices is None: + slice_indices = [D // 4, D // 2, 3 * D // 4] + + fig, axes = plt.subplots(1, len(slice_indices), figsize=(15, 5)) + + for i, idx in enumerate(slice_indices): + axes[i].imshow(label_volume[:, :, idx], cmap='tab20') + axes[i].set_title(f'Slice {idx}') + axes[i].axis('off') + plt.show() + +def train(model, train_loader, test_dataset, epochs=1, lr=0.001): model.to(device) criterion = DiceLoss() optimizer = optim.Adam(model.parameters(), lr=lr) @@ -87,6 +112,7 @@ def train(model, train_loader, test_dataset, epochs=100, lr=0.001): print(f"๐Ÿงช Validation Dice Score after Epoch {epoch+1}: {avg_dice:.4f}") if (atThreshold): + visualize_cnn_slices(cnn_output=output) break print(" Training complete with enhanced U-Net!") From e2b119c1f4e2663ee6286376c8830c44003566f0 Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Fri, 7 Nov 2025 14:57:18 +1000 Subject: [PATCH 15/18] Changed softmax to be in loss --- recognition/ImprovedUNet3D/dataset.py | 3 --- recognition/ImprovedUNet3D/module.py | 3 +-- recognition/ImprovedUNet3D/train.py | 3 ++- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/recognition/ImprovedUNet3D/dataset.py b/recognition/ImprovedUNet3D/dataset.py index e7ea77bac..24b9c2421 100644 --- a/recognition/ImprovedUNet3D/dataset.py +++ b/recognition/ImprovedUNet3D/dataset.py @@ -43,9 +43,6 @@ def TrainingTransform(image, mask): image, mask = RandomRotate_90(image, mask) image = zScoreNormalize(image) - image[mask == 0] = 0 - image = np.clip(image, -5, 5) - image = (image + 5) / 10.0 return image, mask def Resize3dTensor(img_tensor, target_shape=(128,128,128), mode_type='trilinear'): diff --git a/recognition/ImprovedUNet3D/module.py b/recognition/ImprovedUNet3D/module.py index a617c85ef..7226a753a 100644 --- a/recognition/ImprovedUNet3D/module.py +++ b/recognition/ImprovedUNet3D/module.py @@ -144,7 +144,6 @@ def forward(self, x): out = self.convOutput(u4) out = self.seg3(out) out = res2 + out - out = F.softmax(out, dim=1) return out @@ -169,7 +168,7 @@ def forward(self, predictions, targets): targets (torch.Tensor): One-hot encoded ground truth [B, C, ...] """ # Ensure floating point - predictions = predictions.float() + predictions = F.softmax(predictions.float()) targets = targets.float() # Flatten across spatial dimensions but keep class and batch diff --git a/recognition/ImprovedUNet3D/train.py b/recognition/ImprovedUNet3D/train.py index 97f75ff3c..644ef40a2 100644 --- a/recognition/ImprovedUNet3D/train.py +++ b/recognition/ImprovedUNet3D/train.py @@ -87,11 +87,12 @@ def train(model, train_loader, test_dataset, epochs=1, lr=0.001): # Backward pass loss.backward() optimizer.step() - scheduler.step(loss.detach().item()) + epoch_loss += loss.item() avg_loss = epoch_loss / len(train_loader) + scheduler.step(avg_loss) losses.append(avg_loss) print(f"๐Ÿ“ˆ Epoch {epoch+1}/{epochs} Complete: Avg Loss = {avg_loss:.4f}") From 90592880b8aeb92dc28515fe485a965818e07213 Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Fri, 7 Nov 2025 14:59:06 +1000 Subject: [PATCH 16/18] Changed the softmax function --- recognition/ImprovedUNet3D/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recognition/ImprovedUNet3D/module.py b/recognition/ImprovedUNet3D/module.py index 7226a753a..118fcba5c 100644 --- a/recognition/ImprovedUNet3D/module.py +++ b/recognition/ImprovedUNet3D/module.py @@ -168,7 +168,7 @@ def forward(self, predictions, targets): targets (torch.Tensor): One-hot encoded ground truth [B, C, ...] """ # Ensure floating point - predictions = F.softmax(predictions.float()) + predictions = F.softmax(predictions.float(), dim=1) targets = targets.float() # Flatten across spatial dimensions but keep class and batch From b088ae5c13f514ee3ffc3e853637fccd46bd5f00 Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Fri, 7 Nov 2025 15:00:47 +1000 Subject: [PATCH 17/18] Added a test transform --- recognition/ImprovedUNet3D/dataset.py | 5 +++++ recognition/ImprovedUNet3D/train.py | 1 + 2 files changed, 6 insertions(+) diff --git a/recognition/ImprovedUNet3D/dataset.py b/recognition/ImprovedUNet3D/dataset.py index 24b9c2421..08d61fe61 100644 --- a/recognition/ImprovedUNet3D/dataset.py +++ b/recognition/ImprovedUNet3D/dataset.py @@ -45,6 +45,11 @@ def TrainingTransform(image, mask): return image, mask +def TestTransform(image, mask): + image = zScoreNormalize(image) + + return image, mask + def Resize3dTensor(img_tensor, target_shape=(128,128,128), mode_type='trilinear'): """ img_tensor: torch tensor of shape (C, D, H, W) diff --git a/recognition/ImprovedUNet3D/train.py b/recognition/ImprovedUNet3D/train.py index 644ef40a2..5a87678e0 100644 --- a/recognition/ImprovedUNet3D/train.py +++ b/recognition/ImprovedUNet3D/train.py @@ -129,6 +129,7 @@ def train(model, train_loader, test_dataset, epochs=1, lr=0.001): test_dataset = dataset.HipMriDataset3D(image_path='semantic_MRs_anon', mask_path='semantic_labels_anon', + transform=dataset.TestTransform, train=False) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True) From 5838e14d4d68a21d3b8efcbb27a3827199bc2107 Mon Sep 17 00:00:00 2001 From: Anhadh Virk <43526788+anhadh676842@users.noreply.github.com> Date: Fri, 7 Nov 2025 15:44:40 +1000 Subject: [PATCH 18/18] Switched to doing 2D Unet last minute on preprocessed slices --- recognition/ImprovedUNet3D/README.md | 10 - recognition/ImprovedUNet3D/dataset.py | 110 ----------- recognition/ImprovedUNet3D/module.py | 187 ------------------ recognition/ImprovedUNet3D/predict.py | 0 recognition/ImprovedUNet3D/train.py | 140 ------------- recognition/s48027522-HipMRI-2DUnet/README.md | 21 ++ .../s48027522-HipMRI-2DUnet/dataset.py | 90 +++++++++ recognition/s48027522-HipMRI-2DUnet/module.py | 118 +++++++++++ .../s48027522-HipMRI-2DUnet/predict.py | 49 +++++ recognition/s48027522-HipMRI-2DUnet/train.py | 140 +++++++++++++ 10 files changed, 418 insertions(+), 447 deletions(-) delete mode 100644 recognition/ImprovedUNet3D/README.md delete mode 100644 recognition/ImprovedUNet3D/dataset.py delete mode 100644 recognition/ImprovedUNet3D/module.py delete mode 100644 recognition/ImprovedUNet3D/predict.py delete mode 100644 recognition/ImprovedUNet3D/train.py create mode 100644 recognition/s48027522-HipMRI-2DUnet/README.md create mode 100644 recognition/s48027522-HipMRI-2DUnet/dataset.py create mode 100644 recognition/s48027522-HipMRI-2DUnet/module.py create mode 100644 recognition/s48027522-HipMRI-2DUnet/predict.py create mode 100644 recognition/s48027522-HipMRI-2DUnet/train.py diff --git a/recognition/ImprovedUNet3D/README.md b/recognition/ImprovedUNet3D/README.md deleted file mode 100644 index 32c99e899..000000000 --- a/recognition/ImprovedUNet3D/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Recognition Tasks -Various recognition tasks solved in deep learning frameworks. - -Tasks may include: -* Image Segmentation -* Object detection -* Graph node classification -* Image super resolution -* Disease classification -* Generative modelling with StyleGAN and Stable Diffusion \ No newline at end of file diff --git a/recognition/ImprovedUNet3D/dataset.py b/recognition/ImprovedUNet3D/dataset.py deleted file mode 100644 index 08d61fe61..000000000 --- a/recognition/ImprovedUNet3D/dataset.py +++ /dev/null @@ -1,110 +0,0 @@ -import torch -from torch.utils.data import Dataset -import torch.nn.functional as F - -import numpy as np -import os -import nibabel as nib - -def to_channels(label_volume, dtype=np.float32): - """Convert label map (3D) to one-hot channels (4D).""" - num_classes = int(label_volume.max()) + 1 - out = np.zeros((num_classes,) + label_volume.shape, dtype=dtype) - for c in range(num_classes): - out[c] = (label_volume == c) - return out - -def zScoreNormalize(image): - mean = image.mean() - std = image.std() - if std > 0: - image = (image - mean) / std - else: - image = image - mean - return image - -def RandomFlip(image, mask): - axes = [0, 1, 2] # D, H, W axes - for axis in axes: - if np.random.rand() > 0.5: - image = np.flip(image, axis=axis) - mask = np.flip(mask, axis=axis) - return image, mask - -def RandomRotate_90(image, mask): - k = np.random.randint(0, 4) # 0, 90, 180, 270 degrees - axes = (1, 2) # rotate in-plane (H, W) - image = np.rot90(image, k, axes) - mask = np.rot90(mask, k, axes) - return image, mask - -def TrainingTransform(image, mask): - image, mask = RandomFlip(image, mask) - image, mask = RandomRotate_90(image, mask) - image = zScoreNormalize(image) - - return image, mask - -def TestTransform(image, mask): - image = zScoreNormalize(image) - - return image, mask - -def Resize3dTensor(img_tensor, target_shape=(128,128,128), mode_type='trilinear'): - """ - img_tensor: torch tensor of shape (C, D, H, W) - """ - img_tensor = img_tensor.unsqueeze(0) # add batch dim - img_resized = F.interpolate(img_tensor, size=target_shape, mode=mode_type) - return img_resized.squeeze(0) - -class HipMriDataset3D(Dataset): - """Dataset for prostate cancer 3D Images.""" - - def __init__(self, image_path, mask_path, transform=None, train=True): - # Download and load the dataset - self.image_dataset_path = image_path - self.mask_dataset_path = mask_path - self.transform = transform - self.dataset = [] - - image_paths = [os.path.join(self.image_dataset_path, img_name) - for img_name in sorted(os.listdir(self.image_dataset_path))] - mask_paths = [os.path.join(self.mask_dataset_path, mask_name) - for mask_name in sorted(os.listdir(self.mask_dataset_path))] - - if train: - image_paths = image_paths[:int(0.8*len(image_paths))] - mask_paths = mask_paths[:int(0.8*len(mask_paths))] - else: - image_paths = image_paths[int(0.8*len(image_paths)):] - mask_paths = mask_paths[int(0.8*len(mask_paths)):] - - for case in range(len(image_paths)): - self.dataset.append((image_paths[case], mask_paths[case])) - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, idx): - # Get image and mask - image = nib.load(self.dataset[idx][0]) - mask = nib.load(self.dataset[idx][1]) - - image_np = image.get_fdata().astype(np.float32) - mask_np = mask.get_fdata().astype(np.uint8) # convert - - # Apply transforms to image - if self.transform: - image_np, mask_np = self.transform(image_np, mask_np) - - mask_np = to_channels(mask_np, dtype=np.uint8) # convert to one-hot channels - - # Convert to tensor - image_np = torch.from_numpy(image_np).unsqueeze(0).float() - mask_np = torch.from_numpy(mask_np).float() - - mask_np = Resize3dTensor(mask_np, target_shape=(128,128,128), mode_type='nearest') - image_np = Resize3dTensor(image_np, target_shape=(128,128,128)) - - return image_np, mask_np \ No newline at end of file diff --git a/recognition/ImprovedUNet3D/module.py b/recognition/ImprovedUNet3D/module.py deleted file mode 100644 index 118fcba5c..000000000 --- a/recognition/ImprovedUNet3D/module.py +++ /dev/null @@ -1,187 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -class Upsample3D(nn.Module): - def __init__(self, in_channels, scale_factor=2): - super().__init__() - self.up = nn.Upsample(scale_factor=scale_factor, mode='trilinear', align_corners=True) - self.conv1 = nn.Conv3d(in_channels, in_channels // 2, kernel_size=3, padding=1) - - def forward(self, x): - x = self.up(x) - x = self.conv1(x) - return x - -# ----------------------------- -# Pre-activation Residual Block (Context Module) -# ----------------------------- -class ResidualBlock3D(nn.Module): - def __init__(self, in_channels, out_channels, dropout=0.3, stride=1): - super().__init__() - self.stride = stride - self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride) - self.norm1 = nn.InstanceNorm3d(out_channels) - self.act1 = nn.LeakyReLU(0.01) - - self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1) - self.norm2 = nn.InstanceNorm3d(out_channels) - self.act2 = nn.LeakyReLU(0.01) - - self.dropout = nn.Dropout3d(dropout) - - if in_channels != out_channels or stride > 1: - self.skip = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride) - else: - self.skip = nn.Identity() - - def forward(self, x): - residual = self.skip(x) - x = self.conv1(x) - x = self.norm1(x) - x = self.act1(x) - x = self.dropout(x) - x = self.conv2(x) - x = self.norm2(x) - x = x + residual - return self.act2(x) - -# ----------------------------- -# Localization Module -# -----------------------------x -class LocalizationModule(nn.Module): - def __init__(self, in_channels, out_channels): - super().__init__() - self.conv3x3 = nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1) - self.act = nn.LeakyReLU(0.01) - self.conv1x1 = nn.Conv3d(in_channels, out_channels, kernel_size=1) - - def forward(self, x): - x = self.act(self.conv3x3(x)) - x = self.conv1x1(x) - return x - -# ----------------------------- -# Full 3D U-Net with deep supervision -# ----------------------------- -class ImprovedUNet3D(nn.Module): - def __init__(self, in_channels=6, base_filters=16, dropout=0.3): - super().__init__() - self.upsample1 = Upsample3D(base_filters*16) - self.upsample2 = Upsample3D(base_filters*8) - self.upsample3 = Upsample3D(base_filters*4) - self.upsample4 = Upsample3D(base_filters*2) - - self.convInput = nn.Conv3d(in_channels, base_filters, kernel_size=3, padding=1) - self.convOutput = nn.Conv3d(base_filters*2, base_filters*2, kernel_size=3, padding=1) - - self.StrideConv1 = nn.Conv3d(base_filters, base_filters*2, kernel_size=3, padding=1, stride=2) - self.StrideConv2 = nn.Conv3d(base_filters*2, base_filters*4, kernel_size=3, padding=1, stride=2) - self.StrideConv3 = nn.Conv3d(base_filters*4, base_filters*8, kernel_size=3, padding=1, stride=2) - self.StrideConv4 = nn.Conv3d(base_filters*8, base_filters*16, kernel_size=3, padding=1, stride=2) - - #self.maxPool1 = nn.MaxPool3d(2) - #self.maxPool2 = nn.MaxPool3d(2) - #self.maxPool3 = nn.MaxPool3d(2) - #self.maxPool4 = nn.MaxPool3d(2) - - # Encoder / context pathway - self.enc1 = ResidualBlock3D(base_filters, base_filters, dropout) - self.enc2 = ResidualBlock3D(base_filters*2, base_filters*2, dropout) - self.enc3 = ResidualBlock3D(base_filters*4, base_filters*4, dropout) - self.enc4 = ResidualBlock3D(base_filters*8, base_filters*8, dropout) - self.enc5 = ResidualBlock3D(base_filters*16, base_filters*16, dropout) - - # Decoder / localization pathway - self.loc3 = LocalizationModule(base_filters*8 + base_filters*8, base_filters*8) - self.loc2 = LocalizationModule(base_filters*4 + base_filters*4, base_filters*4) - self.loc1 = LocalizationModule(base_filters*2 + base_filters*2, base_filters*2) - - # Segmentation layers - self.seg1 = nn.Conv3d(base_filters*4, 6, kernel_size=1) - self.seg2= nn.Conv3d(base_filters*2, 6, kernel_size=1) - self.seg3 = nn.Conv3d(base_filters*2, 6, kernel_size=1) - - def forward(self, x): - x = self.convInput(x) - # Encoder - e1 = self.enc1(x) - - e2 = self.StrideConv1(e1) - e2 = self.enc2(e2) - - e3 = self.StrideConv2(e2) - e3 = self.enc3(e3) - - e4 = self.StrideConv3(e3) - e4 = self.enc4(e4) - - e5 = self.StrideConv4(e4) - e5 = self.enc5(e5) - - u1 = self.upsample1(e5) - u1 = torch.cat((u1, e4), dim=1) - u1 = self.loc3(u1) - - u2 = self.upsample2(u1) - u2 = torch.cat((u2, e3), dim=1) - u2 = self.loc2(u2) - - res1 = self.seg1(u2) - res1 = F.interpolate(res1, size=e2.shape[2:], mode='nearest') - - u3 = self.upsample3(u2) - u3 = torch.cat((u3, e2), dim=1) - u3 = self.loc1(u3) - - res2 = self.seg2(u3) - res2 = res1 + res2 - res2 = F.interpolate(res2, size=e1.shape[2:], mode='nearest') - - u4 = self.upsample4(u3) - u4 = torch.cat((u4, e1), dim=1) - - out = self.convOutput(u4) - out = self.seg3(out) - out = res2 + out - - return out - -class DiceLoss(nn.Module): - """ - Multi-class Dice Loss supporting one-hot encoded targets. - - Dice Loss = 1 - (2 * |X โˆฉ Y| + smooth) / (|X| + |Y| + smooth) - - Works for both 2D and 3D tensors: - predictions: [B, C, H, W] or [B, C, D, H, W] - targets: same shape (one-hot encoded) - """ - def __init__(self, smooth=1e-6): - super(DiceLoss, self).__init__() - self.smooth = smooth - - def forward(self, predictions, targets): - """ - Args: - predictions (torch.Tensor): Model outputs after sigmoid or softmax [B, C, ...] - targets (torch.Tensor): One-hot encoded ground truth [B, C, ...] - """ - # Ensure floating point - predictions = F.softmax(predictions.float(), dim=1) - targets = targets.float() - - # Flatten across spatial dimensions but keep class and batch - dims = tuple(range(2, predictions.ndim)) # e.g. (2,3,4) for 3D data - - # Compute intersection and union per class - intersection = torch.sum(predictions * targets, dims) - pred_sum = torch.sum(predictions, dims) - target_sum = torch.sum(targets, dims) - - dice_per_class = (2.0 * intersection + self.smooth) / (pred_sum + target_sum + self.smooth) - - # Average over classes and batch - dice_loss = 1.0 - dice_per_class.mean() - - return dice_loss diff --git a/recognition/ImprovedUNet3D/predict.py b/recognition/ImprovedUNet3D/predict.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/recognition/ImprovedUNet3D/train.py b/recognition/ImprovedUNet3D/train.py deleted file mode 100644 index 5a87678e0..000000000 --- a/recognition/ImprovedUNet3D/train.py +++ /dev/null @@ -1,140 +0,0 @@ -import dataset -import matplotlib.pyplot as plt -import torch -import torch.optim as optim -from torch.optim.lr_scheduler import ReduceLROnPlateau -import numpy as np -import torch.nn.functional as F -from module import ImprovedUNet3D, DiceLoss -import dataset -import random - -# Check if CUDA is available -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -print(f'Using device: {device}') - -# Quick visualization of loss -def plot_loss(losses, loss_type='dice'): - plt.figure(figsize=(8, 4)) - plt.plot(losses, 'bo-', linewidth=2, markersize=8) - - title_map = { - 'bce': '๐Ÿ”ฅ Training Loss (BCE)', - 'dice': '๐Ÿ”ฅ Training Loss (Dice)', - 'combined': '๐Ÿ”ฅ Training Loss (Combined BCE + Dice)' - } - plt.title(title_map.get(loss_type, '๐Ÿ”ฅ Training Loss'), fontsize=14, fontweight='bold') - plt.xlabel('Epoch') - plt.ylabel('Loss') - plt.grid(True, alpha=0.3) - plt.show() - -import numpy as np -import matplotlib.pyplot as plt - -def visualize_cnn_slices(cnn_output, slice_indices=None): - """ - Visualize CNN segmentation slice-by-slice using matplotlib. - - cnn_output: C x H x W x D (numpy array) - slice_indices: list of slice indices along depth (D) - """ - # Convert to label indices - label_volume = np.argmax(cnn_output, axis=0) - - D = label_volume.shape[2] - if slice_indices is None: - slice_indices = [D // 4, D // 2, 3 * D // 4] - - fig, axes = plt.subplots(1, len(slice_indices), figsize=(15, 5)) - - for i, idx in enumerate(slice_indices): - axes[i].imshow(label_volume[:, :, idx], cmap='tab20') - axes[i].set_title(f'Slice {idx}') - axes[i].axis('off') - plt.show() - -def train(model, train_loader, test_dataset, epochs=1, lr=0.001): - model.to(device) - criterion = DiceLoss() - optimizer = optim.Adam(model.parameters(), lr=lr) - scheduler = ReduceLROnPlateau( - optimizer, - mode='min', - factor=0.5, - patience=5, - min_lr=1e-6 - ) - - losses = [] - - print(" Starting training with Instance Norm, LeakyReLU, and Softmax activation...") - for epoch in range(epochs): - model.train() - epoch_loss = 0 - - # Training loop with progress - for images, masks in train_loader: - images, masks = images.to(device), masks.to(device) - - optimizer.zero_grad() - outputs = model(images) - - print(f"pred_pet shape: {outputs.shape}, masks shape: {masks.shape}") - - loss = criterion(outputs, masks) - - # Backward pass - loss.backward() - optimizer.step() - - - epoch_loss += loss.item() - - avg_loss = epoch_loss / len(train_loader) - scheduler.step(avg_loss) - losses.append(avg_loss) - print(f"๐Ÿ“ˆ Epoch {epoch+1}/{epochs} Complete: Avg Loss = {avg_loss:.4f}") - - model.eval() - atThreshold = True - with torch.no_grad(): - total_dice = 0 - for image, mask in test_loader: - image = image.to(device) - mask = mask.to(device) - - output = model(image) - total_dice += 1 - criterion(output, mask).item() - if (atThreshold and 1 - criterion(output, mask).item() < 0.7): - atThreshold = False - - avg_dice = total_dice / len(test_dataset) - print(f"๐Ÿงช Validation Dice Score after Epoch {epoch+1}: {avg_dice:.4f}") - - if (atThreshold): - visualize_cnn_slices(cnn_output=output) - break - - print(" Training complete with enhanced U-Net!") - plot_loss(losses, loss_type='dice') - return losses - -if __name__ == "__main__": - # Example usage - train_dataset = dataset.HipMriDataset3D(image_path='semantic_MRs_anon', - mask_path='semantic_labels_anon', - transform=dataset.TrainingTransform, - train=True) - - test_dataset = dataset.HipMriDataset3D(image_path='semantic_MRs_anon', - mask_path='semantic_labels_anon', - transform=dataset.TestTransform, - train=False) - - train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=True) - test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) - - model = ImprovedUNet3D(in_channels=1, base_filters=16, dropout=0.3) - - train(model, train_loader, test_dataset, epochs=50, lr=0.001) diff --git a/recognition/s48027522-HipMRI-2DUnet/README.md b/recognition/s48027522-HipMRI-2DUnet/README.md new file mode 100644 index 000000000..5ce8584aa --- /dev/null +++ b/recognition/s48027522-HipMRI-2DUnet/README.md @@ -0,0 +1,21 @@ +# Prostate MRI Segmentation with 2D U-Net + +## Project Overview + +This project implements a 2D UNet CNN architecture for segmenting prostates from MRI scans using a **2D U-Net** architecture. The workflow is designed to handle **pre-processed 2D slices of prostate MRIs**, enabling efficient training and evaluation of segmentation models. + +The goal is to accurately delineate the prostate from surrounding tissues, which is critical for clinical applications such as **radiotherapy planning, disease diagnosis, and progression monitoring**. + +--- + +## Features + +- **2D Slice-Based Training**: Uses individual 2D slices extracted from 3D MRI volumes, enabling faster training and lower memory requirements. +- **U-Net Architecture**: Lightweight **2D U-Net** with encoder-decoder structure and skip connections for high-resolution segmentation. +- **Binary Dice Loss**: Implements a **Dice similarity coefficient loss** for robust training in binary segmentation (prostate vs. background). +- **Data Augmentation**: Supports **random flips, rotations, and z-score normalization** for better generalization. +- **Early Stopping**: Training halts automatically when the **Dice coefficient of the prostate exceeds 0.75**, avoiding overfitting. +- **Visualization Tools**: Includes slice-level visualizations of input images, ground truth masks, and model predictions. +- **Modular Design**: Easily replaceable components for experimenting with different models, loss functions, and transforms. + +--- \ No newline at end of file diff --git a/recognition/s48027522-HipMRI-2DUnet/dataset.py b/recognition/s48027522-HipMRI-2DUnet/dataset.py new file mode 100644 index 000000000..78be0b200 --- /dev/null +++ b/recognition/s48027522-HipMRI-2DUnet/dataset.py @@ -0,0 +1,90 @@ +import torch +from torch.utils.data import Dataset +import torch.nn.functional as F + +import numpy as np +import os +import nibabel as nib + +def zScoreNormalize(image): + mean = image.mean() + std = image.std() + if std > 0: + image = (image - mean) / std + else: + image = image - mean + return image + +def RandomFlip(image, mask): + axes = [0, 1, 2] # D, H, W axes + for axis in axes: + if np.random.rand() > 0.5: + image = np.flip(image, axis=axis) + mask = np.flip(mask, axis=axis) + return image, mask + +def RandomRotate_90(image, mask): + k = np.random.randint(0, 4) # 0, 90, 180, 270 degrees + axes = (1, 2) # rotate in-plane (H, W) + image = np.rot90(image, k, axes) + mask = np.rot90(mask, k, axes) + return image, mask + +def TrainingTransform(image, mask): + image, mask = RandomFlip(image, mask) + image, mask = RandomRotate_90(image, mask) + image = zScoreNormalize(image) + + return image, mask + +def TestTransform(image, mask): + image = zScoreNormalize(image) + + return image, mask + +def Resize3dTensor(img_tensor, target_shape=(128,128,128), mode_type='trilinear'): + """ + img_tensor: torch tensor of shape (C, D, H, W) + """ + img_tensor = img_tensor.unsqueeze(0) # add batch dim + img_resized = F.interpolate(img_tensor, size=target_shape, mode=mode_type) + return img_resized.squeeze(0) + +def to_channels(label_slice, dtype=np.float32): + """Convert 2D label slice to one-hot channels.""" + num_classes = int(label_slice.max()) + 1 + out = np.zeros((num_classes,) + label_slice.shape, dtype=dtype) + for c in range(num_classes): + out[c] = (label_slice == c) + return out + +class HipMriDataset2D(Dataset): + """Dataset for pre-saved 2D slices.""" + + def __init__(self, image_path, mask_path, transform=None): + self.image_paths = sorted([os.path.join(image_path, f) for f in os.listdir(image_path)]) + self.mask_paths = sorted([os.path.join(mask_path, f) for f in os.listdir(mask_path)]) + self.transform = transform + + assert len(self.image_paths) == len(self.mask_paths), "Number of images and masks must match" + + def __len__(self): + return len(self.image_paths) + + def __getitem__(self, idx): + # Load 2D slice + image = nib.load(self.image_paths[idx]).get_fdata().astype(np.float32) + mask = nib.load(self.mask_paths[idx]).get_fdata().astype(np.uint8) + + # Apply transforms + if self.transform: + image, mask = self.transform(image, mask) + + # Convert mask to one-hot channels + mask = to_channels(mask, dtype=np.uint8) + + # Convert to tensors + image_tensor = torch.from_numpy(image).unsqueeze(0).float() # [1, H, W] + mask_tensor = torch.from_numpy(mask).float() # [C, H, W] + + return image_tensor, mask_tensor \ No newline at end of file diff --git a/recognition/s48027522-HipMRI-2DUnet/module.py b/recognition/s48027522-HipMRI-2DUnet/module.py new file mode 100644 index 000000000..61c2ea400 --- /dev/null +++ b/recognition/s48027522-HipMRI-2DUnet/module.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# ----------------------------- +# 2D Convolutional Block +# ----------------------------- +class ConvBlock2D(nn.Module): + def __init__(self, in_channels, out_channels, dropout=0.0): + super().__init__() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) + self.bn1 = nn.BatchNorm2d(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) + self.bn2 = nn.BatchNorm2d(out_channels) + self.act = nn.ReLU(inplace=True) + self.dropout = nn.Dropout2d(dropout) + + def forward(self, x): + x = self.act(self.bn1(self.conv1(x))) + x = self.dropout(x) + x = self.act(self.bn2(self.conv2(x))) + return x + +# ----------------------------- +# 2D U-Net +# ----------------------------- +class UNet2D(nn.Module): + def __init__(self, in_channels=1, out_channels=1, base_filters=64, dropout=0.0): + super().__init__() + + # Encoder + self.enc1 = ConvBlock2D(in_channels, base_filters, dropout) + self.enc2 = ConvBlock2D(base_filters, base_filters*2, dropout) + self.enc3 = ConvBlock2D(base_filters*2, base_filters*4, dropout) + self.enc4 = ConvBlock2D(base_filters*4, base_filters*8, dropout) + + self.pool = nn.MaxPool2d(2) + + # Bottleneck + self.bottleneck = ConvBlock2D(base_filters*8, base_filters*16, dropout) + + # Decoder + self.up4 = nn.ConvTranspose2d(base_filters*16, base_filters*8, kernel_size=2, stride=2) + self.dec4 = ConvBlock2D(base_filters*16, base_filters*8, dropout) + + self.up3 = nn.ConvTranspose2d(base_filters*8, base_filters*4, kernel_size=2, stride=2) + self.dec3 = ConvBlock2D(base_filters*8, base_filters*4, dropout) + + self.up2 = nn.ConvTranspose2d(base_filters*4, base_filters*2, kernel_size=2, stride=2) + self.dec2 = ConvBlock2D(base_filters*4, base_filters*2, dropout) + + self.up1 = nn.ConvTranspose2d(base_filters*2, base_filters, kernel_size=2, stride=2) + self.dec1 = ConvBlock2D(base_filters*2, base_filters, dropout) + + # Output + self.segmentation_head = nn.Conv2d(base_filters, out_channels, kernel_size=1) + + def forward(self, x): + # Encoder + e1 = self.enc1(x) + e2 = self.enc2(self.pool(e1)) + e3 = self.enc3(self.pool(e2)) + e4 = self.enc4(self.pool(e3)) + + # Bottleneck + b = self.bottleneck(self.pool(e4)) + + # Decoder + d4 = self.up4(b) + d4 = self.dec4(torch.cat([d4, e4], dim=1)) + + d3 = self.up3(d4) + d3 = self.dec3(torch.cat([d3, e3], dim=1)) + + d2 = self.up2(d3) + d2 = self.dec2(torch.cat([d2, e2], dim=1)) + + d1 = self.up1(d2) + d1 = self.dec1(torch.cat([d1, e1], dim=1)) + + out = self.segmentation_head(d1) # logits + return out + +class BinaryDiceLoss(nn.Module): + """ + Dice Loss for binary segmentation. + + Dice Loss = 1 - (2 * |X โˆฉ Y| + smooth) / (|X| + |Y| + smooth) + + Works for 2D or 3D tensors: + predictions: [B, 1, H, W] or [B, 1, D, H, W] (logits) + targets: [B, 1, H, W] or [B, 1, D, H, W] (binary 0/1) + """ + def __init__(self, smooth=1e-6): + super().__init__() + self.smooth = smooth + + def forward(self, predictions, targets): + """ + Args: + predictions (torch.Tensor): logits from the model [B, 1, ...] + targets (torch.Tensor): binary ground truth [B, 1, ...] + """ + # Apply sigmoid to convert logits to probabilities + probs = torch.sigmoid(predictions) + targets = targets.float() + + # Flatten spatial dimensions per batch + dims = tuple(range(1, predictions.ndim)) # flatten everything except batch + + intersection = torch.sum(probs * targets, dims) + pred_sum = torch.sum(probs, dims) + target_sum = torch.sum(targets, dims) + + dice_score = (2.0 * intersection + self.smooth) / (pred_sum + target_sum + self.smooth) + dice_loss = 1.0 - dice_score.mean() # average over batch + + return dice_loss diff --git a/recognition/s48027522-HipMRI-2DUnet/predict.py b/recognition/s48027522-HipMRI-2DUnet/predict.py new file mode 100644 index 000000000..bd8732682 --- /dev/null +++ b/recognition/s48027522-HipMRI-2DUnet/predict.py @@ -0,0 +1,49 @@ +import dataset +import matplotlib.pyplot as plt +import torch +import torch.optim as optim +from torch.optim.lr_scheduler import ReduceLROnPlateau +import numpy as np +import torch.nn.functional as F +import random +import train + +from module import UNet2D, BinaryDiceLoss # your 2D U-Net implementation +from dataset import HipMriDataset2D # 2D dataset + DiceLoss + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f'Using device: {device}') + +if __name__ == "__main__": + test_dataset = HipMriDataset2D( + image_path='/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_test', + mask_path='/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_seg_test', + transform=dataset.TrainingTransform2D, + train=True + ) + + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) + + model = UNet2D(in_channels=1, out_channels=1, base_filters=16, dropout=0.1) + + losses = [] + model.eval() + total_dice = 0.0 + with torch.no_grad(): + for images, masks in test_loader: + images, masks = images.to(device), masks.to(device) + outputs = model(images) + probs = torch.sigmoid(outputs) + pred_mask = (probs > 0.5).float() + + # Dice coefficient for the prostate class (assumes class=1) + intersection = (pred_mask * masks).sum(dim=(1,2,3)) + union = pred_mask.sum(dim=(1,2,3)) + masks.sum(dim=(1,2,3)) + dice_score = ((2*intersection + 1e-6) / (union + 1e-6)).mean().item() + total_dice += dice_score + + avg_dice = total_dice / len(test_loader) + + + print("Training complete!") + train.plot_loss(losses) \ No newline at end of file diff --git a/recognition/s48027522-HipMRI-2DUnet/train.py b/recognition/s48027522-HipMRI-2DUnet/train.py new file mode 100644 index 000000000..289d4d296 --- /dev/null +++ b/recognition/s48027522-HipMRI-2DUnet/train.py @@ -0,0 +1,140 @@ +import dataset +import matplotlib.pyplot as plt +import torch +import torch.optim as optim +from torch.optim.lr_scheduler import ReduceLROnPlateau +import numpy as np +import torch.nn.functional as F +import random + +from module import UNet2D, BinaryDiceLoss # your 2D U-Net implementation +from dataset import HipMriDataset2D # 2D dataset + DiceLoss + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f'Using device: {device}') + +# ----------------------------- +# Quick visualization of loss +# ----------------------------- +def plot_loss(losses): + plt.figure(figsize=(8, 4)) + plt.plot(losses, 'bo-', linewidth=2, markersize=8) + plt.title('๐Ÿ”ฅ Training Loss (Dice)', fontsize=14, fontweight='bold') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.grid(True, alpha=0.3) + plt.show() + +def visualize_slice(image, mask_pred, mask_gt): + """Visualize a single 2D slice with prediction and ground truth""" + _, axes = plt.subplots(1, 3, figsize=(12, 4)) + axes[0].imshow(image[0], cmap='gray') + axes[0].set_title("Input Image") + axes[1].imshow(mask_gt[0], cmap='Reds') + axes[1].set_title("Ground Truth") + axes[2].imshow(mask_pred[0], cmap='Reds') + axes[2].set_title("Predicted Mask") + for ax in axes: + ax.axis('off') + plt.show() + +# ----------------------------- +# Training function +# ----------------------------- +def train(model, train_loader, test_loader, epochs=50, lr=1e-3, dice_threshold=0.75): + model.to(device) + criterion = BinaryDiceLoss() + optimizer = optim.Adam(model.parameters(), lr=lr) + scheduler = ReduceLROnPlateau( + optimizer, + mode='min', + factor=0.5, + patience=5, + min_lr=1e-6 + ) + + losses = [] + + print("Starting 2D slice training...") + + for epoch in range(epochs): + model.train() + epoch_loss = 0.0 + + for images, masks in train_loader: + images, masks = images.to(device), masks.to(device) + + optimizer.zero_grad() + outputs = model(images) # logits + loss = criterion(outputs, masks) + loss.backward() + optimizer.step() + + epoch_loss += loss.item() + + avg_loss = epoch_loss / len(train_loader) + scheduler.step(avg_loss) + losses.append(avg_loss) + print(f"Epoch {epoch+1}/{epochs} - Avg Loss: {avg_loss:.4f}") + + # ----------------------------- + # Validation & early stopping + # ----------------------------- + model.eval() + total_dice = 0.0 + with torch.no_grad(): + for images, masks in test_loader: + images, masks = images.to(device), masks.to(device) + outputs = model(images) + probs = torch.sigmoid(outputs) + pred_mask = (probs > 0.5).float() + + # Dice coefficient for the prostate class (assumes class=1) + intersection = (pred_mask * masks).sum(dim=(1,2,3)) + union = pred_mask.sum(dim=(1,2,3)) + masks.sum(dim=(1,2,3)) + dice_score = ((2*intersection + 1e-6) / (union + 1e-6)).mean().item() + total_dice += dice_score + + avg_dice = total_dice / len(test_loader) + print(f"Validation Dice (Prostate) after Epoch {epoch+1}: {avg_dice:.4f}") + + # Early stopping if prostate Dice exceeds threshold + if avg_dice >= dice_threshold: + print(f"Early stopping: Prostate Dice {avg_dice:.4f} >= {dice_threshold}") + # visualize a random slice + sample_img, sample_mask = next(iter(test_loader)) + sample_img = sample_img.to(device) + sample_mask = sample_mask.to(device) + sample_output = model(sample_img) + pred_mask = (torch.sigmoid(sample_output) > 0.5).float() + visualize_slice(sample_img[0].cpu().numpy(), pred_mask[0].cpu().numpy(), sample_mask[0].cpu().numpy()) + break + + print("๐ŸŽฏ Training complete!") + plot_loss(losses) + return losses + +# ----------------------------- +# Main +# ----------------------------- +if __name__ == "__main__": + train_dataset = HipMriDataset2D( + image_path='/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_train', + mask_path='/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_seg_train', + transform=dataset.TrainingTransform2D, + train=True + ) + + test_dataset = HipMriDataset2D( + image_path='/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_validate', + mask_path='/home/groups/comp3710/HipMRI_Study_open/keras_slices_data/keras_slices_seg_validate', + transform=dataset.TestTransform2D, + train=False + ) + + train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True) + test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False) + + model = UNet2D(in_channels=1, out_channels=1, base_filters=16, dropout=0.1) + + train(model, train_loader, test_loader, epochs=50, lr=1e-3, dice_threshold=0.75) \ No newline at end of file