diff --git a/docs/api/transect.rst b/docs/api/transect.rst index a4c58cf..d2b664c 100644 --- a/docs/api/transect.rst +++ b/docs/api/transect.rst @@ -1,24 +1,64 @@ -.. module:: emsarray.transect ================= emsarray.transect ================= -.. currentmodule:: emsarray.transect +.. module:: emsarray.transect -Plot transects through your dataset. -Transects are vertical slices along some path through your dataset. +This module provides methods for extracting and plotting data +along transects through your datasets. +A transect path is represented as a :class:`shapely.LineString`. +Data along the transect can be extracted in to a new :class:`xarray.Dataset`, +or plotted using :meth:`Transect.make_artist`. + +Currently it is only possible to take transects through grids with polygonal geometry. +Taking transects through other kinds of geometry is a planned future enhancement. Examples --------- +======== + +.. minigallery:: -.. minigallery:: ../examples/plot-kgari-transect.py + ../examples/plot-kgari-transect.py + ../examples/plot-animated-transect.py -.. autofunction:: plot +Transects +========= + +These classes find the intersection of a :class:`shapely.LineString` with a dataset +and provide methods to introspect this intersection, plot data along this path, +and extract data along this path. .. autoclass:: Transect :members: -.. autoclass:: TransectPoint +.. autoclass:: TransectPoint() + :members: + +.. autoclass:: TransectSegment() + :members: + +Artists +======= + +These classes plot data along a transect. +Transect artists are normally created by calling :meth:`.Transect.make_artist`. + +.. module:: emsarray.transect.artists + +.. autoclass:: TransectArtist() + :members: set_data_array + +.. autoclass:: CrossSectionArtist() + :members: from_transect + +.. autoclass:: TransectStepArtist() + :members: from_transect + +Utilities +========= + +.. currentmodule:: emsarray.transect -.. autoclass:: TransectSegment +.. autofunction:: setup_distance_axis +.. autofunction:: setup_depth_axis diff --git a/docs/conf.py b/docs/conf.py index 5d8d49d..694bfca 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -111,6 +111,7 @@ def setup(app): 'examples_dirs': '../examples', 'gallery_dirs': './examples', 'filename_pattern': '/plot-', - 'matplotlib_animations': True, + 'matplotlib_animations': (True, 'jshtml'), 'backreferences_dir': './examples/backreferences', + 'remove_config_comments': True, } diff --git a/examples/plot-animated-transect.py b/examples/plot-animated-transect.py new file mode 100644 index 0000000..c715792 --- /dev/null +++ b/examples/plot-animated-transect.py @@ -0,0 +1,110 @@ +""" +================= +Animated transect +================= + +Transect and cross section plots can be animated using +:meth:`.TransectArtist.set_data_array()` to update the data. +""" + +import shapely +import matplotlib.pyplot as plt +import pandas +import xarray +from matplotlib.artist import Artist +from matplotlib.animation import FuncAnimation +from matplotlib.colors import LogNorm +from matplotlib.ticker import ScalarFormatter + +from emsarray import transect, utils +from emsarray.operations import depth + +# The eReefs GBR4 Biogeochemistry and Sediments v4.2 baseline catchment scenario dataset +# contains one daily timestep per file. +# For an animation we need to open multiple files using xarray.open_mfdataset(). +dataset_url_template = ( + 'https://thredds.nci.org.au/thredds/dodsC/fx3/gbr4_H4p0_ABARRAr2_OBRAN2020_FG2Gv3_B4p2_Cq5b_Dhnd/' + 'gbr4_H4p0_ABARRAr2_OBRAN2020_FG2G_B4p2_Cq5b_Dhnd_simple_{date:%Y-%m-%d}.nc' +) +dataset_urls = [ + dataset_url_template.format(date=date) + for date in pandas.date_range('2022-10-01', '2022-10-14') +] +dataset = xarray.open_mfdataset( + dataset_urls, combine='by_coords', coords=['time'], + compat='override', data_vars='minimal') + +# Select only the variables we want to plot. +dataset = dataset.ems.select_variables(['botz', 'DIN']) + +# Fetch all the data. +# This will only fetch the variable we selected above, +# plus all the geometry and coordinate variables. +dataset.load() + +# The depth coordinate has positive=up, while the bathymetry has positive=down. +# This causes issues when drawing the ocean floor. +# Lets fix the depth coordinate. +dataset = depth.normalize_depth_variables(dataset, ['zc'], positive_down=True) + +# Cross section plots need bounds information, so lets invent some +dataset = utils.estimate_bounds_1d(dataset, 'zc') + +# %% +# Define a :class:`transect ` +# following a path starting near Bowen, passing Mackay, and ending past Gladstone. +reef_transect = transect.Transect(dataset, shapely.LineString([ + [149.19358465107092, -19.745298160117585], + [150.16742824240822, -21.444631352206670], + [151.33964738012760, -22.175746836082112], + [152.21129750817641, -23.780650909462680] +])) + + +# %% +# Now we set up a figure and add a cross section artist. + +# sphinx_gallery_defer_figures + +figure = plt.figure(figsize=(7.8, 3), layout='constrained', dpi=100) +axes = figure.add_subplot() + +axes.set_title('Dissolved inorganic nitrogen') +date_labels = [ + utils.datetime_from_np_time(date).strftime('%Y-%m-%d') + for date in dataset['time'].values +] + +din = dataset['DIN'] +din_artist = reef_transect.make_cross_section_artist( + axes, din.isel(time=0), cmap='plasma', + norm=LogNorm(0.01, .7), colorbar=False) +figure.colorbar( + din_artist, ticks=[0.01, 0.025, 0.05, 0.1, 0.25, 0.5], + format=ScalarFormatter()) + +reef_transect.make_ocean_floor_artist( + axes, dataset['botz']) + +date_annotation = axes.annotate( + date_labels[0], + xy=(5, 5), xycoords='axes points', + verticalalignment='bottom', horizontalalignment='left') + +transect.setup_distance_axis( + reef_transect, axes) +transect.setup_depth_axis( + reef_transect, axes, data_array='DIN', + label='Depth', ylim=(80, -1.5)) + +# %% +# Finally we set up the animation. +# The ``update()`` function is called every frame to update the plot with new data. +# The :meth:`.TransectArtist.set_data_array()` function does all the hard work here. + +def update(frame: int) -> list[Artist]: + din_artist.set_data_array(din.isel(time=frame)) + date_annotation.set_text(date_labels[frame]) + return [din_artist, date_annotation] + +animation = FuncAnimation(figure, update, frames=din.sizes['time']) diff --git a/examples/plot-kgari-transect.py b/examples/plot-kgari-transect.py index 136597c..e00780c 100644 --- a/examples/plot-kgari-transect.py +++ b/examples/plot-kgari-transect.py @@ -5,20 +5,31 @@ """ import shapely -from matplotlib import pyplot +import matplotlib.pyplot as plt +from matplotlib import gridspec +from matplotlib.colors import LogNorm, PowerNorm import emsarray -from emsarray import plot, transect +from emsarray import plot, transect, utils +from emsarray.operations import depth -dataset_url = 'https://thredds.nci.org.au/thredds/dodsC/fx3/gbr4_H4p0_ABARRAr2_OBRAN2020_FG2Gv3_Dhnd/gbr4_simple_2022-10-31.nc' -dataset = emsarray.open_dataset(dataset_url).isel(time=-1) -dataset = dataset.ems.select_variables(['botz', 'temp']) +dataset_url = 'https://thredds.nci.org.au/thredds/dodsC/fx3/gbr4_H4p0_ABARRAr2_OBRAN2020_FG2Gv3_Dhnd/gbr4_simple_2022-10-01.nc' +# dataset_url = '~/example-datasets/gbr4_simple_2022-10-31.nc' +dataset = emsarray.open_dataset(dataset_url).isel(time=12) +# Select only the variables we want to plot. +dataset = dataset.ems.select_variables(['botz', 'temp', 'eta']) +# The depth coordinate has positive=up, while the bathymetry has positive=down. +# This causes issues when drawing the ocean floor. +# Lets fix the depth coordinate. +dataset = depth.normalize_depth_variables(dataset, ['zc'], positive_down=True) +# Cross section plots need bounds information, so lets invent some +dataset = utils.estimate_bounds_1d(dataset, 'zc') # %% # The following is a :mod:`transect ` path # starting in the Great Sandy Strait near K'gari, # heading roughly North out to deeper waters: -line = shapely.LineString([ +north_transect = transect.Transect(dataset, shapely.LineString([ [152.9768944, -25.4827962], [152.9701996, -25.4420345], [152.9727745, -25.3967620], @@ -33,41 +44,94 @@ [152.7607727, -24.3521012], [152.6392365, -24.1906056], [152.4792480, -24.0615124], -]) +])) landmarks = [ ('Round Island', shapely.Point(152.9262543, -25.2878719)), ('Lady Elliot Island', shapely.Point(152.7145958, -24.1129146)), ] + # %% -# Plot a transect showing temperature along this path. +# Set up three axes: one showing the transect path, +# one showing a temperature cross section along the transect, +# and one showing the sea surface height along the transect. + +# sphinx_gallery_defer_figures -figure = transect.plot( - dataset, line, dataset['temp'], - figsize=(7.9, 3), - bathymetry=dataset['botz'], - landmarks=landmarks, - title="Temperature", - cmap='Oranges_r') -pyplot.show() +figure = plt.figure(figsize=(7.8, 8), layout='constrained', dpi=100) +gs_root = gridspec.GridSpec(3, 1, figure=figure, height_ratios=[3, 1, 1]) +path_axes = figure.add_subplot(gs_root[0], projection=dataset.ems.data_crs) +temp_axes = figure.add_subplot(gs_root[1]) +eta_axes = figure.add_subplot(gs_root[2], sharex=temp_axes) # %% -# The path of the transect can be plotted using matplotlib. +# First make a plot showing the path of the transect overlayed on the bathymetry + +# sphinx_gallery_defer_figures +# sphinx_gallery_capture_repr_block = () -# Plot the path of the transect -figure = pyplot.figure(figsize=(5, 5), dpi=100) -axes = figure.add_subplot(projection=dataset.ems.data_crs) -axes.set_aspect(aspect='equal', adjustable='datalim') -axes.set_title('Transect path') +path_axes.set_aspect(aspect='equal', adjustable='datalim') +path_axes.set_title('Transect path') dataset.ems.make_artist( - axes, 'botz', cmap='Blues', clim=(0, 2000), edgecolor='face', + path_axes, 'botz', cmap='Blues', clim=(0, 2000), edgecolor='face', + norm=PowerNorm(gamma=0.5), linewidth=0.5, zorder=0) -axes = figure.axes[0] -axes.set_extent(plot.bounds_to_extent(line.envelope.buffer(0.2).bounds)) -axes.plot(*line.coords.xy, zorder=2, c='orange', linewidth=4) +path_axes.set_extent(plot.bounds_to_extent(north_transect.line.envelope.buffer(0.2).bounds)) +path_axes.plot(*north_transect.line.coords.xy, zorder=1, c='orange', linewidth=4) + +plot.add_coast(path_axes, zorder=1) +plot.add_gridlines(path_axes) +plot.add_landmarks(path_axes, landmarks) + +# %% +# Now plot a cross section along the transect showing the ocean temperature. +# As the temperature variable has a depth axis the cross section is two dimensional. + +# sphinx_gallery_defer_figures +# sphinx_gallery_capture_repr_block = () + +temp_axes.set_title('Temperature') +dataset['temp'].attrs['units'] = '°C' +dataset['zc'].attrs['long_name'] = 'Depth' + +north_transect.make_artist( + temp_axes, 'temp', cmap='plasma') +north_transect.make_ocean_floor_artist( + temp_axes, dataset['botz']) +# yaxis +transect.setup_depth_axis( + north_transect, temp_axes, data_array='temp', + label='Depth', ylim=(50, -1.5)) + + +# %% +# Now plot the sea surface height along the transect. +# As the sea surface height does not have a depth axis +# the transect is one dimensional. + +# sphinx_gallery_defer_figures +# sphinx_gallery_capture_repr_block = () + +eta_axes.set_title('Sea surface height') +eta_artist = north_transect.make_artist( + eta_axes, data_array=dataset['eta']) +# xaxis +transect.setup_distance_axis(north_transect, eta_axes) +# yaxis +eta_axes.set_ylim(-0.5, 1.5) +eta_axes.set_ylabel('Height above\nmean sea level') +eta_axes.axhline(0, linestyle='--', color='lightgrey') +eta_axes.yaxis.set_major_formatter("{x:.2g} m") + +# %% +# The last step is to add some landmarks along the top border of the axes +# to help viewers link the distance along transect path to geographic locations. + +top_axis = temp_axes.secondary_xaxis('top') +top_axis.set_ticks( + [north_transect.distance_along_line(point) for label, point in landmarks], + [label for label, point in landmarks], +) -plot.add_coast(axes, zorder=1) -plot.add_gridlines(axes) -plot.add_landmarks(axes, landmarks) -pyplot.show() +plt.show() diff --git a/src/emsarray/conventions/_base.py b/src/emsarray/conventions/_base.py index c54e7ab..d38bd93 100644 --- a/src/emsarray/conventions/_base.py +++ b/src/emsarray/conventions/_base.py @@ -251,9 +251,11 @@ def wind( axis : int, optional The axis number that should be wound. Optional, defaults to the last axis. + Mutually exclusive with the `linear_dimension` parameter. linear_dimension : Hashable, optional - The axis number that should be wound. + The name of the dimension in the data array that should be wound. Optional, defaults to the last dimension. + Mutually exclusive with the `axis` parameter. Returns ------- @@ -989,9 +991,11 @@ def wind( axis : int, optional The axis number that should be wound. Optional, defaults to the last axis. + Mutually exclusive with the `linear_dimension` parameter. linear_dimension : Hashable, optional - The axis number that should be wound. + The name of the dimension in the data array that should be wound. Optional, defaults to the last dimension. + Mutually exclusive with the `axis` parameter. Returns ------- diff --git a/src/emsarray/conventions/arakawa_c.py b/src/emsarray/conventions/arakawa_c.py index 69a67d9..b92f71d 100644 --- a/src/emsarray/conventions/arakawa_c.py +++ b/src/emsarray/conventions/arakawa_c.py @@ -372,7 +372,7 @@ def _make_geometry_centroid(self, grid_kind: ArakawaCGridKind) -> numpy.ndarray: return cast(numpy.ndarray, points) def get_all_geometry_names(self) -> list[Hashable]: - return [ + return utils.geometry_plus_bounds(self.dataset, [ self.face.longitude.name, self.face.latitude.name, self.node.longitude.name, @@ -381,7 +381,7 @@ def get_all_geometry_names(self) -> list[Hashable]: self.left.latitude.name, self.back.longitude.name, self.back.latitude.name, - ] + ]) def make_clip_mask( self, diff --git a/src/emsarray/conventions/grid.py b/src/emsarray/conventions/grid.py index 2599e5b..12e27c9 100644 --- a/src/emsarray/conventions/grid.py +++ b/src/emsarray/conventions/grid.py @@ -270,20 +270,10 @@ def grid_dimensions(self) -> dict[CFGridKind, Sequence[Hashable]]: def get_all_geometry_names(self) -> list[Hashable]: # Grid datasets contain latitude and longitude variables # plus optional bounds variables. - names = [ + return utils.geometry_plus_bounds(self.dataset, [ self.topology.longitude_name, self.topology.latitude_name, - ] - - bounds_names: list[Hashable | None] = [ - self.topology.longitude.attrs.get('bounds', None), - self.topology.latitude.attrs.get('bounds', None), - ] - for bounds_name in bounds_names: - if bounds_name is not None and bounds_name in self.dataset.variables: - names.append(bounds_name) - - return names + ]) def drop_geometry(self) -> xarray.Dataset: dataset = super().drop_geometry() diff --git a/src/emsarray/conventions/ugrid.py b/src/emsarray/conventions/ugrid.py index ddc4f43..6e1f097 100644 --- a/src/emsarray/conventions/ugrid.py +++ b/src/emsarray/conventions/ugrid.py @@ -1410,7 +1410,7 @@ def get_all_geometry_names(self) -> list[Hashable]: names.append(topology.face_x.name) if topology.face_y is not None: names.append(topology.face_y.name) - return names + return utils.geometry_plus_bounds(self.dataset, names) def drop_geometry(self) -> xarray.Dataset: dataset = super().drop_geometry() diff --git a/src/emsarray/transect.py b/src/emsarray/transect.py deleted file mode 100644 index 507f22a..0000000 --- a/src/emsarray/transect.py +++ /dev/null @@ -1,807 +0,0 @@ -import dataclasses -from collections.abc import Callable, Iterable -from functools import cached_property -from typing import Any, cast - -import cfunits -import numpy -import shapely -import xarray -from cartopy import crs -from matplotlib import animation, pyplot -from matplotlib.artist import Artist -from matplotlib.axes import Axes -from matplotlib.collections import PolyCollection -from matplotlib.colors import Colormap -from matplotlib.figure import Figure -from matplotlib.ticker import EngFormatter, Formatter - -from emsarray.conventions import Convention -from emsarray.plot import _requires_plot, make_plot_title -from emsarray.types import DataArrayOrName, Landmark -from emsarray.utils import move_dimensions_to_end, name_to_data_array - -# Useful for calculating distances in a AzimuthalEquidistant projection -# centred on some point: -# -# az = crs.AzimuthalEquidistant(p1.x, p1.y) -# distance = az.project_geometry(p2).distance(ORIGIN) -ORIGIN = shapely.Point(0, 0) - - -def plot( - dataset: xarray.Dataset, - line: shapely.LineString, - data_array: xarray.DataArray, - *, - figsize: tuple = (12, 3), - **kwargs: Any, -) -> Figure: - """ - Plot a transect of a dataset. - - This is convenience function that handles the most common use cases. - For more options refer to the :class:`.Transect` class. - - Parameters - ---------- - dataset : xarray.Dataset - The dataset to transect. - line : shapely.LineString - The transect path to plot. - data_array : xarray.DataArray - A variable from the dataset to plot. - figsize : tuple of int, int - The size of the figure. - **kwargs - Passed to :meth:`Transect.plot_on_figure()`. - """ - figure = pyplot.figure(layout="constrained", figsize=figsize) - depth_coordinate = dataset.ems.get_depth_coordinate_for_data_array(data_array) - transect = Transect(dataset, line, depth=depth_coordinate) - transect.plot_on_figure(figure, data_array, **kwargs) - pyplot.show() - return figure - - -@dataclasses.dataclass -class TransectPoint: - """ - A TransectPoint holds information about each vertex along a transect path. - """ - #: The original point, in the CRS of the line string / dataset. - point: shapely.Point - - #: An AzimuthalEquidistant CRS centred on this point. - crs: crs.AzimuthalEquidistant - - #: The distance in metres of this point along the line. - distance_metres: float - - #: The projected distance along the line of this point. - #: This is normalised to [0, 1]. - #: The actual value is meaningless but can be used to find - #: the closest vertex on the line string for any other projected point. - distance_normalised: float - - -@dataclasses.dataclass -class TransectSegment: - """ - A TransectSegment holds information about each intersecting segment of the - transect path and the dataset cells. - """ - start_point: shapely.Point - end_point: shapely.Point - intersection: shapely.LineString - start_distance: float - end_distance: float - linear_index: int - polygon: shapely.Polygon - - -class Transect: - """ - """ - #: The dataset to plot a transect through - dataset: xarray.Dataset - - #: The transect path to plot - line: shapely.LineString - - #: The depth coordinate (or the name of the depth coordinate) for the dataset. - depth: xarray.DataArray - - def __init__( - self, - dataset: xarray.Dataset, - line: shapely.LineString, - depth: DataArrayOrName | None = None, - ): - self.dataset = dataset - self.convention = dataset.ems - self.line = line - if depth is not None: - self.depth = name_to_data_array(dataset, depth) - else: - self.depth = self.convention.depth_coordinate - - @cached_property - def convention(self) -> Convention: - convention: Convention = self.dataset.ems - return convention - - @cached_property - def transect_dataset(self) -> xarray.Dataset: - """ - A :class:`~xarray.Dataset` containing all the transect geometry. - This includes the depth data, path lengths, - and the linear index of each intersecting cell in the source dataset. - This transect dataset contains all the information necessary to generate a plot, - except for the actual variable data being plotted. - """ - depth = self.depth - - depth_dimension = depth.dims[0] - - depth_bounds = None - try: - depth_bounds = self.convention.dataset[depth.attrs['bounds']].values - except KeyError: - # Make up some depth bounds data from the depth values - # The top/bottom values will be the first/last depth values, - # all other points are the midpoints between the neighbouring points. - depth_midpoints = numpy.concatenate([ - [depth.values[0]], - (depth.values[1:] + depth.values[:-1]) / 2, - [depth.values[-1]] - ]) - depth_bounds = numpy.column_stack(( - depth_midpoints[:-1], - depth_midpoints[1:], - )) - - try: - positive_down = depth.attrs['positive'] == 'down' - except KeyError as err: - raise ValueError( - f'Depth variable {depth.name!r} must have a `positive` attribute' - ) from err - - linear_indexes = [segment.linear_index for segment in self.segments] - depth = xarray.DataArray( - data=depth.values, - dims=(depth_dimension,), - attrs={ - 'bounds': 'depth_bounds', - 'positive': 'down' if positive_down else 'up', - 'long_name': depth.attrs.get('long_name'), - 'description': depth.attrs.get('description'), - 'units': depth.attrs.get('units'), - }, - ) - depth_bounds = xarray.DataArray( - data=depth_bounds, - dims=(depth_dimension, 'bounds'), - ) - distance_bounds = xarray.DataArray( - data=numpy.fromiter( - ( - [segment.start_distance, segment.end_distance] - for segment in self.segments - ), - # Be explicit here, to handle the case when len(self.segments) == 0. - # This happens when the transect line does not intersect the dataset. - # This will result in an empty transect plot. - count=len(self.segments), - dtype=numpy.dtype((float, 2)), - ), - dims=('index', 'bounds'), - attrs={ - 'long_name': 'Distance along transect', - 'units': 'm', - 'start_distance': self.points[0].distance_metres, - 'end_distance': self.points[-1].distance_metres, - }, - ) - linear_index = xarray.DataArray( - data=linear_indexes, - dims=('index',) - ) - - return xarray.Dataset( - data_vars={ - 'depth_bounds': depth_bounds, - 'distance_bounds': distance_bounds, - }, - coords={ - 'depth': depth, - 'linear_index': linear_index, - }, - ) - - def _set_up_axis(self, variable: xarray.DataArray) -> tuple[str, Formatter]: - title = str(variable.attrs.get('long_name')) - units: str | None = variable.attrs.get('units') - - if units is not None: - # Use cfunits to normalize the units to their short symbol form. - # EngFormatter will write 'k{unit}', 'G{unit}', etc - # so unit symbols are required. - formatted_units = cfunits.Units(units).formatted() - formatter = EngFormatter(unit=formatted_units) - - return title, formatter - - def _crs_for_point( - self, - point: shapely.Point, - globe: crs.Globe | None = None, - ) -> crs.Projection: - return crs.AzimuthalEquidistant( - central_longitude=point.x, central_latitude=point.y, globe=globe) - - @cached_property - def points( - self, - ) -> list[TransectPoint]: - """ - A list of :class:`TransectPoints `, - one for each point in the transect :attr:`.line`. - """ - data_crs = self.convention.data_crs - globe = data_crs.globe - - # Make the TransectPoint for the first point by hand. - point = shapely.Point(self.line.coords[0]) - points = [TransectPoint( - point=point, - crs=self._crs_for_point(point, globe), - distance_metres=0, - distance_normalised=0, - )] - - # Make a TransectPoint for each subsequent point along the line. - for point in map(shapely.Point, self.line.coords[1:]): - previous = points[-1] - - # Calculate the distance from the previous point - # by using the AzimuthalEquidistant CRS centred on the previous point. - distance_from_previous = ORIGIN.distance( - previous.crs.project_geometry(point, src_crs=data_crs)) - - points.append(TransectPoint( - point=point, - crs=self._crs_for_point(point, globe), - distance_metres=previous.distance_metres + distance_from_previous, - distance_normalised=self.line.project(point, normalized=True) - )) - - return points - - @cached_property - def segments(self) -> list[TransectSegment]: - """ - A list of :class:`.TransectSegmens` for each intersecting segment of the transect line and the dataset geometry. - Segments are listed in order from the start of the line to the end of the line. - """ - segments = [] - - grid = self.convention.grids[self.convention.default_grid_kind] - polygons = grid.geometry - - # Find all the cell polygons that intersect the line - intersecting_indexes = grid.strtree.query(self.line, predicate='intersects') - - for linear_index in intersecting_indexes: - polygon = polygons[linear_index] - for intersection in self._intersect_polygon(polygon): - # The line will have two ends. - # The intersection starts and ends at these points. - # Project those points alone the original line to find - # the start and end distance of the intersection along the line. - points = [ - shapely.Point(intersection.coords[0]), - shapely.Point(intersection.coords[-1]) - ] - projections: Iterable[tuple[shapely.Point, float]] = ( - (point, self.distance_along_line(point)) - for point in points) - start, end = sorted(projections, key=lambda pair: pair[1]) - - segments.append(TransectSegment( - start_point=start[0], - end_point=end[0], - intersection=intersection, - start_distance=start[1], - end_distance=end[1], - linear_index=linear_index, - polygon=polygon, - )) - - return sorted(segments, key=lambda i: (i.start_distance, i.end_distance)) - - def _intersect_polygon( - self, - polygon: shapely.Polygon, - ) -> list[shapely.LineString]: - """ - Intersect a cell of the dataset geometry with the transect line, - and return a list of all LineString segments of the intersection. - This assumes that the cell does intersect the transect line. - A line and a polygon can intersect in a number of ways: - - * a simple cut through the polygon - * the line starts and/or stops in the polygon - * the line intersects the polygon at a point - * the line intersects the polygon multiple times - - Only the intersections that are line segments are returned. - Multiple intersections (represented as a GeometryCollection) - are decomposed in to the component geometries. - Points are ignored. - - Parameters - ---------- - polygon : shapely.Polygon - The cell geometry to intersect - - Returns - ------- - list of shapely.LineString - All intersecting line strings - """ - intersection = polygon.intersection(self.line) - if isinstance(intersection, (shapely.GeometryCollection, shapely.MultiLineString)): - geoms = intersection.geoms - else: - geoms = [intersection] - return [geom for geom in geoms if isinstance(geom, shapely.LineString)] - - def distance_along_line(self, point: shapely.Point) -> float: - """ - Calculate the distance in metres that the point - falls along the :attr:`transect line <.line>`. - If the point is not on the line, - the point is projected on to the line - and the distance is calculated to this point instead. - - This can be used to calculate the distance along the transect line - to landmark features. - These landmark features can be added as tick points along the transect. - The landmark features need not fall directly on the line. - - Parameters - ---------- - point : shapely.Point - The point to calculate the distance to - - Returns - ------- - float - The distance the point is along the line in meters. - If the point does not fall on the line, - the point is first projected to the line. - """ - data_crs = self.convention.data_crs - distance_normalised = self.line.project(point, normalized=True) - if distance_normalised < 0 or distance_normalised > 1: - raise ValueError("Point is not on the line!") - - # Find the TransectPoint for the vertex before this point on the line - line_point = next( - lp for lp in reversed(self.points) - if lp.distance_normalised <= distance_normalised) - - distance_from_point: float = ORIGIN.distance( - line_point.crs.project_geometry(point, src_crs=data_crs)) - return line_point.distance_metres + distance_from_point - - def make_poly_collection( - self, - **kwargs: Any, - ) -> PolyCollection: - """ - Make a :class:`matplotlib.collections.PolyCollection` - representing the transect geometry. - - Parameters - ---------- - **kwargs - Any keyword arguments are passed to the PolyCollection constructor. - - Returns - ------- - matplotlib.collections.PolyCollection - A PolyCollection representing all the cells - and all the depths the transect line intesected. - """ - transect_dataset = self.transect_dataset - distance_bounds = transect_dataset['distance_bounds'].values - depth_bounds = transect_dataset['depth_bounds'].values - vertices = [ - [ - (distance_bounds[index, 0], depth_bounds[depth_index][0]), - (distance_bounds[index, 0], depth_bounds[depth_index][1]), - (distance_bounds[index, 1], depth_bounds[depth_index][1]), - (distance_bounds[index, 1], depth_bounds[depth_index][0]), - ] - for depth_index in range(transect_dataset.coords['depth'].size) - for index in range(transect_dataset.sizes['index']) - ] - return PolyCollection(vertices, **kwargs) - - def make_ocean_floor_poly_collection( - self, - bathymetry: xarray.DataArray, - **kwargs: Any - ) -> PolyCollection: - """ - Make a :class:`matplotlib.collections.PolyCollection` - representing the ocean floor. - This can be overlayed on a transect plot to mask out values below the sea floor. - - Parameters - ---------- - bathymetry : xarray.Dataset - A data array containing bathymetry data for the dataset. - **kwargs - Any keyword arguments are passed on to the - :class:`~matplotlib.collections.PolyCollection` constructor - - Returns - ------- - matplotlib.collections.PolyCollection - A collection of polygons representing - the ocean floor along the transect path. - """ - transect_dataset = self.transect_dataset - depth = transect_dataset['depth'] - - bathymetry_values = self.convention.ravel(bathymetry) - # The bathymetry data can be oriented differently to the depth coordinate. - # Correct for this if so. - if 'positive' in bathymetry.attrs: - if bathymetry.attrs['positive'] != depth.attrs['positive']: - bathymetry_values = -bathymetry_values - - positive_down = depth.attrs['positive'] == 'down' - deepest_fn = numpy.nanmax if positive_down else numpy.nanmin - deepest = deepest_fn(bathymetry_values.values) - - distance_bounds = transect_dataset['distance_bounds'].values - linear_indexes = transect_dataset['linear_index'].values - - vertices = [ - [ - (distance_bounds[index, 0], bathymetry_values[linear_indexes[index]]), - (distance_bounds[index, 0], deepest), - (distance_bounds[index, 1], deepest), - (distance_bounds[index, 1], bathymetry_values[linear_indexes[index]]), - ] - for index in range(transect_dataset.sizes['index']) - ] - return PolyCollection(vertices, **kwargs) - - def prepare_data_array_for_transect(self, data_array: xarray.DataArray) -> xarray.DataArray: - """ - Prepare a data array for being used as the data in a transect plot. - - Parameters - ---------- - data_array : xarray.DataArray - The data array that will be plotted - - Returns - ------- - xarray.DataArray - The input data array transformed to have the correct shape - for plotting on the transect. - """ - # Some of the following operations drop attrs, - # so keep a reference to the original ones - attrs = data_array.attrs - - data_array = self.convention.ravel(data_array) - - depth_dimension = self.transect_dataset.coords['depth'].dims[0] - index_dimension = data_array.dims[-1] - data_array = move_dimensions_to_end(data_array, [depth_dimension, index_dimension]) - - linear_indexes = self.transect_dataset['linear_index'].values - data_array = data_array.isel({index_dimension: linear_indexes}) - - # Restore attrs after reformatting - data_array.attrs.update(attrs) - - return data_array - - def _find_depth_bounds(self, data_array: xarray.DataArray) -> tuple[int, int]: - """ - Find the shallowest and deepest layers of the data array - where there is at least one value per depth. - - Most ocean models represent cells that are below the sea floor as nans. - Some ocean models do the same for layers above the sea surface, - which can vary due to tides. - If a transect covers mostly shallow regions - but the dataset includes very deep layers - the shallow regions become very small on the final plot. - - This function finds the indexes of the deepest and shallowest layers - where the values are not entirely nan - along the transect path. - The transect plot can use these to only plot depth values that have data, - trimming off layers that are nothing but ocean floor. - """ - transect_dataset = self.transect_dataset - dim = transect_dataset['depth'].dims[0] - - start = 0 - for index in range(transect_dataset['depth'].size): - if numpy.any(numpy.isfinite(data_array.isel({dim: index}).values)): - start = index - break - - stop = -1 - for index in reversed(range(transect_dataset['depth'].size)): - if numpy.any(numpy.isfinite(data_array.isel({dim: index}))): - stop = index - break - - return start, stop - - @_requires_plot - def plot_on_figure( - self, - figure: Figure, - data_array: xarray.DataArray, - *, - title: str | None = None, - trim_nans: bool = True, - clamp_to_surface: bool = True, - bathymetry: xarray.DataArray | None = None, - cmap: str | Colormap | None = None, - clim: tuple[float, float] | None = None, - ocean_floor_colour: str = 'black', - landmarks: list[Landmark] | None = None, - ) -> None: - """ - Plot the data array along this transect. - - Parameters - ---------- - figure : matplotlib.figure.Figure - The figure to plot on - data_array : xarray.DataArray - The data array to plot. - This should be a data array from the dataset provided to the - Transect constructor, - or a data array of compatible shape. - title : str, optional - The title of the plot. - Defaults to the 'long_name' attribute of the data array. - trim_nans : bool, default True - Whether to trim layers containing all nans. - Layers that are entirely under the ocean floor are often represented as nans. - Without trimming, transects through shallow areas mostly look like ocean floor. - clamp_to_surface : bool, default True - If true, clamp the y-axis to 0 m. - Some datasets define an upper depth bound of some large number - which rather spoils the plot. - bathymetry : xarray.DataArray, optional - A data array containing bathymetry information for the dataset. - This will be used to draw a more detailed ocean floor mask. - ocean_floor_colour : str, default 'grey' - The colour to draw the ocean floor in. - This is used to draw cells containing nan values, - and the bathymetry data. - landmarks : list of str, :class:`shapely.Point` tuples - A list of (name, point) tuples. - These will be added as tick marks along the top of the plot. - """ - axes, collection, data_array = self._plot_on_figure( - figure=figure, - data_array=data_array, - title=title, - trim_nans=trim_nans, - clamp_to_surface=clamp_to_surface, - bathymetry=bathymetry, - cmap=cmap, - clim=clim, - ocean_floor_colour=ocean_floor_colour, - landmarks=landmarks, - ) - collection.set_array(data_array.values.flatten()) - - def animate_on_figure( - self, - figure: Figure, - data_array: xarray.DataArray, - *, - title: str | Callable[[Any], str] | None = None, - trim_nans: bool = True, - clamp_to_surface: bool = True, - bathymetry: xarray.DataArray | None = None, - cmap: str | Colormap | None = None, - clim: tuple[float, float] | None = None, - ocean_floor_colour: str = 'black', - landmarks: list[Landmark] | None = None, - coordinate: xarray.DataArray | None = None, - interval: int = 200, - ) -> animation.FuncAnimation: - """ - Plot the data array along this transect. - - Parameters - ---------- - figure : matplotlib.figure.Figure - The figure to plot on - data_array : xarray.DataArray - The data array to plot. - This should be a data array from the dataset provided to the - Transect constructor, - or a data array of compatible shape. - title : str or callable - The title of the plot. - coordinate : xarray.DataArray - The coordinate to animate along. - Defaults to the time coordinate. - interval : int - Time in milliseconds between frames. - **kwargs - See :meth:`.plot_on_figure` for available keyword arguments - """ - if coordinate is None: - coordinate = self.convention.time_coordinate - coordinate_indexes = numpy.arange(coordinate.size) - animation_dimension = coordinate.dims[0] - - coordinate_callable: Callable[[Any], str] - if title is None: - title = data_array.attrs.get('long_name') - if title is not None: - coordinate_callable = lambda c: f'{title}\n{c}' - else: - coordinate_callable = str - - elif isinstance(title, str): - coordinate_callable = title.format - - else: - coordinate_callable = title - - first_frame = data_array.isel({animation_dimension: 0}) - first_frame.load() - axes, collection, _prepared_frame = self._plot_on_figure( - figure=figure, - data_array=first_frame, - title=None, - trim_nans=trim_nans, - clamp_to_surface=clamp_to_surface, - bathymetry=bathymetry, - cmap=cmap, - clim=clim, - ocean_floor_colour=ocean_floor_colour, - landmarks=landmarks, - ) - - def animate(index: int) -> Iterable[Artist]: - changes: list[Artist] = [] - - coordinate_value = coordinate.values[index] - axes.set_title(coordinate_callable(coordinate_value)) - changes.append(axes) - - frame_data = data_array.isel({animation_dimension: index}) - frame_data.load() - prepared_data = self.prepare_data_array_for_transect(frame_data) - collection.set_array(prepared_data.values.flatten()) - changes.append(collection) - return changes - - # Draw the figure to force everything to compute its size - figure.draw_without_rendering() - - # Set the first frame of data - animate(0) - - # Make the animation - return animation.FuncAnimation( - figure, animate, frames=coordinate_indexes, - interval=interval) - - def _plot_on_figure( - self, - figure: Figure, - data_array: xarray.DataArray, - *, - title: str | None = None, - trim_nans: bool = True, - clamp_to_surface: bool = True, - bathymetry: xarray.DataArray | None = None, - cmap: str | Colormap | None = None, - clim: tuple[float, float] | None = None, - ocean_floor_colour: str = 'black', - landmarks: list[Landmark] | None = None, - ) -> tuple[Axes, PolyCollection, xarray.DataArray]: - """ - Construct the axes and PolyCollections on a plot, - and reformat the data array to the correct shape for plotting. - Assigning the data is left to the caller, - to support both static and animated plots. - """ - transect_dataset = self.transect_dataset - depth = transect_dataset.coords['depth'] - distance_bounds = transect_dataset.data_vars['distance_bounds'] - - data_array = data_array.load() - data_array = self.prepare_data_array_for_transect(data_array) - - positive_down = depth.attrs['positive'] == 'down' - d1, d2 = depth.values[0:2] - deep_to_shallow = (d1 > d2) == positive_down - - if trim_nans: - depth_start, depth_stop = self._find_depth_bounds(data_array) - else: - depth_start, depth_stop = 0, -1 - if deep_to_shallow: - depth_start, depth_stop = depth_stop, depth_start - - down, up = ( - (numpy.nanmax, numpy.nanmin) - if positive_down - else (numpy.nanmin, numpy.nanmax)) - if clamp_to_surface: - depth_limit_shallow = 0 - else: - depth_limit_shallow = up(transect_dataset['depth_bounds'][depth_start]) - depth_limit_deep = down(transect_dataset['depth_bounds'][depth_stop]) - - axes = cast(Axes, figure.subplots()) - x_title, x_formatter = self._set_up_axis(distance_bounds) - y_title, y_formatter = self._set_up_axis(depth) - axes.set_xlabel(x_title) - axes.set_ylabel(y_title) - axes.xaxis.set_major_formatter(x_formatter) - axes.yaxis.set_major_formatter(y_formatter) - axes.set_xlim( - distance_bounds.attrs['start_distance'], - distance_bounds.attrs['end_distance'], - ) - axes.set_ylim(depth_limit_deep, depth_limit_shallow) - - if title is None: - title = make_plot_title(self.dataset, data_array) - if title is not None: - axes.set_title(title) - - cmap = pyplot.get_cmap(cmap).copy() - cmap.set_bad(ocean_floor_colour) - - # Find a min/max from the data if clim isn't provided and the data array is not empty. - # An empty data array happens when the transect line does not intersect - # the dataset geometry. - if clim is None and data_array.size != 0: - clim = (numpy.nanmin(data_array), numpy.nanmax(data_array)) - - collection = self.make_poly_collection(cmap=cmap, clim=clim, edgecolor='face') - axes.add_collection(collection) - - if bathymetry is not None: - ocean_floor = self.make_ocean_floor_poly_collection( - bathymetry, facecolor=ocean_floor_colour) - axes.add_collection(ocean_floor) - - units = data_array.attrs.get('units') - figure.colorbar(collection, ax=axes, location='right', label=units) - - if landmarks is not None: - top_axis = axes.secondary_xaxis('top') - top_axis.set_ticks( - [self.distance_along_line(point) for label, point in landmarks], - [label for label, point in landmarks], - ) - - return axes, collection, data_array diff --git a/src/emsarray/transect/__init__.py b/src/emsarray/transect/__init__.py new file mode 100644 index 0000000..34808d5 --- /dev/null +++ b/src/emsarray/transect/__init__.py @@ -0,0 +1,8 @@ +from .base import Transect, TransectPoint, TransectSegment +from .utils import setup_depth_axis, setup_distance_axis + + +__all__ = [ + 'Transect', 'TransectPoint', 'TransectSegment', + 'setup_depth_axis', 'setup_distance_axis', +] diff --git a/src/emsarray/transect/artists.py b/src/emsarray/transect/artists.py new file mode 100644 index 0000000..2dc3d67 --- /dev/null +++ b/src/emsarray/transect/artists.py @@ -0,0 +1,126 @@ +from typing import Any + +import numpy +import xarray +from matplotlib.artist import Artist +from matplotlib.collections import QuadMesh +from matplotlib.patches import StepPatch + +from . import base + + +class TransectArtist(Artist): + """ + A matplotlib Artist subclass that knows what Transect it is associated with, + and has a `set_data_array()` method. + Users can call `TransectArtist.set_data_array()` to update the data in a plot. + This is useful when making animations, for example. + """ + _transect: 'base.Transect' + + def set_transect(self, transect: 'base.Transect') -> None: + if hasattr(self, '_transect'): + raise ValueError("_transect can not be changed once set") + self._transect = transect + + def get_transect(self) -> 'base.Transect': + return self._transect + + def set_data_array(self, data_array: Any) -> None: + """ + Update the data this artist is plotting. + """ + raise NotImplementedError("Subclasses must implement this") + + +class CrossSectionArtist(QuadMesh, TransectArtist): + @classmethod + def from_transect( + cls, + transect: "base.Transect", + *, + data_array: xarray.DataArray | None = None, + depth_coordinate: xarray.DataArray | None = None, + **kwargs: Any, + ) -> "CrossSectionArtist": + """ + Construct a :class:`CrossSectionArtist` for a transect. + """ + distance_bounds = transect.intersection_bounds + + if depth_coordinate is None and data_array is None: + raise ValueError( + "At least one of data_array and depth_coordinate must be not None") + if depth_coordinate is None: + depth_coordinate = transect.convention.get_depth_coordinate_for_data_array(data_array) + depth_bounds = transect.dataset[depth_coordinate.attrs['bounds']].values + + holes = transect.holes + xs = numpy.concat([distance_bounds[:, 0], distance_bounds[-1:, 1]]) + xs = numpy.insert(xs, holes, distance_bounds[holes - 1, 1]) + ys = numpy.concat([depth_bounds[:, 0], depth_bounds[-1:, 1]]) + coordinates = numpy.stack(numpy.meshgrid(xs, ys), axis=-1) + + # There are issues with passing both transect and data array to the constructor + # where the `set_data_array()` is called before `set_transect()`. + # Doing it this way is safe but kinda gross. + artist = cls(coordinates, transect=transect, **kwargs) + if data_array is not None: + artist.set_data_array(data_array) + + return artist + + def set_data_array(self, data_array: xarray.DataArray) -> None: + self.set_array(self.prepare_data_array(self._transect, data_array)) + + @staticmethod + def prepare_data_array(transect: "base.Transect", data_array: xarray.DataArray) -> numpy.ndarray: + values = transect.extract(data_array).values + values = numpy.insert(values, transect.holes, numpy.nan, axis=-1) + return values + + +class TransectStepArtist(StepPatch, TransectArtist): + _edge_default = True + + @classmethod + def from_transect( + cls, + transect: "base.Transect", + *, + data_array: xarray.DataArray | None = None, + **kwargs: Any, + ) -> "TransectStepArtist": + """ + Construct a :class:`TransectStepArtist` for a transect. + """ + holes = transect.holes + x_bounds = transect.intersection_bounds + edges = x_bounds[:, 0] + edges = numpy.append(edges, x_bounds[-1, 1]) + edges = numpy.insert(edges, holes, x_bounds[holes - 1, 1]) + + if data_array is not None: + values = cls.prepare_data_array(transect, data_array) + else: + values = numpy.full(shape=(len(edges) - 1,), fill_value=numpy.nan) + + return cls(values, edges, transect=transect, **kwargs) + + def set_data_array(self, data_array: xarray.DataArray) -> None: + self.set_data(self.prepare_data_array(self._transect, data_array)) + + @staticmethod + def prepare_data_array(transect: "base.Transect", data_array: xarray.DataArray) -> numpy.ndarray: + values = transect.extract(data_array).values + assert len(values.shape) == 1 + + # If a transect path is not fully contained within the dataset geometry + # the path will have gaps. We can represent these gaps using nans. + values = numpy.insert( + values.astype(float), # Upcast to float in case this was an integer array + transect.holes, + numpy.nan, + ) + + return values diff --git a/src/emsarray/transect/base.py b/src/emsarray/transect/base.py new file mode 100644 index 0000000..ce71629 --- /dev/null +++ b/src/emsarray/transect/base.py @@ -0,0 +1,608 @@ +import dataclasses +from collections.abc import Iterable +from functools import cached_property +from typing import Any, cast + +import numpy +import shapely +import xarray +from cartopy import crs +from matplotlib.axes import Axes +from matplotlib.typing import ColorType + +from emsarray.conventions import Convention, Grid +from emsarray.exceptions import NoSuchCoordinateError +from emsarray.types import DataArrayOrName +from emsarray.utils import move_dimensions_to_end, name_to_data_array + +from . import artists + + + +# Useful for calculating distances in a AzimuthalEquidistant projection +# centred on some point: +# +# az = crs.AzimuthalEquidistant(p1.x, p1.y) +# distance = az.project_geometry(p2).distance(ORIGIN) +ORIGIN = shapely.Point(0, 0) + + +@dataclasses.dataclass +class TransectPoint: + """ + A TransectPoint holds information about each vertex along a transect path. + """ + #: The original point, in the CRS of the line string / dataset. + point: shapely.Point + + #: An AzimuthalEquidistant CRS centred on this point. + crs: crs.AzimuthalEquidistant + + #: The distance in metres of this point along the line. + distance_metres: float + + #: The projected distance along the line of this point. + #: This is normalised to [0, 1]. + #: The actual value is meaningless but can be used to find + #: the closest vertex on the line string for any other projected point. + distance_normalised: float + + +@dataclasses.dataclass +class TransectSegment: + """ + A TransectSegment holds information about each intersecting segment of the + transect path and the dataset cells. + """ + #: The point where the transect path first intersects this dataset cell + start_point: shapely.Point + #: The point where the transect exits this dataset cell + end_point: shapely.Point + #: The entire intersection between the transect path and this dataset cell + intersection: shapely.LineString + #: The distance along the line in metres to the :attr:`.start_point` + start_distance: float + #: The distance along the line in metres to the :attr:`.end_point` + end_distance: float + #: The linear index of this dataset cell + linear_index: int + #: The polygon of the dataset cell + polygon: shapely.Polygon + + +class Transect: + """ + """ + #: The dataset to plot a transect through + dataset: xarray.Dataset + + #: The transect path to plot + line: shapely.LineString + + #: The dataset grid to transect. + grid: Grid + + def __init__( + self, + dataset: xarray.Dataset, + line: shapely.LineString, + *, + grid: Grid | None = None, + ): + self.dataset = dataset + self.convention = cast(Convention, dataset.ems) + self.line = line + if grid is None: + grid = self.convention.default_grid + self.grid = grid + + @cached_property + def intersection_bounds( + self, + ) -> numpy.ndarray: + """ + A numpy array of shape `(len(segments), 2)` + indicating the distance to the start and end of each intersection segment. + This is a shortcut to :attr:`TransectSegment.start_distance` + and :attr:`~TransectSegment.end_distance` from :attr:`Transect.segments`. + """ + return numpy.fromiter( + ( + [segment.start_distance, segment.end_distance] + for segment in self.segments + ), + # Be explicit here, to handle the case when len(self.segments) == 0. + # This happens when the transect line does not intersect the dataset. + # This will result in an empty transect plot. + count=len(self.segments), + dtype=numpy.dtype((float, 2)), + ) + + @cached_property + def linear_indexes(self) -> numpy.ndarray: + """ + A numpy array of length `len(segments)` + of the linear indexes of each intersecting polygon, in order. + This is a shortcut to :attr:`TransectSegment.linear_index` + from :attr:`Transect.segments`. + """ + return numpy.fromiter( + (segment.linear_index for segment in self.segments), + count=len(self.segments), + dtype=numpy.dtype(int), + ) + + @cached_property + def holes(self) -> numpy.ndarray: + """ + An array with the index of any discontinuities in the transect segments. + For transect paths that are entirely within the dataset geometry this will be empty. + For paths that pass in and out of the dataset geometry + this will be the index of the segment just after the discontinuity. + Two segments are not contiguous if `segment[n].end_distance != segment[n+1].start_distance` + """ + bounds = self.intersection_bounds + return numpy.flatnonzero(bounds[:-1, 1] != bounds[1:, 0]) + 1 + + def _crs_for_point( + self, + point: shapely.Point, + globe: crs.Globe | None = None, + ) -> crs.Projection: + return crs.AzimuthalEquidistant( + central_longitude=point.x, central_latitude=point.y, globe=globe) + + @cached_property + def points( + self, + ) -> list[TransectPoint]: + """ + A list of :class:`TransectPoints `, + one for each point in the transect :attr:`.line`. + """ + data_crs = self.convention.data_crs + globe = data_crs.globe + + # Make the TransectPoint for the first point by hand. + point = shapely.Point(self.line.coords[0]) + points = [TransectPoint( + point=point, + crs=self._crs_for_point(point, globe), + distance_metres=0, + distance_normalised=0, + )] + + # Make a TransectPoint for each subsequent point along the line. + for point in map(shapely.Point, self.line.coords[1:]): + previous = points[-1] + + # Calculate the distance from the previous point + # by using the AzimuthalEquidistant CRS centred on the previous point. + distance_from_previous = ORIGIN.distance( + previous.crs.project_geometry(point, src_crs=data_crs)) + + points.append(TransectPoint( + point=point, + crs=self._crs_for_point(point, globe), + distance_metres=previous.distance_metres + distance_from_previous, + distance_normalised=self.line.project(point, normalized=True) + )) + + return points + + @cached_property + def segments(self) -> list[TransectSegment]: + """ + A list of :class:`TransectSegments <.TransectSegment>` + for each intersecting segment of the transect line and the dataset geometry. + Segments are listed in order from the start of the line to the end of the line. + """ + segments = [] + + grid = self.convention.grids[self.convention.default_grid_kind] + polygons = grid.geometry + + # Find all the cell polygons that intersect the line + intersecting_indexes = grid.strtree.query(self.line, predicate='intersects') + + for linear_index in intersecting_indexes: + polygon = polygons[linear_index] + for intersection in self._intersect_polygon(polygon): + # The line will have two ends. + # The intersection starts and ends at these points. + # Project those points alone the original line to find + # the start and end distance of the intersection along the line. + points = [ + shapely.Point(intersection.coords[0]), + shapely.Point(intersection.coords[-1]) + ] + projections: Iterable[tuple[shapely.Point, float]] = ( + (point, self.distance_along_line(point)) + for point in points) + start, end = sorted(projections, key=lambda pair: pair[1]) + + segments.append(TransectSegment( + start_point=start[0], + end_point=end[0], + intersection=intersection, + start_distance=start[1], + end_distance=end[1], + linear_index=linear_index, + polygon=polygon, + )) + + return sorted(segments, key=lambda i: (i.start_distance, i.end_distance)) + + @cached_property + def coordinates(self) -> xarray.Dataset: + """ + A :class:`xarray.Dataset` containing coordinate information + for data extracted along the transect. + """ + index_dim = 'index' + coordinates = xarray.Dataset( + coords={ + 'distance': xarray.DataArray( + data=numpy.average(self.intersection_bounds, axis=1), + dims=index_dim, + attrs={ + 'long_name': 'Distance along transect', + 'units': 'm', + 'bounds': 'distance_bounds', + }, + ), + 'distance_bounds': xarray.DataArray( + data=self.intersection_bounds, + dims=(index_dim, 'Two'), + ), + } + ) + return coordinates + + def _intersect_polygon( + self, + polygon: shapely.Polygon, + ) -> list[shapely.LineString]: + """ + Intersect a cell of the dataset geometry with the transect line, + and return a list of all LineString segments of the intersection. + This assumes that the cell does intersect the transect line. + A line and a polygon can intersect in a number of ways: + + * a simple cut through the polygon + * the line starts and/or stops in the polygon + * the line intersects the polygon at a point + * the line intersects the polygon multiple times + + Only the intersections that are line segments are returned. + Multiple intersections (represented as a GeometryCollection) + are decomposed in to the component geometries. + Points are ignored. + + Parameters + ---------- + polygon : shapely.Polygon + The cell geometry to intersect + + Returns + ------- + list of shapely.LineString + All intersecting line strings + """ + intersection = polygon.intersection(self.line) + if isinstance(intersection, (shapely.GeometryCollection, shapely.MultiLineString)): + geoms = intersection.geoms + else: + geoms = [intersection] + return [geom for geom in geoms if isinstance(geom, shapely.LineString)] + + def distance_along_line(self, point: shapely.Point) -> float: + """ + Calculate the distance in metres that the point + falls along the :attr:`transect line <.line>`. + If the point is not on the line, + the point is projected on to the line + and the distance is calculated to this point instead. + + This can be used to calculate the distance along the transect line + to landmark features. + These landmark features can be added as tick points along the transect. + The landmark features need not fall directly on the line. + + Parameters + ---------- + point : shapely.Point + The point to calculate the distance to + + Returns + ------- + float + The distance the point is along the line in meters. + If the point does not fall on the line, + the point is first projected to the line. + """ + data_crs = self.convention.data_crs + distance_normalised = self.line.project(point, normalized=True) + if distance_normalised < 0 or distance_normalised > 1: + raise ValueError("Point is not on the line!") + + # Find the TransectPoint for the vertex before this point on the line + line_point = next( + lp for lp in reversed(self.points) + if lp.distance_normalised <= distance_normalised) + + distance_from_point: float = ORIGIN.distance( + line_point.crs.project_geometry(point, src_crs=data_crs)) + return line_point.distance_metres + distance_from_point + + def extract(self, data_array: xarray.DataArray) -> xarray.DataArray: + """ + Extract data from a data array along a transect. + + Parameters + ---------- + data_array : xarray.DataArray + The data array to extract data from. + + Returns + ------- + xarray.DataArray + A new :class:`xarray.DataArray` containing data from the input data array + extracted along the path of the transect. + """ + # Some of the following operations drop attrs, + # so keep a reference to the original ones + attrs = data_array.attrs + + data_array = self.convention.ravel(data_array) + + index_dimension = data_array.dims[-1] + data_array = move_dimensions_to_end(data_array, [index_dimension]) + + data_array = data_array.isel({index_dimension: self.linear_indexes}) + + # Restore attrs after reformatting + data_array.attrs.update(attrs) + + return data_array + + def make_artist( + self, + axes: Axes, + data_array: DataArrayOrName, + **kwargs: Any, + ) -> 'artists.TransectArtist': + """ + Make an artist to plot values extracted from a data array along this transect. + The kind of artist used depends on the dimensionality of the data array. + + To be plotted along a transect the data array must be defined on a supported :ref:`grid `. + Currently only polygonal grids are supported. + + If a data array has a depth axis, :meth:`.make_cross_section_artist` is called, + otherwise :meth:`.make_transect_step_artist` is called. + + Parameters + ========== + axes : Axes + The :class:`matplotlib.axes.Axes` to add this artist to. + data_array : DataArrayOrName + The data array to plot + **kwargs + Passed on to the artist, can be used to customise the plot style. + + Returns + ======= + :class:`.artists.TransectArtist` + The artist that will plot the data. + This artist will already have been added to the axes. + + See also + ======== + :func:`~.utils.setup_distance_axis` + Setup the x-axis of an :class:`~matplotlib.axes.Axes` + for plotting distance along a transect. + :func:`~.utils.setup_depth_axis` + Setup the y-axis of an :class:`~matplotlib.axes.Axes` + for plotting down a depth coordinate. + """ + data_array = name_to_data_array(self.dataset, data_array) + grid = self.convention.get_grid(data_array) + try: + depth_coordinate = self.convention.get_depth_coordinate_for_data_array(data_array) + except NoSuchCoordinateError: + depth_coordinate = None + + if grid.geometry_type is not shapely.Polygon: + raise ValueError( + f"I don't know how to plot transects across {grid.geometry_type.__name__} geometry.") + + if depth_coordinate is not None: + return self.make_cross_section_artist(axes, data_array, **kwargs) + else: + return self.make_transect_step_artist(axes, data_array, **kwargs) + + def make_cross_section_artist( + self, + axes: Axes, + data_array: DataArrayOrName, + colorbar: bool = True, + **kwargs: Any, + ) -> 'artists.CrossSectionArtist': + """ + Make an artist that plots a vertical slice along the length of the transect. + The data must be three dimensional with a depth axis. + The data are plotted as a grid of values, + with distance along the transect as the x-axis + and depth represented as the y-axis. + + Parameters + ========== + axes : Axes + The :class:`matplotlib.axes.Axes` to add this line to. + data_array : DataArrayOrName + The data array to plot. + This data array must be defined on a polygonal :class:`~emsarray.conventions.Grid` + and must have a depth coordinate with bounds. + colorbar : bool, default True + Whether to add a colorbar for this artist. + Sensible defaults are used for the colorbar, but if more customisation is required + set `colorbar=False` and configure a colorbar manually. + edgecolor : color, optional + The colour of the line. + Optional, defaults to the next available colour in the matplotlib plot colours. + fill : bool, default False + Whether to fill in values between the line and the baseline. + Defaults to False. + **kwargs + Passed on to the :class:`~.artists.TransectStepArtist`, + can be used to customise the plot. + + See also + ======== + :func:`~.utils.setup_distance_axis` + Setup the x-axis of an :class:`~matplotlib.axes.Axes` + for plotting distance along a transect. + :func:`~.utils.setup_depth_axis` + Setup the y-axis of an :class:`~matplotlib.axes.Axes` + for plotting down a depth coordinate. + :func:`~emsarray.utils.estimate_bounds_1d` + Estimate some bounds for a coordinate variable. + """ + data_array = name_to_data_array(self.dataset, data_array) + artist = artists.CrossSectionArtist.from_transect( + self, data_array=data_array, + **kwargs) + axes.add_artist(artist) + if colorbar: + units = data_array.attrs.get('units', None) + axes.figure.colorbar(artist, label=units) + return artist + + def make_transect_step_artist( + self, + axes: Axes, + data_array: DataArrayOrName, + edgecolor: ColorType | None = 'auto', + fill: bool = False, + **kwargs: Any, + ) -> 'artists.TransectStepArtist': + """ + Make an artist that plots values along the length of the transect. + The data must be two dimensional - it must have no depth axis. + The data are plotted as a stepped line. + + Parameters + ========== + axes : Axes + The :class:`matplotlib.axes.Axes` to add this line to. + data_array : DataArrayOrName + The data array to plot. + This data array must be defined on a polygonal :class:`~emsarray.conventions.Grid` + and must not have any other dimensions such as time or depth. + edgecolor : color, optional + The colour of the line. + Optional, defaults to the next available colour in the matplotlib plot colours. + fill : bool, default False + Whether to fill in values between the line and the baseline. + Defaults to False. + **kwargs + Passed on to the :class:`~.artists.TransectStepArtist`, + can be used to customise the plot. + + Returns + ======= + :class:`~.artists.TransectStepArtist` + The artist that will plot the data. + This artist will already have been added to the axes. + + See also + ======== + :func:`.utils.setup_distance_axis` + Setup the x-axis of an :class:`~matplotlib.axes.Axes` + for plotting distance along a transect. + """ + data_array = name_to_data_array(self.dataset, data_array) + if edgecolor == 'auto': + edgecolor = axes._get_lines.get_next_color() + artist = artists.TransectStepArtist.from_transect( + self, data_array=data_array, + fill=fill, edgecolor=edgecolor, **kwargs) + axes.add_artist(artist) + return artist + + def make_ocean_floor_artist( + self, + axes: Axes, + data_array: DataArrayOrName, + fill: bool = True, + facecolor: ColorType | None = 'lightgrey', + edgecolor: ColorType | None = 'none', + baseline: float | None = None, + **kwargs: Any, + ) -> 'artists.TransectStepArtist': + """ + Make an artist that renders a solid polygon following a bathymetry variable. + This can be drawn in front of a cross section artist to mask out values below the ocean floor. + + Parameters + ========== + axes : Axes + The :class:`matplotlib.axes.Axes` to add the ocean floor artist to + data_array : DataArrayOrName + The data array or name of a data array with the ocean floor data + baseline : float, optional + The deepest part of the ocean floor to render. + The ocean floor will be filled in from the bathymetry value down to the baseline. + Optional, if not provided the deepest value in the data array is used instead. + **kwargs + Passed on to the :class:`.artists.TransectStepArtist` for styling. + Set `facecolor` to change the colour of the ocean floor polygon. + + Returns + ======= + .artists.TransectStepArtist + The artist that will render the ocean floor. + This artist will already have been added to the axes. + + Notes + ===== + The `sign convention `_ + of the bathymetry variable and the depth coordinate must match. + If they differ the ocean floor polygon is likely to be either + entirely outside of the plot extent or to cover the entire plot extent. + :func:`~emsarray.operations.depth.normalize_depth_variables` + can be used to change the sign convention of a depth coordinate variable. + + See also + ======== + `CF Conventions on Vertical Coordinates `_ + More information on the `positive` attribute. + :func:`emsarray.operations.depth.normalize_depth_variables` + Update the sign convention of a depth coordinate variable. + + .. _CF-vertical-coordinates: https://cfconventions.org/Data/cf-conventions/cf-conventions-1.8/cf-conventions.html#vertical-coordinate + + """ + data_array = name_to_data_array(self.dataset, data_array) + if baseline is None: + data_min, data_max = numpy.nanmin(data_array.values), numpy.nanmax(data_array.values) + if 'positive' in data_array.attrs: + if data_array.attrs['positive'] == 'down': + baseline = data_max + else: + baseline = data_min + else: + # Take a guess by using the most extreme value + if numpy.abs(data_min) < numpy.abs(data_max): + baseline = data_max + else: + baseline = data_min + + artist = artists.TransectStepArtist.from_transect( + self, data_array=data_array, + fill=fill, baseline=baseline, + facecolor=facecolor, edgecolor=edgecolor, + **kwargs) + axes.add_artist(artist) + return artist diff --git a/src/emsarray/transect/utils.py b/src/emsarray/transect/utils.py new file mode 100644 index 0000000..a4a7b47 --- /dev/null +++ b/src/emsarray/transect/utils.py @@ -0,0 +1,96 @@ +import cfunits +import numpy +from matplotlib.axes import Axes +from matplotlib.ticker import EngFormatter + +from emsarray.types import DataArrayOrName +from emsarray.utils import name_to_data_array +from .base import Transect + + +def setup_distance_axis(transect: Transect, axes: Axes) -> None: + """ + Configure the x-axis of a :class:`~matplotlib.axes.Axes` for values along a transect. + + Parameters + ========== + transect : emsarray.transect.Transect + The transect being plotted + axes : matplotlib.axes.Axes + The axes to configure + """ + axis = axes.xaxis + + axes.set_xlim(transect.points[0].distance_metres, transect.points[-1].distance_metres) + axis.set_label_text("Distance along transect") + axis.set_major_formatter(EngFormatter(unit='m')) + + +def setup_depth_axis( + transect: "Transect", + axes: Axes, + data_array: DataArrayOrName | None = None, + depth_coordinate: DataArrayOrName | None = None, + ylim: tuple[float, float] | bool = True, + label: str | None | bool = True, + units: str | None | bool = True, +) -> None: + """ + Configure the y-axis of a :class:`~matplolib.axes.Axes` for values along a depth coordinate. + + Parameters + ========== + transect : emsarray.transect.Transect + The transect being plotted + axes : matplotlib.axes.Axes + The axes to configure + data_array : DataArrayOrName, optional + depth_coordinate : DataArrayOrName, optional + One of `data_array` or `depth_coordinate` must be provided. + The y-axis is configured to show values along this depth coordinate. + If data_array is provided, the depth coordinate for this data array is used. + ylim : tuple of float, float, optional + The ylim of the axes. If not provided the limit is calculated from the depth coordinate. + label : str or None, optional + The label for the y-axis. + Optional, defaults to the `long_name` attribute of the depth coordinate. + Set to `None` to disable the label. + units : str or None, optional + The units for the y-axis. + Optional, defaults to the `units` attribute of the depth coordinate. + Set to `None` to disable the units and formatting of tick labels. + """ + if data_array is None and depth_coordinate is None: + raise ValueError("Either data_array or depth_coordinate must be provided") + if data_array is not None and depth_coordinate is not None: + raise ValueError("Only one of data_array or depth_bounds must be provided") + + if data_array is not None: + depth_coordinate = transect.convention.get_depth_coordinate_for_data_array(data_array) + else: + depth_coordinate = name_to_data_array(transect.dataset, depth_coordinate) + + axis = axes.yaxis + + if ylim is True: + depth_bounds = transect.dataset[depth_coordinate.attrs['bounds']].values + positive_down = depth_coordinate.attrs['positive'].lower() == 'down' + depth_min, depth_max = numpy.nanmin(depth_bounds), numpy.nanmax(depth_bounds) + + if positive_down: + axes.set_ylim(depth_max, depth_min) + else: + axes.set_ylim(depth_min, depth_max) + elif ylim not in {False, None}: + axes.set_ylim(ylim) + + if label is True: + label = depth_coordinate.attrs.get('long_name') + if label not in {False, None}: + axis.set_label_text(label) + + if units is True: + units = depth_coordinate.attrs.get('units') + if units not in {False, None}: + formatted_units = cfunits.Units(units).formatted() + axis.set_major_formatter(EngFormatter(unit=formatted_units)) diff --git a/src/emsarray/utils.py b/src/emsarray/utils.py index 17aa2f8..d3542cc 100644 --- a/src/emsarray/utils.py +++ b/src/emsarray/utils.py @@ -545,6 +545,18 @@ def find_unused_dimension( if candidate not in existing_dims) +def find_unused_name( + dataset: xarray.Dataset, + candidate: str, +) -> str: + if candidate not in dataset.variables.keys(): + return candidate + candidates = (f'{candidate}_{suffix}' for suffix in itertools.count(start=0)) + return next( + candidate for candidate in candidates + if candidate not in dataset.variables.keys()) + + def ravel_dimensions( data_array: xarray.DataArray, dimensions: list[Hashable], @@ -936,3 +948,66 @@ def data_array_to_name(dataset: xarray.Dataset, data_array: DataArrayOrName) -> if data_array not in dataset.variables: raise ValueError(f"Data array {data_array!r} is not in the dataset") return data_array + + +def estimate_bounds_1d( + dataset: xarray.Dataset, + coordinate: DataArrayOrName, +) -> xarray.Dataset: + """ + Estimate the bounds of a one dimensional coordinate variable. + The bounds between two coordinates is the average of the two values, + while the bounds on each end are the first and last coordinate values. + This is a crude approach. + + Parameters + ========== + dataset : xarray.Dataset + The dataset containing the coordinate. + coordinate : xarray.DataArray or str + The coordinate variable to estimate the bounds of. + + Returns + ======= + xarray.Dataset + A copy of the original dataset including the new estimated bounds. + + Raises + ====== + ValueError + Raised if the coordinate variable already has a 'bounds' attribute. + """ + dataset = dataset.copy() + coordinate = name_to_data_array(dataset, coordinate) + if 'bounds' in coordinate.attrs: + raise ValueError("Coordinate already has a bounds attribute") + + bounds_name = find_unused_name(dataset, f'{coordinate.name}_bounds') + values = coordinate.values + midpoints = (values[:-1] + values[1:]) / 2 + midpoints = numpy.concat([[values[0]], midpoints, [values[-1]]]) + dataset[bounds_name] = xarray.DataArray( + name=bounds_name, + data=numpy.c_[midpoints[:-1], midpoints[1:]], + ) + dataset = dataset.set_coords(bounds_name) + dataset[coordinate.name].attrs['bounds'] = bounds_name + return dataset + + +def geometry_plus_bounds(dataset: xarray.Dataset, names: list[Hashable]) -> list[Hashable]: + bounds_names: list[Hashable] = [] + for name in names: + print("Checking", name) + data_array = dataset[name] + if 'bounds' not in data_array.attrs: + print("no bounds!") + continue + bounds_name = data_array.attrs['bounds'] + print("bounds name:", bounds_name) + if bounds_name not in dataset.variables.keys(): + print("bounds", bounds_names, "doesn't exist") + continue + print("including", bounds_name) + bounds_names.append(bounds_name) + return names + bounds_names