Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions morph_utils/ccf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import matplotlib.pyplot as plt
from morph_utils.query import get_id_by_name, get_structures, query_pinning_info_cell_locator
from morph_utils.measurements import get_node_spacing
from morph_utils.modifications import resample_morphology


NAME_MAP_FILE = files('morph_utils') / 'data/ccf_structure_name_map.json'
with open(NAME_MAP_FILE, "r") as fn:
Expand Down Expand Up @@ -289,7 +291,8 @@ def get_ccf_structure(voxel, name_map=None, annotation=None, coordinate_to_voxel
def projection_matrix_for_swc(input_swc_file, mask_method = "tip_and_branch",
tip_count = False, annotation=None,
annotation_path = None, volume_shape=(1320, 800, 1140),
resolution=10, node_type_list=[2]):
resolution=10, node_type_list=[2],
resample_spacing=None):
"""
Given a swc file, quantify the projection matrix. That is the amount of axon in each structure. This function assumes
there is equivalent internode spacing (i.e. the input swc file should be resampled prior to running this code).
Expand All @@ -307,6 +310,8 @@ def projection_matrix_for_swc(input_swc_file, mask_method = "tip_and_branch",
volume_shape (tuple, optional): the size in voxels of the ccf atlas (annotation volume). Defaults to (1320, 800, 1140).
resolution (int, optional): resolution (um/pixel) of the annotation volume
node_type_list (list of ints): node type to extract projection data for, typically axon (2)
resample_spacing (float or None): if not None, will resample the input morphology to the designated
internode spacing

Returns:
filename (str)
Expand Down Expand Up @@ -340,7 +345,10 @@ def projection_matrix_for_swc(input_swc_file, mask_method = "tip_and_branch",
z_midline = z_size / 2

morph = morphology_from_swc(input_swc_file)
morph = move_soma_to_left_hemisphere(morph, resolution, volume_shape, z_midline)
morph = move_soma_to_left_hemisphere(morph, resolution, volume_shape, z_midline)
if resample_spacing is not None:
morph = resample_morphology(morph, resample_spacing)

spacing = get_node_spacing(morph)[0]

morph_df = pd.DataFrame(morph.nodes())
Expand All @@ -349,6 +357,10 @@ def projection_matrix_for_swc(input_swc_file, mask_method = "tip_and_branch",
morph_df = morph_df[morph_df['type'].isin(node_type_list)]

# annotate each node
if morph_df.empty:
print("Its empty")
return input_swc_file, {}

morph_df['ccf_structure'] = morph_df.apply(lambda rw: full_name_to_abbrev_dict[get_ccf_structure( np.array([rw.x, rw.y, rw.z]) , name_map, annotation, True)], axis=1)

# roll up fiber tracts
Expand Down
111 changes: 111 additions & 0 deletions morph_utils/executable_scripts/projection_matrix_for_single_cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
from tqdm import tqdm
import pandas as pd
import argschema as ags
from morph_utils.ccf import projection_matrix_for_swc

class IO_Schema(ags.ArgSchema):
input_swc_file = ags.fields.InputFile(description='directory with micron resolution ccf registered files')
output_projection_csv = ags.fields.OutputFile(description="output projection csv")
projection_threshold = ags.fields.Int(default=0)
normalize_proj_mat = ags.fields.Boolean(default=True)
mask_method = ags.fields.Str(default="tip_and_branch",description = " 'tip_and_branch', 'branch', 'tip', or 'tip_or_branch' ")
tip_count = ags.fields.Boolean(default=False, description="when true, this will measure a matrix of number of tips instead of number of nodes")
annotation_path = ags.fields.Str(default="",description = "Optional. Path to annotation .nrrd file. Defaults to 10um ccf atlas")
resolution = ags.fields.Int(default=10, description="Optional. ccf resolution (micron/pixel")
volume_shape = ags.fields.List(ags.fields.Int, default=[1320, 800, 1140], description = "Optional. Size of input annotation")
resample_spacing = ags.fields.Float(allow_none=True, default=None, description = 'internode spacing to resample input morphology with')


def normalize_projection_columns_per_cell(input_df, projection_column_identifiers=['ipsi', 'contra']):
"""
:param input_df: input projection df
:param projection_column_identifiers: list of identifiers for projection columns. i.e. strings that identify projection columns from metadata columns
:return: normalized projection matrix
"""
proj_cols = [c for c in input_df.columns if any([ider in c for ider in projection_column_identifiers])]
input_df[proj_cols] = input_df[proj_cols].fillna(0)

res = input_df[proj_cols].T / input_df[proj_cols].sum(axis=1)
input_df[proj_cols] = res.T

return input_df


def main(input_swc_file,
output_projection_csv,
resolution,
projection_threshold,
normalize_proj_mat,
mask_method,
tip_count,
annotation_path,
volume_shape,
resample_spacing,
**kwargs):

if annotation_path == "":
annotation_path = None

results = []
res = projection_matrix_for_swc(input_swc_file=input_swc_file,
tip_count = tip_count,
mask_method = mask_method,
annotation=None,
annotation_path = annotation_path,
volume_shape=volume_shape,
resolution=resolution,
resample_spacing=resample_spacing)
results = [res]

output_projection_csv = output_projection_csv.replace(".csv", f"_{mask_method}.csv")
projection_records = {}
# branch_and_tip_projection_records = {}
for res in results:
fn = os.path.abspath(res[0])
proj_records = res[1]
# brnch_tip_records = res[1]

projection_records[fn] = proj_records
# branch_and_tip_projection_records[fn] = brnch_tip_records

proj_df = pd.DataFrame(projection_records).T.fillna(0)
# proj_df_mask = pd.DataFrame(branch_and_tip_projection_records).T.fillna(0)

proj_df.to_csv(output_projection_csv)
# proj_df_mask.to_csv(output_projection_csv_tip_branch_mask)

if projection_threshold != 0:
output_projection_csv = output_projection_csv.replace(".csv",
"{}thresh.csv".format(projection_threshold))
# output_projection_csv_tip_branch_mask = output_projection_csv_tip_branch_mask.replace(".csv",
# "{}thresh.csv".format(
# projection_threshold))

proj_df_arr = proj_df.values
proj_df_arr[proj_df_arr < projection_threshold] = 0
proj_df = pd.DataFrame(proj_df_arr, columns=proj_df.columns, index=proj_df.index)
proj_df.to_csv(output_projection_csv)

# proj_df_mask_arr = proj_df_mask.values
# proj_df_mask_arr[proj_df_mask_arr < projection_threshold] = 0
# proj_df_mask = pd.DataFrame(proj_df_mask_arr, columns=proj_df_mask.columns, index=proj_df_mask.index)
# proj_df_mask.to_csv(output_projection_csv_tip_branch_mask)

if normalize_proj_mat:
output_projection_csv = output_projection_csv.replace(".csv", "_norm.csv")
# output_projection_csv_tip_branch_mask = output_projection_csv_tip_branch_mask.replace(".csv", "_norm.csv")

proj_df = normalize_projection_columns_per_cell(proj_df)
proj_df.to_csv(output_projection_csv)

# proj_df_mask = normalize_projection_columns_per_cell(proj_df_mask)
# proj_df_mask.to_csv(output_projection_csv_tip_branch_mask)

def console_script():
module = ags.ArgSchemaParser(schema_type=IO_Schema)
main(**module.args)

if __name__ == "__main__":
module = ags.ArgSchemaParser(schema_type=IO_Schema)
main(**module.args)
Loading