-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
43 lines (35 loc) · 1.67 KB
/
test.py
File metadata and controls
43 lines (35 loc) · 1.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import torch
import time
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import average_precision_score, roc_auc_score, recall_score, precision_score
from sklearn.metrics import precision_recall_curve, roc_curve
import numpy as np
from Tester import Tester
from argparser import args
from utils.set_logger import set_logger
if __name__ == '__main__':
torch.manual_seed(args.random_seed)
if args.cache_dir == '':
summarywriter_dir = os.path.join(args.experiment_dir, 'runs', 'train')
else:
summarywriter_dir = os.path.join(args.cache_dir, 'runs', 'train')
writer = SummaryWriter(summarywriter_dir)
tester = Tester(args)
logger = set_logger(summarywriter_dir)
if args.cache_dir == '':
model_path = os.path.join(args.experiment_dir, 'models', args.weight_name)
else:
model_path = os.path.join(args.cache_dir, 'models', args.weight_name)
begin_time = time.time()
# cur_roc, cur_pr, cur_test_loss, cur_label, cur_predict, total_uid = tester.eval(
cur_roc, cur_pr, precision, recall, F1, cur_test_loss, cur_label, cur_predict, total_uid = tester.eval(
state_dict_path=model_path)
end_time = time.time()
logger.info("Time: %.4f" % (end_time - begin_time))
np.save(os.path.join(summarywriter_dir, 'valid_label.npy'), cur_label)
np.save(os.path.join(summarywriter_dir, 'valid_predict.npy'), cur_predict)
np.save(os.path.join(summarywriter_dir, 'valid_uid.npy'), total_uid)
logger.info("ROC-AUC: %.4f, PR-AUC: %.4f, PRECISION: %.4f, RECALL: %.4f, F1: %.4f, VALID LOSS: %.4f" % (cur_roc, cur_pr, precision, recall, F1, cur_test_loss))
writer.flush()
writer.close()