diff --git a/.coveragerc_omit b/.coveragerc_omit index e70d702e..6c75dc6f 100644 --- a/.coveragerc_omit +++ b/.coveragerc_omit @@ -12,5 +12,6 @@ omit = src/vitessce/data_utils/ome.py src/vitessce/data_utils/entities.py src/vitessce/data_utils/multivec.py + src/vitessce/data_utils/spatialdata_points_zorder.py src/vitessce/widget_plugins/demo_plugin.py src/vitessce/widget_plugins/spatial_query.py \ No newline at end of file diff --git a/docs/notebooks/spatial_data_xenium_morton.ipynb b/docs/notebooks/spatial_data_xenium_morton.ipynb new file mode 100644 index 00000000..e5efa995 --- /dev/null +++ b/docs/notebooks/spatial_data_xenium_morton.ipynb @@ -0,0 +1,387 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "nbsphinx": "hidden" + }, + "source": [ + "# Vitessce Widget Tutorial" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Visualization of a SpatialData object" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import dependencies\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from os.path import join, isfile, isdir\n", + "from urllib.request import urlretrieve\n", + "import zipfile\n", + "import shutil\n", + "\n", + "from vitessce import (\n", + " VitessceConfig,\n", + " ViewType as vt,\n", + " CoordinationType as ct,\n", + " CoordinationLevel as CL,\n", + " SpatialDataWrapper,\n", + " get_initial_coordination_scope_prefix\n", + ")\n", + "\n", + "from vitessce.data_utils import (\n", + " sdata_morton_sort_points,\n", + " sdata_points_process_columns,\n", + " sdata_points_write_bounding_box_attrs,\n", + " sdata_points_modify_row_group_size,\n", + " sdata_morton_query_rect,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from spatialdata import read_zarr" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_dir = \"data\"\n", + "zip_filepath = join(data_dir, \"xenium_rep1_io.spatialdata.zarr.zip\")\n", + "spatialdata_filepath = join(data_dir, \"xenium_rep1_io.spatialdata.zarr\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if not isdir(spatialdata_filepath):\n", + " if not isfile(zip_filepath):\n", + " os.makedirs(data_dir, exist_ok=True)\n", + " urlretrieve('https://s3.embl.de/spatialdata/spatialdata-sandbox/xenium_rep1_io.zip', zip_filepath)\n", + " with zipfile.ZipFile(zip_filepath,\"r\") as zip_ref:\n", + " zip_ref.extractall(data_dir)\n", + " os.rename(join(data_dir, \"data.zarr\"), spatialdata_filepath)\n", + " \n", + " # This Xenium dataset has an AnnData \"raw\" element.\n", + " # Reference: https://github.com/giovp/spatialdata-sandbox/issues/55\n", + " raw_dir = join(spatialdata_filepath, \"tables\", \"table\", \"raw\")\n", + " if isdir(raw_dir):\n", + " shutil.rmtree(raw_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sdata = read_zarr(spatialdata_filepath)\n", + "sdata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sdata[\"transcripts\"].shape[0].compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sdata.tables[\"table\"].X = sdata.tables[\"table\"].X.toarray()\n", + "sdata.tables[\"dense_table\"] = sdata.tables[\"table\"]\n", + "sdata.write_element(\"dense_table\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO: store the two separate images as a single image with two channels.\n", + "# Similar to https://github.com/EricMoerthVis/tissue-map-tools/pull/12" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sdata.tables['table'].obs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sdata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sdata.points['transcripts'].head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sorting Points and creating a new Points element in the SpatialData object" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 1. Sort rows with `sdata_morton_sort_points`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sdata = sdata_morton_sort_points(sdata, \"transcripts\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 2. Clean up columns with `sdata_points_process_columns`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Add feature_index column to dataframe, and reorder columns so that feature_name (dict column) is the rightmost column.\n", + "ddf = sdata_points_process_columns(sdata, \"transcripts\", var_name_col=\"feature_name\", table_name=\"table\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ddf.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 3. Save sorted dataframe to new Points element" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sdata[\"transcripts_with_morton_codes\"] = ddf\n", + "sdata.write_element(\"transcripts_with_morton_codes\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 4. Write bounding box metadata with `sdata_points_write_bounding_box_attrs`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sdata_points_write_bounding_box_attrs(sdata, \"transcripts_with_morton_codes\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Step 5. Modify the row group sizes of the Parquet files with `sdata_points_modify_row_group_size`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sdata_points_modify_row_group_size(sdata, \"transcripts_with_morton_codes\", row_group_size=25_000)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Done" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Optionally, check the number of row groups in one of the parquet file parts.\n", + "import pyarrow.parquet as pq\n", + "from os.path import join\n", + "\n", + "parquet_file = pq.ParquetFile(join(sdata.path, \"points\", \"transcripts_with_morton_codes\", \"points.parquet\", \"part.0.parquet\"))\n", + "\n", + "# Get the number of row groups in this part-0 file.\n", + "num_groups = parquet_file.num_row_groups\n", + "num_groups" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Configure Vitessce\n", + "\n", + "Vitessce needs to know which pieces of data we are interested in visualizing, the visualization types we would like to use, and how we want to coordinate (or link) the views." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vc = VitessceConfig(\n", + " schema_version=\"1.0.18\",\n", + " name='MERFISH SpatialData Demo',\n", + ")\n", + "# Add data to the configuration:\n", + "wrapper = SpatialDataWrapper(\n", + " sdata_path=spatialdata_filepath,\n", + " # The following paths are relative to the root of the SpatialData zarr store on-disk.\n", + " image_path=\"images/rasterized\",\n", + " table_path=\"tables/table\",\n", + " obs_feature_matrix_path=\"tables/table/X\",\n", + " obs_spots_path=\"shapes/cells\",\n", + " coordinate_system=\"global\",\n", + " coordination_values={\n", + " # The following tells Vitessce to consider each observation as a \"spot\"\n", + " \"obsType\": \"cell\",\n", + " }\n", + ")\n", + "dataset = vc.add_dataset(name='MERFISH').add_object(wrapper)\n", + "\n", + "# Add views (visualizations) to the configuration:\n", + "spatial = vc.add_view(\"spatialBeta\", dataset=dataset)\n", + "feature_list = vc.add_view(\"featureList\", dataset=dataset)\n", + "layer_controller = vc.add_view(\"layerControllerBeta\", dataset=dataset)\n", + "obs_sets = vc.add_view(\"obsSets\", dataset=dataset)\n", + "\n", + "vc.link_views_by_dict([spatial, layer_controller], {\n", + " 'spotLayer': CL([{\n", + " 'obsType': 'cell',\n", + " }]),\n", + "}, scope_prefix=get_initial_coordination_scope_prefix(\"A\", \"obsSpots\"))\n", + "\n", + "vc.link_views([spatial, layer_controller, feature_list, obs_sets], ['obsType'], [wrapper.obs_type_label])\n", + "\n", + "# Layout the views\n", + "vc.layout(spatial | (feature_list / layer_controller / obs_sets));" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Render the widget" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vw = vc.widget()\n", + "vw" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pyproject.toml b/pyproject.toml index b1642fc0..50974bd2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,6 +114,8 @@ dev = [ 'boto3>=1.16.30', 'scikit-misc>=0.1.3', 'autopep8>=2.0.2', + 'spatialdata>=0.3.0', + 'dask[dataframe]==2024.11.1', ] [tool.uv] diff --git a/src/vitessce/data_utils/__init__.py b/src/vitessce/data_utils/__init__.py index 0b1d4cce..77490c90 100644 --- a/src/vitessce/data_utils/__init__.py +++ b/src/vitessce/data_utils/__init__.py @@ -17,3 +17,14 @@ from .multivec import ( adata_to_multivec_zarr, ) +from .spatialdata_points_zorder import ( + # Function for computing codes and sorting + sdata_morton_sort_points, + # Other helper functions + sdata_points_process_columns, + sdata_points_write_bounding_box_attrs, + sdata_points_modify_row_group_size, + # Functions for querying + sdata_morton_query_rect, + row_ranges_to_row_indices, +) diff --git a/src/vitessce/data_utils/spatialdata_points_zorder.py b/src/vitessce/data_utils/spatialdata_points_zorder.py new file mode 100644 index 00000000..b5aa9e4b --- /dev/null +++ b/src/vitessce/data_utils/spatialdata_points_zorder.py @@ -0,0 +1,514 @@ +from typing import Tuple, List, Optional + +import os +from os.path import join +from bisect import bisect_left, bisect_right +import pandas as pd +import numpy as np + + +from spatialdata import get_element_annotators +import dask.dataframe as dd +import zarr + + +MORTON_CODE_NUM_BITS = 32 # Resulting morton codes will be stored as uint32. +MORTON_CODE_VALUE_MIN = 0 +MORTON_CODE_VALUE_MAX = 2**(MORTON_CODE_NUM_BITS / 2) - 1 + +# -------------------------- +# Functions for computing Morton codes for SpatialData points (2D). +# -------------------------- + + +def norm_series_to_uint(series, v_min, v_max): + """ + Scale numeric Series (int or float) to integer grid [0, 2^bits-1], handling NaNs. + """ + # Cast to float64 + series_f64 = series.astype("float64") + # Normalize the array values to be between 0.0 and 1.0 + norm_series_f64 = (series_f64 - v_min) / (v_max - v_min) + # Clip to ensure no values are outside 0/1 range + clipped_norm_series_f64 = np.clip(norm_series_f64, 0.0, 1.0) + # Multiply by the morton code max-value to scale from [0,1] to [0,65535] + out = (clipped_norm_series_f64 * MORTON_CODE_VALUE_MAX).astype(np.uint32) + # Set NaNs to 0. + out = out.fillna(0) + return out + + +def norm_ddf_to_uint(ddf): + [x_min, x_max, y_min, y_max] = [ddf["x"].min().compute(), ddf["x"].max().compute(), ddf["y"].min().compute(), ddf["y"].max().compute()] + ddf["x_uint"] = norm_series_to_uint(ddf["x"], x_min, x_max) + ddf["y_uint"] = norm_series_to_uint(ddf["y"], y_min, y_max) + + # Insert the bounding box as metadata for the sdata.points[element] Points element dataframe. + # TODO: does anything special need to be done to ensure this is saved to disk? + ddf.attrs["bounding_box"] = { + "x_min": float(x_min), + "x_max": float(x_max), + "y_min": float(y_min), + "y_max": float(y_max), + } + + return ddf + + +def _part1by1_16(x): + """ + Spread each 16-bit value into 32 bits by inserting zeros between bits. + Input: uint32 array (values must fit in 16 bits) + Output: uint32 array (bit-spread) + """ + + assert x.dtype.name == 'uint32' + + # Mask away any bits above 16 (just in case input wasn't clean). + x = x & np.uint32(0x0000FFFF) + + # First spread: shift left by 8 bits, OR with original, then mask. + # After this, groups of 8 bits are separated by 8 zeros. + x = (x | np.left_shift(x, 8)) & np.uint32(0x00FF00FF) + + # Spread further: now groups of 4 bits separated by 4 zeros. + x = (x | np.left_shift(x, 4)) & np.uint32(0x0F0F0F0F) + + # Spread further: groups of 2 bits separated by 2 zeros. + x = (x | np.left_shift(x, 2)) & np.uint32(0x33333333) + + # Final spread: single bits separated by a zero bit. + # Now each original bit is in every other position (positions 0,2,4,...). + x = (x | np.left_shift(x, 1)) & np.uint32(0x55555555) + + return x + + +def _part1by1_32(x): + """ + Spread each 32-bit value into 64 bits by inserting zeros between bits. + Input: uint64 array (values must fit in 32 bits) + Output: uint64 array (bit-spread) + """ + + assert x.dtype.name == 'uint64' + + # Mask away any bits above 32 (safety). + x = x.astype(np.uint64) & np.uint64(0x00000000FFFFFFFF) + + # First spread: separate into 16-bit chunks spaced out. + x = (x | np.left_shift(x, 16)) & np.uint64(0x0000FFFF0000FFFF) + + # Spread further: each 8-bit chunk separated. + x = (x | np.left_shift(x, 8)) & np.uint64(0x00FF00FF00FF00FF) + + # Spread further: each 4-bit nibble separated. + x = (x | np.left_shift(x, 4)) & np.uint64(0x0F0F0F0F0F0F0F0F) + + # Spread further: 2-bit groups separated. + x = (x | np.left_shift(x, 2)) & np.uint64(0x3333333333333333) + + # Final spread: single bits separated by zeros. + # Now each original bit occupies every other position (0,2,4,...). + x = (x | np.left_shift(x, 1)) & np.uint64(0x5555555555555555) + + return x + + +def morton_interleave(ddf): + """ + Vectorized Morton interleave for integer arrays xi, yi + already scaled to [0, 2^bits - 1]. + Returns Morton codes as uint32 (if bits<=16) or uint64 (if bits<=32). + """ + + xi = ddf["x_uint"] + yi = ddf["y_uint"] + + # Spread x and y bits into even (x) and odd (y) positions. + xs = _part1by1_16(xi) + ys = _part1by1_16(yi) + + # Interleave: shift y bits left by 1 so they go into odd positions, + # then OR with x bits in even positions. + code = np.left_shift(ys.astype(np.uint64), 1) | xs.astype(np.uint64) + + # Fits in 32 bits since we only had 16+16 input bits. + return code.astype(np.uint32) + + +def sdata_morton_sort_points(sdata, element): + ddf = sdata.points[element] + + # Compute morton codes + ddf = norm_ddf_to_uint(ddf) + ddf["morton_code_2d"] = morton_interleave(ddf) + + if "z" in ddf.columns: + num_unique_z = ddf["z"].unique().shape[0].compute() + if num_unique_z < 100: + # Heuristic for interpreting the 3D data as 2.5D + # Reference: https://github.com/scverse/spatialdata/issues/961 + sorted_ddf = ddf.sort_values(by=["z", "morton_code_2d"], ascending=True) + else: + # TODO: include z as a dimension in the morton code in the 3D case? + + # For now, just return the data sorted by 2D code. + sorted_ddf = ddf.sort_values(by="morton_code_2d", ascending=True) + else: + sorted_ddf = ddf.sort_values(by="morton_code_2d", ascending=True) + sdata.points[element] = sorted_ddf + + # annotating_tables = get_element_annotators(sdata, element) + + # TODO: Sort any annotating table(s) as well. + + return sdata + + +def sdata_morton_query_rect_aux(sdata, element, orig_rect): + # orig_rect = [[50, 50], [100, 150]] # [[x0, y0], [x1, y1]] + # norm_rect = [ + # orig_coord_to_norm_coord(orig_rect[0], orig_x_min=0, orig_x_max=100, orig_y_min=0, orig_y_max=200), + # orig_coord_to_norm_coord(orig_rect[1], orig_x_min=0, orig_x_max=100, orig_y_min=0, orig_y_max=200) + # ] + + sorted_ddf = sdata.points[element] + + # TODO: fail if no morton_code_2d column + # TODO: fail if not sorted as expected + # TODO: fail if no bounding box metadata + + bounding_box = sorted_ddf.attrs["bounding_box"] + x_min = bounding_box["x_min"] + x_max = bounding_box["x_max"] + y_min = bounding_box["y_min"] + y_max = bounding_box["y_max"] + + norm_rect = [ + orig_coord_to_norm_coord(orig_rect[0], orig_x_min=x_min, orig_x_max=x_max, orig_y_min=y_min, orig_y_max=y_max), + orig_coord_to_norm_coord(orig_rect[1], orig_x_min=x_min, orig_x_max=x_max, orig_y_min=y_min, orig_y_max=y_max) + ] + + # Get a list of morton code intervals that cover this rectangle region + # [ (morton_start, morton_end), ... ] + morton_intervals = zcover_rectangle( + rx0=norm_rect[0][0], ry0=norm_rect[0][1], + rx1=norm_rect[1][0], ry1=norm_rect[1][1], + bits=16, + stop_level=None, + merge=True, + ) + + return morton_intervals + + +def sdata_morton_query_rect(sdata, element, orig_rect): + sorted_ddf = sdata.points[element] + + # TODO: generalize to 3D morton codes + + morton_intervals = sdata_morton_query_rect_aux(sdata, element, orig_rect) + + # Get morton code column as a list of integers + morton_sorted = sorted_ddf["morton_code_2d"].compute().values.tolist() + + # Get a list of row ranges that match the morton intervals. + # (This uses binary searches internally to find the matching row indices). + # [ (row_start, row_end), ... ] + matching_row_ranges = zquery_rows(morton_sorted, morton_intervals, merge=True) + + return matching_row_ranges + + +def sdata_morton_query_rect_debug(sdata, element, orig_rect): + # This is the same as the above sdata_morton_query_rect function, + # but it also returns the list of row indices that were checked + # during the binary searches. + sorted_ddf = sdata.points[element] + morton_intervals = sdata_morton_query_rect_aux(sdata, element, orig_rect) + morton_sorted = sorted_ddf["morton_code_2d"].compute().values.tolist() + matching_row_ranges, rows_checked = zquery_rows_aux(morton_sorted, morton_intervals, merge=True) + return matching_row_ranges, rows_checked + +# -------------------------- +# Functions for rectangle queries. +# -------------------------- + +# Convert a coordinate from the normalized [0, 65535] space to the original space. + + +def norm_coord_to_orig_coord(norm_coord, orig_x_min, orig_x_max, orig_y_min, orig_y_max): + [norm_x, norm_y] = norm_coord + orig_x_range = orig_x_max - orig_x_min + orig_y_range = orig_y_max - orig_y_min + return [ + (orig_x_min + (norm_x / MORTON_CODE_VALUE_MAX) * orig_x_range), + (orig_y_min + (norm_y / MORTON_CODE_VALUE_MAX) * orig_y_range), + ] + +# Convert a coordinate from the original space to the [0, 65535] normalized space. + + +def orig_coord_to_norm_coord(orig_coord, orig_x_min, orig_x_max, orig_y_min, orig_y_max): + [orig_x, orig_y] = orig_coord + orig_x_range = orig_x_max - orig_x_min + orig_y_range = orig_y_max - orig_y_min + return [ + np.float64(((orig_x - orig_x_min) / orig_x_range) * MORTON_CODE_VALUE_MAX).astype(np.uint32), + np.float64(((orig_y - orig_y_min) / orig_y_range) * MORTON_CODE_VALUE_MAX).astype(np.uint32), + ] + +# -------------------------- +# Quadtree / Z-interval helpers +# -------------------------- + + +def intersects(ax0: int, ay0: int, ax1: int, ay1: int, + bx0: int, by0: int, bx1: int, by1: int) -> bool: + """Axis-aligned box intersection (inclusive integer bounds).""" + return not (ax1 < bx0 or bx1 < ax0 or ay1 < by0 or by1 < ay0) + + +def contained(ix0: int, iy0: int, ix1: int, iy1: int, + ox0: int, oy0: int, ox1: int, oy1: int) -> bool: + """Is inner box entirely inside outer box? (inclusive integer bounds)""" + return (ox0 <= ix0 <= ix1 <= ox1) and (oy0 <= iy0 <= iy1 <= oy1) + + +def point_inside(x: int, y: int, rx0: int, ry0: int, rx1: int, ry1: int) -> bool: + return (rx0 <= x <= rx1) and (ry0 <= y <= ry1) + + +def cell_range(prefix: int, level: int, bits: int) -> Tuple[int, int]: + """ + All Morton codes in a quadtree cell share the same prefix (2*level bits). + Fill the remaining lower bits with 0s (lo) or 1s (hi). + """ + shift = 2 * (bits - level) + lo = prefix << shift + hi = ((prefix + 1) << shift) - 1 + return lo, hi + + +def merge_adjacent(intervals: List[Tuple[int, int]]) -> List[Tuple[int, int]]: + """Merge overlapping or directly adjacent intervals.""" + if not intervals: + return [] + intervals.sort(key=lambda t: t[0]) + merged = [intervals[0]] + for lo, hi in intervals[1:]: + mlo, mhi = merged[-1] + if lo <= mhi + 1: + merged[-1] = (mlo, max(mhi, hi)) + else: + merged.append((lo, hi)) + return merged + +# -------------------------- +# Rectangle -> list of Morton intervals +# -------------------------- + + +def zcover_rectangle(rx0: int, ry0: int, rx1: int, ry1: int, bits: int, stop_level: Optional[int] = None, merge: bool = True) -> List[Tuple[int, int]]: + """ + Compute a (near-)minimal set of Morton code ranges covering the rectangle + [rx0..rx1] x [ry0..ry1] on an integer grid [0..2^bits-1]^2. + + - If stop_level is None: exact cover (descend to exact containment). + - If stop_level is set (0..bits): stop descending at that level, adding + partially-overlapping cells as whole ranges (superset cover). + """ + if not (0 <= rx0 <= rx1 <= (1 << bits) - 1 and 0 <= ry0 <= ry1 <= (1 << bits) - 1): + raise ValueError("Rectangle out of bounds for given bits.") + + intervals: List[Tuple[int, int]] = [] + + # stack entries: (prefix, level, xmin, ymin, xmax, ymax) + stack = [(0, 0, 0, 0, (1 << bits) - 1, (1 << bits) - 1)] + + while stack: + prefix, level, xmin, ymin, xmax, ymax = stack.pop() + + if not intersects(xmin, ymin, xmax, ymax, rx0, ry0, rx1, ry1): + continue + + # If we stop at this level for a loose cover, add full cell range. + if stop_level is not None and level == stop_level: + intervals.append(cell_range(prefix, level, bits)) + continue + + # Fully contained: add full cell range. + if contained(xmin, ymin, xmax, ymax, rx0, ry0, rx1, ry1): + intervals.append(cell_range(prefix, level, bits)) + continue + + # Leaf cell: single lattice point (only happens when level==bits) + if level == bits: + if point_inside(xmin, ymin, rx0, ry0, rx1, ry1): + intervals.append(cell_range(prefix, level, bits)) + continue + + # Otherwise, split into 4 children (Morton order: 00,01,10,11) + midx = (xmin + xmax) // 2 + midy = (ymin + ymax) // 2 + + # q0: (x<=midx, y<=midy) -> child code 0b00 + stack.append(((prefix << 2) | 0, + level + 1, + xmin, ymin, midx, midy)) + # q1: (x>midx, y<=midy) -> child code 0b01 + stack.append(((prefix << 2) | 1, + level + 1, + midx + 1, ymin, xmax, midy)) + # q2: (x<=midx, y>midy) -> child code 0b10 + stack.append(((prefix << 2) | 2, + level + 1, + xmin, midy + 1, midx, ymax)) + # q3: (x>midx, y>midy) -> child code 0b11 + stack.append(((prefix << 2) | 3, + level + 1, + midx + 1, midy + 1, xmax, ymax)) + + return merge_adjacent(intervals) if merge else intervals + + +# -------------------------- +# Morton intervals -> row ranges in a Morton-sorted column +# -------------------------- + +def zquery_rows_aux(morton_sorted: List[int], intervals: List[Tuple[int, int]], merge: bool = True) -> Tuple[List[Tuple[int, int]], List[int]]: + """ + For each Z-interval [zlo, zhi], binary-search in the sorted Morton column + and return row index half-open ranges [i, j) to scan. + """ + + # Keep track of which keys were looked at during the binary searches. + # This is used for analysis / debugging, for instance, to enable + # evaluating how many HTTP requests would be needed in network-based case + # (which will also depend on Arrow row group size). + recorded_keys = [] + + def record_key_check(k: int) -> int: + # TODO: Does recorded_keys need to be marked as a global here? + recorded_keys.append(k) + return k + + ranges: List[Tuple[int, int]] = [] + # TODO: can these multiple binary searches be optimized? + # Since we are doing many searches in the same array, and in each search we learn where more elements are located. + for zlo, zhi in intervals: + i = bisect_left(morton_sorted, zlo, key=record_key_check) + # TODO: use lo=i in bisect_right to limit the search range? + # TODO: can the second binary search be further optimized since we just did a binary search via bisect_left? + j = bisect_right(morton_sorted, zhi, key=record_key_check) + if i < j: + ranges.append((i, j)) + + result = merge_adjacent(ranges) if merge else ranges + return result, recorded_keys + + +def zquery_rows(morton_sorted: List[int], intervals: List[Tuple[int, int]], merge: bool = True) -> List[Tuple[int, int]]: + """ + For each Z-interval [zlo, zhi], binary-search in the sorted Morton column + and return row index half-open ranges [i, j) to scan. + """ + return zquery_rows_aux(morton_sorted, intervals, merge=merge)[0] + + +def row_ranges_to_row_indices(intervals: List[Tuple[int, int]]) -> List[int]: + """ + Convert row ranges [i, j) to a list of row indices. + Then, can index into pandas DataFrame using df.iloc[indices, :] + """ + indices: List[int] = [] + for i, j in intervals: + indices.extend(list(range(i, j))) + return indices + + +# More helper functions. +def sdata_points_process_columns(sdata, element, var_name_col=None, table_name=None) -> dd.DataFrame: + ddf = sdata.points[element] + + if var_name_col is None: + # We can try to get it from the spatialdata_attrs metadata. + var_name_col = sdata.points[element].attrs["spatialdata_attrs"].get("feature_key") + + # Appending codes for dictionary-encoded feature_name column. + if table_name is None and var_name_col is not None: + annotating_tables = get_element_annotators(sdata, element) + if len(annotating_tables) == 1: + table_name = annotating_tables[0] + elif len(annotating_tables) == 0: + raise ValueError(f"No annotating table found for Points element {element}, please specify table_name explicitly.") + else: + raise ValueError(f"Multiple annotating tables found for Points element {element}, please specify table_name explicitly.") + + if var_name_col is not None: + var_df = sdata.tables[table_name].var + var_index = var_df.index.values.tolist() + + def try_index(gene_name): + try: + return var_index.index(gene_name) + except BaseException: + return -1 + ddf[f"{var_name_col}_codes"] = ddf[var_name_col].apply(try_index).astype('int32') + + # Identify dictionary-encoded columns (categorical/string) + orig_columns = ddf.columns.tolist() + dict_encoded_cols = [col for col in orig_columns if pd.api.types.is_categorical_dtype(ddf[col].dtype) or pd.api.types.is_string_dtype(ddf[col].dtype)] + + # Dictionary-encoded columns (i.e., categorical and string) must be stored as the rightmost columns of the dataframe. + ordered_columns = sorted(orig_columns, key=lambda colname: orig_columns.index(colname) if colname not in dict_encoded_cols else len(orig_columns)) + + # Reorder the columns of the dataframe + ddf = ddf[ordered_columns] + + return ddf + + +def sdata_points_write_bounding_box_attrs(sdata, element) -> dd.DataFrame: + ddf = sdata.points[element] + + [x_min, x_max, y_min, y_max] = [ddf["x"].min().compute(), ddf["x"].max().compute(), ddf["y"].min().compute(), ddf["y"].max().compute()] + bounding_box = { + "x_min": float(x_min), + "x_max": float(x_max), + "y_min": float(y_min), + "y_max": float(y_max), + } + + sdata_path = sdata.path + # TODO: error if no path + + # Insert the bounding box as metadata for the sdata.points[element] Points element dataframe. + z = zarr.open(sdata_path, mode='a') + group = z[f'points/{element}'] + group.attrs['bounding_box'] = bounding_box + + # TODO: does anything special need to be done to ensure this is saved to disk? + + +def sdata_points_modify_row_group_size(sdata, element, row_group_size: int = 50_000): + import pyarrow.parquet as pq + + sdata_path = sdata.path + # TODO: error if no path + + # List the parts of the parquet file. + parquet_path = join(sdata_path, "points", element, "points.parquet") + + # Read the number of "part.*.parquet" files on disk. + part_files = [f for f in os.listdir(parquet_path) if f.startswith("part.") and f.endswith(".parquet")] + num_parts = len(part_files) + + # Update the row group size in each .parquet file part. + for i in range(num_parts): + part_path = join(parquet_path, f"part.{i}.parquet") + table_read = pq.read_table(part_path) + + # Write the table to a new Parquet file with the desired row group size. + pq.write_table(table_read, part_path, row_group_size=row_group_size) diff --git a/tests/create_xenium_filtered_points.py b/tests/create_xenium_filtered_points.py new file mode 100644 index 00000000..b2a94297 --- /dev/null +++ b/tests/create_xenium_filtered_points.py @@ -0,0 +1,51 @@ +import os +from os.path import join, isfile, isdir +from urllib.request import urlretrieve +import zipfile +import shutil + +# Used spatialdata==0.4.0 on October 30, 2025 +from spatialdata import read_zarr, SpatialData + + +data_dir = "data" +zip_filepath = join(data_dir, "xenium_rep1_io.spatialdata.zarr.zip") +spatialdata_filepath = join(data_dir, "xenium_rep1_io.spatialdata.zarr") + + +if not isdir(spatialdata_filepath): + if not isfile(zip_filepath): + os.makedirs(data_dir, exist_ok=True) + urlretrieve('https://s3.embl.de/spatialdata/spatialdata-sandbox/xenium_rep1_io.zip', zip_filepath) + with zipfile.ZipFile(zip_filepath, "r") as zip_ref: + zip_ref.extractall(data_dir) + os.rename(join(data_dir, "data.zarr"), spatialdata_filepath) + + # This Xenium dataset has an AnnData "raw" element. + # Reference: https://github.com/giovp/spatialdata-sandbox/issues/55 + raw_dir = join(spatialdata_filepath, "tables", "table", "raw") + if isdir(raw_dir): + shutil.rmtree(raw_dir) + +sdata = read_zarr(spatialdata_filepath) + +ddf = sdata.points["transcripts"] + +# 2. Define a function to take every 100th row from a partition + + +def select_every_200th(partition): + # Each 'partition' is a Pandas DataFrame + # .iloc[::100] is the efficient pandas way to get every 100th row + return partition.iloc[::200] + + +# 3. Apply this function to every partition in the Dask DataFrame +result = ddf.map_partitions(select_every_200th) + +# 4. Compute the result to see it +filtered_ddf = result[["x", "y", "z", "feature_name", "cell_id"]] + +small_sdata = SpatialData(points={"transcripts": filtered_ddf}) + +small_sdata.write("xenium_rep1_io.points_only.spatialdata.zarr", overwrite=True) diff --git a/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/.zattrs b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/.zattrs new file mode 100644 index 00000000..058b792f --- /dev/null +++ b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/.zattrs @@ -0,0 +1,6 @@ +{ + "spatialdata_attrs": { + "spatialdata_software_version": "0.4.0", + "version": "0.1" + } +} \ No newline at end of file diff --git a/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/.zgroup b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/.zgroup new file mode 100644 index 00000000..3b7daf22 --- /dev/null +++ b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} \ No newline at end of file diff --git a/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/.zgroup b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/.zgroup new file mode 100644 index 00000000..3b7daf22 --- /dev/null +++ b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} \ No newline at end of file diff --git a/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/.zattrs b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/.zattrs new file mode 100644 index 00000000..189b0502 --- /dev/null +++ b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/.zattrs @@ -0,0 +1,63 @@ +{ + "axes": [ + "x", + "y", + "z" + ], + "coordinateTransformations": [ + { + "input": { + "axes": [ + { + "name": "x", + "type": "space", + "unit": "unit" + }, + { + "name": "y", + "type": "space", + "unit": "unit" + }, + { + "name": "z", + "type": "space", + "unit": "unit" + } + ], + "name": "xyz" + }, + "output": { + "axes": [ + { + "name": "x", + "type": "space", + "unit": "unit" + }, + { + "name": "y", + "type": "space", + "unit": "unit" + }, + { + "name": "z", + "type": "space", + "unit": "unit" + } + ], + "name": "global" + }, + "scale": [ + 4.705882352941177, + 4.705882352941177, + 1.0 + ], + "type": "scale" + } + ], + "encoding-type": "ngff:points", + "spatialdata_attrs": { + "feature_key": "feature_name", + "instance_key": "cell_id", + "version": "0.1" + } +} \ No newline at end of file diff --git a/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/.zgroup b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/.zgroup new file mode 100644 index 00000000..3b7daf22 --- /dev/null +++ b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/.zgroup @@ -0,0 +1,3 @@ +{ + "zarr_format": 2 +} \ No newline at end of file diff --git a/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.0.parquet b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.0.parquet new file mode 100644 index 00000000..e04eef85 Binary files /dev/null and b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.0.parquet differ diff --git a/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.1.parquet b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.1.parquet new file mode 100644 index 00000000..8c9d61af Binary files /dev/null and b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.1.parquet differ diff --git a/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.2.parquet b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.2.parquet new file mode 100644 index 00000000..ce0b0c09 Binary files /dev/null and b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.2.parquet differ diff --git a/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.3.parquet b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.3.parquet new file mode 100644 index 00000000..00623083 Binary files /dev/null and b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.3.parquet differ diff --git a/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.4.parquet b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.4.parquet new file mode 100644 index 00000000..73433b0f Binary files /dev/null and b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.4.parquet differ diff --git a/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.5.parquet b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.5.parquet new file mode 100644 index 00000000..89a9a105 Binary files /dev/null and b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.5.parquet differ diff --git a/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.6.parquet b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.6.parquet new file mode 100644 index 00000000..7eddaaff Binary files /dev/null and b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.6.parquet differ diff --git a/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.7.parquet b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.7.parquet new file mode 100644 index 00000000..832f3470 Binary files /dev/null and b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/points/transcripts/points.parquet/part.7.parquet differ diff --git a/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/zmetadata b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/zmetadata new file mode 100644 index 00000000..658cac3a --- /dev/null +++ b/tests/data/xenium_rep1_io.points_only.spatialdata.zarr/zmetadata @@ -0,0 +1,83 @@ +{ + "metadata": { + ".zattrs": { + "spatialdata_attrs": { + "spatialdata_software_version": "0.4.0", + "version": "0.1" + } + }, + ".zgroup": { + "zarr_format": 2 + }, + "points/.zgroup": { + "zarr_format": 2 + }, + "points/transcripts/.zattrs": { + "axes": [ + "x", + "y", + "z" + ], + "coordinateTransformations": [ + { + "input": { + "axes": [ + { + "name": "x", + "type": "space", + "unit": "unit" + }, + { + "name": "y", + "type": "space", + "unit": "unit" + }, + { + "name": "z", + "type": "space", + "unit": "unit" + } + ], + "name": "xyz" + }, + "output": { + "axes": [ + { + "name": "x", + "type": "space", + "unit": "unit" + }, + { + "name": "y", + "type": "space", + "unit": "unit" + }, + { + "name": "z", + "type": "space", + "unit": "unit" + } + ], + "name": "global" + }, + "scale": [ + 4.705882352941177, + 4.705882352941177, + 1.0 + ], + "type": "scale" + } + ], + "encoding-type": "ngff:points", + "spatialdata_attrs": { + "feature_key": "feature_name", + "instance_key": "cell_id", + "version": "0.1" + } + }, + "points/transcripts/.zgroup": { + "zarr_format": 2 + } + }, + "zarr_consolidated_format": 1 +} \ No newline at end of file diff --git a/tests/test_sdata_points_zorder.py b/tests/test_sdata_points_zorder.py new file mode 100644 index 00000000..783569de --- /dev/null +++ b/tests/test_sdata_points_zorder.py @@ -0,0 +1,183 @@ +import pytest +from pathlib import Path +from spatialdata import read_zarr + +from vitessce.data_utils.spatialdata_points_zorder import ( + # Function for computing codes and sorting + sdata_morton_sort_points, + # Functions for querying + sdata_morton_query_rect_debug, + row_ranges_to_row_indices, + orig_coord_to_norm_coord, +) + + +def _is_sorted(arr): + return all(arr[i] <= arr[i + 1] for i in range(len(arr) - 1)) + + +data_path = Path('tests/data') + + +@pytest.fixture +def sdata_with_points(): + sdata = read_zarr(data_path / "xenium_rep1_io.points_only.spatialdata.zarr") + return sdata + + +def test_zorder_sorting(sdata_with_points): + sdata = sdata_with_points + + sdata_morton_sort_points(sdata, "transcripts") + + # Check that the morton codes are sorted + sorted_ddf = sdata.points["transcripts"] + morton_sorted = sorted_ddf["morton_code_2d"].compute().values.tolist() + + assert _is_sorted(morton_sorted) + + +def test_zorder_query(sdata_with_points): + sdata = sdata_with_points + + sdata_morton_sort_points(sdata, "transcripts") + + # Query a rectangle that should return some points + orig_rect = [[50.0, 50.0], [100.0, 150.0]] # x0, y0, x1, y1 + matching_row_ranges, rows_checked = sdata_morton_query_rect_debug(sdata, "transcripts", orig_rect) + rect_row_indices = row_ranges_to_row_indices(matching_row_ranges) + + # Cannot use df.iloc on a dask dataframe, so convert it to pandas first + ddf = sdata.points["transcripts"] + df = ddf.compute() + df = df.reset_index(drop=True) + estimated_row_indices = df.iloc[rect_row_indices].index.tolist() + + assert df.shape[0] == 213191 + + # Do the same query the "dumb" way, by checking all points + + # We need an epsilon for the "dumb" query since the normalization + # introduces rounding issues. We can instead verify that a slightly + # smaller rectangle is fully contained in the morton code query + # estimated results. + EXACT_BOUNDARY_EPSILON = 1 + + in_rect = ( + (df["x"] >= orig_rect[0][0] + EXACT_BOUNDARY_EPSILON) + & (df["x"] <= orig_rect[1][0] - EXACT_BOUNDARY_EPSILON) + & (df["y"] >= orig_rect[0][1] + EXACT_BOUNDARY_EPSILON) + & (df["y"] <= orig_rect[1][1] - EXACT_BOUNDARY_EPSILON) + ) + dumb_df_subset = df.loc[in_rect] + # Get the row indices of the points in the rectangle + # (these are the indices in the original dataframe) + exact_row_indices = dumb_df_subset.index.tolist() + + # Check that the estimated rows 100% contain the exact rows. + # A.issubset(B) checks that all elements of A are in B ("A is a subset of B"). + assert set(exact_row_indices).issubset(set(estimated_row_indices)) + assert len(exact_row_indices) == 4 + assert len(estimated_row_indices) <= 4 + + # Check that the number of rows checked is less than the total number of points + assert len(rows_checked) <= 19858 + assert len(matching_row_ranges) == 2 # Kind of an implementation detail. + + # Do a second check, this time against x_uint/y_uint (the normalized coordinates) + # TODO: does this ensure that estimated == exact? + + bounding_box = ddf.attrs["bounding_box"] + x_min = bounding_box["x_min"] + x_max = bounding_box["x_max"] + y_min = bounding_box["y_min"] + y_max = bounding_box["y_max"] + norm_rect = [ + orig_coord_to_norm_coord(orig_rect[0], orig_x_min=x_min, orig_x_max=x_max, orig_y_min=y_min, orig_y_max=y_max), + orig_coord_to_norm_coord(orig_rect[1], orig_x_min=x_min, orig_x_max=x_max, orig_y_min=y_min, orig_y_max=y_max) + ] + + in_rect_norm = ( + (df["x_uint"] >= norm_rect[0][0]) + & (df["x_uint"] <= norm_rect[1][0]) + & (df["y_uint"] >= norm_rect[0][1]) + & (df["y_uint"] <= norm_rect[1][1]) + ) + dumb_df_subset_norm = df.loc[in_rect_norm] + # Get the row indices of the points in the rectangle + # (these are the indices in the original dataframe) + exact_row_indices_norm = dumb_df_subset_norm.index.tolist() + + # A.issubset(B) + # True if A is a subset of B and False otherwise. + assert set(exact_row_indices_norm).issubset(set(estimated_row_indices)) + + assert len(exact_row_indices_norm) == 4 + assert len(estimated_row_indices) <= 4 + + # ========= Another query ========== + orig_rect = [[500.0, 500.0], [600.0, 600.0]] # x0, y0, x1, y1 + + # Query using z-order + matching_row_ranges, rows_checked = sdata_morton_query_rect_debug(sdata, "transcripts", orig_rect) + rect_row_indices = row_ranges_to_row_indices(matching_row_ranges) + estimated_row_indices = df.iloc[rect_row_indices].index.tolist() + + # Do the same query the "dumb" way, by checking all points + in_rect = ( + (df["x"] >= orig_rect[0][0] + EXACT_BOUNDARY_EPSILON) + & (df["x"] <= orig_rect[1][0] - EXACT_BOUNDARY_EPSILON) + & (df["y"] >= orig_rect[0][1] + EXACT_BOUNDARY_EPSILON) + & (df["y"] <= orig_rect[1][1] - EXACT_BOUNDARY_EPSILON) + ) + dumb_df_subset = df.loc[in_rect] + # Get the row indices of the points in the rectangle + # (these are the indices in the original dataframe) + exact_row_indices = dumb_df_subset.index.tolist() + + # Check that the estimated rows 100% contain the exact rows. + # A.issubset(B) checks that all elements of A are in B ("A is a subset of B"). + assert set(exact_row_indices).issubset(set(estimated_row_indices)) + assert len(exact_row_indices) == 85 + assert len(estimated_row_indices) <= 95 + + # Check that the number of rows checked is less than the total number of points + assert len(rows_checked) <= 71675 + assert len(matching_row_ranges) == 13 # Kind of an implementation detail. + + # Do the same query the "dumb" way, by checking all points + in_rect = ( + (df["x"] >= orig_rect[0][0] + EXACT_BOUNDARY_EPSILON) + & (df["x"] <= orig_rect[1][0] - EXACT_BOUNDARY_EPSILON) + & (df["y"] >= orig_rect[0][1] + EXACT_BOUNDARY_EPSILON) + & (df["y"] <= orig_rect[1][1] - EXACT_BOUNDARY_EPSILON) + ) + dumb_df_subset = df.loc[in_rect] + # Get the row indices of the points in the rectangle + # (these are the indices in the original dataframe) + exact_row_indices = dumb_df_subset.index.tolist() + + # Query 2: Do a second check, this time against x_uint/y_uint (the normalized coordinates) + norm_rect = [ + orig_coord_to_norm_coord(orig_rect[0], orig_x_min=x_min, orig_x_max=x_max, orig_y_min=y_min, orig_y_max=y_max), + orig_coord_to_norm_coord(orig_rect[1], orig_x_min=x_min, orig_x_max=x_max, orig_y_min=y_min, orig_y_max=y_max) + ] + + in_rect_norm = ( + (df["x_uint"] >= norm_rect[0][0]) + & (df["x_uint"] <= norm_rect[1][0]) + & (df["y_uint"] >= norm_rect[0][1]) + & (df["y_uint"] <= norm_rect[1][1]) + ) + dumb_df_subset_norm = df.loc[in_rect_norm] + # Get the row indices of the points in the rectangle + # (these are the indices in the original dataframe) + exact_row_indices_norm = dumb_df_subset_norm.index.tolist() + + # A.issubset(B) + # True if A is a subset of B and False otherwise. + assert set(exact_row_indices_norm).issubset(set(estimated_row_indices)) + + # Check that the estimated rows contain all of the exact rows. + assert len(exact_row_indices_norm) == 91 + assert len(estimated_row_indices) <= 95