diff --git a/naplib/localization/freesurfer.py b/naplib/localization/freesurfer.py index 51be56e..eb89c27 100644 --- a/naplib/localization/freesurfer.py +++ b/naplib/localization/freesurfer.py @@ -124,6 +124,7 @@ def __init__( except Exception as e: logger.warning(f'No {hemi}.sulc file found. No sulcus information will be used.') self.sulc = None + self.sulc_alpha = 1.0 self.load_labels() diff --git a/naplib/utils/surfdist.py b/naplib/utils/surfdist.py index d8e1897..cc7e1b1 100644 --- a/naplib/utils/surfdist.py +++ b/naplib/utils/surfdist.py @@ -5,6 +5,7 @@ import gdist import matplotlib.pyplot as plt +from matplotlib.colors import LightSource import numpy as np from nibabel.freesurfer.io import read_annot @@ -124,17 +125,19 @@ def surfdist_viz( alpha="auto", bg_map=None, bg_on_stat=False, + bg_alpha=1.0, figsize=None, ax=None, vmin=None, vmax=None, + light_source=None, ): """Visualize results on cortical surface using matplotlib. Parameters ---------- coords : numpy array of shape (n_nodes,3), each row specifying the x,y,z - coordinates of one node of surface mesh + coordinates of one node of surface mesh faces : numpy array of shape (n_faces, 3), each row specifying the indices of the three nodes building one node of the surface mesh stat_map : numpy array of shape (n_nodes,) containing the values to be @@ -158,9 +161,16 @@ def surfdist_viz( multiplied with the background map for shadowing. Otherwise, only areas that are not covered by the statsitical map after thresholding will show shadows. + bg_alpha : float, determines the opacity of the background map. + bg_alpha defaults to 1.0 and is only relevant if bg_on_stat figsize : tuple of intergers, dimensions of the figure that is produced. ax : Axis Axis to plot on, with 3d projection. + light_source: None, bool, or tuple of int, optional + Whether to apply a light source for shading. If True, the light + source position is inferred from `elev` and `azim`. If a tuple of + (alt, az), these values will be used to specify the light source + position. If None or False, no shading is applied. Default is None. Returns ------- @@ -226,7 +236,7 @@ def surfdist_viz( bg_faces = np.mean(bg_data[faces], axis=1) bg_faces = bg_faces - bg_faces.min() bg_faces = bg_faces / bg_faces.max() - face_colors = plt.cm.gray_r(bg_faces) + face_colors = plt.cm.gray_r(bg_faces * bg_alpha) # modify alpha values of background face_colors[:, 3] = alpha * face_colors[:, 3] @@ -260,6 +270,41 @@ def surfdist_viz( else: face_colors = cmap(stat_map_faces) + if light_source: + if hasattr(light_source, '__len__'): + if len(light_source) == 2: + ls = LightSource(azdeg=light_source[1], altdeg=light_source[0]) + else: + # Apply lighting to the face colors for shading + ls = LightSource(azdeg=azim, altdeg=elev) + + # Manually calculate the light vector since the 'light_vector' + # attribute is not accessible in some matplotlib versions. + az = np.radians(ls.azdeg) + alt = np.radians(ls.altdeg) + light_vec = np.array([ + np.cos(az) * np.cos(alt), + np.sin(az) * np.cos(alt), + np.sin(alt) + ]) + + # Calculate face normals + v0 = coords[faces[:, 0]] + v1 = coords[faces[:, 1]] + v2 = coords[faces[:, 2]] + face_normals = np.cross(v1 - v0, v2 - v0) + face_normals /= np.linalg.norm(face_normals, axis=1)[:, np.newaxis] + + # The shade is the dot product of the light vector and face normals + shade = np.dot(face_normals, light_vec) + + # Modulate the RGB colors by the shade, keeping the alpha channel + # Use np.clip to keep shade values between 0 and 1 + illuminated_rgb = face_colors[:, :3] * np.clip(shade, 0, 1)[:, np.newaxis] + + # Combine illuminated RGB with the original alpha channel + face_colors = np.hstack((illuminated_rgb, face_colors[:, 3:])) + p3dcollec.set_facecolors(face_colors) if not premade_ax: diff --git a/naplib/visualization/brain_plots.py b/naplib/visualization/brain_plots.py index 5ec0de4..a9808a0 100644 --- a/naplib/visualization/brain_plots.py +++ b/naplib/visualization/brain_plots.py @@ -90,7 +90,7 @@ def _view(hemi, mode: str = "lateral", backend: str = "mpl"): raise ValueError(f"Unknown `mode`: {mode}.") -def _plot_hemi(hemi, cmap="coolwarm", ax=None, view="best", threshold=None, vmin=None, vmax=None): +def _plot_hemi(hemi, cmap="coolwarm", ax=None, view="best", threshold=None, vmin=None, vmax=None, light_source=False): if isinstance(view, tuple): elev, azim = view else: @@ -105,16 +105,18 @@ def _plot_hemi(hemi, cmap="coolwarm", ax=None, view="best", threshold=None, vmin alpha=hemi.alpha, bg_map=hemi.sulc, bg_on_stat=True, + bg_alpha=hemi.sulc_alpha, ax=ax, vmin=vmin, - vmax=vmax + vmax=vmax, + light_source=light_source ) ax.axes.set_axis_off() ax.grid(False) def plot_brain_overlay( - brain, cmap="coolwarm", ax=None, hemi='both', view="best", vmin=None, vmax=None, cmap_quantile=1.0, threshold=None, **kwargs + brain, cmap="coolwarm", ax=None, hemi='both', view="best", vmin=None, vmax=None, cmap_quantile=1.0, threshold=None, light_source=False, **kwargs ): """ Plot brain overlay on the 3D cortical surface using matplotlib. @@ -149,6 +151,11 @@ def plot_brain_overlay( threshold : positive float, optional If given, then only values on the overlay which are less -threshold or greater than threshold will be shown. + light_source: None, bool, or tuple of int, optional + Whether to apply a light source for shading. If True, the light + source position is inferred from `elev` and `azim`. If a tuple of + (alt, az), these values will be used to specify the light source + position. If None or False, no shading is applied. Default is True. **kwargs : kwargs Any other kwargs to pass to matplotlib.pyplot.figure (such as figsize) @@ -216,9 +223,9 @@ def plot_brain_overlay( if ax[0] is not None: - _plot_hemi(brain.lh, cmap, ax[0], view=view, vmin=vmin, vmax=vmax, threshold=threshold) + _plot_hemi(brain.lh, cmap, ax[0], view=view, vmin=vmin, vmax=vmax, threshold=threshold, light_source=light_source) if ax[1] is not None: - _plot_hemi(brain.rh, cmap, ax[1], view=view, vmin=vmin, vmax=vmax, threshold=threshold) + _plot_hemi(brain.rh, cmap, ax[1], view=view, vmin=vmin, vmax=vmax, threshold=threshold, light_source=light_source) return fig, ax