From 28cb69b3784fbced0cfed1bceab3183c84154fa7 Mon Sep 17 00:00:00 2001 From: "google-labs-jules[bot]" <161369871+google-labs-jules[bot]@users.noreply.github.com> Date: Sun, 28 Dec 2025 08:15:34 +0000 Subject: [PATCH 1/2] feat: Ensure consistent order and transparency in calibration plots This commit addresses two issues in the calibration plots: 1. The order of reference groups and colors is now consistent with the input order. This is achieved by deriving the reference groups from the keys of the input `probs` dictionary, which preserves the intended order. 2. The Plotly figures are now fully transparent. This is achieved by setting both `plot_bgcolor` and `paper_bgcolor` to `rgba(0, 0, 0, 0)`. --- src/rtichoke/calibration/calibration.py | 33 ++++++++++++++----- src/rtichoke/processing/exported_functions.py | 1 + .../processing/plotly_helper_functions.py | 32 +++++++++--------- src/rtichoke/processing/transforms.py | 16 ++++----- 4 files changed, 51 insertions(+), 31 deletions(-) diff --git a/src/rtichoke/calibration/calibration.py b/src/rtichoke/calibration/calibration.py index 0e03df0..567c05c 100644 --- a/src/rtichoke/calibration/calibration.py +++ b/src/rtichoke/calibration/calibration.py @@ -2,7 +2,7 @@ A module for Calibration Curves """ -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, cast # import pandas as pd import plotly.graph_objects as go @@ -247,6 +247,7 @@ def _create_plotly_curve_from_calibration_curve_list_times( }, barmode="overlay", plot_bgcolor="rgba(0, 0, 0, 0)", + paper_bgcolor="rgba(0, 0, 0, 0)", legend={ "orientation": "h", "xanchor": "center", @@ -285,6 +286,7 @@ def _create_plotly_curve_from_calibration_curve_list( "yaxis": {"showgrid": False}, "barmode": "overlay", "plot_bgcolor": "rgba(0, 0, 0, 0)", + "paper_bgcolor": "rgba(0, 0, 0, 0)", "legend": { "orientation": "h", "xanchor": "center", @@ -470,7 +472,7 @@ def _make_deciles_dat_binary( if isinstance(reals, dict): reference_groups_keys = list(reals.keys()) y_list = [ - np.asarray(reals[reference_group]).ravel() + np.asarray(reals[str(reference_group)]).ravel() for reference_group in reference_groups_keys ] lengths = np.array([len(y) for y in y_list], dtype=np.int64) @@ -533,7 +535,7 @@ def _make_deciles_dat_binary( ( (pl.col("prob").rank("ordinal").over(["reference_group", "model"]) - 1) * n_bins - // pl.count().over(["reference_group", "model"]) + // pl.len().over(["reference_group", "model"]) + 1 ).alias("decile"), ] @@ -602,7 +604,7 @@ def _create_calibration_curve_list( reference_data = _create_reference_data_for_calibration_curve() - reference_groups = deciles_data["reference_group"].unique().to_list() + reference_groups = list(probs.keys()) colors_dictionary = _create_colors_dictionary_for_calibration( reference_groups, color_values, performance_type @@ -689,7 +691,9 @@ def process_single_array(p, r, group_name): for group_name in reals.keys(): if group_name in probs: frame = process_single_array( - probs[group_name], reals[group_name], group_name + probs[str(group_name)], + reals[str(group_name)], + str(group_name), ) smooth_frames.append(frame) @@ -856,8 +860,21 @@ def _define_limits_for_calibration_plot(deciles_dat: pl.DataFrame) -> List[float if deciles_dat.height == 1: lower_bound, upper_bound = 0.0, 1.0 else: - lower_bound = float(max(0, min(deciles_dat["x"].min(), deciles_dat["y"].min()))) - upper_bound = float(max(deciles_dat["x"].max(), deciles_dat["y"].max())) + lower_bound = float( + max( + 0, + min( + cast(float, deciles_dat["x"].min()), + cast(float, deciles_dat["y"].min()), + ), + ) + ) + upper_bound = float( + max( + cast(float, deciles_dat["x"].max()), + cast(float, deciles_dat["y"].max()), + ) + ) return [ lower_bound - (upper_bound - lower_bound) * 0.05, @@ -1101,7 +1118,7 @@ def _create_calibration_curve_list_times( ) reference_data = _create_reference_data_for_calibration_curve() - reference_groups = deciles_dat_final["reference_group"].unique().to_list() + reference_groups = list(probs.keys()) colors_dictionary = _create_colors_dictionary_for_calibration( reference_groups, color_values, performance_type ) diff --git a/src/rtichoke/processing/exported_functions.py b/src/rtichoke/processing/exported_functions.py index 778ad91..f9aaddb 100644 --- a/src/rtichoke/processing/exported_functions.py +++ b/src/rtichoke/processing/exported_functions.py @@ -148,6 +148,7 @@ def create_plotly_curve(rtichoke_curve_dict): "y": 0, "steps": [], } + sliders_dict["steps"] = [] for k in range( len( diff --git a/src/rtichoke/processing/plotly_helper_functions.py b/src/rtichoke/processing/plotly_helper_functions.py index 074fc52..07b6bc2 100644 --- a/src/rtichoke/processing/plotly_helper_functions.py +++ b/src/rtichoke/processing/plotly_helper_functions.py @@ -5,7 +5,7 @@ import plotly.graph_objects as go import polars as pl import math -from typing import Any, Dict, Union, Sequence +from typing import Any, Dict, Union, Sequence, cast import numpy as np from rtichoke.performance_data.performance_data import prepare_performance_data from rtichoke.performance_data.performance_data_times import ( @@ -329,8 +329,8 @@ def _create_reference_lines_data( # random-guess (y=1 unless all p==0 -> NaN) all_zero = ( aj_df["p"].len() > 0 - and float(aj_df["p"].max()) == 0.0 - and float(aj_df["p"].min()) == 0.0 + and float(cast(float, aj_df["p"].max())) == 0.0 + and float(cast(float, aj_df["p"].min())) == 0.0 ) rand_y = pl.Series( np.full(len(x_s), np.nan) if all_zero else np.ones(len(x_s)), @@ -992,7 +992,7 @@ def _check_if_multiple_populations_are_being_validated_times( ] .max() ) - return max_val is not None and max_val > 1 + return max_val is not None and float(cast(float, max_val)) > 1 def _check_if_multiple_populations_are_being_validated( @@ -1977,10 +1977,21 @@ def _create_curve_layout( "b": max(80, base_pad.get("b", 0)), **base_pad, } + xaxis: dict[str, Any] = {"showgrid": False} + yaxis: dict[str, Any] = {"showgrid": False} + + if axes_ranges is not None: + xaxis["range"] = axes_ranges["xaxis"] + yaxis["range"] = axes_ranges["yaxis"] + + if x_label: + xaxis["title"] = {"text": x_label} + if y_label: + yaxis["title"] = {"text": y_label} curve_layout = { - "xaxis": {"showgrid": False}, - "yaxis": {"showgrid": False}, + "xaxis": xaxis, + "yaxis": yaxis, "template": "plotly", "plot_bgcolor": "rgba(0, 0, 0, 0)", "paper_bgcolor": "rgba(0, 0, 0, 0)", @@ -2014,15 +2025,6 @@ def _create_curve_layout( "modebar": {"remove": list(DEFAULT_MODEBAR_BUTTONS_TO_REMOVE)}, } - if axes_ranges is not None: - curve_layout["xaxis"]["range"] = axes_ranges["xaxis"] - curve_layout["yaxis"]["range"] = axes_ranges["yaxis"] - - if x_label: - curve_layout["xaxis"]["title"] = {"text": x_label} - if y_label: - curve_layout["yaxis"]["title"] = {"text": y_label} - return curve_layout diff --git a/src/rtichoke/processing/transforms.py b/src/rtichoke/processing/transforms.py index 4e4339a..ba162ac 100644 --- a/src/rtichoke/processing/transforms.py +++ b/src/rtichoke/processing/transforms.py @@ -67,13 +67,13 @@ def transform_group(group: pl.DataFrame, by: float) -> pl.DataFrame: def pivot_longer_strata(data: pl.DataFrame) -> pl.DataFrame: # Identify id_vars and value_vars - id_vars = [col for col in data.columns if not col.startswith("strata_")] - value_vars = [col for col in data.columns if col.startswith("strata_")] + index_cols = [col for col in data.columns if not col.startswith("strata_")] + on_cols = [col for col in data.columns if col.startswith("strata_")] - # Perform the melt (equivalent to pandas.melt) - data_long = data.melt( - id_vars=id_vars, - value_vars=value_vars, + # Perform the unpivot (equivalent to pandas.melt) + data_long = data.unpivot( + index=index_cols, + on=on_cols, variable_name="stratified_by", value_name="strata", ) @@ -257,12 +257,12 @@ def _create_list_data_to_adjust( probs_array = np.asarray(probs_dict[reference_group_labels[0]]) if isinstance(reals_dict, dict): - reals_array = np.asarray(reals_dict[0]) + reals_array = np.asarray(reals_dict[reference_group_labels[0]]) else: reals_array = np.asarray(reals_dict) if isinstance(times_dict, dict): - times_array = np.asarray(times_dict[0]) + times_array = np.asarray(times_dict[reference_group_labels[0]]) else: times_array = np.asarray(times_dict) From 568607535a7d1eccdc0c1a5a8b80d0cb6d7a4830 Mon Sep 17 00:00:00 2001 From: Uriah Finkel Date: Sun, 28 Dec 2025 10:30:39 +0200 Subject: [PATCH 2/2] build: bump version --- pyproject.toml | 2 +- uv.lock | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6a17fef..5645f67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "polars>=1.31.0", ] name = "rtichoke" -version = "0.1.27" +version = "0.1.28" description = "interactive visualizations for performance of predictive models" readme = "README.md" diff --git a/uv.lock b/uv.lock index a345597..971a9da 100644 --- a/uv.lock +++ b/uv.lock @@ -5143,7 +5143,7 @@ wheels = [ [[package]] name = "rtichoke" -version = "0.1.27" +version = "0.1.28" source = { editable = "." } dependencies = [ { name = "marimo", version = "0.17.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" },