From 8375d09136f5ac271e3692da6947e0a5977c1127 Mon Sep 17 00:00:00 2001 From: Tim Heap Date: Thu, 19 Feb 2026 15:18:13 +1100 Subject: [PATCH 1/5] Update transect implementation The Transect class can now transect 2D surface variables, not just 3d volumes. The plotting functions have been moved to the module level. New CrossSectionArtist and TransectStepArtists have been added, much like the new Artists for surface plotting. Tests and docs to come. --- src/emsarray/transect.py | 827 ++++++++++++++++----------------------- 1 file changed, 329 insertions(+), 498 deletions(-) diff --git a/src/emsarray/transect.py b/src/emsarray/transect.py index 507f22a..ae9ea62 100644 --- a/src/emsarray/transect.py +++ b/src/emsarray/transect.py @@ -1,4 +1,6 @@ import dataclasses +import itertools +import warnings from collections.abc import Callable, Iterable from functools import cached_property from typing import Any, cast @@ -11,16 +13,26 @@ from matplotlib import animation, pyplot from matplotlib.artist import Artist from matplotlib.axes import Axes -from matplotlib.collections import PolyCollection +from matplotlib.axis import Axis +from matplotlib.collections import QuadMesh from matplotlib.colors import Colormap from matplotlib.figure import Figure +from matplotlib.patches import StepPatch from matplotlib.ticker import EngFormatter, Formatter -from emsarray.conventions import Convention -from emsarray.plot import _requires_plot, make_plot_title +from emsarray.conventions import Convention, Grid +from emsarray.plot import _requires_plot, make_plot_title, GridArtist from emsarray.types import DataArrayOrName, Landmark +from emsarray.operations import depth from emsarray.utils import move_dimensions_to_end, name_to_data_array +""" +Things to do: +* Make 1d transects possible +* enhance plotting routines to make `plot_on_axes` and `make_artist` +""" + + # Useful for calculating distances in a AzimuthalEquidistant projection # centred on some point: # @@ -29,7 +41,102 @@ ORIGIN = shapely.Point(0, 0) -def plot( +class CrossSectionArtist(QuadMesh): + transect: "Transect" + + def set_transect(self, transect: "Transect") -> None: + self.transect = transect + + def set_data_array(self, data_array: xarray.DataArray) -> None: + self.set_array(self.prepare_data_array(self.transect, data_array)) + + @classmethod + def from_transect( + cls, + transect: "Transect", + *, + data_array: xarray.DataArray | None = None, + depth_coordinate: xarray.DataArray | None = None, + **kwargs: Any, + ) -> "CrossSectionArtist": + distance_bounds = transect.intersection_bounds + + if depth_coordinate is None and data_array is None: + raise ValueError( + "At least one of data_array and depth_coordinate must be not None") + if depth_coordinate is None: + depth_coordinate = transect.convention.get_depth_coordinate_for_data_array(data_array) + depth_bounds = transect.dataset[depth_coordinate.attrs['bounds']].values + + holes = transect.holes + xs = numpy.concat([distance_bounds[:, 0], distance_bounds[-1:, 1]]) + xs = numpy.insert(xs, holes, distance_bounds[holes - 1, 1]) + ys = numpy.concat([depth_bounds[:, 0], depth_bounds[-1:, 1]]) + coordinates = numpy.stack(numpy.meshgrid(xs, ys), axis=-1) + + # There are issues with passing both transect and data array to the constructor + # where the `set_data_array()` is called before `set_transect()`. + # Doing it this way is safe but kinda gross. + artist = cls(coordinates, transect=transect, **kwargs) + if data_array is not None: + artist.set_data_array(data_array) + + return artist + + @staticmethod + def prepare_data_array(transect: "Transect", data_array: xarray.DataArray) -> numpy.ndarray: + values = transect.extract(data_array).values + values = numpy.insert(values, transect.holes, numpy.nan, axis=-1) + return values + + +class TransectStepArtist(StepPatch): + transect: "Transect" + + def set_transect(self, transect: "Transect") -> None: + self.transect = transect + + def set_data_array(self, data_array: xarray.DataArray) -> None: + self.set_data(self.prepare_data_array(self.transect, data_array)) + + @classmethod + def from_transect( + cls, + transect: "Transect", + *, + data_array: xarray.DataArray | None = None, + **kwargs: Any, + ) -> "TransectStepArtist": + holes = transect.holes + x_bounds = transect.intersection_bounds + edges = x_bounds[:, 0] + edges = numpy.append(edges, x_bounds[-1, 1]) + edges = numpy.insert(edges, holes, x_bounds[holes - 1, 1]) + + if data_array is not None: + values = cls.prepare_data_array(transect, data_array) + else: + values = numpy.full(shape=(len(edges) - 1,), fill_value=numpy.nan) + + return cls(values, edges, transect=transect, **kwargs) + + @staticmethod + def prepare_data_array(transect: "Transect", data_array: xarray.DataArray) -> numpy.ndarray: + values = transect.extract(data_array).values + assert len(values.shape) == 1 + + # If a transect path is not fully contained within the dataset geometry + # the path will have gaps. We can represent these gaps using nans. + values = numpy.insert( + values.astype(float), # Upcast to float in case this was an integer array + transect.holes, + numpy.nan, + ) + + return values + + +def plot_2d_transect( dataset: xarray.Dataset, line: shapely.LineString, data_array: xarray.DataArray, @@ -57,13 +164,174 @@ def plot( Passed to :meth:`Transect.plot_on_figure()`. """ figure = pyplot.figure(layout="constrained", figsize=figsize) - depth_coordinate = dataset.ems.get_depth_coordinate_for_data_array(data_array) - transect = Transect(dataset, line, depth=depth_coordinate) - transect.plot_on_figure(figure, data_array, **kwargs) + data_array = name_to_data_array(dataset, data_array) + + grid = dataset.ems.get_grid(data_array) + transect = Transect(dataset, line, grid=grid) + + axes = figure.subplots() + collection = _plot_on_axes(dataset, data_array, transect, axes, **kwargs) + + units = data_array.attrs.get('units') + figure.colorbar(collection, ax=axes, location='right', label=units) + pyplot.show() return figure +def _plot_on_axes( + dataset: xarray.Dataset, + data_array: xarray.DataArray, + transect: "Transect", + axes: Axes, + *, + title: str | None = None, + bathymetry: xarray.DataArray | None = None, + cmap: str | Colormap | None = None, + clim: tuple[float, float] | None = None, + ocean_floor_colour: str = 'black', + landmarks: list[Landmark] | None = None, + **kwargs: Any, +) -> CrossSectionArtist: + """ + Construct the axes and PolyCollections on a plot, + and reformat the data array to the correct shape for plotting. + Assigning the data is left to the caller, + to support both static and animated plots. + """ + depth_coordinate = dataset.ems.get_depth_coordinate_for_data_array(data_array) + depth_bounds = dataset[depth_coordinate.attrs['bounds']].values + + distance_bounds = transect.intersection_bounds + + data_array = data_array.load() + + positive_down = depth_coordinate.attrs['positive'] == 'down' + d1, d2 = depth_coordinate.values[0:2] + deep_to_shallow = (d1 > d2) == positive_down + + depth_start, depth_stop = 0, -1 + if deep_to_shallow: + depth_start, depth_stop = depth_stop, depth_start + + down, up = ( + (numpy.nanmax, numpy.nanmin) + if positive_down + else (numpy.nanmin, numpy.nanmax)) + depth_limit_shallow = up(depth_bounds[depth_start]) + depth_limit_deep = down(depth_bounds[depth_stop]) + + _setup_distance_axis(axes.xaxis, distance_bounds) + + axes.yaxis.set_label_text(depth_coordinate.attrs.get('long_name')) + axes.set_xlim( + transect.points[0].distance_metres, + transect.points[-1].distance_metres, + ) + axes.set_ylim(depth_limit_deep, depth_limit_shallow) + + if title is None: + title = make_plot_title(dataset, data_array) + if title is not None: + axes.set_title(title) + + cmap = pyplot.get_cmap(cmap).copy() + cmap.set_bad(ocean_floor_colour) + + collection = CrossSectionArtist.from_transect( + transect, depth_coordinate=depth_coordinate, data_array=data_array, cmap=cmap, **kwargs) + axes.add_collection(collection) + + if bathymetry is not None: + bathymetry_artist = TransectStepArtist.from_transect( + transect, data_array=bathymetry, + facecolor=ocean_floor_colour, + fill=True, baseline=depth_limit_deep) + axes.add_patch(bathymetry_artist) + + if landmarks is not None: + top_axis = axes.secondary_xaxis('top') + top_axis.set_ticks( + [transect.distance_along_line(point) for label, point in landmarks], + [label for label, point in landmarks], + ) + + return collection + + +def setup_distance_axis(transect: "Transect", axes: Axes) -> None: + axis = axes.xaxis + + axes.set_xlim(transect.points[0].distance_metres, transect.points[-1].distance_metres) + axis.set_label_text("Distance along transect") + axis.set_major_formatter(EngFormatter(unit='m')) + + +def setup_depth_axis( + transect: "Transect", + depth_coordinate: xarray.DataArray, + axes: Axes, +) -> None: + axis = axes.yaxis + + depth_bounds = transect.dataset[depth_coordinate.attrs['bounds']].values + positive_down = depth_coordinate.attrs['positive'] == 'down' + depth_min, depth_max = numpy.nanmin(depth_bounds), numpy.nanmax(depth_bounds) + + if positive_down: + axes.set_ylim(depth_max, depth_min) + else: + axes.set_ylim(depth_min, depth_max) + + label = depth_coordinate.attrs.get('long_name') + if label is not None: + axis.set_label_text(label) + + units = depth_coordinate.attrs.get('units') + if units is not None: + formatted_units = cfunits.Units(units).formatted() + axis.set_major_formatter(EngFormatter(unit=formatted_units)) + + +def _find_depth_bounds( + dataset: xarray.Dataset, + data_array: xarray.DataArray, +) -> tuple[int, int]: + """ + Find the shallowest and deepest layers of the data array + where there is at least one value per depth. + + Most ocean models represent cells that are below the sea floor as nans. + Some ocean models do the same for layers above the sea surface, + which can vary due to tides. + If a transect covers mostly shallow regions + but the dataset includes very deep layers + the shallow regions become very small on the final plot. + + This function finds the indexes of the deepest and shallowest layers + where the values are not entirely nan + along the transect path. + The transect plot can use these to only plot depth values that have data, + trimming off layers that are nothing but ocean floor. + """ + depth_coordinate = dataset.ems.get_depth_coordinate_for_data_array(data_array) + dim = depth_coordinate.dims[0] + + start = 0 + for index in range(depth_coordinate.size): + if numpy.any(numpy.isfinite(data_array.isel({dim: index}).values)): + start = index + break + + stop = -1 + for index in reversed(range(depth_coordinate.size)): + if numpy.any(numpy.isfinite(data_array.isel({dim: index}))): + stop = index + break + + return start, stop + + @dataclasses.dataclass class TransectPoint: """ @@ -109,129 +377,70 @@ class Transect: #: The transect path to plot line: shapely.LineString - #: The depth coordinate (or the name of the depth coordinate) for the dataset. - depth: xarray.DataArray + #: The dataset grid to transect. + grid: Grid def __init__( self, dataset: xarray.Dataset, line: shapely.LineString, - depth: DataArrayOrName | None = None, + *, + grid: Grid | None = None, ): self.dataset = dataset - self.convention = dataset.ems + self.convention = cast(Convention, dataset.ems) self.line = line - if depth is not None: - self.depth = name_to_data_array(dataset, depth) - else: - self.depth = self.convention.depth_coordinate - - @cached_property - def convention(self) -> Convention: - convention: Convention = self.dataset.ems - return convention + if grid is None: + grid = self.convention.default_grid + self.grid = grid @cached_property - def transect_dataset(self) -> xarray.Dataset: + def intersection_bounds( + self, + ) -> numpy.ndarray: """ - A :class:`~xarray.Dataset` containing all the transect geometry. - This includes the depth data, path lengths, - and the linear index of each intersecting cell in the source dataset. - This transect dataset contains all the information necessary to generate a plot, - except for the actual variable data being plotted. + A numpy array of shape (len(segments), 2) + indicating the distance to the start and end of each intersection segment. + This is a shortcut to :attr:`TransectSegment.start_distance` + and :attr:`~TransectSegment.end_distance` from :attr:`Transect.segments`. """ - depth = self.depth - - depth_dimension = depth.dims[0] - - depth_bounds = None - try: - depth_bounds = self.convention.dataset[depth.attrs['bounds']].values - except KeyError: - # Make up some depth bounds data from the depth values - # The top/bottom values will be the first/last depth values, - # all other points are the midpoints between the neighbouring points. - depth_midpoints = numpy.concatenate([ - [depth.values[0]], - (depth.values[1:] + depth.values[:-1]) / 2, - [depth.values[-1]] - ]) - depth_bounds = numpy.column_stack(( - depth_midpoints[:-1], - depth_midpoints[1:], - )) - - try: - positive_down = depth.attrs['positive'] == 'down' - except KeyError as err: - raise ValueError( - f'Depth variable {depth.name!r} must have a `positive` attribute' - ) from err - - linear_indexes = [segment.linear_index for segment in self.segments] - depth = xarray.DataArray( - data=depth.values, - dims=(depth_dimension,), - attrs={ - 'bounds': 'depth_bounds', - 'positive': 'down' if positive_down else 'up', - 'long_name': depth.attrs.get('long_name'), - 'description': depth.attrs.get('description'), - 'units': depth.attrs.get('units'), - }, - ) - depth_bounds = xarray.DataArray( - data=depth_bounds, - dims=(depth_dimension, 'bounds'), - ) - distance_bounds = xarray.DataArray( - data=numpy.fromiter( - ( - [segment.start_distance, segment.end_distance] - for segment in self.segments - ), - # Be explicit here, to handle the case when len(self.segments) == 0. - # This happens when the transect line does not intersect the dataset. - # This will result in an empty transect plot. - count=len(self.segments), - dtype=numpy.dtype((float, 2)), + return numpy.fromiter( + ( + [segment.start_distance, segment.end_distance] + for segment in self.segments ), - dims=('index', 'bounds'), - attrs={ - 'long_name': 'Distance along transect', - 'units': 'm', - 'start_distance': self.points[0].distance_metres, - 'end_distance': self.points[-1].distance_metres, - }, - ) - linear_index = xarray.DataArray( - data=linear_indexes, - dims=('index',) + # Be explicit here, to handle the case when len(self.segments) == 0. + # This happens when the transect line does not intersect the dataset. + # This will result in an empty transect plot. + count=len(self.segments), + dtype=numpy.dtype((float, 2)), ) - return xarray.Dataset( - data_vars={ - 'depth_bounds': depth_bounds, - 'distance_bounds': distance_bounds, - }, - coords={ - 'depth': depth, - 'linear_index': linear_index, - }, + @cached_property + def linear_indexes(self) -> numpy.ndarray: + """ + A numpy array of shape (len(segments)) + of the linear index of each intersecting polygon, in order. + This is a shortcut to :attr:`TransectSegment.linear_index` + from :attr:`Transect.segments`. + """ + return numpy.fromiter( + (segment.linear_index for segment in self.segments), + count=len(self.segments), + dtype=numpy.dtype(int), ) - def _set_up_axis(self, variable: xarray.DataArray) -> tuple[str, Formatter]: - title = str(variable.attrs.get('long_name')) - units: str | None = variable.attrs.get('units') - - if units is not None: - # Use cfunits to normalize the units to their short symbol form. - # EngFormatter will write 'k{unit}', 'G{unit}', etc - # so unit symbols are required. - formatted_units = cfunits.Units(units).formatted() - formatter = EngFormatter(unit=formatted_units) - - return title, formatter + @cached_property + def holes(self) -> numpy.ndarray: + """ + An array with the index of any discontinuities in the transect segments. + For transect paths that are entirely within the dataset geometry this will be empty. + For paths that pass in and out of the dataset geometry + this will be the index of the segment just after the discontinuity. + Two segments are not contiguous if `segment[n].end_distance != segment[n+1].start_distance` + """ + bounds = self.intersection_bounds + return numpy.flatnonzero(bounds[:-1, 1] != bounds[1:, 0]) + 1 def _crs_for_point( self, @@ -397,106 +606,20 @@ def distance_along_line(self, point: shapely.Point) -> float: line_point.crs.project_geometry(point, src_crs=data_crs)) return line_point.distance_metres + distance_from_point - def make_poly_collection( - self, - **kwargs: Any, - ) -> PolyCollection: - """ - Make a :class:`matplotlib.collections.PolyCollection` - representing the transect geometry. - - Parameters - ---------- - **kwargs - Any keyword arguments are passed to the PolyCollection constructor. - - Returns - ------- - matplotlib.collections.PolyCollection - A PolyCollection representing all the cells - and all the depths the transect line intesected. - """ - transect_dataset = self.transect_dataset - distance_bounds = transect_dataset['distance_bounds'].values - depth_bounds = transect_dataset['depth_bounds'].values - vertices = [ - [ - (distance_bounds[index, 0], depth_bounds[depth_index][0]), - (distance_bounds[index, 0], depth_bounds[depth_index][1]), - (distance_bounds[index, 1], depth_bounds[depth_index][1]), - (distance_bounds[index, 1], depth_bounds[depth_index][0]), - ] - for depth_index in range(transect_dataset.coords['depth'].size) - for index in range(transect_dataset.sizes['index']) - ] - return PolyCollection(vertices, **kwargs) - - def make_ocean_floor_poly_collection( - self, - bathymetry: xarray.DataArray, - **kwargs: Any - ) -> PolyCollection: - """ - Make a :class:`matplotlib.collections.PolyCollection` - representing the ocean floor. - This can be overlayed on a transect plot to mask out values below the sea floor. - - Parameters - ---------- - bathymetry : xarray.Dataset - A data array containing bathymetry data for the dataset. - **kwargs - Any keyword arguments are passed on to the - :class:`~matplotlib.collections.PolyCollection` constructor - - Returns - ------- - matplotlib.collections.PolyCollection - A collection of polygons representing - the ocean floor along the transect path. + def extract(self, data_array: xarray.DataArray) -> xarray.DataArray: """ - transect_dataset = self.transect_dataset - depth = transect_dataset['depth'] - - bathymetry_values = self.convention.ravel(bathymetry) - # The bathymetry data can be oriented differently to the depth coordinate. - # Correct for this if so. - if 'positive' in bathymetry.attrs: - if bathymetry.attrs['positive'] != depth.attrs['positive']: - bathymetry_values = -bathymetry_values - - positive_down = depth.attrs['positive'] == 'down' - deepest_fn = numpy.nanmax if positive_down else numpy.nanmin - deepest = deepest_fn(bathymetry_values.values) - - distance_bounds = transect_dataset['distance_bounds'].values - linear_indexes = transect_dataset['linear_index'].values - - vertices = [ - [ - (distance_bounds[index, 0], bathymetry_values[linear_indexes[index]]), - (distance_bounds[index, 0], deepest), - (distance_bounds[index, 1], deepest), - (distance_bounds[index, 1], bathymetry_values[linear_indexes[index]]), - ] - for index in range(transect_dataset.sizes['index']) - ] - return PolyCollection(vertices, **kwargs) - - def prepare_data_array_for_transect(self, data_array: xarray.DataArray) -> xarray.DataArray: - """ - Prepare a data array for being used as the data in a transect plot. + Extract data from a data array along a transect. Parameters ---------- data_array : xarray.DataArray - The data array that will be plotted + The data array to extract data from. Returns ------- xarray.DataArray - The input data array transformed to have the correct shape - for plotting on the transect. + A new :class:`xarray.DataArray` containing data from the input data array + extracted along the path of the transect. """ # Some of the following operations drop attrs, # so keep a reference to the original ones @@ -504,304 +627,12 @@ def prepare_data_array_for_transect(self, data_array: xarray.DataArray) -> xarra data_array = self.convention.ravel(data_array) - depth_dimension = self.transect_dataset.coords['depth'].dims[0] index_dimension = data_array.dims[-1] - data_array = move_dimensions_to_end(data_array, [depth_dimension, index_dimension]) + data_array = move_dimensions_to_end(data_array, [index_dimension]) - linear_indexes = self.transect_dataset['linear_index'].values - data_array = data_array.isel({index_dimension: linear_indexes}) + data_array = data_array.isel({index_dimension: self.linear_indexes}) # Restore attrs after reformatting data_array.attrs.update(attrs) return data_array - - def _find_depth_bounds(self, data_array: xarray.DataArray) -> tuple[int, int]: - """ - Find the shallowest and deepest layers of the data array - where there is at least one value per depth. - - Most ocean models represent cells that are below the sea floor as nans. - Some ocean models do the same for layers above the sea surface, - which can vary due to tides. - If a transect covers mostly shallow regions - but the dataset includes very deep layers - the shallow regions become very small on the final plot. - - This function finds the indexes of the deepest and shallowest layers - where the values are not entirely nan - along the transect path. - The transect plot can use these to only plot depth values that have data, - trimming off layers that are nothing but ocean floor. - """ - transect_dataset = self.transect_dataset - dim = transect_dataset['depth'].dims[0] - - start = 0 - for index in range(transect_dataset['depth'].size): - if numpy.any(numpy.isfinite(data_array.isel({dim: index}).values)): - start = index - break - - stop = -1 - for index in reversed(range(transect_dataset['depth'].size)): - if numpy.any(numpy.isfinite(data_array.isel({dim: index}))): - stop = index - break - - return start, stop - - @_requires_plot - def plot_on_figure( - self, - figure: Figure, - data_array: xarray.DataArray, - *, - title: str | None = None, - trim_nans: bool = True, - clamp_to_surface: bool = True, - bathymetry: xarray.DataArray | None = None, - cmap: str | Colormap | None = None, - clim: tuple[float, float] | None = None, - ocean_floor_colour: str = 'black', - landmarks: list[Landmark] | None = None, - ) -> None: - """ - Plot the data array along this transect. - - Parameters - ---------- - figure : matplotlib.figure.Figure - The figure to plot on - data_array : xarray.DataArray - The data array to plot. - This should be a data array from the dataset provided to the - Transect constructor, - or a data array of compatible shape. - title : str, optional - The title of the plot. - Defaults to the 'long_name' attribute of the data array. - trim_nans : bool, default True - Whether to trim layers containing all nans. - Layers that are entirely under the ocean floor are often represented as nans. - Without trimming, transects through shallow areas mostly look like ocean floor. - clamp_to_surface : bool, default True - If true, clamp the y-axis to 0 m. - Some datasets define an upper depth bound of some large number - which rather spoils the plot. - bathymetry : xarray.DataArray, optional - A data array containing bathymetry information for the dataset. - This will be used to draw a more detailed ocean floor mask. - ocean_floor_colour : str, default 'grey' - The colour to draw the ocean floor in. - This is used to draw cells containing nan values, - and the bathymetry data. - landmarks : list of str, :class:`shapely.Point` tuples - A list of (name, point) tuples. - These will be added as tick marks along the top of the plot. - """ - axes, collection, data_array = self._plot_on_figure( - figure=figure, - data_array=data_array, - title=title, - trim_nans=trim_nans, - clamp_to_surface=clamp_to_surface, - bathymetry=bathymetry, - cmap=cmap, - clim=clim, - ocean_floor_colour=ocean_floor_colour, - landmarks=landmarks, - ) - collection.set_array(data_array.values.flatten()) - - def animate_on_figure( - self, - figure: Figure, - data_array: xarray.DataArray, - *, - title: str | Callable[[Any], str] | None = None, - trim_nans: bool = True, - clamp_to_surface: bool = True, - bathymetry: xarray.DataArray | None = None, - cmap: str | Colormap | None = None, - clim: tuple[float, float] | None = None, - ocean_floor_colour: str = 'black', - landmarks: list[Landmark] | None = None, - coordinate: xarray.DataArray | None = None, - interval: int = 200, - ) -> animation.FuncAnimation: - """ - Plot the data array along this transect. - - Parameters - ---------- - figure : matplotlib.figure.Figure - The figure to plot on - data_array : xarray.DataArray - The data array to plot. - This should be a data array from the dataset provided to the - Transect constructor, - or a data array of compatible shape. - title : str or callable - The title of the plot. - coordinate : xarray.DataArray - The coordinate to animate along. - Defaults to the time coordinate. - interval : int - Time in milliseconds between frames. - **kwargs - See :meth:`.plot_on_figure` for available keyword arguments - """ - if coordinate is None: - coordinate = self.convention.time_coordinate - coordinate_indexes = numpy.arange(coordinate.size) - animation_dimension = coordinate.dims[0] - - coordinate_callable: Callable[[Any], str] - if title is None: - title = data_array.attrs.get('long_name') - if title is not None: - coordinate_callable = lambda c: f'{title}\n{c}' - else: - coordinate_callable = str - - elif isinstance(title, str): - coordinate_callable = title.format - - else: - coordinate_callable = title - - first_frame = data_array.isel({animation_dimension: 0}) - first_frame.load() - axes, collection, _prepared_frame = self._plot_on_figure( - figure=figure, - data_array=first_frame, - title=None, - trim_nans=trim_nans, - clamp_to_surface=clamp_to_surface, - bathymetry=bathymetry, - cmap=cmap, - clim=clim, - ocean_floor_colour=ocean_floor_colour, - landmarks=landmarks, - ) - - def animate(index: int) -> Iterable[Artist]: - changes: list[Artist] = [] - - coordinate_value = coordinate.values[index] - axes.set_title(coordinate_callable(coordinate_value)) - changes.append(axes) - - frame_data = data_array.isel({animation_dimension: index}) - frame_data.load() - prepared_data = self.prepare_data_array_for_transect(frame_data) - collection.set_array(prepared_data.values.flatten()) - changes.append(collection) - return changes - - # Draw the figure to force everything to compute its size - figure.draw_without_rendering() - - # Set the first frame of data - animate(0) - - # Make the animation - return animation.FuncAnimation( - figure, animate, frames=coordinate_indexes, - interval=interval) - - def _plot_on_figure( - self, - figure: Figure, - data_array: xarray.DataArray, - *, - title: str | None = None, - trim_nans: bool = True, - clamp_to_surface: bool = True, - bathymetry: xarray.DataArray | None = None, - cmap: str | Colormap | None = None, - clim: tuple[float, float] | None = None, - ocean_floor_colour: str = 'black', - landmarks: list[Landmark] | None = None, - ) -> tuple[Axes, PolyCollection, xarray.DataArray]: - """ - Construct the axes and PolyCollections on a plot, - and reformat the data array to the correct shape for plotting. - Assigning the data is left to the caller, - to support both static and animated plots. - """ - transect_dataset = self.transect_dataset - depth = transect_dataset.coords['depth'] - distance_bounds = transect_dataset.data_vars['distance_bounds'] - - data_array = data_array.load() - data_array = self.prepare_data_array_for_transect(data_array) - - positive_down = depth.attrs['positive'] == 'down' - d1, d2 = depth.values[0:2] - deep_to_shallow = (d1 > d2) == positive_down - - if trim_nans: - depth_start, depth_stop = self._find_depth_bounds(data_array) - else: - depth_start, depth_stop = 0, -1 - if deep_to_shallow: - depth_start, depth_stop = depth_stop, depth_start - - down, up = ( - (numpy.nanmax, numpy.nanmin) - if positive_down - else (numpy.nanmin, numpy.nanmax)) - if clamp_to_surface: - depth_limit_shallow = 0 - else: - depth_limit_shallow = up(transect_dataset['depth_bounds'][depth_start]) - depth_limit_deep = down(transect_dataset['depth_bounds'][depth_stop]) - - axes = cast(Axes, figure.subplots()) - x_title, x_formatter = self._set_up_axis(distance_bounds) - y_title, y_formatter = self._set_up_axis(depth) - axes.set_xlabel(x_title) - axes.set_ylabel(y_title) - axes.xaxis.set_major_formatter(x_formatter) - axes.yaxis.set_major_formatter(y_formatter) - axes.set_xlim( - distance_bounds.attrs['start_distance'], - distance_bounds.attrs['end_distance'], - ) - axes.set_ylim(depth_limit_deep, depth_limit_shallow) - - if title is None: - title = make_plot_title(self.dataset, data_array) - if title is not None: - axes.set_title(title) - - cmap = pyplot.get_cmap(cmap).copy() - cmap.set_bad(ocean_floor_colour) - - # Find a min/max from the data if clim isn't provided and the data array is not empty. - # An empty data array happens when the transect line does not intersect - # the dataset geometry. - if clim is None and data_array.size != 0: - clim = (numpy.nanmin(data_array), numpy.nanmax(data_array)) - - collection = self.make_poly_collection(cmap=cmap, clim=clim, edgecolor='face') - axes.add_collection(collection) - - if bathymetry is not None: - ocean_floor = self.make_ocean_floor_poly_collection( - bathymetry, facecolor=ocean_floor_colour) - axes.add_collection(ocean_floor) - - units = data_array.attrs.get('units') - figure.colorbar(collection, ax=axes, location='right', label=units) - - if landmarks is not None: - top_axis = axes.secondary_xaxis('top') - top_axis.set_ticks( - [self.distance_along_line(point) for label, point in landmarks], - [label for label, point in landmarks], - ) - - return axes, collection, data_array From e9799d973974cb65c8f89aae2640ce38bfc230f0 Mon Sep 17 00:00:00 2001 From: Tim Heap Date: Mon, 23 Feb 2026 15:29:38 +1100 Subject: [PATCH 2/5] Update transect example in the docs --- docs/api/transect.rst | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/api/transect.rst b/docs/api/transect.rst index a4c58cf..cba3617 100644 --- a/docs/api/transect.rst +++ b/docs/api/transect.rst @@ -14,8 +14,6 @@ Examples .. minigallery:: ../examples/plot-kgari-transect.py -.. autofunction:: plot - .. autoclass:: Transect :members: From 6861316d8c592e21515cfcd9fb514af6d6884502 Mon Sep 17 00:00:00 2001 From: Tim Heap Date: Wed, 25 Feb 2026 16:47:37 +1100 Subject: [PATCH 3/5] More transect updates --- docs/api/transect.rst | 58 ++- docs/conf.py | 3 +- examples/plot-animated-transect.py | 110 +++++ examples/plot-kgari-transect.py | 124 ++++-- src/emsarray/transect.py | 638 ----------------------------- src/emsarray/transect/__init__.py | 8 + src/emsarray/transect/artists.py | 126 ++++++ src/emsarray/transect/base.py | 608 +++++++++++++++++++++++++++ src/emsarray/transect/utils.py | 96 +++++ src/emsarray/utils.py | 57 +++ 10 files changed, 1151 insertions(+), 677 deletions(-) create mode 100644 examples/plot-animated-transect.py delete mode 100644 src/emsarray/transect.py create mode 100644 src/emsarray/transect/__init__.py create mode 100644 src/emsarray/transect/artists.py create mode 100644 src/emsarray/transect/base.py create mode 100644 src/emsarray/transect/utils.py diff --git a/docs/api/transect.rst b/docs/api/transect.rst index cba3617..d2b664c 100644 --- a/docs/api/transect.rst +++ b/docs/api/transect.rst @@ -1,22 +1,64 @@ -.. module:: emsarray.transect ================= emsarray.transect ================= -.. currentmodule:: emsarray.transect +.. module:: emsarray.transect -Plot transects through your dataset. -Transects are vertical slices along some path through your dataset. +This module provides methods for extracting and plotting data +along transects through your datasets. +A transect path is represented as a :class:`shapely.LineString`. +Data along the transect can be extracted in to a new :class:`xarray.Dataset`, +or plotted using :meth:`Transect.make_artist`. + +Currently it is only possible to take transects through grids with polygonal geometry. +Taking transects through other kinds of geometry is a planned future enhancement. Examples --------- +======== + +.. minigallery:: -.. minigallery:: ../examples/plot-kgari-transect.py + ../examples/plot-kgari-transect.py + ../examples/plot-animated-transect.py + +Transects +========= + +These classes find the intersection of a :class:`shapely.LineString` with a dataset +and provide methods to introspect this intersection, plot data along this path, +and extract data along this path. .. autoclass:: Transect :members: -.. autoclass:: TransectPoint +.. autoclass:: TransectPoint() + :members: + +.. autoclass:: TransectSegment() + :members: + +Artists +======= + +These classes plot data along a transect. +Transect artists are normally created by calling :meth:`.Transect.make_artist`. + +.. module:: emsarray.transect.artists + +.. autoclass:: TransectArtist() + :members: set_data_array + +.. autoclass:: CrossSectionArtist() + :members: from_transect + +.. autoclass:: TransectStepArtist() + :members: from_transect + +Utilities +========= + +.. currentmodule:: emsarray.transect -.. autoclass:: TransectSegment +.. autofunction:: setup_distance_axis +.. autofunction:: setup_depth_axis diff --git a/docs/conf.py b/docs/conf.py index 5d8d49d..694bfca 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -111,6 +111,7 @@ def setup(app): 'examples_dirs': '../examples', 'gallery_dirs': './examples', 'filename_pattern': '/plot-', - 'matplotlib_animations': True, + 'matplotlib_animations': (True, 'jshtml'), 'backreferences_dir': './examples/backreferences', + 'remove_config_comments': True, } diff --git a/examples/plot-animated-transect.py b/examples/plot-animated-transect.py new file mode 100644 index 0000000..c715792 --- /dev/null +++ b/examples/plot-animated-transect.py @@ -0,0 +1,110 @@ +""" +================= +Animated transect +================= + +Transect and cross section plots can be animated using +:meth:`.TransectArtist.set_data_array()` to update the data. +""" + +import shapely +import matplotlib.pyplot as plt +import pandas +import xarray +from matplotlib.artist import Artist +from matplotlib.animation import FuncAnimation +from matplotlib.colors import LogNorm +from matplotlib.ticker import ScalarFormatter + +from emsarray import transect, utils +from emsarray.operations import depth + +# The eReefs GBR4 Biogeochemistry and Sediments v4.2 baseline catchment scenario dataset +# contains one daily timestep per file. +# For an animation we need to open multiple files using xarray.open_mfdataset(). +dataset_url_template = ( + 'https://thredds.nci.org.au/thredds/dodsC/fx3/gbr4_H4p0_ABARRAr2_OBRAN2020_FG2Gv3_B4p2_Cq5b_Dhnd/' + 'gbr4_H4p0_ABARRAr2_OBRAN2020_FG2G_B4p2_Cq5b_Dhnd_simple_{date:%Y-%m-%d}.nc' +) +dataset_urls = [ + dataset_url_template.format(date=date) + for date in pandas.date_range('2022-10-01', '2022-10-14') +] +dataset = xarray.open_mfdataset( + dataset_urls, combine='by_coords', coords=['time'], + compat='override', data_vars='minimal') + +# Select only the variables we want to plot. +dataset = dataset.ems.select_variables(['botz', 'DIN']) + +# Fetch all the data. +# This will only fetch the variable we selected above, +# plus all the geometry and coordinate variables. +dataset.load() + +# The depth coordinate has positive=up, while the bathymetry has positive=down. +# This causes issues when drawing the ocean floor. +# Lets fix the depth coordinate. +dataset = depth.normalize_depth_variables(dataset, ['zc'], positive_down=True) + +# Cross section plots need bounds information, so lets invent some +dataset = utils.estimate_bounds_1d(dataset, 'zc') + +# %% +# Define a :class:`transect ` +# following a path starting near Bowen, passing Mackay, and ending past Gladstone. +reef_transect = transect.Transect(dataset, shapely.LineString([ + [149.19358465107092, -19.745298160117585], + [150.16742824240822, -21.444631352206670], + [151.33964738012760, -22.175746836082112], + [152.21129750817641, -23.780650909462680] +])) + + +# %% +# Now we set up a figure and add a cross section artist. + +# sphinx_gallery_defer_figures + +figure = plt.figure(figsize=(7.8, 3), layout='constrained', dpi=100) +axes = figure.add_subplot() + +axes.set_title('Dissolved inorganic nitrogen') +date_labels = [ + utils.datetime_from_np_time(date).strftime('%Y-%m-%d') + for date in dataset['time'].values +] + +din = dataset['DIN'] +din_artist = reef_transect.make_cross_section_artist( + axes, din.isel(time=0), cmap='plasma', + norm=LogNorm(0.01, .7), colorbar=False) +figure.colorbar( + din_artist, ticks=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5], + format=ScalarFormatter()) + +reef_transect.make_ocean_floor_artist( + axes, dataset['botz']) + +date_annotation = axes.annotate( + date_labels[0], + xy=(5, 5), xycoords='axes points', + verticalalignment='bottom', horizontalalignment='left') + +transect.setup_distance_axis( + reef_transect, axes) +transect.setup_depth_axis( + reef_transect, axes, data_array='DIN', + label='Depth', ylim=(80, -1.5)) + +# %% +# Finally we set up the animation. +# The ``update()`` function is called every frame to update the plot with new data. +# The :meth:`.TransectArtist.set_data_array()` function does all the hard work here. + +def update(frame: int) -> list[Artist]: + din_artist.set_data_array(din.isel(time=frame)) + date_annotation.set_text(date_labels[frame]) + return [din_artist, date_annotation] + +animation = FuncAnimation(figure, update, frames=din.sizes['time']) diff --git a/examples/plot-kgari-transect.py b/examples/plot-kgari-transect.py index 136597c..e00780c 100644 --- a/examples/plot-kgari-transect.py +++ b/examples/plot-kgari-transect.py @@ -5,20 +5,31 @@ """ import shapely -from matplotlib import pyplot +import matplotlib.pyplot as plt +from matplotlib import gridspec +from matplotlib.colors import LogNorm, PowerNorm import emsarray -from emsarray import plot, transect +from emsarray import plot, transect, utils +from emsarray.operations import depth -dataset_url = 'https://thredds.nci.org.au/thredds/dodsC/fx3/gbr4_H4p0_ABARRAr2_OBRAN2020_FG2Gv3_Dhnd/gbr4_simple_2022-10-31.nc' -dataset = emsarray.open_dataset(dataset_url).isel(time=-1) -dataset = dataset.ems.select_variables(['botz', 'temp']) +dataset_url = 'https://thredds.nci.org.au/thredds/dodsC/fx3/gbr4_H4p0_ABARRAr2_OBRAN2020_FG2Gv3_Dhnd/gbr4_simple_2022-10-01.nc' +# dataset_url = '~/example-datasets/gbr4_simple_2022-10-31.nc' +dataset = emsarray.open_dataset(dataset_url).isel(time=12) +# Select only the variables we want to plot. +dataset = dataset.ems.select_variables(['botz', 'temp', 'eta']) +# The depth coordinate has positive=up, while the bathymetry has positive=down. +# This causes issues when drawing the ocean floor. +# Lets fix the depth coordinate. +dataset = depth.normalize_depth_variables(dataset, ['zc'], positive_down=True) +# Cross section plots need bounds information, so lets invent some +dataset = utils.estimate_bounds_1d(dataset, 'zc') # %% # The following is a :mod:`transect ` path # starting in the Great Sandy Strait near K'gari, # heading roughly North out to deeper waters: -line = shapely.LineString([ +north_transect = transect.Transect(dataset, shapely.LineString([ [152.9768944, -25.4827962], [152.9701996, -25.4420345], [152.9727745, -25.3967620], @@ -33,41 +44,94 @@ [152.7607727, -24.3521012], [152.6392365, -24.1906056], [152.4792480, -24.0615124], -]) +])) landmarks = [ ('Round Island', shapely.Point(152.9262543, -25.2878719)), ('Lady Elliot Island', shapely.Point(152.7145958, -24.1129146)), ] + # %% -# Plot a transect showing temperature along this path. +# Set up three axes: one showing the transect path, +# one showing a temperature cross section along the transect, +# and one showing the sea surface height along the transect. + +# sphinx_gallery_defer_figures -figure = transect.plot( - dataset, line, dataset['temp'], - figsize=(7.9, 3), - bathymetry=dataset['botz'], - landmarks=landmarks, - title="Temperature", - cmap='Oranges_r') -pyplot.show() +figure = plt.figure(figsize=(7.8, 8), layout='constrained', dpi=100) +gs_root = gridspec.GridSpec(3, 1, figure=figure, height_ratios=[3, 1, 1]) +path_axes = figure.add_subplot(gs_root[0], projection=dataset.ems.data_crs) +temp_axes = figure.add_subplot(gs_root[1]) +eta_axes = figure.add_subplot(gs_root[2], sharex=temp_axes) # %% -# The path of the transect can be plotted using matplotlib. +# First make a plot showing the path of the transect overlayed on the bathymetry + +# sphinx_gallery_defer_figures +# sphinx_gallery_capture_repr_block = () -# Plot the path of the transect -figure = pyplot.figure(figsize=(5, 5), dpi=100) -axes = figure.add_subplot(projection=dataset.ems.data_crs) -axes.set_aspect(aspect='equal', adjustable='datalim') -axes.set_title('Transect path') +path_axes.set_aspect(aspect='equal', adjustable='datalim') +path_axes.set_title('Transect path') dataset.ems.make_artist( - axes, 'botz', cmap='Blues', clim=(0, 2000), edgecolor='face', + path_axes, 'botz', cmap='Blues', clim=(0, 2000), edgecolor='face', + norm=PowerNorm(gamma=0.5), linewidth=0.5, zorder=0) -axes = figure.axes[0] -axes.set_extent(plot.bounds_to_extent(line.envelope.buffer(0.2).bounds)) -axes.plot(*line.coords.xy, zorder=2, c='orange', linewidth=4) +path_axes.set_extent(plot.bounds_to_extent(north_transect.line.envelope.buffer(0.2).bounds)) +path_axes.plot(*north_transect.line.coords.xy, zorder=1, c='orange', linewidth=4) + +plot.add_coast(path_axes, zorder=1) +plot.add_gridlines(path_axes) +plot.add_landmarks(path_axes, landmarks) + +# %% +# Now plot a cross section along the transect showing the ocean temperature. +# As the temperature variable has a depth axis the cross section is two dimensional. + +# sphinx_gallery_defer_figures +# sphinx_gallery_capture_repr_block = () + +temp_axes.set_title('Temperature') +dataset['temp'].attrs['units'] = '°C' +dataset['zc'].attrs['long_name'] = 'Depth' + +north_transect.make_artist( + temp_axes, 'temp', cmap='plasma') +north_transect.make_ocean_floor_artist( + temp_axes, dataset['botz']) +# yaxis +transect.setup_depth_axis( + north_transect, temp_axes, data_array='temp', + label='Depth', ylim=(50, -1.5)) + + +# %% +# Now plot the sea surface height along the transect. +# As the sea surface height does not have a depth axis +# the transect is one dimensional. + +# sphinx_gallery_defer_figures +# sphinx_gallery_capture_repr_block = () + +eta_axes.set_title('Sea surface height') +eta_artist = north_transect.make_artist( + eta_axes, data_array=dataset['eta']) +# xaxis +transect.setup_distance_axis(north_transect, eta_axes) +# yaxis +eta_axes.set_ylim(-0.5, 1.5) +eta_axes.set_ylabel('Height above\nmean sea level') +eta_axes.axhline(0, linestyle='--', color='lightgrey') +eta_axes.yaxis.set_major_formatter("{x:.2g} m") + +# %% +# The last step is to add some landmarks along the top border of the axes +# to help viewers link the distance along transect path to geographic locations. + +top_axis = temp_axes.secondary_xaxis('top') +top_axis.set_ticks( + [north_transect.distance_along_line(point) for label, point in landmarks], + [label for label, point in landmarks], +) -plot.add_coast(axes, zorder=1) -plot.add_gridlines(axes) -plot.add_landmarks(axes, landmarks) -pyplot.show() +plt.show() diff --git a/src/emsarray/transect.py b/src/emsarray/transect.py deleted file mode 100644 index ae9ea62..0000000 --- a/src/emsarray/transect.py +++ /dev/null @@ -1,638 +0,0 @@ -import dataclasses -import itertools -import warnings -from collections.abc import Callable, Iterable -from functools import cached_property -from typing import Any, cast - -import cfunits -import numpy -import shapely -import xarray -from cartopy import crs -from matplotlib import animation, pyplot -from matplotlib.artist import Artist -from matplotlib.axes import Axes -from matplotlib.axis import Axis -from matplotlib.collections import QuadMesh -from matplotlib.colors import Colormap -from matplotlib.figure import Figure -from matplotlib.patches import StepPatch -from matplotlib.ticker import EngFormatter, Formatter - -from emsarray.conventions import Convention, Grid -from emsarray.plot import _requires_plot, make_plot_title, GridArtist -from emsarray.types import DataArrayOrName, Landmark -from emsarray.operations import depth -from emsarray.utils import move_dimensions_to_end, name_to_data_array - -""" -Things to do: -* Make 1d transects possible -* enhance plotting routines to make `plot_on_axes` and `make_artist` -""" - - -# Useful for calculating distances in a AzimuthalEquidistant projection -# centred on some point: -# -# az = crs.AzimuthalEquidistant(p1.x, p1.y) -# distance = az.project_geometry(p2).distance(ORIGIN) -ORIGIN = shapely.Point(0, 0) - - -class CrossSectionArtist(QuadMesh): - transect: "Transect" - - def set_transect(self, transect: "Transect") -> None: - self.transect = transect - - def set_data_array(self, data_array: xarray.DataArray) -> None: - self.set_array(self.prepare_data_array(self.transect, data_array)) - - @classmethod - def from_transect( - cls, - transect: "Transect", - *, - data_array: xarray.DataArray | None = None, - depth_coordinate: xarray.DataArray | None = None, - **kwargs: Any, - ) -> "CrossSectionArtist": - distance_bounds = transect.intersection_bounds - - if depth_coordinate is None and data_array is None: - raise ValueError( - "At least one of data_array and depth_coordinate must be not None") - if depth_coordinate is None: - depth_coordinate = transect.convention.get_depth_coordinate_for_data_array(data_array) - depth_bounds = transect.dataset[depth_coordinate.attrs['bounds']].values - - holes = transect.holes - xs = numpy.concat([distance_bounds[:, 0], distance_bounds[-1:, 1]]) - xs = numpy.insert(xs, holes, distance_bounds[holes - 1, 1]) - ys = numpy.concat([depth_bounds[:, 0], depth_bounds[-1:, 1]]) - coordinates = numpy.stack(numpy.meshgrid(xs, ys), axis=-1) - - # There are issues with passing both transect and data array to the constructor - # where the `set_data_array()` is called before `set_transect()`. - # Doing it this way is safe but kinda gross. - artist = cls(coordinates, transect=transect, **kwargs) - if data_array is not None: - artist.set_data_array(data_array) - - return artist - - @staticmethod - def prepare_data_array(transect: "Transect", data_array: xarray.DataArray) -> numpy.ndarray: - values = transect.extract(data_array).values - values = numpy.insert(values, transect.holes, numpy.nan, axis=-1) - return values - - -class TransectStepArtist(StepPatch): - transect: "Transect" - - def set_transect(self, transect: "Transect") -> None: - self.transect = transect - - def set_data_array(self, data_array: xarray.DataArray) -> None: - self.set_data(self.prepare_data_array(self.transect, data_array)) - - @classmethod - def from_transect( - cls, - transect: "Transect", - *, - data_array: xarray.DataArray | None = None, - **kwargs: Any, - ) -> "TransectStepArtist": - holes = transect.holes - x_bounds = transect.intersection_bounds - edges = x_bounds[:, 0] - edges = numpy.append(edges, x_bounds[-1, 1]) - edges = numpy.insert(edges, holes, x_bounds[holes - 1, 1]) - - if data_array is not None: - values = cls.prepare_data_array(transect, data_array) - else: - values = numpy.full(shape=(len(edges) - 1,), fill_value=numpy.nan) - - return cls(values, edges, transect=transect, **kwargs) - - @staticmethod - def prepare_data_array(transect: "Transect", data_array: xarray.DataArray) -> numpy.ndarray: - values = transect.extract(data_array).values - assert len(values.shape) == 1 - - # If a transect path is not fully contained within the dataset geometry - # the path will have gaps. We can represent these gaps using nans. - values = numpy.insert( - values.astype(float), # Upcast to float in case this was an integer array - transect.holes, - numpy.nan, - ) - - return values - - -def plot_2d_transect( - dataset: xarray.Dataset, - line: shapely.LineString, - data_array: xarray.DataArray, - *, - figsize: tuple = (12, 3), - **kwargs: Any, -) -> Figure: - """ - Plot a transect of a dataset. - - This is convenience function that handles the most common use cases. - For more options refer to the :class:`.Transect` class. - - Parameters - ---------- - dataset : xarray.Dataset - The dataset to transect. - line : shapely.LineString - The transect path to plot. - data_array : xarray.DataArray - A variable from the dataset to plot. - figsize : tuple of int, int - The size of the figure. - **kwargs - Passed to :meth:`Transect.plot_on_figure()`. - """ - figure = pyplot.figure(layout="constrained", figsize=figsize) - data_array = name_to_data_array(dataset, data_array) - - grid = dataset.ems.get_grid(data_array) - transect = Transect(dataset, line, grid=grid) - - axes = figure.subplots() - collection = _plot_on_axes(dataset, data_array, transect, axes, **kwargs) - - units = data_array.attrs.get('units') - figure.colorbar(collection, ax=axes, location='right', label=units) - - pyplot.show() - return figure - - -def _plot_on_axes( - dataset: xarray.Dataset, - data_array: xarray.DataArray, - transect: "Transect", - axes: Axes, - *, - title: str | None = None, - bathymetry: xarray.DataArray | None = None, - cmap: str | Colormap | None = None, - clim: tuple[float, float] | None = None, - ocean_floor_colour: str = 'black', - landmarks: list[Landmark] | None = None, - **kwargs: Any, -) -> CrossSectionArtist: - """ - Construct the axes and PolyCollections on a plot, - and reformat the data array to the correct shape for plotting. - Assigning the data is left to the caller, - to support both static and animated plots. - """ - depth_coordinate = dataset.ems.get_depth_coordinate_for_data_array(data_array) - depth_bounds = dataset[depth_coordinate.attrs['bounds']].values - - distance_bounds = transect.intersection_bounds - - data_array = data_array.load() - - positive_down = depth_coordinate.attrs['positive'] == 'down' - d1, d2 = depth_coordinate.values[0:2] - deep_to_shallow = (d1 > d2) == positive_down - - depth_start, depth_stop = 0, -1 - if deep_to_shallow: - depth_start, depth_stop = depth_stop, depth_start - - down, up = ( - (numpy.nanmax, numpy.nanmin) - if positive_down - else (numpy.nanmin, numpy.nanmax)) - depth_limit_shallow = up(depth_bounds[depth_start]) - depth_limit_deep = down(depth_bounds[depth_stop]) - - _setup_distance_axis(axes.xaxis, distance_bounds) - - axes.yaxis.set_label_text(depth_coordinate.attrs.get('long_name')) - axes.set_xlim( - transect.points[0].distance_metres, - transect.points[-1].distance_metres, - ) - axes.set_ylim(depth_limit_deep, depth_limit_shallow) - - if title is None: - title = make_plot_title(dataset, data_array) - if title is not None: - axes.set_title(title) - - cmap = pyplot.get_cmap(cmap).copy() - cmap.set_bad(ocean_floor_colour) - - collection = CrossSectionArtist.from_transect( - transect, depth_coordinate=depth_coordinate, data_array=data_array, cmap=cmap, **kwargs) - axes.add_collection(collection) - - if bathymetry is not None: - bathymetry_artist = TransectStepArtist.from_transect( - transect, data_array=bathymetry, - facecolor=ocean_floor_colour, - fill=True, baseline=depth_limit_deep) - axes.add_patch(bathymetry_artist) - - if landmarks is not None: - top_axis = axes.secondary_xaxis('top') - top_axis.set_ticks( - [transect.distance_along_line(point) for label, point in landmarks], - [label for label, point in landmarks], - ) - - return collection - - -def setup_distance_axis(transect: "Transect", axes: Axes) -> None: - axis = axes.xaxis - - axes.set_xlim(transect.points[0].distance_metres, transect.points[-1].distance_metres) - axis.set_label_text("Distance along transect") - axis.set_major_formatter(EngFormatter(unit='m')) - - -def setup_depth_axis( - transect: "Transect", - depth_coordinate: xarray.DataArray, - axes: Axes, -) -> None: - axis = axes.yaxis - - depth_bounds = transect.dataset[depth_coordinate.attrs['bounds']].values - positive_down = depth_coordinate.attrs['positive'] == 'down' - depth_min, depth_max = numpy.nanmin(depth_bounds), numpy.nanmax(depth_bounds) - - if positive_down: - axes.set_ylim(depth_max, depth_min) - else: - axes.set_ylim(depth_min, depth_max) - - label = depth_coordinate.attrs.get('long_name') - if label is not None: - axis.set_label_text(label) - - units = depth_coordinate.attrs.get('units') - if units is not None: - formatted_units = cfunits.Units(units).formatted() - axis.set_major_formatter(EngFormatter(unit=formatted_units)) - - -def _find_depth_bounds( - dataset: xarray.Dataset, - data_array: xarray.DataArray, -) -> tuple[int, int]: - """ - Find the shallowest and deepest layers of the data array - where there is at least one value per depth. - - Most ocean models represent cells that are below the sea floor as nans. - Some ocean models do the same for layers above the sea surface, - which can vary due to tides. - If a transect covers mostly shallow regions - but the dataset includes very deep layers - the shallow regions become very small on the final plot. - - This function finds the indexes of the deepest and shallowest layers - where the values are not entirely nan - along the transect path. - The transect plot can use these to only plot depth values that have data, - trimming off layers that are nothing but ocean floor. - """ - depth_coordinate = dataset.ems.get_depth_coordinate_for_data_array(data_array) - dim = depth_coordinate.dims[0] - - start = 0 - for index in range(depth_coordinate.size): - if numpy.any(numpy.isfinite(data_array.isel({dim: index}).values)): - start = index - break - - stop = -1 - for index in reversed(range(depth_coordinate.size)): - if numpy.any(numpy.isfinite(data_array.isel({dim: index}))): - stop = index - break - - return start, stop - - -@dataclasses.dataclass -class TransectPoint: - """ - A TransectPoint holds information about each vertex along a transect path. - """ - #: The original point, in the CRS of the line string / dataset. - point: shapely.Point - - #: An AzimuthalEquidistant CRS centred on this point. - crs: crs.AzimuthalEquidistant - - #: The distance in metres of this point along the line. - distance_metres: float - - #: The projected distance along the line of this point. - #: This is normalised to [0, 1]. - #: The actual value is meaningless but can be used to find - #: the closest vertex on the line string for any other projected point. - distance_normalised: float - - -@dataclasses.dataclass -class TransectSegment: - """ - A TransectSegment holds information about each intersecting segment of the - transect path and the dataset cells. - """ - start_point: shapely.Point - end_point: shapely.Point - intersection: shapely.LineString - start_distance: float - end_distance: float - linear_index: int - polygon: shapely.Polygon - - -class Transect: - """ - """ - #: The dataset to plot a transect through - dataset: xarray.Dataset - - #: The transect path to plot - line: shapely.LineString - - #: The dataset grid to transect. - grid: Grid - - def __init__( - self, - dataset: xarray.Dataset, - line: shapely.LineString, - *, - grid: Grid | None = None, - ): - self.dataset = dataset - self.convention = cast(Convention, dataset.ems) - self.line = line - if grid is None: - grid = self.convention.default_grid - self.grid = grid - - @cached_property - def intersection_bounds( - self, - ) -> numpy.ndarray: - """ - A numpy array of shape (len(segments), 2) - indicating the distance to the start and end of each intersection segment. - This is a shortcut to :attr:`TransectSegment.start_distance` - and :attr:`~TransectSegment.end_distance` from :attr:`Transect.segments`. - """ - return numpy.fromiter( - ( - [segment.start_distance, segment.end_distance] - for segment in self.segments - ), - # Be explicit here, to handle the case when len(self.segments) == 0. - # This happens when the transect line does not intersect the dataset. - # This will result in an empty transect plot. - count=len(self.segments), - dtype=numpy.dtype((float, 2)), - ) - - @cached_property - def linear_indexes(self) -> numpy.ndarray: - """ - A numpy array of shape (len(segments)) - of the linear index of each intersecting polygon, in order. - This is a shortcut to :attr:`TransectSegment.linear_index` - from :attr:`Transect.segments`. - """ - return numpy.fromiter( - (segment.linear_index for segment in self.segments), - count=len(self.segments), - dtype=numpy.dtype(int), - ) - - @cached_property - def holes(self) -> numpy.ndarray: - """ - An array with the index of any discontinuities in the transect segments. - For transect paths that are entirely within the dataset geometry this will be empty. - For paths that pass in and out of the dataset geometry - this will be the index of the segment just after the discontinuity. - Two segments are not contiguous if `segment[n].end_distance != segment[n+1].start_distance` - """ - bounds = self.intersection_bounds - return numpy.flatnonzero(bounds[:-1, 1] != bounds[1:, 0]) + 1 - - def _crs_for_point( - self, - point: shapely.Point, - globe: crs.Globe | None = None, - ) -> crs.Projection: - return crs.AzimuthalEquidistant( - central_longitude=point.x, central_latitude=point.y, globe=globe) - - @cached_property - def points( - self, - ) -> list[TransectPoint]: - """ - A list of :class:`TransectPoints `, - one for each point in the transect :attr:`.line`. - """ - data_crs = self.convention.data_crs - globe = data_crs.globe - - # Make the TransectPoint for the first point by hand. - point = shapely.Point(self.line.coords[0]) - points = [TransectPoint( - point=point, - crs=self._crs_for_point(point, globe), - distance_metres=0, - distance_normalised=0, - )] - - # Make a TransectPoint for each subsequent point along the line. - for point in map(shapely.Point, self.line.coords[1:]): - previous = points[-1] - - # Calculate the distance from the previous point - # by using the AzimuthalEquidistant CRS centred on the previous point. - distance_from_previous = ORIGIN.distance( - previous.crs.project_geometry(point, src_crs=data_crs)) - - points.append(TransectPoint( - point=point, - crs=self._crs_for_point(point, globe), - distance_metres=previous.distance_metres + distance_from_previous, - distance_normalised=self.line.project(point, normalized=True) - )) - - return points - - @cached_property - def segments(self) -> list[TransectSegment]: - """ - A list of :class:`.TransectSegmens` for each intersecting segment of the transect line and the dataset geometry. - Segments are listed in order from the start of the line to the end of the line. - """ - segments = [] - - grid = self.convention.grids[self.convention.default_grid_kind] - polygons = grid.geometry - - # Find all the cell polygons that intersect the line - intersecting_indexes = grid.strtree.query(self.line, predicate='intersects') - - for linear_index in intersecting_indexes: - polygon = polygons[linear_index] - for intersection in self._intersect_polygon(polygon): - # The line will have two ends. - # The intersection starts and ends at these points. - # Project those points alone the original line to find - # the start and end distance of the intersection along the line. - points = [ - shapely.Point(intersection.coords[0]), - shapely.Point(intersection.coords[-1]) - ] - projections: Iterable[tuple[shapely.Point, float]] = ( - (point, self.distance_along_line(point)) - for point in points) - start, end = sorted(projections, key=lambda pair: pair[1]) - - segments.append(TransectSegment( - start_point=start[0], - end_point=end[0], - intersection=intersection, - start_distance=start[1], - end_distance=end[1], - linear_index=linear_index, - polygon=polygon, - )) - - return sorted(segments, key=lambda i: (i.start_distance, i.end_distance)) - - def _intersect_polygon( - self, - polygon: shapely.Polygon, - ) -> list[shapely.LineString]: - """ - Intersect a cell of the dataset geometry with the transect line, - and return a list of all LineString segments of the intersection. - This assumes that the cell does intersect the transect line. - A line and a polygon can intersect in a number of ways: - - * a simple cut through the polygon - * the line starts and/or stops in the polygon - * the line intersects the polygon at a point - * the line intersects the polygon multiple times - - Only the intersections that are line segments are returned. - Multiple intersections (represented as a GeometryCollection) - are decomposed in to the component geometries. - Points are ignored. - - Parameters - ---------- - polygon : shapely.Polygon - The cell geometry to intersect - - Returns - ------- - list of shapely.LineString - All intersecting line strings - """ - intersection = polygon.intersection(self.line) - if isinstance(intersection, (shapely.GeometryCollection, shapely.MultiLineString)): - geoms = intersection.geoms - else: - geoms = [intersection] - return [geom for geom in geoms if isinstance(geom, shapely.LineString)] - - def distance_along_line(self, point: shapely.Point) -> float: - """ - Calculate the distance in metres that the point - falls along the :attr:`transect line <.line>`. - If the point is not on the line, - the point is projected on to the line - and the distance is calculated to this point instead. - - This can be used to calculate the distance along the transect line - to landmark features. - These landmark features can be added as tick points along the transect. - The landmark features need not fall directly on the line. - - Parameters - ---------- - point : shapely.Point - The point to calculate the distance to - - Returns - ------- - float - The distance the point is along the line in meters. - If the point does not fall on the line, - the point is first projected to the line. - """ - data_crs = self.convention.data_crs - distance_normalised = self.line.project(point, normalized=True) - if distance_normalised < 0 or distance_normalised > 1: - raise ValueError("Point is not on the line!") - - # Find the TransectPoint for the vertex before this point on the line - line_point = next( - lp for lp in reversed(self.points) - if lp.distance_normalised <= distance_normalised) - - distance_from_point: float = ORIGIN.distance( - line_point.crs.project_geometry(point, src_crs=data_crs)) - return line_point.distance_metres + distance_from_point - - def extract(self, data_array: xarray.DataArray) -> xarray.DataArray: - """ - Extract data from a data array along a transect. - - Parameters - ---------- - data_array : xarray.DataArray - The data array to extract data from. - - Returns - ------- - xarray.DataArray - A new :class:`xarray.DataArray` containing data from the input data array - extracted along the path of the transect. - """ - # Some of the following operations drop attrs, - # so keep a reference to the original ones - attrs = data_array.attrs - - data_array = self.convention.ravel(data_array) - - index_dimension = data_array.dims[-1] - data_array = move_dimensions_to_end(data_array, [index_dimension]) - - data_array = data_array.isel({index_dimension: self.linear_indexes}) - - # Restore attrs after reformatting - data_array.attrs.update(attrs) - - return data_array diff --git a/src/emsarray/transect/__init__.py b/src/emsarray/transect/__init__.py new file mode 100644 index 0000000..34808d5 --- /dev/null +++ b/src/emsarray/transect/__init__.py @@ -0,0 +1,8 @@ +from .base import Transect, TransectPoint, TransectSegment +from .utils import setup_depth_axis, setup_distance_axis + + +__all__ = [ + 'Transect', 'TransectPoint', 'TransectSegment', + 'setup_depth_axis', 'setup_distance_axis', +] diff --git a/src/emsarray/transect/artists.py b/src/emsarray/transect/artists.py new file mode 100644 index 0000000..2dc3d67 --- /dev/null +++ b/src/emsarray/transect/artists.py @@ -0,0 +1,126 @@ +from typing import Any + +import numpy +import xarray +from matplotlib.artist import Artist +from matplotlib.collections import QuadMesh +from matplotlib.patches import StepPatch + +from . import base + + +class TransectArtist(Artist): + """ + A matplotlib Artist subclass that knows what Transect it is associated with, + and has a `set_data_array()` method. + Users can call `TransectArtist.set_data_array()` to update the data in a plot. + This is useful when making animations, for example. + """ + _transect: 'base.Transect' + + def set_transect(self, transect: 'base.Transect') -> None: + if hasattr(self, '_transect'): + raise ValueError("_transect can not be changed once set") + self._transect = transect + + def get_transect(self) -> 'base.Transect': + return self._transect + + def set_data_array(self, data_array: Any) -> None: + """ + Update the data this artist is plotting. + """ + raise NotImplementedError("Subclasses must implement this") + + +class CrossSectionArtist(QuadMesh, TransectArtist): + @classmethod + def from_transect( + cls, + transect: "base.Transect", + *, + data_array: xarray.DataArray | None = None, + depth_coordinate: xarray.DataArray | None = None, + **kwargs: Any, + ) -> "CrossSectionArtist": + """ + Construct a :class:`CrossSectionArtist` for a transect. + """ + distance_bounds = transect.intersection_bounds + + if depth_coordinate is None and data_array is None: + raise ValueError( + "At least one of data_array and depth_coordinate must be not None") + if depth_coordinate is None: + depth_coordinate = transect.convention.get_depth_coordinate_for_data_array(data_array) + depth_bounds = transect.dataset[depth_coordinate.attrs['bounds']].values + + holes = transect.holes + xs = numpy.concat([distance_bounds[:, 0], distance_bounds[-1:, 1]]) + xs = numpy.insert(xs, holes, distance_bounds[holes - 1, 1]) + ys = numpy.concat([depth_bounds[:, 0], depth_bounds[-1:, 1]]) + coordinates = numpy.stack(numpy.meshgrid(xs, ys), axis=-1) + + # There are issues with passing both transect and data array to the constructor + # where the `set_data_array()` is called before `set_transect()`. + # Doing it this way is safe but kinda gross. + artist = cls(coordinates, transect=transect, **kwargs) + if data_array is not None: + artist.set_data_array(data_array) + + return artist + + def set_data_array(self, data_array: xarray.DataArray) -> None: + self.set_array(self.prepare_data_array(self._transect, data_array)) + + @staticmethod + def prepare_data_array(transect: "base.Transect", data_array: xarray.DataArray) -> numpy.ndarray: + values = transect.extract(data_array).values + values = numpy.insert(values, transect.holes, numpy.nan, axis=-1) + return values + + +class TransectStepArtist(StepPatch, TransectArtist): + _edge_default = True + + @classmethod + def from_transect( + cls, + transect: "base.Transect", + *, + data_array: xarray.DataArray | None = None, + **kwargs: Any, + ) -> "TransectStepArtist": + """ + Construct a :class:`TransectStepArtist` for a transect. + """ + holes = transect.holes + x_bounds = transect.intersection_bounds + edges = x_bounds[:, 0] + edges = numpy.append(edges, x_bounds[-1, 1]) + edges = numpy.insert(edges, holes, x_bounds[holes - 1, 1]) + + if data_array is not None: + values = cls.prepare_data_array(transect, data_array) + else: + values = numpy.full(shape=(len(edges) - 1,), fill_value=numpy.nan) + + return cls(values, edges, transect=transect, **kwargs) + + def set_data_array(self, data_array: xarray.DataArray) -> None: + self.set_data(self.prepare_data_array(self._transect, data_array)) + + @staticmethod + def prepare_data_array(transect: "base.Transect", data_array: xarray.DataArray) -> numpy.ndarray: + values = transect.extract(data_array).values + assert len(values.shape) == 1 + + # If a transect path is not fully contained within the dataset geometry + # the path will have gaps. We can represent these gaps using nans. + values = numpy.insert( + values.astype(float), # Upcast to float in case this was an integer array + transect.holes, + numpy.nan, + ) + + return values diff --git a/src/emsarray/transect/base.py b/src/emsarray/transect/base.py new file mode 100644 index 0000000..ce71629 --- /dev/null +++ b/src/emsarray/transect/base.py @@ -0,0 +1,608 @@ +import dataclasses +from collections.abc import Iterable +from functools import cached_property +from typing import Any, cast + +import numpy +import shapely +import xarray +from cartopy import crs +from matplotlib.axes import Axes +from matplotlib.typing import ColorType + +from emsarray.conventions import Convention, Grid +from emsarray.exceptions import NoSuchCoordinateError +from emsarray.types import DataArrayOrName +from emsarray.utils import move_dimensions_to_end, name_to_data_array + +from . import artists + + + +# Useful for calculating distances in a AzimuthalEquidistant projection +# centred on some point: +# +# az = crs.AzimuthalEquidistant(p1.x, p1.y) +# distance = az.project_geometry(p2).distance(ORIGIN) +ORIGIN = shapely.Point(0, 0) + + +@dataclasses.dataclass +class TransectPoint: + """ + A TransectPoint holds information about each vertex along a transect path. + """ + #: The original point, in the CRS of the line string / dataset. + point: shapely.Point + + #: An AzimuthalEquidistant CRS centred on this point. + crs: crs.AzimuthalEquidistant + + #: The distance in metres of this point along the line. + distance_metres: float + + #: The projected distance along the line of this point. + #: This is normalised to [0, 1]. + #: The actual value is meaningless but can be used to find + #: the closest vertex on the line string for any other projected point. + distance_normalised: float + + +@dataclasses.dataclass +class TransectSegment: + """ + A TransectSegment holds information about each intersecting segment of the + transect path and the dataset cells. + """ + #: The point where the transect path first intersects this dataset cell + start_point: shapely.Point + #: The point where the transect exits this dataset cell + end_point: shapely.Point + #: The entire intersection between the transect path and this dataset cell + intersection: shapely.LineString + #: The distance along the line in metres to the :attr:`.start_point` + start_distance: float + #: The distance along the line in metres to the :attr:`.end_point` + end_distance: float + #: The linear index of this dataset cell + linear_index: int + #: The polygon of the dataset cell + polygon: shapely.Polygon + + +class Transect: + """ + """ + #: The dataset to plot a transect through + dataset: xarray.Dataset + + #: The transect path to plot + line: shapely.LineString + + #: The dataset grid to transect. + grid: Grid + + def __init__( + self, + dataset: xarray.Dataset, + line: shapely.LineString, + *, + grid: Grid | None = None, + ): + self.dataset = dataset + self.convention = cast(Convention, dataset.ems) + self.line = line + if grid is None: + grid = self.convention.default_grid + self.grid = grid + + @cached_property + def intersection_bounds( + self, + ) -> numpy.ndarray: + """ + A numpy array of shape `(len(segments), 2)` + indicating the distance to the start and end of each intersection segment. + This is a shortcut to :attr:`TransectSegment.start_distance` + and :attr:`~TransectSegment.end_distance` from :attr:`Transect.segments`. + """ + return numpy.fromiter( + ( + [segment.start_distance, segment.end_distance] + for segment in self.segments + ), + # Be explicit here, to handle the case when len(self.segments) == 0. + # This happens when the transect line does not intersect the dataset. + # This will result in an empty transect plot. + count=len(self.segments), + dtype=numpy.dtype((float, 2)), + ) + + @cached_property + def linear_indexes(self) -> numpy.ndarray: + """ + A numpy array of length `len(segments)` + of the linear indexes of each intersecting polygon, in order. + This is a shortcut to :attr:`TransectSegment.linear_index` + from :attr:`Transect.segments`. + """ + return numpy.fromiter( + (segment.linear_index for segment in self.segments), + count=len(self.segments), + dtype=numpy.dtype(int), + ) + + @cached_property + def holes(self) -> numpy.ndarray: + """ + An array with the index of any discontinuities in the transect segments. + For transect paths that are entirely within the dataset geometry this will be empty. + For paths that pass in and out of the dataset geometry + this will be the index of the segment just after the discontinuity. + Two segments are not contiguous if `segment[n].end_distance != segment[n+1].start_distance` + """ + bounds = self.intersection_bounds + return numpy.flatnonzero(bounds[:-1, 1] != bounds[1:, 0]) + 1 + + def _crs_for_point( + self, + point: shapely.Point, + globe: crs.Globe | None = None, + ) -> crs.Projection: + return crs.AzimuthalEquidistant( + central_longitude=point.x, central_latitude=point.y, globe=globe) + + @cached_property + def points( + self, + ) -> list[TransectPoint]: + """ + A list of :class:`TransectPoints `, + one for each point in the transect :attr:`.line`. + """ + data_crs = self.convention.data_crs + globe = data_crs.globe + + # Make the TransectPoint for the first point by hand. + point = shapely.Point(self.line.coords[0]) + points = [TransectPoint( + point=point, + crs=self._crs_for_point(point, globe), + distance_metres=0, + distance_normalised=0, + )] + + # Make a TransectPoint for each subsequent point along the line. + for point in map(shapely.Point, self.line.coords[1:]): + previous = points[-1] + + # Calculate the distance from the previous point + # by using the AzimuthalEquidistant CRS centred on the previous point. + distance_from_previous = ORIGIN.distance( + previous.crs.project_geometry(point, src_crs=data_crs)) + + points.append(TransectPoint( + point=point, + crs=self._crs_for_point(point, globe), + distance_metres=previous.distance_metres + distance_from_previous, + distance_normalised=self.line.project(point, normalized=True) + )) + + return points + + @cached_property + def segments(self) -> list[TransectSegment]: + """ + A list of :class:`TransectSegments <.TransectSegment>` + for each intersecting segment of the transect line and the dataset geometry. + Segments are listed in order from the start of the line to the end of the line. + """ + segments = [] + + grid = self.convention.grids[self.convention.default_grid_kind] + polygons = grid.geometry + + # Find all the cell polygons that intersect the line + intersecting_indexes = grid.strtree.query(self.line, predicate='intersects') + + for linear_index in intersecting_indexes: + polygon = polygons[linear_index] + for intersection in self._intersect_polygon(polygon): + # The line will have two ends. + # The intersection starts and ends at these points. + # Project those points alone the original line to find + # the start and end distance of the intersection along the line. + points = [ + shapely.Point(intersection.coords[0]), + shapely.Point(intersection.coords[-1]) + ] + projections: Iterable[tuple[shapely.Point, float]] = ( + (point, self.distance_along_line(point)) + for point in points) + start, end = sorted(projections, key=lambda pair: pair[1]) + + segments.append(TransectSegment( + start_point=start[0], + end_point=end[0], + intersection=intersection, + start_distance=start[1], + end_distance=end[1], + linear_index=linear_index, + polygon=polygon, + )) + + return sorted(segments, key=lambda i: (i.start_distance, i.end_distance)) + + @cached_property + def coordinates(self) -> xarray.Dataset: + """ + A :class:`xarray.Dataset` containing coordinate information + for data extracted along the transect. + """ + index_dim = 'index' + coordinates = xarray.Dataset( + coords={ + 'distance': xarray.DataArray( + data=numpy.average(self.intersection_bounds, axis=1), + dims=index_dim, + attrs={ + 'long_name': 'Distance along transect', + 'units': 'm', + 'bounds': 'distance_bounds', + }, + ), + 'distance_bounds': xarray.DataArray( + data=self.intersection_bounds, + dims=(index_dim, 'Two'), + ), + } + ) + return coordinates + + def _intersect_polygon( + self, + polygon: shapely.Polygon, + ) -> list[shapely.LineString]: + """ + Intersect a cell of the dataset geometry with the transect line, + and return a list of all LineString segments of the intersection. + This assumes that the cell does intersect the transect line. + A line and a polygon can intersect in a number of ways: + + * a simple cut through the polygon + * the line starts and/or stops in the polygon + * the line intersects the polygon at a point + * the line intersects the polygon multiple times + + Only the intersections that are line segments are returned. + Multiple intersections (represented as a GeometryCollection) + are decomposed in to the component geometries. + Points are ignored. + + Parameters + ---------- + polygon : shapely.Polygon + The cell geometry to intersect + + Returns + ------- + list of shapely.LineString + All intersecting line strings + """ + intersection = polygon.intersection(self.line) + if isinstance(intersection, (shapely.GeometryCollection, shapely.MultiLineString)): + geoms = intersection.geoms + else: + geoms = [intersection] + return [geom for geom in geoms if isinstance(geom, shapely.LineString)] + + def distance_along_line(self, point: shapely.Point) -> float: + """ + Calculate the distance in metres that the point + falls along the :attr:`transect line <.line>`. + If the point is not on the line, + the point is projected on to the line + and the distance is calculated to this point instead. + + This can be used to calculate the distance along the transect line + to landmark features. + These landmark features can be added as tick points along the transect. + The landmark features need not fall directly on the line. + + Parameters + ---------- + point : shapely.Point + The point to calculate the distance to + + Returns + ------- + float + The distance the point is along the line in meters. + If the point does not fall on the line, + the point is first projected to the line. + """ + data_crs = self.convention.data_crs + distance_normalised = self.line.project(point, normalized=True) + if distance_normalised < 0 or distance_normalised > 1: + raise ValueError("Point is not on the line!") + + # Find the TransectPoint for the vertex before this point on the line + line_point = next( + lp for lp in reversed(self.points) + if lp.distance_normalised <= distance_normalised) + + distance_from_point: float = ORIGIN.distance( + line_point.crs.project_geometry(point, src_crs=data_crs)) + return line_point.distance_metres + distance_from_point + + def extract(self, data_array: xarray.DataArray) -> xarray.DataArray: + """ + Extract data from a data array along a transect. + + Parameters + ---------- + data_array : xarray.DataArray + The data array to extract data from. + + Returns + ------- + xarray.DataArray + A new :class:`xarray.DataArray` containing data from the input data array + extracted along the path of the transect. + """ + # Some of the following operations drop attrs, + # so keep a reference to the original ones + attrs = data_array.attrs + + data_array = self.convention.ravel(data_array) + + index_dimension = data_array.dims[-1] + data_array = move_dimensions_to_end(data_array, [index_dimension]) + + data_array = data_array.isel({index_dimension: self.linear_indexes}) + + # Restore attrs after reformatting + data_array.attrs.update(attrs) + + return data_array + + def make_artist( + self, + axes: Axes, + data_array: DataArrayOrName, + **kwargs: Any, + ) -> 'artists.TransectArtist': + """ + Make an artist to plot values extracted from a data array along this transect. + The kind of artist used depends on the dimensionality of the data array. + + To be plotted along a transect the data array must be defined on a supported :ref:`grid `. + Currently only polygonal grids are supported. + + If a data array has a depth axis, :meth:`.make_cross_section_artist` is called, + otherwise :meth:`.make_transect_step_artist` is called. + + Parameters + ========== + axes : Axes + The :class:`matplotlib.axes.Axes` to add this artist to. + data_array : DataArrayOrName + The data array to plot + **kwargs + Passed on to the artist, can be used to customise the plot style. + + Returns + ======= + :class:`.artists.TransectArtist` + The artist that will plot the data. + This artist will already have been added to the axes. + + See also + ======== + :func:`~.utils.setup_distance_axis` + Setup the x-axis of an :class:`~matplotlib.axes.Axes` + for plotting distance along a transect. + :func:`~.utils.setup_depth_axis` + Setup the y-axis of an :class:`~matplotlib.axes.Axes` + for plotting down a depth coordinate. + """ + data_array = name_to_data_array(self.dataset, data_array) + grid = self.convention.get_grid(data_array) + try: + depth_coordinate = self.convention.get_depth_coordinate_for_data_array(data_array) + except NoSuchCoordinateError: + depth_coordinate = None + + if grid.geometry_type is not shapely.Polygon: + raise ValueError( + f"I don't know how to plot transects across {grid.geometry_type.__name__} geometry.") + + if depth_coordinate is not None: + return self.make_cross_section_artist(axes, data_array, **kwargs) + else: + return self.make_transect_step_artist(axes, data_array, **kwargs) + + def make_cross_section_artist( + self, + axes: Axes, + data_array: DataArrayOrName, + colorbar: bool = True, + **kwargs: Any, + ) -> 'artists.CrossSectionArtist': + """ + Make an artist that plots a vertical slice along the length of the transect. + The data must be three dimensional with a depth axis. + The data are plotted as a grid of values, + with distance along the transect as the x-axis + and depth represented as the y-axis. + + Parameters + ========== + axes : Axes + The :class:`matplotlib.axes.Axes` to add this line to. + data_array : DataArrayOrName + The data array to plot. + This data array must be defined on a polygonal :class:`~emsarray.conventions.Grid` + and must have a depth coordinate with bounds. + colorbar : bool, default True + Whether to add a colorbar for this artist. + Sensible defaults are used for the colorbar, but if more customisation is required + set `colorbar=False` and configure a colorbar manually. + edgecolor : color, optional + The colour of the line. + Optional, defaults to the next available colour in the matplotlib plot colours. + fill : bool, default False + Whether to fill in values between the line and the baseline. + Defaults to False. + **kwargs + Passed on to the :class:`~.artists.TransectStepArtist`, + can be used to customise the plot. + + See also + ======== + :func:`~.utils.setup_distance_axis` + Setup the x-axis of an :class:`~matplotlib.axes.Axes` + for plotting distance along a transect. + :func:`~.utils.setup_depth_axis` + Setup the y-axis of an :class:`~matplotlib.axes.Axes` + for plotting down a depth coordinate. + :func:`~emsarray.utils.estimate_bounds_1d` + Estimate some bounds for a coordinate variable. + """ + data_array = name_to_data_array(self.dataset, data_array) + artist = artists.CrossSectionArtist.from_transect( + self, data_array=data_array, + **kwargs) + axes.add_artist(artist) + if colorbar: + units = data_array.attrs.get('units', None) + axes.figure.colorbar(artist, label=units) + return artist + + def make_transect_step_artist( + self, + axes: Axes, + data_array: DataArrayOrName, + edgecolor: ColorType | None = 'auto', + fill: bool = False, + **kwargs: Any, + ) -> 'artists.TransectStepArtist': + """ + Make an artist that plots values along the length of the transect. + The data must be two dimensional - it must have no depth axis. + The data are plotted as a stepped line. + + Parameters + ========== + axes : Axes + The :class:`matplotlib.axes.Axes` to add this line to. + data_array : DataArrayOrName + The data array to plot. + This data array must be defined on a polygonal :class:`~emsarray.conventions.Grid` + and must not have any other dimensions such as time or depth. + edgecolor : color, optional + The colour of the line. + Optional, defaults to the next available colour in the matplotlib plot colours. + fill : bool, default False + Whether to fill in values between the line and the baseline. + Defaults to False. + **kwargs + Passed on to the :class:`~.artists.TransectStepArtist`, + can be used to customise the plot. + + Returns + ======= + :class:`~.artists.TransectStepArtist` + The artist that will plot the data. + This artist will already have been added to the axes. + + See also + ======== + :func:`.utils.setup_distance_axis` + Setup the x-axis of an :class:`~matplotlib.axes.Axes` + for plotting distance along a transect. + """ + data_array = name_to_data_array(self.dataset, data_array) + if edgecolor == 'auto': + edgecolor = axes._get_lines.get_next_color() + artist = artists.TransectStepArtist.from_transect( + self, data_array=data_array, + fill=fill, edgecolor=edgecolor, **kwargs) + axes.add_artist(artist) + return artist + + def make_ocean_floor_artist( + self, + axes: Axes, + data_array: DataArrayOrName, + fill: bool = True, + facecolor: ColorType | None = 'lightgrey', + edgecolor: ColorType | None = 'none', + baseline: float | None = None, + **kwargs: Any, + ) -> 'artists.TransectStepArtist': + """ + Make an artist that renders a solid polygon following a bathymetry variable. + This can be drawn in front of a cross section artist to mask out values below the ocean floor. + + Parameters + ========== + axes : Axes + The :class:`matplotlib.axes.Axes` to add the ocean floor artist to + data_array : DataArrayOrName + The data array or name of a data array with the ocean floor data + baseline : float, optional + The deepest part of the ocean floor to render. + The ocean floor will be filled in from the bathymetry value down to the baseline. + Optional, if not provided the deepest value in the data array is used instead. + **kwargs + Passed on to the :class:`.artists.TransectStepArtist` for styling. + Set `facecolor` to change the colour of the ocean floor polygon. + + Returns + ======= + .artists.TransectStepArtist + The artist that will render the ocean floor. + This artist will already have been added to the axes. + + Notes + ===== + The `sign convention `_ + of the bathymetry variable and the depth coordinate must match. + If they differ the ocean floor polygon is likely to be either + entirely outside of the plot extent or to cover the entire plot extent. + :func:`~emsarray.operations.depth.normalize_depth_variables` + can be used to change the sign convention of a depth coordinate variable. + + See also + ======== + `CF Conventions on Vertical Coordinates `_ + More information on the `positive` attribute. + :func:`emsarray.operations.depth.normalize_depth_variables` + Update the sign convention of a depth coordinate variable. + + .. _CF-vertical-coordinates: https://cfconventions.org/Data/cf-conventions/cf-conventions-1.8/cf-conventions.html#vertical-coordinate + + """ + data_array = name_to_data_array(self.dataset, data_array) + if baseline is None: + data_min, data_max = numpy.nanmin(data_array.values), numpy.nanmax(data_array.values) + if 'positive' in data_array.attrs: + if data_array.attrs['positive'] == 'down': + baseline = data_max + else: + baseline = data_min + else: + # Take a guess by using the most extreme value + if numpy.abs(data_min) < numpy.abs(data_max): + baseline = data_max + else: + baseline = data_min + + artist = artists.TransectStepArtist.from_transect( + self, data_array=data_array, + fill=fill, baseline=baseline, + facecolor=facecolor, edgecolor=edgecolor, + **kwargs) + axes.add_artist(artist) + return artist diff --git a/src/emsarray/transect/utils.py b/src/emsarray/transect/utils.py new file mode 100644 index 0000000..a4a7b47 --- /dev/null +++ b/src/emsarray/transect/utils.py @@ -0,0 +1,96 @@ +import cfunits +import numpy +from matplotlib.axes import Axes +from matplotlib.ticker import EngFormatter + +from emsarray.types import DataArrayOrName +from emsarray.utils import name_to_data_array +from .base import Transect + + +def setup_distance_axis(transect: Transect, axes: Axes) -> None: + """ + Configure the x-axis of a :class:`~matplotlib.axes.Axes` for values along a transect. + + Parameters + ========== + transect : emsarray.transect.Transect + The transect being plotted + axes : matplotlib.axes.Axes + The axes to configure + """ + axis = axes.xaxis + + axes.set_xlim(transect.points[0].distance_metres, transect.points[-1].distance_metres) + axis.set_label_text("Distance along transect") + axis.set_major_formatter(EngFormatter(unit='m')) + + +def setup_depth_axis( + transect: "Transect", + axes: Axes, + data_array: DataArrayOrName | None = None, + depth_coordinate: DataArrayOrName | None = None, + ylim: tuple[float, float] | bool = True, + label: str | None | bool = True, + units: str | None | bool = True, +) -> None: + """ + Configure the y-axis of a :class:`~matplolib.axes.Axes` for values along a depth coordinate. + + Parameters + ========== + transect : emsarray.transect.Transect + The transect being plotted + axes : matplotlib.axes.Axes + The axes to configure + data_array : DataArrayOrName, optional + depth_coordinate : DataArrayOrName, optional + One of `data_array` or `depth_coordinate` must be provided. + The y-axis is configured to show values along this depth coordinate. + If data_array is provided, the depth coordinate for this data array is used. + ylim : tuple of float, float, optional + The ylim of the axes. If not provided the limit is calculated from the depth coordinate. + label : str or None, optional + The label for the y-axis. + Optional, defaults to the `long_name` attribute of the depth coordinate. + Set to `None` to disable the label. + units : str or None, optional + The units for the y-axis. + Optional, defaults to the `units` attribute of the depth coordinate. + Set to `None` to disable the units and formatting of tick labels. + """ + if data_array is None and depth_coordinate is None: + raise ValueError("Either data_array or depth_coordinate must be provided") + if data_array is not None and depth_coordinate is not None: + raise ValueError("Only one of data_array or depth_bounds must be provided") + + if data_array is not None: + depth_coordinate = transect.convention.get_depth_coordinate_for_data_array(data_array) + else: + depth_coordinate = name_to_data_array(transect.dataset, depth_coordinate) + + axis = axes.yaxis + + if ylim is True: + depth_bounds = transect.dataset[depth_coordinate.attrs['bounds']].values + positive_down = depth_coordinate.attrs['positive'].lower() == 'down' + depth_min, depth_max = numpy.nanmin(depth_bounds), numpy.nanmax(depth_bounds) + + if positive_down: + axes.set_ylim(depth_max, depth_min) + else: + axes.set_ylim(depth_min, depth_max) + elif ylim not in {False, None}: + axes.set_ylim(ylim) + + if label is True: + label = depth_coordinate.attrs.get('long_name') + if label not in {False, None}: + axis.set_label_text(label) + + if units is True: + units = depth_coordinate.attrs.get('units') + if units not in {False, None}: + formatted_units = cfunits.Units(units).formatted() + axis.set_major_formatter(EngFormatter(unit=formatted_units)) diff --git a/src/emsarray/utils.py b/src/emsarray/utils.py index 17aa2f8..68e903b 100644 --- a/src/emsarray/utils.py +++ b/src/emsarray/utils.py @@ -545,6 +545,18 @@ def find_unused_dimension( if candidate not in existing_dims) +def find_unused_name( + dataset: xarray.Dataset, + candidate: str, +) -> str: + if candidate not in dataset.variables.keys(): + return candidate + candidates = (f'{candidate}_{suffix}' for suffix in itertools.count(start=0)) + return next( + candidate for candidate in candidates + if candidate not in dataset.variables.keys()) + + def ravel_dimensions( data_array: xarray.DataArray, dimensions: list[Hashable], @@ -936,3 +948,48 @@ def data_array_to_name(dataset: xarray.Dataset, data_array: DataArrayOrName) -> if data_array not in dataset.variables: raise ValueError(f"Data array {data_array!r} is not in the dataset") return data_array + + +def estimate_bounds_1d( + dataset: xarray.Dataset, + coordinate: DataArrayOrName, +) -> xarray.Dataset: + """ + Estimate the bounds of a one dimensional coordinate variable. + The bounds between two coordinates is the average of the two values, + while the bounds on each end are the first and last coordinate values. + This is a crude approach. + + Parameters + ========== + dataset : xarray.Dataset + The dataset containing the coordinate. + coordinate : xarray.DataArray or str + The coordinate variable to estimate the bounds of. + + Returns + ======= + xarray.Dataset + A copy of the original dataset including the new estimated bounds. + + Raises + ====== + ValueError + Raised if the coordinate variable already has a 'bounds' attribute. + """ + dataset = dataset.copy() + coordinate = name_to_data_array(dataset, coordinate) + if 'bounds' in coordinate.attrs: + raise ValueError("Coordinate already has a bounds attribute") + + bounds_name = find_unused_name(dataset, f'{coordinate.name}_bounds') + values = coordinate.values + midpoints = (values[:-1] + values[1:]) / 2 + midpoints = numpy.concat([[values[0]], midpoints, [values[-1]]]) + dataset[bounds_name] = xarray.DataArray( + name=bounds_name, + data=numpy.c_[midpoints[:-1], midpoints[1:]], + ) + dataset = dataset.set_coords(bounds_name) + dataset[coordinate.name].attrs['bounds'] = bounds_name + return dataset From 17c894e00e1abc36d961cfbca0d18948dd9c8a9f Mon Sep 17 00:00:00 2001 From: Tim Heap Date: Wed, 25 Feb 2026 16:48:09 +1100 Subject: [PATCH 4/5] Include bounds variables in get_all_geometry_names() --- src/emsarray/conventions/arakawa_c.py | 4 ++-- src/emsarray/conventions/grid.py | 14 ++------------ src/emsarray/conventions/ugrid.py | 2 +- src/emsarray/utils.py | 18 ++++++++++++++++++ 4 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/emsarray/conventions/arakawa_c.py b/src/emsarray/conventions/arakawa_c.py index 69a67d9..b92f71d 100644 --- a/src/emsarray/conventions/arakawa_c.py +++ b/src/emsarray/conventions/arakawa_c.py @@ -372,7 +372,7 @@ def _make_geometry_centroid(self, grid_kind: ArakawaCGridKind) -> numpy.ndarray: return cast(numpy.ndarray, points) def get_all_geometry_names(self) -> list[Hashable]: - return [ + return utils.geometry_plus_bounds(self.dataset, [ self.face.longitude.name, self.face.latitude.name, self.node.longitude.name, @@ -381,7 +381,7 @@ def get_all_geometry_names(self) -> list[Hashable]: self.left.latitude.name, self.back.longitude.name, self.back.latitude.name, - ] + ]) def make_clip_mask( self, diff --git a/src/emsarray/conventions/grid.py b/src/emsarray/conventions/grid.py index 2599e5b..12e27c9 100644 --- a/src/emsarray/conventions/grid.py +++ b/src/emsarray/conventions/grid.py @@ -270,20 +270,10 @@ def grid_dimensions(self) -> dict[CFGridKind, Sequence[Hashable]]: def get_all_geometry_names(self) -> list[Hashable]: # Grid datasets contain latitude and longitude variables # plus optional bounds variables. - names = [ + return utils.geometry_plus_bounds(self.dataset, [ self.topology.longitude_name, self.topology.latitude_name, - ] - - bounds_names: list[Hashable | None] = [ - self.topology.longitude.attrs.get('bounds', None), - self.topology.latitude.attrs.get('bounds', None), - ] - for bounds_name in bounds_names: - if bounds_name is not None and bounds_name in self.dataset.variables: - names.append(bounds_name) - - return names + ]) def drop_geometry(self) -> xarray.Dataset: dataset = super().drop_geometry() diff --git a/src/emsarray/conventions/ugrid.py b/src/emsarray/conventions/ugrid.py index ddc4f43..6e1f097 100644 --- a/src/emsarray/conventions/ugrid.py +++ b/src/emsarray/conventions/ugrid.py @@ -1410,7 +1410,7 @@ def get_all_geometry_names(self) -> list[Hashable]: names.append(topology.face_x.name) if topology.face_y is not None: names.append(topology.face_y.name) - return names + return utils.geometry_plus_bounds(self.dataset, names) def drop_geometry(self) -> xarray.Dataset: dataset = super().drop_geometry() diff --git a/src/emsarray/utils.py b/src/emsarray/utils.py index 68e903b..d3542cc 100644 --- a/src/emsarray/utils.py +++ b/src/emsarray/utils.py @@ -993,3 +993,21 @@ def estimate_bounds_1d( dataset = dataset.set_coords(bounds_name) dataset[coordinate.name].attrs['bounds'] = bounds_name return dataset + + +def geometry_plus_bounds(dataset: xarray.Dataset, names: list[Hashable]) -> list[Hashable]: + bounds_names: list[Hashable] = [] + for name in names: + print("Checking", name) + data_array = dataset[name] + if 'bounds' not in data_array.attrs: + print("no bounds!") + continue + bounds_name = data_array.attrs['bounds'] + print("bounds name:", bounds_name) + if bounds_name not in dataset.variables.keys(): + print("bounds", bounds_names, "doesn't exist") + continue + print("including", bounds_name) + bounds_names.append(bounds_name) + return names + bounds_names From 14a8cd9117bcf15be81491f8b0d19ecd37316b9b Mon Sep 17 00:00:00 2001 From: Tim Heap Date: Wed, 25 Feb 2026 16:49:20 +1100 Subject: [PATCH 5/5] Fix some copy-paste typos in docstrings --- src/emsarray/conventions/_base.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/emsarray/conventions/_base.py b/src/emsarray/conventions/_base.py index c54e7ab..d38bd93 100644 --- a/src/emsarray/conventions/_base.py +++ b/src/emsarray/conventions/_base.py @@ -251,9 +251,11 @@ def wind( axis : int, optional The axis number that should be wound. Optional, defaults to the last axis. + Mutually exclusive with the `linear_dimension` parameter. linear_dimension : Hashable, optional - The axis number that should be wound. + The name of the dimension in the data array that should be wound. Optional, defaults to the last dimension. + Mutually exclusive with the `axis` parameter. Returns ------- @@ -989,9 +991,11 @@ def wind( axis : int, optional The axis number that should be wound. Optional, defaults to the last axis. + Mutually exclusive with the `linear_dimension` parameter. linear_dimension : Hashable, optional - The axis number that should be wound. + The name of the dimension in the data array that should be wound. Optional, defaults to the last dimension. + Mutually exclusive with the `axis` parameter. Returns -------