-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbatch_run.py
More file actions
50 lines (40 loc) · 1.6 KB
/
batch_run.py
File metadata and controls
50 lines (40 loc) · 1.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
import subprocess
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"Number of GPUs: {torch.cuda.device_count()}")
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
print(f"Device {i}: {torch.cuda.get_device_name(i)}")
models = ['WavePCNet']
gpus = '0'
trset = 'HKU_train'
valset = 'HKU_test'
batch_size = 3
weight_base_path = './weight/'
train_script = 'train.py'
test_script = 'test.py'
for model in models:
print(f"\n================ {model} ================\n")
train_cmd = f"python3 {train_script} {model} --gpus={gpus} --trset={trset} --batch={batch_size} --val={valset} --resume --weight=./weight/ECSSD/"
ret_train = subprocess.call(train_cmd, shell=True)
if ret_train != 0:
print(f"[!] :{model}")
continue
model_dir = os.path.join(weight_base_path, model, 'resnet50', 'base')
weight_files = [f for f in os.listdir(model_dir) if f.endswith('.pth') and f.startswith(f"{model}_resnet50")]
if not weight_files:
print(f"[!] can't find {model} ")
continue
weight_files.sort(key=lambda x: int(x.split('_')[-2][5:]))
latest_weight = weight_files[-1]
weight_path = os.path.join(model_dir, latest_weight)
print(f"\n>>> : {model}")
print(f" : {weight_path}\n")
test_cmd = f"python3 {test_script} {model} --gpus={gpus} --weight={weight_path} --save --val={valset}"
ret_test = subprocess.call(test_cmd, shell=True)
if ret_test != 0:
print(f"[!] :{model}")
else:
print(f"[✓] :{model}\n")
print("\n================================\n")