diff --git a/src/spac/visualization.py b/src/spac/visualization.py index 9003c163..f49dce6a 100644 --- a/src/spac/visualization.py +++ b/src/spac/visualization.py @@ -403,9 +403,11 @@ def tsne_plot(adata, color_column=None, ax=None, **kwargs): return fig, ax + def histogram(adata, feature=None, annotation=None, layer=None, group_by=None, together=False, ax=None, - x_log_scale=False, y_log_scale=False, **kwargs): + x_log_scale=False, y_log_scale=False, + defined_color_map=None, **kwargs): """ Plot the histogram of cells based on a specific feature from adata.X or annotation from adata.obs. @@ -447,6 +449,10 @@ def histogram(adata, feature=None, annotation=None, layer=None, y_log_scale : bool, default False If True, the y-axis will be set to log scale. + defined_color_map : str, optional, default=None + Key in adata.uns used to retrieve a color mapping dictionary to + color code the histogram. + **kwargs Additional keyword arguments passed to seaborn histplot function. Key arguments include: @@ -495,7 +501,6 @@ def histogram(adata, feature=None, annotation=None, layer=None, DataFrame containing the data used for plotting the histogram. """ - # If no feature or annotation is specified, apply default behavior if feature is None and annotation is None: # Default to the first feature in adata.var_names @@ -530,6 +535,10 @@ def histogram(adata, feature=None, annotation=None, layer=None, df = pd.concat([df, adata.obs], axis=1) + if defined_color_map: + color_dict = get_defined_color_map(adata, defined_color_map) + kwargs.setdefault("palette", color_dict) + if feature and annotation: raise ValueError("Cannot pass both feature and annotation," " choose one.") @@ -567,7 +576,7 @@ def cal_bin_num( ): bins = max(int(2*(num_rows ** (1/3))), 1) print(f'Automatically calculated number of bins is: {bins}') - return(bins) + return (bins) num_rows = plot_data.shape[0] @@ -657,8 +666,7 @@ def calculate_histogram(data, bins, bin_edges=None): # Set default values if not provided in kwargs kwargs.setdefault("multiple", "stack") kwargs.setdefault("element", "bars") - - + sns.histplot(data=hist_data, x='bin_center', weights='count', hue=group_by, ax=ax, **kwargs) # If plotting feature specify which layer @@ -682,8 +690,14 @@ def calculate_histogram(data, bins, bin_edges=None): groups[i]][data_column] hist_data = calculate_histogram(group_data, kwargs['bins']) + # If defined_color_map provided, retrieves color map + group_color = None + if defined_color_map: + group_color = color_dict.get(group, None) + sns.histplot(data=hist_data, x="bin_center", ax=ax_i, - weights='count', **kwargs) + weights='count', color=group_color, **kwargs) + # If plotting feature specify which layer if feature: ax_i.set_title(f'{groups[i]} with Layer: {layer}') @@ -720,7 +734,13 @@ def calculate_histogram(data, bins, bin_edges=None): hist_data = calculate_histogram(plot_data[data_column], kwargs['bins']) if pd.api.types.is_numeric_dtype(plot_data[data_column]): ax.set_xlim(hist_data['bin_left'].min(), - hist_data['bin_right'].max()) + hist_data['bin_right'].max()) + + # Set default color from custom color map if available + if defined_color_map: + color_dict = get_defined_color_map(adata,defined_color_map) + default_color = list(color_dict.values())[0] + kwargs['color'] = default_color sns.histplot( data=hist_data,