Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 27 additions & 7 deletions src/spac/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,11 @@ def tsne_plot(adata, color_column=None, ax=None, **kwargs):

return fig, ax


def histogram(adata, feature=None, annotation=None, layer=None,
group_by=None, together=False, ax=None,
x_log_scale=False, y_log_scale=False, **kwargs):
x_log_scale=False, y_log_scale=False,
defined_color_map=None, **kwargs):
"""
Plot the histogram of cells based on a specific feature from adata.X
or annotation from adata.obs.
Expand Down Expand Up @@ -447,6 +449,10 @@ def histogram(adata, feature=None, annotation=None, layer=None,
y_log_scale : bool, default False
If True, the y-axis will be set to log scale.

defined_color_map : str, optional, default=None
Key in adata.uns used to retrieve a color mapping dictionary to
color code the histogram.

**kwargs
Additional keyword arguments passed to seaborn histplot function.
Key arguments include:
Expand Down Expand Up @@ -495,7 +501,6 @@ def histogram(adata, feature=None, annotation=None, layer=None,
DataFrame containing the data used for plotting the histogram.

"""

# If no feature or annotation is specified, apply default behavior
if feature is None and annotation is None:
# Default to the first feature in adata.var_names
Expand Down Expand Up @@ -530,6 +535,10 @@ def histogram(adata, feature=None, annotation=None, layer=None,

df = pd.concat([df, adata.obs], axis=1)

if defined_color_map:
color_dict = get_defined_color_map(adata, defined_color_map)
kwargs.setdefault("palette", color_dict)

if feature and annotation:
raise ValueError("Cannot pass both feature and annotation,"
" choose one.")
Expand Down Expand Up @@ -567,7 +576,7 @@ def cal_bin_num(
):
bins = max(int(2*(num_rows ** (1/3))), 1)
print(f'Automatically calculated number of bins is: {bins}')
return(bins)
return (bins)

num_rows = plot_data.shape[0]

Expand Down Expand Up @@ -657,8 +666,7 @@ def calculate_histogram(data, bins, bin_edges=None):
# Set default values if not provided in kwargs
kwargs.setdefault("multiple", "stack")
kwargs.setdefault("element", "bars")



sns.histplot(data=hist_data, x='bin_center', weights='count',
hue=group_by, ax=ax, **kwargs)
# If plotting feature specify which layer
Expand All @@ -682,8 +690,14 @@ def calculate_histogram(data, bins, bin_edges=None):
groups[i]][data_column]
hist_data = calculate_histogram(group_data, kwargs['bins'])

# If defined_color_map provided, retrieves color map
group_color = None
if defined_color_map:
group_color = color_dict.get(group, None)

sns.histplot(data=hist_data, x="bin_center", ax=ax_i,
weights='count', **kwargs)
weights='count', color=group_color, **kwargs)

# If plotting feature specify which layer
if feature:
ax_i.set_title(f'{groups[i]} with Layer: {layer}')
Expand Down Expand Up @@ -720,7 +734,13 @@ def calculate_histogram(data, bins, bin_edges=None):
hist_data = calculate_histogram(plot_data[data_column], kwargs['bins'])
if pd.api.types.is_numeric_dtype(plot_data[data_column]):
ax.set_xlim(hist_data['bin_left'].min(),
hist_data['bin_right'].max())
hist_data['bin_right'].max())

# Set default color from custom color map if available
if defined_color_map:
color_dict = get_defined_color_map(adata,defined_color_map)
default_color = list(color_dict.values())[0]
kwargs['color'] = default_color

sns.histplot(
data=hist_data,
Expand Down