Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion slide2vec/configs/default_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ model:
arch: # architecture of custom model
pretrained_weights: # path to the pretrained weights when using a custom model
batch_size: 256
tile_size: ${tiling.params.tile_size}
input_size: ${tiling.params.tile_size}
restrict_to_tissue: false # whether to restrict tile content to tissue pixels only when feeding tile through encoder
patch_size: 256 # if level is "region", size used to unroll the region into patches
token_size: 16 # size of the tokens used model is a custom pretrained ViT
Expand Down
10 changes: 3 additions & 7 deletions slide2vec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def run_feature_extraction(config_file, output_dir, run_on_cpu: False):
"torch.distributed.run",
f"--master_port={free_port}",
"--nproc_per_node=gpu",
"-m",
"slide2vec.embed",
"slide2vec/embed.py",
"--config-file",
os.path.abspath(config_file),
"--output-dir",
Expand All @@ -103,8 +102,7 @@ def run_feature_extraction(config_file, output_dir, run_on_cpu: False):
if run_on_cpu:
cmd = [
sys.executable,
"-m",
"slide2vec.embed",
"slide2vec/embed.py",
"--config-file",
os.path.abspath(config_file),
"--output-dir",
Expand All @@ -127,11 +125,9 @@ def run_feature_extraction(config_file, output_dir, run_on_cpu: False):

def run_feature_aggregation(config_file, output_dir, run_on_cpu: False):
print("Running aggregate.py...")
# find a free port
cmd = [
sys.executable,
"-m",
"slide2vec.aggregate",
"slide2vec/aggregate.py",
"--config-file",
os.path.abspath(config_file),
"--output-dir",
Expand Down
14 changes: 7 additions & 7 deletions slide2vec/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
model = PathoJEPA(
pretrained_weights=options.pretrained_weights,
arch=options.arch,
input_size=options.tile_size,
input_size=options.input_size,
patch_size=options.token_size,
normalize_embeddings=options.normalize_embeddings,
)
Expand All @@ -77,13 +77,13 @@ def __init__(
model = PandaViT(
arch="vit_small",
pretrained_weights=options.pretrained_weights,
input_size=options.tile_size,
input_size=options.input_size,
)
elif options.name == "dino" and options.arch:
model = DINOViT(
arch=options.arch,
pretrained_weights=options.pretrained_weights,
input_size=options.tile_size,
input_size=options.input_size,
patch_size=options.token_size,
)
elif options.level == "region":
Expand Down Expand Up @@ -117,7 +117,7 @@ def __init__(
tile_encoder = PathoJEPA(
pretrained_weights=options.pretrained_weights,
arch=options.arch,
input_size=options.patch_size,
input_size=options.input_size,
patch_size=options.token_size,
normalize_embeddings=options.normalize_embeddings,
)
Expand All @@ -131,13 +131,13 @@ def __init__(
tile_encoder = PandaViT(
arch="vit_small",
pretrained_weights=options.pretrained_weights,
input_size=options.tile_size,
input_size=options.input_size,
)
elif options.name is None and options.arch:
tile_encoder = DINOViT(
arch=options.arch,
pretrained_weights=options.pretrained_weights,
input_size=options.patch_size,
input_size=options.input_size,
patch_size=options.token_size,
)
model = RegionFeatureExtractor(tile_encoder, tile_size=options.patch_size)
Expand Down Expand Up @@ -861,7 +861,7 @@ class RegionFeatureExtractor(nn.Module):
def __init__(
self,
tile_encoder: nn.Module,
tile_size: int = 256,
tile_size: int,
):
super(RegionFeatureExtractor, self).__init__()
self.tile_encoder = tile_encoder
Expand Down
Loading