diff --git a/CIGIN_V2/main.py b/CIGIN_V2/main.py index 3be9f5f..26d2ce7 100644 --- a/CIGIN_V2/main.py +++ b/CIGIN_V2/main.py @@ -53,8 +53,8 @@ def collate(samples): solute_graphs, solvent_graphs, labels = map(list, zip(*samples)) solute_graphs = dgl.batch(solute_graphs) solvent_graphs = dgl.batch(solvent_graphs) - solute_len_matrix = get_len_matrix(solute_graphs.batch_num_nodes) - solvent_len_matrix = get_len_matrix(solvent_graphs.batch_num_nodes) + solute_len_matrix = get_len_matrix(solute_graphs.batch_num_nodes().numpy()) + solvent_len_matrix = get_len_matrix(solvent_graphs.batch_num_nodes().numpy()) return solute_graphs, solvent_graphs, solute_len_matrix, solvent_len_matrix, labels