From 8557ab082a01f1c9539526b5fa0fccf1eac69345 Mon Sep 17 00:00:00 2001 From: harrypotty18 Date: Tue, 29 Mar 2022 04:08:55 +0800 Subject: [PATCH] Update main.py --- CIGIN_V2/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CIGIN_V2/main.py b/CIGIN_V2/main.py index 3be9f5f..26d2ce7 100644 --- a/CIGIN_V2/main.py +++ b/CIGIN_V2/main.py @@ -53,8 +53,8 @@ def collate(samples): solute_graphs, solvent_graphs, labels = map(list, zip(*samples)) solute_graphs = dgl.batch(solute_graphs) solvent_graphs = dgl.batch(solvent_graphs) - solute_len_matrix = get_len_matrix(solute_graphs.batch_num_nodes) - solvent_len_matrix = get_len_matrix(solvent_graphs.batch_num_nodes) + solute_len_matrix = get_len_matrix(solute_graphs.batch_num_nodes().numpy()) + solvent_len_matrix = get_len_matrix(solvent_graphs.batch_num_nodes().numpy()) return solute_graphs, solvent_graphs, solute_len_matrix, solvent_len_matrix, labels