Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions naplib/localization/freesurfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def __init__(
except Exception as e:
logger.warning(f'No {hemi}.sulc file found. No sulcus information will be used.')
self.sulc = None
Copy link
Collaborator

@gavinmischler gavinmischler Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add self.sulc_alpha = 1.0 after line 127 also, or just put it after and separate from the "try-except" block? Just in case someone calls plot_hemi, it will automatically fail because it tries to do surfdist_viz and plug in self.sulc_alpha but it won't exist, so it's best to set it to something I think

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, good call. Changed

self.sulc_alpha = 1.0


self.load_labels()
Expand Down
49 changes: 47 additions & 2 deletions naplib/utils/surfdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import gdist
import matplotlib.pyplot as plt
from matplotlib.colors import LightSource
import numpy as np
from nibabel.freesurfer.io import read_annot

Expand Down Expand Up @@ -124,17 +125,19 @@ def surfdist_viz(
alpha="auto",
bg_map=None,
bg_on_stat=False,
bg_alpha=1.0,
figsize=None,
ax=None,
vmin=None,
vmax=None,
light_source=None,
):
"""Visualize results on cortical surface using matplotlib.

Parameters
----------
coords : numpy array of shape (n_nodes,3), each row specifying the x,y,z
coordinates of one node of surface mesh
coordinates of one node of surface mesh
faces : numpy array of shape (n_faces, 3), each row specifying the indices
of the three nodes building one node of the surface mesh
stat_map : numpy array of shape (n_nodes,) containing the values to be
Expand All @@ -158,9 +161,16 @@ def surfdist_viz(
multiplied with the background map for shadowing. Otherwise,
only areas that are not covered by the statsitical map after
thresholding will show shadows.
bg_alpha : float, determines the opacity of the background map.
bg_alpha defaults to 1.0 and is only relevant if bg_on_stat
figsize : tuple of intergers, dimensions of the figure that is produced.
ax : Axis
Axis to plot on, with 3d projection.
light_source: None, bool, or tuple of int, optional
Whether to apply a light source for shading. If True, the light
source position is inferred from `elev` and `azim`. If a tuple of
(alt, az), these values will be used to specify the light source
position. If None or False, no shading is applied. Default is None.

Returns
-------
Expand Down Expand Up @@ -226,7 +236,7 @@ def surfdist_viz(
bg_faces = np.mean(bg_data[faces], axis=1)
bg_faces = bg_faces - bg_faces.min()
bg_faces = bg_faces / bg_faces.max()
face_colors = plt.cm.gray_r(bg_faces)
face_colors = plt.cm.gray_r(bg_faces * bg_alpha)

# modify alpha values of background
face_colors[:, 3] = alpha * face_colors[:, 3]
Expand Down Expand Up @@ -260,6 +270,41 @@ def surfdist_viz(
else:
face_colors = cmap(stat_map_faces)

if light_source:
if hasattr(light_source, '__len__'):
if len(light_source) == 2:
ls = LightSource(azdeg=light_source[1], altdeg=light_source[0])
else:
# Apply lighting to the face colors for shading
ls = LightSource(azdeg=azim, altdeg=elev)

# Manually calculate the light vector since the 'light_vector'
# attribute is not accessible in some matplotlib versions.
az = np.radians(ls.azdeg)
alt = np.radians(ls.altdeg)
light_vec = np.array([
np.cos(az) * np.cos(alt),
np.sin(az) * np.cos(alt),
np.sin(alt)
])

# Calculate face normals
v0 = coords[faces[:, 0]]
v1 = coords[faces[:, 1]]
v2 = coords[faces[:, 2]]
face_normals = np.cross(v1 - v0, v2 - v0)
face_normals /= np.linalg.norm(face_normals, axis=1)[:, np.newaxis]

# The shade is the dot product of the light vector and face normals
shade = np.dot(face_normals, light_vec)

# Modulate the RGB colors by the shade, keeping the alpha channel
# Use np.clip to keep shade values between 0 and 1
illuminated_rgb = face_colors[:, :3] * np.clip(shade, 0, 1)[:, np.newaxis]

# Combine illuminated RGB with the original alpha channel
face_colors = np.hstack((illuminated_rgb, face_colors[:, 3:]))

p3dcollec.set_facecolors(face_colors)

if not premade_ax:
Expand Down
17 changes: 12 additions & 5 deletions naplib/visualization/brain_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _view(hemi, mode: str = "lateral", backend: str = "mpl"):
raise ValueError(f"Unknown `mode`: {mode}.")


def _plot_hemi(hemi, cmap="coolwarm", ax=None, view="best", threshold=None, vmin=None, vmax=None):
def _plot_hemi(hemi, cmap="coolwarm", ax=None, view="best", threshold=None, vmin=None, vmax=None, light_source=False):
if isinstance(view, tuple):
elev, azim = view
else:
Expand All @@ -105,16 +105,18 @@ def _plot_hemi(hemi, cmap="coolwarm", ax=None, view="best", threshold=None, vmin
alpha=hemi.alpha,
bg_map=hemi.sulc,
bg_on_stat=True,
bg_alpha=hemi.sulc_alpha,
ax=ax,
vmin=vmin,
vmax=vmax
vmax=vmax,
light_source=light_source
)
ax.axes.set_axis_off()
ax.grid(False)


def plot_brain_overlay(
brain, cmap="coolwarm", ax=None, hemi='both', view="best", vmin=None, vmax=None, cmap_quantile=1.0, threshold=None, **kwargs
brain, cmap="coolwarm", ax=None, hemi='both', view="best", vmin=None, vmax=None, cmap_quantile=1.0, threshold=None, light_source=False, **kwargs
):
"""
Plot brain overlay on the 3D cortical surface using matplotlib.
Expand Down Expand Up @@ -149,6 +151,11 @@ def plot_brain_overlay(
threshold : positive float, optional
If given, then only values on the overlay which are less -threshold or greater than threshold will
be shown.
light_source: None, bool, or tuple of int, optional
Whether to apply a light source for shading. If True, the light
source position is inferred from `elev` and `azim`. If a tuple of
(alt, az), these values will be used to specify the light source
position. If None or False, no shading is applied. Default is True.
**kwargs : kwargs
Any other kwargs to pass to matplotlib.pyplot.figure (such as figsize)

Expand Down Expand Up @@ -216,9 +223,9 @@ def plot_brain_overlay(


if ax[0] is not None:
_plot_hemi(brain.lh, cmap, ax[0], view=view, vmin=vmin, vmax=vmax, threshold=threshold)
_plot_hemi(brain.lh, cmap, ax[0], view=view, vmin=vmin, vmax=vmax, threshold=threshold, light_source=light_source)
if ax[1] is not None:
_plot_hemi(brain.rh, cmap, ax[1], view=view, vmin=vmin, vmax=vmax, threshold=threshold)
_plot_hemi(brain.rh, cmap, ax[1], view=view, vmin=vmin, vmax=vmax, threshold=threshold, light_source=light_source)

return fig, ax

Expand Down