diff --git a/src/scaffoldmaker/meshtypes/meshtype_3d_nerve1.py b/src/scaffoldmaker/meshtypes/meshtype_3d_nerve1.py index 82363d74..fc6a6ba0 100644 --- a/src/scaffoldmaker/meshtypes/meshtype_3d_nerve1.py +++ b/src/scaffoldmaker/meshtypes/meshtype_3d_nerve1.py @@ -1,6 +1,3 @@ -import math -import logging - from cmlibs.maths.vectorops import ( add, cross, distance, dot, magnitude, matrix_mult, matrix_inv, mult, normalize, rejection, set_magnitude, sub) from cmlibs.utils.zinc.field import find_or_create_field_group, find_or_create_field_coordinates @@ -27,8 +24,11 @@ smoothCurveSideCrossDerivatives, track_curve_side_direction) from scaffoldmaker.utils.read_vagus_data import load_vagus_data from scaffoldmaker.utils.zinc_utils import ( - define_and_fit_field, find_or_create_field_zero_fibres, fit_hermite_curve, generate_curve_mesh, generate_datapoints,\ - generate_mesh_marker_points) + define_and_fit_field, find_or_create_field_zero_fibres, fit_hermite_curve, generate_curve_mesh, + generate_datapoints, generate_mesh_marker_points) +import copy +import logging +import math logger = logging.getLogger(__name__) @@ -60,7 +60,7 @@ def getDefaultOptions(cls, parameterSetName="Default"): 'Trunk proportion': 1.0, 'Trunk fit number of iterations': 5, 'Default anterior direction': [0.0, 1.0, 0.0], - 'Default trunk diameter': 3.0, + 'Default trunk diameter': 3000.0, 'Branch diameter trunk proportion': 0.5 } return options @@ -584,6 +584,8 @@ def generateBaseMesh(cls, region, options): parent_parameters = {} parent_parameters[trunk_group_name] = (tx, td1, td2, td12, td3, td13, tnid) + parent_elements_counts = {} # map of parent branch name to list of numbers of elements in branches + parent_elements_counts[trunk_group_name] = [trunk_elements_count] # the first branch node mainly marks connection on the parent/trunk centroid, # hence the following code starts branches from a proportion of parent radius # out, up to an upper limit on the proportion of the first segment: @@ -605,22 +607,49 @@ def generateBaseMesh(cls, region, options): # iterate over branches off trunk, and branches of branches visited_branches_order = [] branch_root_parameters = {} - branch_data = vagus_data.get_branch_data() + branch_coordinates_data = vagus_data.get_branch_coordinates_data() + branch_sequences_data = vagus_data.get_branch_sequences_data() branch_parent_map = vagus_data.get_branch_parent_map() queue = [branch for branch in branch_parent_map.keys() if branch_parent_map[branch] == trunk_group_name] - while queue: - branch_name = queue.pop(0) - if branch_name in visited_branches_order: - logger.warning("already processed branch " + branch_name) - continue - visited_branches_order.append(branch_name) - - branch_px = [branch_x[0] for branch_x in branch_data[branch_name]] - branch_parent_name = branch_parent_map[branch_name] - trunk_is_parent = branch_parent_name == trunk_group_name - # print(branch_name, '<--', branch_parent_name) + branch_name = None + branch_parent_name = None + trunk_is_parent = False + branch_coordinates = [] + branch_data_nodes_counts = [] + branch_box_group = None + branch_box_mesh_group = None + branch_box_face_mesh_group = None + branch_box_line_mesh_group = None + + # iterate over branch names and distinct branches from the branch_sequences_data + while True: + if not branch_coordinates: + # get the next branch in queue + if not queue: + break + branch_name = queue.pop(0) + if branch_name in visited_branches_order: + logger.warning("already processed branch " + branch_name) + continue + visited_branches_order.append(branch_name) + branch_parent_name = branch_parent_map[branch_name] + trunk_is_parent = branch_parent_name == trunk_group_name + # print(branch_name, '<--', branch_parent_name) + branch_coordinates = copy.copy(branch_coordinates_data[branch_name]) + branch_data_nodes_counts = copy.copy(branch_sequences_data[branch_name]) + + # branch annotation groups + branch_box_group = AnnotationGroup(region, (branch_name, annotation_term_map[branch_name])) + annotation_groups.append(branch_box_group) + branch_box_mesh_group = branch_box_group.getMeshGroup(mesh3d) + branch_box_face_mesh_group = branch_box_group.getMeshGroup(mesh2d) + branch_box_line_mesh_group = branch_box_group.getMeshGroup(mesh1d) tx, td1, td2, td12, td3, td13, tnid = parent_parameters[branch_parent_name] + branch_nodes_count = branch_data_nodes_counts[0] + branch_px = branch_coordinates[:branch_nodes_count] + branch_coordinates = branch_coordinates[branch_nodes_count:] + branch_data_nodes_counts.pop(0) # get point in trunk volume closest to first point in branch data # parent_group = trunk_group @@ -666,14 +695,24 @@ def generateBaseMesh(cls, region, options): logger.error("Nerve: branch " + branch_name + " fitted start point could not be found in parent nerve") continue parent_first_element = parent_mesh_group.createElementiterator().next() - parent_location = (parent_element.getIdentifier() - parent_first_element.getIdentifier(), parent_xi[0]) - if (not trunk_is_parent) and (parent_location[0] == 0): - # can't have branch from the root element of a branch - if parent_mesh_group.getSize() == 1: - logger.error("Nerve: can't make branch " + branch_name + - " off single element parent " + branch_parent_name) - continue - parent_location = (1, 0.0) + parent_index = parent_element.getIdentifier() - parent_first_element.getIdentifier() + branch_in_first_parent_element = False + # handle element indexes when there are several branches of the same name + # because the parent parameters are just appended after each branch = there isn't an element in between. + parent_start_index = 0 + for elements_count in parent_elements_counts[branch_parent_name]: + if parent_index < (parent_start_index + elements_count): + break + parent_index += 1 + parent_start_index += elements_count + 1 + + parent_location = (parent_index, parent_xi[0]) + if (not trunk_is_parent) and (parent_index - parent_start_index == 0): + if parent_location[0] < 0.99: + logger.warning("Nerve: attaching branch " + branch_name + + " at end of first element of parent branch " + branch_parent_name + + " instead of calculated proportion " + str(parent_location[0])) + parent_location = (0, 1.0) cxd2 = 2.0 * (parent_xi[1] - 0.5) cxd3 = 2.0 * (parent_xi[2] - 0.5) @@ -708,13 +747,6 @@ def generateBaseMesh(cls, region, options): basis_to = [bd1, bd2, bd3] coefs = matrix_mult(basis_to, matrix_inv(basis_from)) - # branch annotation groups - branch_box_group = AnnotationGroup(region, (branch_name, annotation_term_map[branch_name])) - annotation_groups.append(branch_box_group) - branch_box_mesh_group = branch_box_group.getMeshGroup(mesh3d) - branch_box_face_mesh_group = branch_box_group.getMeshGroup(mesh2d) - branch_box_line_mesh_group = branch_box_group.getMeshGroup(mesh1d) - # get side derivatives, minimising rotation from trunk # dir2 = normalize(bd2) dir3 = normalize(bd3) @@ -741,7 +773,12 @@ def generateBaseMesh(cls, region, options): if e == 0: # branch root 3D element - nids = [tnid[pn1], tnid[pn2], node_identifier] + if parent_location[0] == 0: + # special case of branch off first parent element + # doesn't use the first local node as parent_location[1] == 1.0 + nids = [tnid[pn2], tnid[pn2], node_identifier] + else: + nids = [tnid[pn1], tnid[pn2], node_identifier] scalefactors = [-1] + fns + dfns + [cxd2, cxd3] + coefs[0] + coefs[1] + coefs[2] element = mesh3d.createElement(element_identifier, elementtemplate_branch_root) element.setNodesByIdentifier(eft3dBR, nids) @@ -773,7 +810,12 @@ def generateBaseMesh(cls, region, options): if e == 0: # branch root 2D face facetemplate_branch_root, eft2dBR = facetemplate_and_eft_list_branch_root[f] - nids = [tnid[pn1], tnid[pn2], node_identifier] + if parent_location[0] == 0: + # special case of branch off first parent element + # doesn't use the first local node as parent_location[1] == 1.0 + nids = [tnid[pn2], tnid[pn2], node_identifier] + else: + nids = [tnid[pn1], tnid[pn2], node_identifier] scalefactors = scalefactors2d + fns + dfns + [cxd2, cxd3] + coefs[0] + coefs[1] + coefs[2] face = mesh3d.createElement(face_identifier, facetemplate_branch_root) face.setNodesByIdentifier(eft2dBR, nids) @@ -795,7 +837,14 @@ def generateBaseMesh(cls, region, options): child_branches = [branch for branch in branch_parent_map.keys() if branch_parent_map[branch] == branch_name] if child_branches: queue = child_branches + queue - parent_parameters[branch_name] = (cx, cd1, cd2, cd12, cd3, cd13, cnid) + existing_parent_parameters = parent_parameters.get(branch_name) + if existing_parent_parameters: + for dst, src in zip(existing_parent_parameters, (cx, cd1, cd2, cd12, cd3, cd13, cnid)): + dst += src + parent_elements_counts[branch_name].append(len(cx) - 1) + else: + parent_parameters[branch_name] = existing_parent_parameters = (cx, cd1, cd2, cd12, cd3, cd13, cnid) + parent_elements_counts[branch_name] = [len(cx) - 1] # ================================================= # Add material coordinates and straight coordinates @@ -987,17 +1036,21 @@ def generateBaseMesh(cls, region, options): for branch_common_name, branch_names in branch_common_groups.items(): term = get_vagus_term(branch_common_name) branch_common_group = findOrCreateAnnotationGroupForTerm(annotation_groups, region, term) - branch_common_mesh_group = branch_common_group.getMeshGroup(mesh3d) for branch_name in branch_names: branch_group = findAnnotationGroupByName(annotation_groups, branch_name) - branch_mesh_group = branch_group.getMeshGroup(mesh3d) - - el_iter = branch_mesh_group.createElementiterator() - element = el_iter.next() - while element.isValid(): - branch_common_mesh_group.addElement(element) + if not branch_group: + logger.warning("Nerve: Could not find annotation for branch " + branch_name + + ". Can't add to common branch group") + continue + for meshnd in [mesh1d, mesh2d, mesh3d]: + branch_common_mesh_group = branch_common_group.getMeshGroup(meshnd) + branch_mesh_group = branch_group.getMeshGroup(meshnd) + el_iter = branch_mesh_group.createElementiterator() element = el_iter.next() + while element.isValid(): + branch_common_mesh_group.addElement(element) + element = el_iter.next() # ============================================ # Add trunk section groups: cervical, thoracic @@ -1096,7 +1149,14 @@ def generate_trunk_1d(vagus_data, trunk_proportion, trunk_elements_count_prefit, is_left = vagus_data.get_side_label() == 'left' raw_marker_data = vagus_data.get_level_markers() px = [e[0] for e in trunk_data_coordinates] - bx, bd1 = get_curve_from_points(px, number_of_elements=trunk_elements_count_prefit) + segment_trunk_info_list = vagus_data.get_segment_trunk_info_list() + if segment_trunk_info_list: + ax = [] + for segment_trunk_info in segment_trunk_info_list: + ax += segment_trunk_info['ordered_points'] + else: + ax = px + bx, bd1 = get_curve_from_points(ax, number_of_elements=trunk_elements_count_prefit) length = getCubicHermiteCurvesLength(bx, bd1) # outlier_length = 0.025 * length # # needs to be bigger if fewer elements: @@ -1221,13 +1281,45 @@ def generate_trunk_1d(vagus_data, trunk_proportion, trunk_elements_count_prefit, zero_fibres = find_or_create_field_zero_fibres(fieldmodule) + # create group for applying higher stiffness for multi-path segments + # elements are in this group if either end nodes are in the range of a multi-path segment + multi_path_group = fieldmodule.createFieldGroup() + multi_path_group.setName('multi-path') + multi_path_group.setManaged(True) # if not managed, group settings are ignored + multi_path_group_name = multi_path_group.getName() # in case name was in use + multi_path_mesh_group = multi_path_group.createMeshGroup(mesh1d) + data_multi_path_ranges = [] + for segment_trunk_info in segment_trunk_info_list: + if not segment_trunk_info.get('ordered_coordinates'): + data_multi_path_ranges.append(segment_trunk_info['range']) + # print('segment', segment_trunk_info['name'], 'is multi-path') + + node_in_data_multi_path_range = [] + for x in ex: + in_data_multi_path_range = False + for data_multi_path_range in data_multi_path_ranges: + for c in range(3): + if (x[c] < data_multi_path_range[0][c]) or (x[c] > data_multi_path_range[1][c]): + break + else: + in_data_multi_path_range = True + break + node_in_data_multi_path_range.append(in_data_multi_path_range) + element_in_data_multi_path_range = [] + for element_identifier in range(1, trunk_elements_count + 1): + if (node_in_data_multi_path_range[element_identifier - 1] or + node_in_data_multi_path_range[element_identifier]): + multi_path_mesh_group.addElement(mesh1d.findElementByIdentifier(element_identifier)) + # print("Element", element_identifier,"is in multi-path range") + # note that fitting is very slow if done within ChangeManager as find mesh location is slow # this includes working with the user-supplied region which is called with ChangeManager on. fitter = GeometryFitter(region=fit_region) length = getCubicHermiteCurvesLength(ex, ed1) - outlier_length = 0.025 * length - fitter.getInitialFitterStepConfig().setGroupOutlierLength(None, outlierLength=outlier_length) - # fitter.setDiagnosticLevel(1) + outlier_length = 0.05 * length + config_step = fitter.getInitialFitterStepConfig() + config_step.setGroupOutlierLength(None, outlierLength=outlier_length) + fitter.setDiagnosticLevel(1) fitter.setModelCoordinatesField(coordinates) fitter.setFibreField(zero_fibres) del zero_fibres @@ -1241,30 +1333,29 @@ def generate_trunk_1d(vagus_data, trunk_proportion, trunk_elements_count_prefit, points_count_calibration_factor = len(px) / 25000 # calibration_length = 27840.0 length_calibration_factor = length / 25000.0 - strain_penalty = 1000.0 * points_count_calibration_factor * length_calibration_factor - curvature_penalty = 1.0E+8 * points_count_calibration_factor * (length_calibration_factor ** 3) - marker_weight = 10.0 * points_count_calibration_factor - sliding_factor = 0.0001 - - if trunk_fit_iterations > 0: - fit1 = FitterStepFit() - fitter.addFitterStep(fit1) - fit1.setGroupDataWeight("marker", marker_weight) - fit1.setGroupStrainPenalty(None, [strain_penalty]) - fit1.setGroupCurvaturePenalty(None, [curvature_penalty]) - fit1.setGroupDataSlidingFactor(None, sliding_factor) - fit1.run() - del fit1 - - if trunk_fit_iterations > 1: - fit2 = FitterStepFit() - fitter.addFitterStep(fit2) - fit2.setGroupStrainPenalty(None, [0.1 * strain_penalty]) - fit2.setGroupCurvaturePenalty(None, [0.1 * curvature_penalty]) - fit2.setGroupDataSlidingFactor(None, 0.1 * sliding_factor) - fit2.setNumberOfIterations(trunk_fit_iterations - 1) - fit2.run() - del fit2 + strain_penalty = 1.0E+4 * points_count_calibration_factor * length_calibration_factor + curvature_penalty = 1.0E+9 * points_count_calibration_factor * (length_calibration_factor ** 3) + marker_weight = 20.0 * points_count_calibration_factor + sliding_factor = 0.01 + + for step in range(0, min(2, trunk_fit_iterations)): + fit_step = FitterStepFit() + fitter.addFitterStep(fit_step) + if step == 0: + fit_step.setGroupDataWeight("marker", marker_weight) + fit_step.setUpdateReferenceState(True) + weight = 1.0 if (step == 0) else 0.1 + fit_step.setGroupStrainPenalty(None, [weight * strain_penalty]) + fit_step.setGroupCurvaturePenalty(None, [weight * curvature_penalty]) + fit_step.setGroupDataSlidingFactor(None, weight * sliding_factor) + + fit_step.setGroupStrainPenalty(multi_path_group_name, [5.0 * weight * strain_penalty]) + fit_step.setGroupCurvaturePenalty(multi_path_group_name, [50.0 * weight * curvature_penalty]) + + if step > 0: + fit_step.setNumberOfIterations(trunk_fit_iterations - 1) + fit_step.run() + del fit_step datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) rms_error, max_error = fitter.getDataRMSAndMaximumProjectionError(trunk_group.getNodesetGroup(datapoints)) @@ -1275,6 +1366,16 @@ def generate_trunk_1d(vagus_data, trunk_proportion, trunk_elements_count_prefit, # fit radius if pr: + # add projection distance to radius + trunk_location = fieldmodule.createFieldFindMeshLocation(coordinates, coordinates, mesh1d) + trunk_location.setSearchMode(trunk_location.SEARCH_MODE_NEAREST) + projected_coordinates = fieldmodule.createFieldEmbedded(coordinates, trunk_location) + projection_distance = fieldmodule.createFieldMagnitude(projected_coordinates - coordinates) + new_radius = radius + projection_distance + trunk_datapoints = trunk_group.getNodesetGroup(datapoints) + fieldassignment = radius.createFieldassignment(new_radius) + fieldassignment.setNodeset(trunk_datapoints) + fieldassignment.assign() gradient1_penalty = 1000.0 * points_count_calibration_factor * length_calibration_factor gradient2_penalty = 1.0E+8 * points_count_calibration_factor * (length_calibration_factor ** 3) rms_error, max_error = define_and_fit_field( @@ -1316,7 +1417,6 @@ def generate_trunk_1d(vagus_data, trunk_proportion, trunk_elements_count_prefit, datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) datapoints.destroyAllNodes() - segments_trunk_coordinates = vagus_data.get_segments_trunk_coordinates() segments_metadata = {} with ChangeManager(fieldmodule): # make a real field which increases down the trunk proportional to vagus coordinates @@ -1329,7 +1429,9 @@ def generate_trunk_1d(vagus_data, trunk_proportion, trunk_elements_count_prefit, datapoints_max_trunk_distance = fieldmodule.createFieldNodesetMaximum(host_trunk_distance, datapoints) fieldcache.clearLocation() distance_to_material = trunk_proportion / trunk_elements_count - for segment_name, sx in segments_trunk_coordinates.items(): + for segment_trunk_info in segment_trunk_info_list: + segment_name = segment_trunk_info['name'] + sx = segment_trunk_info['unordered_coordinates'] generate_datapoints(fit_region, sx, start_data_identifier=1) min_result, segment_min_trunk_distance = datapoints_min_trunk_distance.evaluateReal(fieldcache, 1) max_result, segment_max_trunk_distance = datapoints_max_trunk_distance.evaluateReal(fieldcache, 1) diff --git a/src/scaffoldmaker/utils/read_vagus_data.py b/src/scaffoldmaker/utils/read_vagus_data.py index 174acc21..b8abd535 100644 --- a/src/scaffoldmaker/utils/read_vagus_data.py +++ b/src/scaffoldmaker/utils/read_vagus_data.py @@ -3,16 +3,17 @@ import logging import tempfile -from cmlibs.maths.vectorops import distance +from cmlibs.maths.vectorops import add, distance, magnitude, mult, normalize, sub from cmlibs.utils.zinc.field import get_group_list from cmlibs.utils.zinc.finiteelement import get_element_node_identifiers +from cmlibs.utils.zinc.general import ChangeManager from cmlibs.utils.zinc.group import groups_have_same_local_contents from cmlibs.zinc.field import Field from cmlibs.zinc.node import Node from scaffoldmaker.annotation.vagus_terms import ( get_vagus_term, marker_name_in_terms, get_left_vagus_marker_locations_list, get_right_vagus_marker_locations_list) -from scaffoldmaker.utils.zinc_utils import get_nodeset_field_parameters +from scaffoldmaker.utils.zinc_utils import get_nodeset_field_parameters, get_mesh_node_identifier_sequences logger = logging.getLogger(__name__) @@ -30,16 +31,18 @@ def __init__(self, data_region): self._trunk_keywords = ['cervical vagus nerve', 'thoracic vagus nerve', 'cervical trunk', 'thoracic trunk', 'vagus x nerve trunk'] - self._branch_keywords = ['branch', 'nerve'] + self._branch_keywords = ['branch', 'nerve', 'ganglion'] self._non_branch_keywords = ['perineurium', 'epineurium'] self._term_keywords = ['fma:', 'fma_', 'ilx:', 'ilx_', 'uberon:', 'uberon_'] self._orientation_keywords = ['orientation'] self._annotation_term_map = {} self._branch_coordinates_data = {} + self._branch_radius_data = {} + self._branch_connectivity_data = {} + self._branch_sequences_data = {} self._branch_parent_map = {} self._branch_common_group_map = {} - self._branch_radius_data = {} self._datafile_path = None self._level_markers = {} self._orientation_data = {} @@ -54,6 +57,8 @@ def __init__(self, data_region): nodes = fm.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) coordinates = fm.findFieldByName("coordinates").castFiniteElement() radius = fm.findFieldByName("radius").castFiniteElement() + if not radius.isValid(): + radius = None datapoints = fm.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) marker_names = fm.findFieldByName("marker_name") mesh = fm.findMeshByDimension(1) @@ -109,12 +114,20 @@ def __init__(self, data_region): marker_node = marker_node_iter.next() # extract orientation data + orientation_ignore_node_ids = [] + orientation_ignore_group_name = 'orientation ignore' + if orientation_ignore_group_name in orientation_group_names: + group = fm.findFieldByName(orientation_ignore_group_name).castGroup() + nodeset_group = group.getNodesetGroup(nodes) + _, values = get_nodeset_field_parameters(nodeset_group, coordinates, [Node.VALUE_LABEL_VALUE]) + orientation_ignore_node_ids = [value[0] for value in values] + orientation_group_names.remove(orientation_ignore_group_name) for orientation_group_name in orientation_group_names: group = fm.findFieldByName(orientation_group_name).castGroup() nodeset_group = group.getNodesetGroup(nodes) _, values = get_nodeset_field_parameters(nodeset_group, coordinates, [Node.VALUE_LABEL_VALUE]) - orientation_points = [value[1][0][0] for value in values] - self._orientation_data[orientation_group_name] = orientation_points[:] + orientation_points = [value[1][0][0] for value in values if (value[0] not in orientation_ignore_node_ids)] + self._orientation_data[orientation_group_name] = orientation_points # extract trunk data - coordinates, nodes, radius if len(found_trunk_group_names) > 0: @@ -126,144 +139,33 @@ def __init__(self, data_region): self._trunk_group_name = 'right vagus nerve' self._annotation_term_map[self._trunk_group_name] = get_vagus_term(self._trunk_group_name)[1] self._side_label = 'right' - - # build list of trunk centroid data associated with segment groups ending in .exf - # exported by segmentation stitcher, but not from connections which have .exf twice - # get map from segment group_name to (nodeset_group, []) - # where [] = coordinates list to be filled from trunk groups' coordinates - self._segment_groups_info = {} - for group in group_list: - group_name = group.getName() - if (group_name[-4:] == '.exf') and (1 == group_name.count('.exf')): - self._segment_groups_info[group_name] = (group.getNodesetGroup(nodes), []) + assert self._trunk_group_name if self._trunk_group_name: - trunk_group_count = 0 - trunk_nodes = [] + trunk_node_ids = [] trunk_coordinates = [] trunk_radius = [] - trunk_elements = [] for found_trunk_group_name in found_trunk_group_names: group = fm.findFieldByName(found_trunk_group_name).castGroup() nodeset_group = group.getNodesetGroup(nodes) mesh_group = group.getMeshGroup(mesh) - coordinate_values = get_nodeset_field_parameters(nodeset_group, coordinates, - [Node.VALUE_LABEL_VALUE])[1] - trunk_nodes += [value[0] for value in coordinate_values] - trunk_coordinates += [value[1][0] for value in coordinate_values] - if radius.isValid(): - radius_values = get_nodeset_field_parameters(nodeset_group, radius, [Node.VALUE_LABEL_VALUE])[1] - trunk_radius += [value[1][0][0] for value in radius_values] - - # get trunk elements - if mesh_group.getSize() > 0: - element_iterator = mesh_group.createElementiterator() - element = element_iterator.next() - while element.isValid(): - eft = element.getElementfieldtemplate(coordinates, -1) - local_node_identifiers = get_element_node_identifiers(element, eft) - trunk_elements.append({'id': element.getIdentifier(), - 'nodes': local_node_identifiers}) - element = element_iterator.next() - - trunk_group_count += 1 - - # fill segment groups with coordinates of trunk nodes contained in them - if self._segment_groups_info: - for n, node_identifier in enumerate(trunk_nodes): - node = nodes.findNodeByIdentifier(node_identifier) - for segment_nodeset_group, segment_points_list in self._segment_groups_info.values(): - if segment_nodeset_group.containsNode(node): - segment_points_list.append(trunk_coordinates[n][0]) - - # order trunk coordinates top to bottom in case trunk elements are available - if len(trunk_elements) > 0: - # build trunk graph - nid_coords = {node_id: n_coord[0] for node_id, n_coord in zip(trunk_nodes, trunk_coordinates)} - trunk_graph = {node_id: [] for node_id in nid_coords} - for element in trunk_elements: - local_node_1, local_node_2 = element['nodes'] - if local_node_1 in trunk_nodes and local_node_2 in trunk_nodes: - trunk_graph[local_node_1].append(local_node_2) - trunk_graph[local_node_2].append(local_node_1) - # add any isolated nodes - for node in trunk_nodes: - if node not in trunk_graph.keys(): - trunk_graph[node] = [] - unconnected_nodes = [] - for el in trunk_graph.keys(): - if len(trunk_graph[el]) <= 1: - unconnected_nodes.append(el) - - # choose start, not necessarily first in unconnected nodes - furthest_distance = 0 - for index_1, node_id_1 in enumerate(unconnected_nodes): - for index_2 in range(index_1, len(unconnected_nodes)): - node_id_2 = unconnected_nodes[index_2] - dist = distance(nid_coords[node_id_1], nid_coords[node_id_2]) - if dist > furthest_distance: - furthest_distance = dist - furthest_index_1 = index_1 - furthest_index_2 = index_2 - - start_index = furthest_index_1 if furthest_index_1 < furthest_index_2 else furthest_index_2 - start = unconnected_nodes[start_index] - - trunk_path_ids = [] - # BFS from first to next unconnected, all connected in one long path - while len(unconnected_nodes) > 0: - unconnected_nodes.pop(start_index) - if start not in trunk_path_ids: - local_trunk_path_ids = bfs_to_furthest(trunk_graph, start, trunk_path_ids) - trunk_path_ids.extend(local_trunk_path_ids) - - # find next closest unconnected node - closest_distance = math.inf - last_node_id_in_path = trunk_path_ids[-1] - for index, node_id in enumerate(unconnected_nodes): - dist = distance(nid_coords[node_id], nid_coords[last_node_id_in_path]) - if dist < closest_distance: - closest_distance = dist - start = node_id - start_index = index - - # get one of the top markers to check if trunk path needs to be reversed - if self._side_label == 'left': - markers_ordered_list = get_left_vagus_marker_locations_list() - else: - markers_ordered_list = get_right_vagus_marker_locations_list() - - for marker_name in markers_ordered_list.keys(): - if marker_name in self._level_markers.keys(): - top_marker = self._level_markers[marker_name] - break - start_dist = distance(nid_coords[trunk_path_ids[0]], top_marker) - end_dist = distance(nid_coords[trunk_path_ids[-1]], top_marker) - if trunk_path_ids and end_dist < start_dist: - trunk_path_ids.reverse() - - ordered_trunk_coordinates = [] - for trunk_path_id in trunk_path_ids: - index = trunk_nodes.index(trunk_path_id) - ordered_trunk_coordinates.append(trunk_coordinates[index]) - - if len(trunk_elements) > 0 and trunk_path_ids: - self._trunk_coordinates = ordered_trunk_coordinates[:] - else: - self._trunk_coordinates = trunk_coordinates[:] - - if radius.isValid() and not all(value == 0.0 for value in trunk_radius): - if len(trunk_elements) > 0 and trunk_path_ids: - ordered_trunk_radius = [] - for trunk_path_id in trunk_path_ids: - index = trunk_nodes.index(trunk_path_id) - ordered_trunk_radius.append(trunk_radius[index]) - self._trunk_radius = ordered_trunk_radius[:] - else: - self._trunk_radius = trunk_radius[:] + node_coordinate_values = get_nodeset_field_parameters( + nodeset_group, coordinates, [Node.VALUE_LABEL_VALUE])[1] + trunk_node_ids += [value[0] for value in node_coordinate_values] + for value in node_coordinate_values: + trunk_coordinates.append(value[1][0]) + if radius: + node_radius_values = get_nodeset_field_parameters(nodeset_group, radius, [Node.VALUE_LABEL_VALUE])[1] + for value in node_radius_values: + trunk_radius.append(value[1][0][0]) + + self._segment_trunk_info_list = make_segment_trunk_info( + fm, fc, coordinates, nodes, mesh, group_list, found_trunk_group_names, self._trunk_group_name) + + self._trunk_coordinates = trunk_coordinates + self._trunk_radius = trunk_radius # extract branch data - name, coordinates, nodes, radius - branch_nodes_data = {} for branch_name in branch_group_names: group = fm.findFieldByName(branch_name).castGroup() nodeset_group = group.getNodesetGroup(nodes) @@ -272,36 +174,97 @@ def __init__(self, data_region): # branch should have at least two nodes to be connected to parent continue _, values = get_nodeset_field_parameters(nodeset_group, coordinates, [Node.VALUE_LABEL_VALUE]) - branch_nodes = [value[0] for value in values] - branch_parameters = [value[1][0] for value in values] - self._branch_coordinates_data[branch_name] = branch_parameters - branch_nodes_data[branch_name] = branch_nodes - - # not used at the moment - if radius.isValid(): + # make above into a dict to look up by node id + node_coordinates_parameters = {} + for node_id, parameters in values: + node_coordinates_parameters[node_id] = parameters[0][0] + + mesh_group = group.getMeshGroup(mesh) + node_ids_list = get_mesh_node_identifier_sequences(mesh_group, coordinates) + if node_ids_list: + # check for old data which didn't make an element from the trunk node to the first branch node: + first_branch_node_id = values[0][0] + for node_ids in node_ids_list: + if first_branch_node_id in node_ids: + break + else: + logger.warning('Branch ' + branch_name + ' did not have parent node ' + str(first_branch_node_id) + + ' in branch elements. Including as first node.') + node_ids_list[0].insert(0, first_branch_node_id) + else: + # if no elements, fall back to branch nodes in order, allows only a single branch + node_ids_list = [[value[0] for value in values]] + # get node coordinates in order + branch_coordinates = [] + for node_ids in node_ids_list: + for node_id in node_ids: + branch_coordinates.append(node_coordinates_parameters[node_id]) + self._branch_coordinates_data[branch_name] = branch_coordinates + self._branch_connectivity_data[branch_name] = node_ids_list + self._branch_sequences_data[branch_name] = [len(node_ids) for node_ids in node_ids_list] + + if radius: _, values = get_nodeset_field_parameters(nodeset_group, radius, [Node.VALUE_LABEL_VALUE]) - branch_radius = [value[1][0][0] for value in values] - if not all(value == 0.0 for value in branch_radius): + if not all((value[1][0][0] == 0.0) for value in values): + # make above into a dict to look up by node id + node_radius_parameters = {} + for node_id, parameters in values: + node_radius_parameters[node_id] = parameters[0][0] + branch_radius = [] + for node_ids in node_ids_list: + for node_id in node_ids: + branch_radius.append(node_radius_parameters[node_id]) self._branch_radius_data[branch_name] = branch_radius - # find parent branch where it connects to - for branch_name, branch_nodes in branch_nodes_data.items(): - # assumes trunk and branch node identifiers are strictly increasing. - branch_first_node = branch_nodes[0] - - # first check if trunk is a parent by searching for a common node - parent_name = '' - if branch_first_node in trunk_nodes: - parent_name = self._trunk_group_name - else: - # check other branches if a common node exists - for parent_branch_name, parent_branch_nodes in branch_nodes_data.items(): - if parent_branch_name != branch_name: - parent_first_node = parent_branch_nodes[0] - if branch_first_node != parent_first_node and branch_first_node in parent_branch_nodes: - parent_name = parent_branch_name + # find parent branches where each branch connects to + # limitaton: parent is the same for each separate branch with the same name + for branch_name, node_ids_list in self._branch_connectivity_data.items(): + parent_name = None + parameters_index = 0 + for node_ids in node_ids_list: + next_parent_name = None + for start_node_index in (0, -1): # in case reversed order + branch_first_node_id = node_ids[start_node_index] + # first check if trunk is a parent by searching for a common node + if branch_first_node_id in trunk_node_ids: + next_parent_name = self._trunk_group_name + break + if not next_parent_name: + # check other branches if a common node exists + # issue: likely to have problems if branch is processed before child branch + for start_node_index in (0, -1): # in case reversed order + branch_first_node_id = node_ids[start_node_index] + for parent_branch_name, parent_node_ids_list in self._branch_connectivity_data.items(): + if parent_branch_name != branch_name: + for parent_node_ids in parent_node_ids_list: + if branch_first_node_id in parent_node_ids: + next_parent_name = parent_branch_name + break + if next_parent_name: + break + if next_parent_name: break - if parent_name == '': + if next_parent_name: + if parent_name: + if next_parent_name != parent_name: + logger.warning('Branches with name ' + branch_name + ' have both ' + parent_name + ' and ' + + next_parent_name + ' as parents. Using ' + parent_name) + else: + parent_name = next_parent_name + if start_node_index == -1: + # reverse order of branch nodes, parameters etc. so always heading away from parent + count = len(node_ids) + branch_x = self._branch_coordinates_data[branch_name] + branch_r = self._branch_radius_data.get(branch_name) + for i in range(count // 2): + i1 = parameters_index + i + i2 = parameters_index + count - i - 1 + branch_x[i1], branch_x[i2] = branch_x[i2], branch_x[i1] + if branch_r: + branch_r[i1], branch_r[i2] = branch_r[i2], branch_r[i1] + node_ids.reverse() # reverse in place + parameters_index += len(node_ids) + if not parent_name: # assume trunk is a parent by default, if no other is found parent_name = self._trunk_group_name self._branch_parent_map[branch_name] = parent_name @@ -355,7 +318,7 @@ def get_trunk_radius(self): """ return self._trunk_radius - def get_branch_data(self): + def get_branch_coordinates_data(self): """ Get all branch names and coordinates from the data. return: Dict mapping branch name to x, y, z data. @@ -370,6 +333,13 @@ def get_branch_radius_data(self): """ return self._branch_radius_data + def get_branch_sequences_data(self): + """ + Get information about how many distinct branches there are for each branch name. + return: Dict mapping branch name to list of numbers of nodes in each branch sequence. + """ + return self._branch_sequences_data + def get_annotation_term_map(self): """ Get all annotation names and terms. @@ -410,17 +380,16 @@ def reset_datafile_path(self): """ self._datafile_path = None - def get_segments_trunk_coordinates(self): + def get_segment_trunk_info_list(self): """ - Get coordinates of trunk nodes in each segment corresponding to each .exf file read into the - segmentations stitcher, recognized by group names ending in '.exf', but not containing '.exf' - multiple times as for connection groups. - :return: dict segment_name -> list of coordinates + Get segment trunk information gleaned from each .exf file read into segmentations stitcher, recognized by + group names ending in '.exf', but not containing '.exf' multiple times as for connection groups. + These are in order down the nerve. + :return: list of segment info dict with at least fields: 'name', 'unordered_coordinates', 'ordered_points', + 'centroid', 'range'. 'ordered_coordinates' is present iff a single polyline crosses the segment. """ - segments_trunk_coordinates = {} - for segment_name, rhs in self._segment_groups_info.items(): - segments_trunk_coordinates[segment_name] = rhs[1] # only the coordinates - return segments_trunk_coordinates + return self._segment_trunk_info_list + def group_common_branches(branch_names): """ @@ -460,10 +429,11 @@ def load_vagus_data(region): def bfs_to_furthest(graph, start, trunk_path_ids): """ - :param graph: - :param start: - :param trunk_path_ids: - return: Returns the furthest node and the path to it using BFS. + Breadth first search. + :param graph: Map of node identifiers to node identifiers they are connected to. + :param start: Start node identifier, a graph end point. + :param trunk_path_ids: List of previously added node identifiers down path. + return: List of node identifiers from start to the furthest connected end. """ visited = set() @@ -476,14 +446,205 @@ def bfs_to_furthest(graph, start, trunk_path_ids): visited.add(current) last = current for neighbor in graph[current]: - if neighbor not in visited and neighbor not in trunk_path_ids and neighbor not in queue: + if (neighbor not in visited) and (neighbor not in trunk_path_ids) and (neighbor not in queue): parent[neighbor] = current queue.append(neighbor) - # Trace path from furthest node back to start + # Trace path from the furthest node back to start path = [] while last is not None: path.append(last) last = parent[last] return list(reversed(path)) + +def make_segment_trunk_info(fieldmodule, fieldcache, coordinates, nodes, mesh1d, group_list, trunk_group_names, + trunk_group_name): + """ + Make ordered (from top to bottom of nerve) segment trunk information to help get an initial guess of path. + Segment data is in groups with names ending in .exf, but not containing .exf more than once as that is + used for connection groups from Segmentation Stitcher. + :param fieldmodule: Zinc Fieldmodule for region. + :param fieldcache: Zinc Fieldcache. + :param coordinates: Coordinate field. + :param nodes: Nodes in region. + :param mesh1d: 1-D mesh in region. + :param group_list: List of all zinc groups in region + :param trunk_group_names: Names of trunk groups in source data. + :param trunk_group_name: Name of whole trunk group. + :return: List of segment trunk information in order down nerve. Information is a dict with at least fields + 'name': segment name + 'unordered_coordinates': List of all node coordinates in segment. + 'ordered_coordinates': Optional ordered list of raw coordinates from top to bottom. Only present if there is + single polyline in segment; absent if not so. + 'ordered_points': Same as 'ordered_coordinates' if present, otherwise exactly 2 points in the mean direction + of trunk in segment; nerve ends have these points right to the end, but interior segments have 2 points at + 0.25, 0.75 proportion along. This is used to give initial path down trunk. + """ + segment_trunk_info_list = [] + with ChangeManager(fieldmodule): + fieldcache.clearLocation() + is_trunk = None + for trunk_group_name in trunk_group_names: + trunk_group = fieldmodule.findFieldByName(trunk_group_name).castGroup() + is_trunk = fieldmodule.createFieldOr(is_trunk, trunk_group) if is_trunk else trunk_group + + # get raw segment information + for group in group_list: + group_name = group.getName() + if (group_name[-4:] == '.exf') and (1 == group_name.count('.exf')): + segment_trunk_group = fieldmodule.createFieldGroup() + segment_trunk_group.setName(group_name + ' ' + trunk_group_name) + segment_trunk_nodeset_group = segment_trunk_group.createNodesetGroup(nodes) + is_segment_trunk = fieldmodule.createFieldAnd(group, is_trunk) + segment_trunk_nodeset_group.addNodesConditional(is_segment_trunk) + segment_trunk_mesh_group = segment_trunk_group.createMeshGroup(mesh1d) + segment_trunk_mesh_group.addElementsConditional(is_segment_trunk) + del is_segment_trunk + first_node_id = segment_trunk_nodeset_group.createNodeiterator().next().getIdentifier() + if first_node_id < 0: + continue # empty segment + unordered_coordinates = [] + mean_coordinates = fieldmodule.createFieldNodesetMean(coordinates, segment_trunk_nodeset_group) + result, centroid = mean_coordinates.evaluateReal(fieldcache, 3) + del mean_coordinates + minimum_coordinates = fieldmodule.createFieldNodesetMinimum(coordinates, segment_trunk_nodeset_group) + result, min_x = minimum_coordinates.evaluateReal(fieldcache, 3) + del minimum_coordinates + maximum_coordinates = fieldmodule.createFieldNodesetMaximum(coordinates, segment_trunk_nodeset_group) + result, max_x = maximum_coordinates.evaluateReal(fieldcache, 3) + del maximum_coordinates + segment_trunk_info = { + 'name': group_name, + 'first_node_id': first_node_id, + 'nodeset_group': segment_trunk_nodeset_group, + 'unordered_coordinates': unordered_coordinates, + 'centroid': centroid, + 'range': [min_x, max_x] + } + # ensure in order of lowest node in segment which should be from top to bottom of nerve + for i, other_segment_info in enumerate(segment_trunk_info_list): + if first_node_id < other_segment_info['first_node_id']: + segment_trunk_info_list.insert(i, segment_trunk_info) + break + else: + segment_trunk_info_list.append(segment_trunk_info) + node_coordinate_values = get_nodeset_field_parameters( + segment_trunk_nodeset_group, coordinates, [Node.VALUE_LABEL_VALUE])[1] + # build segment trunk graph: map of node identifiers to node identifiers they are connected to + node_id_coordinates = {} + trunk_graph = {} + for node_id, values in node_coordinate_values: + x = values[0][0] + node_id_coordinates[node_id] = x + trunk_graph[node_id] = [] + unordered_coordinates.append(x) + element_iterator = segment_trunk_mesh_group.createElementiterator() + element = element_iterator.next() + while element.isValid(): + eft = element.getElementfieldtemplate(coordinates, -1) + node_ids = get_element_node_identifiers(element, eft) + for n in range(len(node_ids) - 1): + trunk_graph[node_ids[n]].append(node_ids[n + 1]) + trunk_graph[node_ids[n + 1]].append(node_ids[n]) + element = element_iterator.next() + # check if single polyline from one end to the other + count0 = 0 + count1 = 0 + count3plus = 0 + start_node_id = None + for node_id, connected_node_ids in trunk_graph.items(): + count = len(connected_node_ids) + if count == 0: + count0 += 1 + elif count == 1: + count1 += 1 + if start_node_id is None: + start_node_id = node_id + elif count > 2: + count3plus += 1 + if (count0 == 0) and (count1 == 2) and (count3plus == 0): + ordered_coordinates = [] + node_id = start_node_id + last_node_id = None + while True: + x = node_id_coordinates[node_id] + ordered_coordinates.append(x) + connected_node_ids = trunk_graph[node_id] + for connected_node_id in connected_node_ids: + if connected_node_id != last_node_id: + break + else: + break + last_node_id = node_id + node_id = connected_node_id + segment_trunk_info['ordered_coordinates'] = ordered_coordinates + del segment_trunk_mesh_group + del segment_trunk_nodeset_group + del segment_trunk_group + del is_trunk + + # add segment coordinates/point order + s_count = len(segment_trunk_info_list) + prev_centroid = None + for s, segment_trunk_info in enumerate(segment_trunk_info_list): + unordered_coordinates = segment_trunk_info['unordered_coordinates'] + centroid = segment_trunk_info['centroid'] + next_centroid = segment_trunk_info_list[s + 1]['centroid'] if s < (s_count - 1) else None + nodeset_group = segment_trunk_info['nodeset_group'] + ordered_coordinates = segment_trunk_info.get('ordered_coordinates') + print(segment_trunk_info['name'], 'single path' if ordered_coordinates else 'complex path') + direction = None + if s_count == 1: + if ordered_coordinates: + pass # assume in correct order + elif len(unordered_coordinates) == 1: + ordered_coordinates = unordered_coordinates + else: + direction = normalize(sub(centroid, unordered_coordinates[0])) + else: + if ordered_coordinates: + # reverse ordered coordinates if wrong end is closer to prev/next_centroid + if prev_centroid: + far_distance = magnitude(sub(ordered_coordinates[-1], prev_centroid)) + near_distance = magnitude(sub(ordered_coordinates[0], prev_centroid)) + else: + far_distance = magnitude(sub(next_centroid, ordered_coordinates[0])) + near_distance = magnitude(sub(next_centroid, ordered_coordinates[-1])) + if far_distance < near_distance: + ordered_coordinates.reverse() + else: + next_point = next_centroid if next_centroid else centroid + prev_point = prev_centroid if prev_centroid else centroid + direction = normalize(sub(next_point, prev_point)) + if ordered_coordinates: + ordered_points = ordered_coordinates + else: + # get range of coordinates in direction, sample 2 points for straight line + direction_coordinate = fieldmodule.createFieldDotProduct( + coordinates - fieldmodule.createFieldConstant(centroid), + fieldmodule.createFieldConstant(direction)) + result, min_d = fieldmodule.createFieldNodesetMinimum( + direction_coordinate, nodeset_group).evaluateReal(fieldcache, 1) + result, max_d = fieldmodule.createFieldNodesetMaximum( + direction_coordinate, nodeset_group).evaluateReal(fieldcache, 1) + del direction_coordinate + min_x = add(centroid, mult(direction, min_d)) + max_x = add(centroid, mult(direction, max_d)) + # only go to the full min_x / max_x range for the first and last segments + if s_count == 1: + xi_list = [0.0, 1.0] + elif s == 0: + xi_list = [0.0, 0.75] + elif s < (s_count - 1): + xi_list = [0.25, 0.75] + else: + xi_list = [0.25, 1.0] + ordered_points = [add(mult(min_x, 1.0 - xi), mult(max_x, xi)) for xi in xi_list] + # don't want these left behind + del segment_trunk_info['nodeset_group'] + del nodeset_group + segment_trunk_info['ordered_points'] = ordered_points + prev_centroid = centroid + + return segment_trunk_info_list diff --git a/src/scaffoldmaker/utils/zinc_utils.py b/src/scaffoldmaker/utils/zinc_utils.py index 126d85aa..dc266a54 100644 --- a/src/scaffoldmaker/utils/zinc_utils.py +++ b/src/scaffoldmaker/utils/zinc_utils.py @@ -5,7 +5,8 @@ from cmlibs.utils.zinc.field import ( find_or_create_field_coordinates, find_or_create_field_finite_element, find_or_create_field_group, find_or_create_field_stored_mesh_location, find_or_create_field_stored_string) -from cmlibs.utils.zinc.finiteelement import get_maximum_element_identifier, get_maximum_node_identifier +from cmlibs.utils.zinc.finiteelement import ( + get_element_node_identifiers, get_maximum_element_identifier, get_maximum_node_identifier) from cmlibs.utils.zinc.general import ChangeManager, HierarchicalChangeManager from cmlibs.zinc.context import Context from cmlibs.zinc.element import Element, Elementbasis, MeshGroup @@ -19,9 +20,13 @@ from scaffoldfitter.fitterstepfit import FitterStepFit from scaffoldmaker.utils import interpolation as interp import copy +import logging import math +logger = logging.getLogger(__name__) + + def interpolateNodesCubicHermite(cache, coordinates, xi, normal_scale, node1, derivative1, scale1, cross_derivative1, cross_scale1, node2, derivative2, scale2, cross_derivative2, cross_scale2): @@ -155,6 +160,32 @@ def get_mesh_first_element_with_node(mesh, field, node): return None +def get_mesh_node_identifier_sequences(mesh1d, field): + """ + Get sequences of connected nodes in 1D mesh which field is directly defined on. + Implementation expects mesh to consist of only polylines i.e. not a network and for elements to be + consecutive and in the same orientation along each sequence. + :param mesh1d: 1-D mesh or mesh group which field is defined on in every element. + :param field: The field to get connectivity for. Must be finite element type. + :return: List of lists of node identifiers in each sequence. + """ + node_ids_list = [] + elementiterator = mesh1d.createElementiterator() + element = elementiterator.next() + while element.isValid(): + eft = element.getElementfieldtemplate(field, -1) + if not eft.isValid(): + logger.error("mesh_get_connected_node_identifier_sequences. Field not defined on element") + return [] + node_ids = get_element_node_identifiers(element, eft) + if node_ids_list and (node_ids_list[-1][-1] == node_ids[0]): + node_ids_list[-1] += node_ids[1:] + else: + node_ids_list.append(node_ids) + element = elementiterator.next() + return node_ids_list + + def get_nodeset_field_parameters(nodeset, field, only_value_labels=None): """ Returns parameters of field from nodes in nodeset in identifier order. @@ -966,18 +997,17 @@ def fit_hermite_curve(bx, bd1, px, outlier_length=0.0, region=None, group_name=N fitter.defineCommonMeshFields() fitter.setDataCoordinatesField(coordinates) fitter.defineDataProjectionFields() - # aim for no more than 25 points per element: - points_per_element = 25 - data_proportion = min(1.0, points_per_element * elements_count / points_count) - if data_proportion < 1.0: - fitter.getInitialFitterStepConfig().setGroupDataProportion(None, data_proportion) fitter.initializeFit() + # calibrated for 25 points / element + data_weight = (25.0 * elements_count) / points_count + strain_penalty = 1.0E-6 * curve_length # calibrated by scaling the model: a power of 3 relationship - curvature_penalty = ((points_count * data_proportion) / (points_per_element * elements_count) * - 1.0E-6 * (curve_length ** 3)) + curvature_penalty = 1.0E-6 * (curve_length ** 3) fit1 = FitterStepFit() fitter.addFitterStep(fit1) + fit1.setGroupDataWeight(None, data_weight) + fit1.setGroupStrainPenalty(None, [strain_penalty]) fit1.setGroupCurvaturePenalty(None, [curvature_penalty]) fit1.run() del fit1 diff --git a/tests/resources/vagus_test_data1.exf b/tests/resources/vagus_test_data1.exf index dc2515f0..f27ffc69 100644 --- a/tests/resources/vagus_test_data1.exf +++ b/tests/resources/vagus_test_data1.exf @@ -1686,6 +1686,16 @@ Node: 335 -5.161084783385378e+03 -1.361564654132272e+02 0.000000000000000e+00 +Node: 336 + 8.773479487054547e+03 + -1.285509252460388e+03 + -2.958677308834912e+02 + 2.574722374853684e+02 +Node: 337 + 2.900467806037634e+04 + -4.161084783385378e+03 + -1.361564654132272e+02 + 0.000000000000000e+00 !#mesh mesh1d, dimension=1, nodeset=nodes Define element template: element1 Shape. Dimension=1, line @@ -2649,6 +2659,9 @@ Element: 309 Element: 310 Nodes: 314 315 +Element: 311 + Nodes: + 60 336 !#nodeset datapoints Define node template: node2 Shape. Dimension=0 @@ -2697,17 +2710,17 @@ Element group: Group name: left superior laryngeal nerve !#nodeset nodes Node group: -46,202..242 +46,60,202..242,336 !#mesh mesh1d, dimension=1, nodeset=nodes Element group: -201..240 +201..240,311 Group name: http://uri.interlex.org/base/ilx_0788780 !#nodeset nodes Node group: -46,202..242 +46,60,202..242,336 !#mesh mesh1d, dimension=1, nodeset=nodes Element group: -201..240 +201..240,311 Group name: left A thoracic cardiopulmonary branch of vagus nerve !#nodeset nodes Node group: @@ -2757,7 +2770,11 @@ Node group: Group name: orientation anterior !#nodeset nodes Node group: -316,318..320,322,324,326,328,330,332,334..335 +316,318..320,322,324,326,328,330,332,334..335,337 +Group name: orientation ignore +!#nodeset nodes +Node group: +337 Group name: orientation left !#nodeset nodes Node group: diff --git a/tests/test_vagus.py b/tests/test_vagus.py index ecf233de..77f00957 100644 --- a/tests/test_vagus.py +++ b/tests/test_vagus.py @@ -1,12 +1,12 @@ -from cmlibs.maths.vectorops import mult from cmlibs.utils.zinc.field import ( find_or_create_field_coordinates, find_or_create_field_finite_element, find_or_create_field_group, find_or_create_field_stored_string) +from cmlibs.utils.zinc.finiteelement import get_element_node_identifiers from cmlibs.utils.zinc.general import ChangeManager from cmlibs.utils.zinc.group import mesh_group_to_identifier_ranges, nodeset_group_to_identifier_ranges from cmlibs.zinc.context import Context from cmlibs.zinc.element import Element, Elementbasis -from cmlibs.zinc.field import Field +from cmlibs.zinc.field import Field, FieldGroup from cmlibs.zinc.node import Node from cmlibs.zinc.result import RESULT_OK @@ -44,25 +44,73 @@ def reorder_vagus_test_data1(testcase, region): node_identifiers = [element.getNode(eft, ln).getIdentifier() for ln in range(1, local_node_count + 1)] node_identifiers.reverse() testcase.assertEqual(RESULT_OK, element.setNodesByIdentifier(eft, node_identifiers)) - for node_identifier in range(1, 51): - other_node_identifier = 101 - node_identifier + # can't renumber between segments as algorithm expects nodes in first segment to have lower numbers + # don't renumber nodes 46, 60 as they're the start of a branch + # currently rely on them being the first identifier in that branch for a workaround which includes it + for node_identifier in range(1, 14): + other_node_identifier = 75 - node_identifier node = nodes.findNodeByIdentifier(node_identifier) other_node = nodes.findNodeByIdentifier(other_node_identifier) testcase.assertEqual(RESULT_OK, other_node.setIdentifier(UNUSED_IDENTIFIER)) testcase.assertEqual(RESULT_OK, node.setIdentifier(other_node_identifier)) testcase.assertEqual(RESULT_OK, other_node.setIdentifier(node_identifier)) - for element_identifier in range(76, 101): - other_element_identifier = 151 - element_identifier + for element_identifier in range(1, 37): + other_element_identifier = 74 - element_identifier element = mesh1d.findElementByIdentifier(element_identifier) other_element = mesh1d.findElementByIdentifier(other_element_identifier) testcase.assertEqual(RESULT_OK, other_element.setIdentifier(UNUSED_IDENTIFIER)) testcase.assertEqual(RESULT_OK, element.setIdentifier(other_element_identifier)) testcase.assertEqual(RESULT_OK, other_element.setIdentifier(element_identifier)) - # for node_identifier in range(126, 151): - # node = nodes.findNodeByIdentifier(node_identifier) - # testcase.assertEqual(RESULT_OK, node.setIdentifier(node_identifier + IDENTIFIER_OFFSET)) - for element_identifier in range(101, 104): - mesh1d.destroyElement(mesh1d.findElementByIdentifier(element_identifier)) + + +def create_segment_groups_vagus_test_data1(testcase, data_region): + """ + Create segment groups dividing the data approximately in thirds over the x-span of the trunk. + """ + data_fieldmodule = data_region.getFieldmodule() + data_nodes = data_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) + data_mesh = data_fieldmodule.findMeshByDimension(1) + data_coordinates = data_fieldmodule.findFieldByName("coordinates") + with ChangeManager(data_fieldmodule): + data_x = data_fieldmodule.createFieldComponent(data_coordinates, 1) + conditions = [ + data_fieldmodule.createFieldLessThan(data_x, data_fieldmodule.createFieldConstant(10000.0)), + None, + data_fieldmodule.createFieldGreaterThan(data_x, data_fieldmodule.createFieldConstant(21000.0)) + ] + conditions[1] = data_fieldmodule.createFieldNot( + data_fieldmodule.createFieldOr(conditions[0], conditions[2])) + for s in range(3): + segment_group = data_fieldmodule.createFieldGroup() + segment_group.setName("segment" + str(s + 1) + ".exf") + segment_group.setManaged(True) + segment_nodeset_group = segment_group.createNodesetGroup(data_nodes) + segment_nodeset_group.addNodesConditional(conditions[s]) + # ensure elements with both nodes in group are in the mesh group + segment_mesh_group = segment_group.createMeshGroup(data_mesh) + data_fieldcache = data_fieldmodule.createFieldcache() + elementiterator = data_mesh.createElementiterator() + element = elementiterator.next() + while element.isValid(): + eft = element.getElementfieldtemplate(data_coordinates, -1) + local_nodes_count = eft.getNumberOfLocalNodes() + for ln in range(1, local_nodes_count + 1): + node = element.getNode(eft, ln) + data_fieldcache.setNode(node) + _, in_segment_group = segment_group.evaluateReal(data_fieldcache, 1) + if not in_segment_group: + break + else: + segment_mesh_group.addElement(element) + element = elementiterator.next() + if s == 0: + testcase.assertEqual(segment_mesh_group.getSize(), 134) + elif s == 1: + testcase.assertEqual(segment_mesh_group.getSize(), 98) + else: + testcase.assertEqual(segment_mesh_group.getSize(), 77) + del conditions + del data_x class VagusScaffoldTestCase(unittest.TestCase): @@ -94,6 +142,7 @@ def test_input_vagus_data(self): assert result == RESULT_OK if i == 1: reorder_vagus_test_data1(self, data_region) + create_segment_groups_vagus_test_data1(self, data_region) data_fieldmodule = data_region.getFieldmodule() data_mesh1d = data_fieldmodule.findMeshByDimension(1) @@ -104,10 +153,10 @@ def test_input_vagus_data(self): data_trunk_nodeset_group = trunk_group.getNodesetGroup(data_nodes) mesh_ranges = mesh_group_to_identifier_ranges(data_trunk_mesh_group) nodeset_ranges = nodeset_group_to_identifier_ranges(data_trunk_nodeset_group) + self.assertEqual([[1, 200]], mesh_ranges) + self.assertEqual([[1, 201]], nodeset_ranges) + self.assertEqual(200, data_trunk_mesh_group.getSize()) if i == 0: - self.assertEqual([[1, 200]], mesh_ranges) - self.assertEqual([[1, 201]], nodeset_ranges) - self.assertEqual(200, data_trunk_mesh_group.getSize()) expected_element_info = { 1: [1, 2], 2: [2, 3], @@ -117,25 +166,19 @@ def test_input_vagus_data(self): 103: [103, 104] } else: - self.assertEqual([[1, 100], [104, 200]], mesh_ranges) - # self.assertEqual([[1, 125], [151, 201], [1126, 1150]], nodeset_ranges) - self.assertEqual(197, data_trunk_mesh_group.getSize()) expected_element_info = { - 1: [99, 100], - 2: [98, 99], - 51: [101, 1], - 101: None, - 102: None, - 103: None + 1: [1, 2], + 2: [2, 3], + 51: [24, 23], + 101: [101, 102], + 102: [102, 103], + 103: [103, 104] } for element_id, expected_node_ids in expected_element_info.items(): element = data_mesh1d.findElementByIdentifier(element_id) eft = element.getElementfieldtemplate(data_coordinates, -1) - if expected_node_ids: - node_ids = [element.getNode(eft, n + 1).getIdentifier() for n in range(eft.getNumberOfLocalNodes())] - self.assertEqual(expected_node_ids, node_ids) - else: - self.assertFalse(eft.isValid()) + node_ids = [element.getNode(eft, n + 1).getIdentifier() for n in range(eft.getNumberOfLocalNodes())] + self.assertEqual(expected_node_ids, node_ids) vagus_data = VagusInputData(data_region) self.assertEqual(vagus_data.get_side_label(), 'left') @@ -165,27 +208,32 @@ def test_input_vagus_data(self): self.assertEqual(len(trunk_coordinates), 201) annotation_term_map = vagus_data.get_annotation_term_map() self.assertTrue(trunk_group_name in annotation_term_map) - # self.assertEqual(annotation_term_map[trunk_group_name], 'http://purl.obolibrary.org/obo/UBERON_0035020') + self.assertEqual(annotation_term_map[trunk_group_name], 'http://uri.interlex.org/base/ilx_0785628') # do a simple fit to the trunk data coordinates to check trunk ordering is working - trunk_data_coordinates = vagus_data.get_trunk_coordinates() - px = [e[0] for e in trunk_data_coordinates] + px = [] + segment_trunk_info_list = vagus_data.get_segment_trunk_info_list() + self.assertEqual(len(segment_trunk_info_list), 3) + for segment_trunk_info in segment_trunk_info_list: + px += segment_trunk_info['ordered_points'] self.assertEqual(201, len(px)) bx, bd1 = get_curve_from_points(px, number_of_elements=10) length = getCubicHermiteCurvesLength(bx, bd1) self.assertAlmostEqual(31726.825262197974, length, delta=1.0E-3) - branch_data = vagus_data.get_branch_data() - self.assertEqual(len(branch_data), 4) - self.assertTrue("left superior laryngeal nerve" in branch_data) - self.assertEqual(len(branch_data["left superior laryngeal nerve"]), 42) - self.assertTrue("left A branch of superior laryngeal nerve" in branch_data) - self.assertEqual(len(branch_data["left A branch of superior laryngeal nerve"]), 22) + branch_coordinates_data = vagus_data.get_branch_coordinates_data() + branch_sequences_data = vagus_data.get_branch_sequences_data() + self.assertEqual(len(branch_coordinates_data), 4) + self.assertTrue("left superior laryngeal nerve" in branch_coordinates_data) + self.assertEqual(len(branch_coordinates_data["left superior laryngeal nerve"]), 44) + self.assertEqual(len(branch_sequences_data["left superior laryngeal nerve"]), 2) + self.assertTrue("left A branch of superior laryngeal nerve" in branch_coordinates_data) + self.assertEqual(len(branch_coordinates_data["left A branch of superior laryngeal nerve"]), 22) left_thoracic_cardiopulmonary_branches = ( "left A thoracic cardiopulmonary branch of vagus nerve", "left B thoracic cardiopulmonary branch of vagus nerve") for branch_name in left_thoracic_cardiopulmonary_branches: - self.assertTrue(branch_name in branch_data) + self.assertTrue(branch_name in branch_coordinates_data) branch_parents = vagus_data.get_branch_parent_map() self.assertEqual(branch_parents["left superior laryngeal nerve"], trunk_group_name) @@ -235,7 +283,7 @@ def test_vagus_nerve_1(self): self.assertEqual(options.get('Trunk proportion'), 1.0) self.assertEqual(options.get('Trunk fit number of iterations'), 5) self.assertEqual(options.get('Default anterior direction'), [0.0, 1.0, 0.0]) - self.assertEqual(options.get('Default trunk diameter'), 3.0) + self.assertEqual(options.get('Default trunk diameter'), 3000.0) self.assertEqual(options.get('Branch diameter trunk proportion'), 0.5) # change options to make test fast and consistent, with minor effect on result: options['Number of elements along the trunk pre-fit'] = 10 @@ -254,28 +302,7 @@ def test_vagus_nerve_1(self): self.assertEqual(data_region.readFile(data_file), RESULT_OK) if i == 1: reorder_vagus_test_data1(self, data_region) - - # create segment groups dividing the data approximately in thirds over the x-span of the trunk - data_fieldmodule = data_region.getFieldmodule() - data_nodes = data_fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) - data_coordinates = data_fieldmodule.findFieldByName("coordinates") - with ChangeManager(data_fieldmodule): - data_x = data_fieldmodule.createFieldComponent(data_coordinates, 1) - conditions = [ - data_fieldmodule.createFieldLessThan(data_x, data_fieldmodule.createFieldConstant(10000.0)), - None, - data_fieldmodule.createFieldGreaterThan(data_x, data_fieldmodule.createFieldConstant(20000.0)) - ] - conditions[1] = data_fieldmodule.createFieldNot( - data_fieldmodule.createFieldOr(conditions[0], conditions[2])) - for s in range(3): - segment_group = data_fieldmodule.createFieldGroup() - segment_group.setName("segment" + str(s + 1) + ".exf") - segment_group.setManaged(True) - segment_nodeset_group = segment_group.createNodesetGroup(data_nodes) - segment_nodeset_group.addNodesConditional(conditions[s]) - del conditions - del data_x + create_segment_groups_vagus_test_data1(self, data_region) # check annotation groups annotation_groups, nerve_metadata = scaffold.generateMesh(region, options) @@ -284,19 +311,19 @@ def test_vagus_nerve_1(self): TOL = 1.0E-6 expected_metadata = { 'segments': { - 'segment1.exf': {'minimum vagus coordinate': 0.062179363163301214, - 'maximum vagus coordinate': 0.244237232602641}, - 'segment2.exf': {'minimum vagus coordinate': 0.24685186671128517, - 'maximum vagus coordinate': 0.40960681735379395}, - 'segment3.exf': {'minimum vagus coordinate': 0.41221671261995313, - 'maximum vagus coordinate': 0.5754599929406741} + 'segment1.exf': {'maximum vagus coordinate': 0.24612805437844187, + 'minimum vagus coordinate': 0.06141502232856895}, + 'segment2.exf': {'maximum vagus coordinate': 0.4241573651485895, + 'minimum vagus coordinate': 0.24875823454739732}, + 'segment3.exf': {'maximum vagus coordinate': 0.5725426975292824, + 'minimum vagus coordinate': 0.4267079428429691} }, - 'trunk centroid fit error rms': 1.6796999717877277, - 'trunk centroid fit error max': 6.004413110311745, - 'trunk radius fit error rms': 0.20126533544206293, - 'trunk radius fit error max': 1.0496575899143181, - 'trunk twist angle fit error degrees rms': 3.9094139417227405, - 'trunk twist angle fit error degrees max': 9.786303215289262} + 'trunk centroid fit error rms': 3.195304274611684, + 'trunk centroid fit error max': 12.380175719267326, + 'trunk radius fit error rms': 1.4392862015456782, + 'trunk radius fit error max': 5.275441468095039, + 'trunk twist angle fit error degrees rms': 3.9171753300051773, + 'trunk twist angle fit error degrees max': 9.782285739956329} self.assertEqual(len(metadata), len(expected_metadata)) for key, value in metadata.items(): expected_value = expected_metadata[key] @@ -314,39 +341,39 @@ def test_vagus_nerve_1(self): expected_group_info = { 'left vagus nerve': ( 'http://uri.interlex.org/base/ilx_0785628', None, 25, - [-1269.8048516184547, -6359.977051431916, -69.78642824721726], - [2163.657939271601, -1111.9771974322234, 121.45057496461462], - [49.68213484225328, 258.2220400479382, 1479.1356481323735], - 249152179.8529517, - 33286242951.84727), + [-1242.1408436110323, -6449.120634594644, -61.12471585811795], + [2266.498485565621, -981.8862708609084, 109.4776933847775], + [43.42178093509881, 272.5812192544454, 1545.779335665412], + 253031226.32177484, + 34120295538.69093), 'left superior laryngeal nerve': ( - 'http://uri.interlex.org/base/ilx_0788780', 'left vagus nerve', 3, - [5923.104657597034, -4450.2479197707235, -196.91175665569313], - [-1473.665051675919, 858.0807042974039, 37.618907343841], - [29.2408913560962, 24.815173194025647, 579.4389023716839], - 9798165.396244952, - 559746405.7287067), + 'http://uri.interlex.org/base/ilx_0788780', 'left vagus nerve', 4, + [5923.1038437858815, -4450.247296980159, -196.91168267106667], + [-1473.666589331894, 858.0767721707383, 37.61788533802071], + [29.49243092889992, 24.929254450240933, 586.7101701356864], + 14619671.696585286, + 898349351.6575073), 'left A branch of superior laryngeal nerve': ( 'http://uri.interlex.org/base/ilx_0795823', 'left superior laryngeal nerve', 2, - [5105.456364262518, -1456.268405569011, 0.1879309337306836], - [-1289.581295107282, 381.4601337342457, 17.493930561764717], - [2.990626464939851, -3.6461864740685996, 299.96293350603094], - 4696615.99004511, - 236649716.13535246), + [5105.456405352692, -1456.2684158902327, 0.1879302148130364], + [-1289.5813291965014, 381.4601031501815, 17.4939357310412], + [2.9906257013149116, -3.6461931422077214, 299.9629334325899], + 4696616.032520553, + 236649717.0007874), 'left A thoracic cardiopulmonary branch of vagus nerve': ( 'http://uri.interlex.org/base/ilx_0794192', 'left vagus nerve', 2, - [20637.123232118392, -2947.094130818923, -608.0143068866595], - [99.38115735940329, -1713.8817535655442, -61.058795544347106], - [-8.872203312143029, 11.926532324519485, -349.21088399635704], - 6203011.915679664, - 328721624.0619874), + [20637.1231811151, -2947.0943923264213, -608.0143165605032], + [99.37959607618936, -1713.8821062071527, -61.058814561237654], + [-8.791810160853856, 11.98817110434402, -350.80993036588853], + 6229138.1929114945, + 331466617.8992749), 'left B thoracic cardiopulmonary branch of vagus nerve': ( 'http://uri.interlex.org/base/ilx_0794193', 'left vagus nerve', 1, - [22164.372546340644, -3219.413785808189, -620.4335804280928], - [1775.1656728388964, 1620.6261382213868, -217.23677284320627], - [2.363562419413938, 43.37866675798887, 342.92682167353087], - 4658935.705149433, - 267763437.92570886) + [22164.37237177626, -3219.4138243419347, -620.4335665416426], + [1775.1658782860482, 1620.6243020068152, -217.2367115667926], + [2.2452165218564915, 43.82745017664547, 345.30748541161313], + 4687937.203481174, + 271049251.5313337) } groups_count = len(expected_group_info) @@ -357,12 +384,13 @@ def test_vagus_nerve_1(self): self.assertTrue(coordinates.isValid()) self.assertEqual(RESULT_OK, fieldmodule.defineAllFaces()) mesh3d = fieldmodule.findMeshByDimension(3) - expected_elements_count = 33 + expected_elements_count = 34 self.assertEqual(expected_elements_count, mesh3d.getSize()) mesh2d = fieldmodule.findMeshByDimension(2) - self.assertEqual(expected_elements_count * 9 + groups_count, mesh2d.getSize()) + # groups_count + 1 due to one group having 2 branches + self.assertEqual(expected_elements_count * 9 + groups_count + 1, mesh2d.getSize()) mesh1d = fieldmodule.findMeshByDimension(1) - self.assertEqual(expected_elements_count * 17 + groups_count * 8, mesh1d.getSize()) + self.assertEqual(expected_elements_count * 17 + (groups_count + 1) * 8, mesh1d.getSize()) nodes = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) self.assertEqual(expected_elements_count + 1 + 8, nodes.getSize()) # including 6 marker points datapoints = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_DATAPOINTS) @@ -383,16 +411,22 @@ def test_vagus_nerve_1(self): annotation_group = findAnnotationGroupByName(annotation_groups, group_name) self.assertEqual(term_id, annotation_group.getId()) mesh_group3d = group.getMeshGroup(mesh3d) - self.assertEqual(expected_elements_count, mesh_group3d.getSize()) - mesh_group2d = group.getMeshGroup(mesh2d) + elements_count = mesh_group3d.getSize() expected_face_count = expected_elements_count * 9 + 1 - self.assertEqual(expected_face_count, mesh_group2d.getSize()) - mesh_group1d = group.getMeshGroup(mesh1d) expected_line_count = expected_elements_count * 17 + 8 - self.assertEqual(expected_line_count, mesh_group1d.getSize()) - nodeset_group = group.getNodesetGroup(nodes) expected_node_count = expected_elements_count + (2 if parent_group_name else 1) - self.assertEqual(expected_node_count, nodeset_group.getSize()) + if group_name == 'left superior laryngeal nerve': + # there are 2 branches with this name + expected_face_count += 1 + expected_line_count += 8 + expected_node_count += 1 + self.assertEqual(expected_elements_count, elements_count) + mesh_group2d = group.getMeshGroup(mesh2d) + self.assertEqual(expected_face_count, mesh_group2d.getSize(), msg=group_name) + mesh_group1d = group.getMeshGroup(mesh1d) + self.assertEqual(expected_line_count, mesh_group1d.getSize(), msg=group_name) + nodeset_group = group.getNodesetGroup(nodes) + self.assertEqual(expected_node_count, nodeset_group.getSize(), msg=group_name) branch_of_branch = False if parent_group_name: # check first 2 nodes are in parent nodeset group @@ -432,14 +466,14 @@ def test_vagus_nerve_1(self): xi_centre = [0.5, 0.5, 0.5] # (element_identifier, expected_d3) expected_d3_info = [ - (2, [-33.391218366765855, 268.39160153845626, 1158.860985674948]), - (4, [-501.3786015366995, 675.7031254761372, 793.3595901220758]), - (6, [-33.04316735076699, 629.5192377516283, 194.49196298196927]), - (8, [-23.822365306323434, 202.80333345607843, 665.523590669333]), - (10, [-25.625279982419272, -275.14752889054415, 641.9228605835165]), - (12, [-242.75605360012025, -550.3231979114498, 242.35581137281747]), - (14, [0.06310578889011254, -474.67423296131636, 117.90792203003063]), - (16, [-3.504130626277629, -465.9492020804986, 105.11830088131188])] + (2, [-27.841230048890253, 260.7466659639747, 1176.8751799214626]), + (4, [-514.0930254730785, 671.2935039396489, 822.8835226739518]), + (6, [-33.44088684233287, 645.9285494148206, 194.22531263541956]), + (8, [-23.73600988035605, 206.6258555865557, 665.2926097904865]), + (10, [-30.597506104686232, -283.59176921547873, 641.4881272444806]), + (12, [-238.74470808058186, -573.0695314634801, 238.27251895181752]), + (14, [-0.9631502817560147, -476.5043907385434, 117.34817168731932]), + (16, [-3.1500151690115956, -475.1247911810948, 108.02626785082279])] for element_identifier, expected_d3 in expected_d3_info: element = mesh3d.findElementByIdentifier(element_identifier) self.assertEqual(RESULT_OK, fieldcache.setMeshLocation(element, xi_centre)) @@ -460,9 +494,9 @@ def test_vagus_nerve_1(self): fieldcache.clearLocation() result, volume = volume_field.evaluateReal(fieldcache, 1) self.assertEqual(result, RESULT_OK) - expected_volume = 33286242951.84727 if (coordinate_field is coordinates) else 33282940849.74868 + expected_volume = 34120295538.69093 if (coordinate_field is coordinates) else 34133114193.810123 self.assertAlmostEqual(expected_volume, volume, delta=STOL) - expected_elements_count = 33 + expected_elements_count = 34 group = fieldmodule.findFieldByName("epineurium").castGroup() mesh_group2d = group.getMeshGroup(mesh2d) self.assertEqual(expected_elements_count * 4, mesh_group2d.getSize()) @@ -471,7 +505,7 @@ def test_vagus_nerve_1(self): fieldcache.clearLocation() result, surface_area = surface_area_field.evaluateReal(fieldcache, 1) self.assertEqual(result, RESULT_OK) - expected_surface_area = 72452883.40392067 if (coordinate_field is coordinates) else 72585973.86409168 + expected_surface_area = 74658732.66536702 if (coordinate_field is coordinates) else 74810823.39355227 self.assertAlmostEqual(expected_surface_area, surface_area, delta=STOL) group = fieldmodule.findFieldByName("vagus centroid").castGroup() mesh_group1d = group.getMeshGroup(mesh1d) @@ -480,7 +514,7 @@ def test_vagus_nerve_1(self): length_field.setNumbersOfPoints(4) result, length = length_field.evaluateReal(fieldcache, 1) self.assertEqual(result, RESULT_OK) - self.assertAlmostEqual(75894.09718530288, length, delta=LTOL) + self.assertAlmostEqual(77989.74712379556, length, delta=LTOL) # check all markers are added marker_group = fieldmodule.findFieldByName("marker").castGroup() @@ -516,32 +550,32 @@ def test_vagus_nerve_1(self): 0.07044881379783888, 0.00014399999999999916), 'left superior laryngeal nerve': ( - [0.00047730703517693016, 0.0001590104729754135, 0.13155226693210442], - [0.012766941239985318, 0.011729858195898774, -0.006056668690651825], - [-0.00403931404678569, 0.0043905839670676464, -7.155118975268882e-05], - 0.0019802618878763203, - 1.9797849471091192e-06), + [0.0004896991248311625, 0.0001718248196490253, 0.13358994717220393], + [0.012591688735249879, 0.011554907571233508, -0.005801519857524198], + [-0.004035969328504113, 0.004384632192912359, -9.01711847667297e-05], + 0.0029647324574537596, + 2.9681678935347134e-06), 'left A branch of superior laryngeal nerve': ( - [0.028688067705606772, 0.02535673458269107, 0.11727390888183663], - [-0.009599655303649913, -0.013266188786747662, -0.017300073369861106], - [-0.004249304133747675, 0.004150895449488073, -0.0008858783562686601], - 0.0015990563197801552, - 1.5260431704427458e-06), + [0.028322975480738796, 0.02498830599810807, 0.11987166123924467], + [-0.009335630171933725, -0.013097308942832018, -0.01748433385528671], + [-0.004243088648924635, 0.004153044805551016, -0.0009050664963775323], + 0.0015933505248969856, + 1.5184536519716786e-06), 'left A thoracic cardiopulmonary branch of vagus nerve': ( - [-0.00023275582415062705, -5.5790425955213165e-06, 0.3810973389155537], - [-0.026617743875947883, 0.010946968854188322, 0.006187567037017488], - [-0.0022754210043383436, -0.005550105193537561, -3.3829765141102364e-05], - 0.002071680752066034, - 2.119536598394646e-06), + [-0.00023095034344209296, -1.0213543563703398e-05, 0.38006390857609595], + [-0.026498073841997774, 0.01090206319896006, 0.006047544274512271], + [-0.002274865823058741, -0.0055501710871262095, -3.291278514777618e-05], + 0.0020613515506789195, + 2.1083349503909266e-06), 'left B thoracic cardiopulmonary branch of vagus nerve': ( - [0.0005128262687161793, -0.0009094452229105069, 0.4063493881429567], - [0.023578121466429198, -0.026751584774707175, 0.020367952991549275], - [0.004501879772565175, 0.003966530167337301, 1.7100648076362468e-05], - 0.001442332241718134, - 1.4795610548057572e-06)} - XTOL = 2.0E-7 # coordinates and derivatives - STOL = 1.0E-9 # surface area - VTOL = 1.0E-11 # volume + [0.0005187730319848923, -0.000913147031720288, 0.4048333718275829], + [0.023402426848619456, -0.026601176543883488, 0.01999575491995305], + [0.004504705667267902, 0.003963119934858472, 1.7331437428769192e-05], + 0.001429957375006481, + 1.4665142890587577e-06)} + XTOL = 1.0E-4 # coordinates and derivatives + STOL = 1.0E-5 # surface area + VTOL = 1.0E-8 # volume for group_name in expected_group_info.keys(): expected_start_x, expected_start_d1, expected_start_d3, expected_surface_area, expected_volume = \ expected_group_material_info[group_name] @@ -571,8 +605,8 @@ def test_vagus_nerve_1(self): self.assertEqual(result, RESULT_OK) result, volume = volume_field.evaluateReal(fieldcache, 1) self.assertEqual(result, RESULT_OK) - self.assertAlmostEqual(expected_surface_area, surface_area, delta=2.0E-7 if branch_of_branch else STOL) - self.assertAlmostEqual(expected_volume, volume, delta=2.0E-10 if branch_of_branch else VTOL) + self.assertAlmostEqual(expected_surface_area, surface_area, delta=2.0E-6 if branch_of_branch else STOL) + self.assertAlmostEqual(expected_volume, volume, delta=2.0E-9 if branch_of_branch else VTOL) # test combined groups branch_common_map = { @@ -626,7 +660,7 @@ def test_arc_vagus(self): self.assertEqual(14, len(annotation_groups)) fit_metadata = nerve_metadata.getMetadata()['vagus nerve'] self.assertAlmostEqual(fit_metadata['trunk centroid fit error rms'], 0.0, delta=1.0E-4) - self.assertAlmostEqual(fit_metadata['trunk radius fit error rms'], 0.0, delta=1.0E-12) + self.assertAlmostEqual(fit_metadata['trunk radius fit error rms'], 1.2555492226192078e-05, delta=1.0E-12) self.assertAlmostEqual(fit_metadata['trunk twist angle fit error degrees rms'], 0.0, delta=0.002) fieldmodule = region.getFieldmodule() fieldcache = fieldmodule.createFieldcache() @@ -642,12 +676,12 @@ def test_arc_vagus(self): length_field.setNumbersOfPoints(4) result, length = length_field.evaluateReal(fieldcache, 1) self.assertEqual(result, RESULT_OK) - self.assertAlmostEqual(math.pi, length, delta=1.0E-3) + self.assertAlmostEqual(math.pi, length, delta=1.0E-2) nodes = fieldmodule.findNodesetByFieldDomainType(Field.DOMAIN_TYPE_NODES) node = nodes.findNodeByIdentifier((elements_count // 2) + 1) fieldcache.setNode(node) - XTOL = 1.0E-6 + XTOL = 1.0E-3 expected_parameters = [ [1.000000940622472, -6.338102288830355e-06, 0.0], [1.8235912738924405e-07, 0.19637019676125334, 0.0],