Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pysteps/tests/test_verification_probscores.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,14 @@ def test_ROC_curve_area(X_f, X_o, X_min, n_prob_thrs, compute_area, expected):
assert_array_almost_equal(
probscores.ROC_curve(P_f, X_o, X_min, n_prob_thrs, compute_area)[2], expected
)


@pytest.mark.parametrize(
"X_f, X_o, X_min, n_prob_thrs, compute_area, expected", test_data
)
def test_PR_curve_area(X_f, X_o, X_min, n_prob_thrs, compute_area, expected):
"""Test the PR_curve."""
P_f = excprob(X_f, X_min, ignore_nan=False)
assert_array_almost_equal(
probscores.PR_curve(P_f, X_o, X_min, n_prob_thrs, compute_area)[2], expected
)
53 changes: 53 additions & 0 deletions pysteps/verification/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,3 +220,56 @@ def plot_ROC(ROC, ax=None, opt_prob_thr=False):

for p_thr_, x, y in zip(p_thr, POFD, POD):
ax.text(x + 0.02, y - 0.02, "%.2f" % p_thr_, fontsize=7)


def plot_PR(PR, ax=None, opt_prob_thr=False):
"""
Plot a Precision-Recall (PR) curve.

Parameters
----------
PR: dict
A PR curve object created by probscores.PR_curve_init.
ax: axis handle, optional
Axis handle for the figure. If set to None, the handle is taken from
the current figure (matplotlib.pylab.gca()).
opt_prob_thr: bool, optional
If set to True, plot the optimal probability threshold that maximizes
the F1 score (harmonic mean of precision and recall).
"""
if ax is None:
ax = plt.gca()

precision, recall, area = probscores.PR_curve_compute(PR, compute_area=True)
p_thr = PR["prob_thrs"]

total_pos = PR["hits"][0] + PR["misses"][0]
total_neg = PR["false_alarms"][0] + PR["corr_neg"][0]
prevalence = total_pos / (total_pos + total_neg) if (total_pos + total_neg) > 0 else 0

ax.plot([0, 1], [prevalence, prevalence], "k--")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_xlabel("Recall (POD)")
ax.set_ylabel("Precision")
ax.grid(True, ls=":")

ax.plot(recall, precision, "kD-")

if opt_prob_thr:
p = np.array(precision)
r = np.array(recall)
f1 = np.divide(2 * p * r, p + r, out=np.zeros_like(p), where=(p + r) != 0)

opt_idx = np.argmax(f1)
ax.scatter(
[recall[opt_idx]],
[precision[opt_idx]],
c="r",
s=150,
facecolors="none",
edgecolors="r",
)

for p_thr_, x, y in zip(p_thr, recall, precision):
ax.text(x + 0.02, y + 0.02, "%.2f" % p_thr_, fontsize=7)
169 changes: 159 additions & 10 deletions pysteps/verification/probscores.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
ROC_curve_init
ROC_curve_accum
ROC_curve_compute
PR_curve
PR_curve_init
PR_curve_accum
PR_curve_compute
"""

import numpy as np
Expand Down Expand Up @@ -350,7 +354,8 @@ def ROC_curve_init(X_min, n_prob_thrs=10):


def ROC_curve_accum(ROC, P_f, X_o):
"""Accumulate the given probability-observation pairs into the given ROC
"""
Accumulate the given probability-observation pairs into the given ROC
object.

Parameters
Expand All @@ -367,16 +372,17 @@ def ROC_curve_accum(ROC, P_f, X_o):

P_f = P_f[mask]
X_o = X_o[mask]

obs_yes = X_o >= ROC["X_min"]
obs_no = ~obs_yes

for i, p in enumerate(ROC["prob_thrs"]):
mask = np.logical_and(P_f >= p, X_o >= ROC["X_min"])
ROC["hits"][i] += np.sum(mask.astype(int))
mask = np.logical_and(P_f < p, X_o >= ROC["X_min"])
ROC["misses"][i] += np.sum(mask.astype(int))
mask = np.logical_and(P_f >= p, X_o < ROC["X_min"])
ROC["false_alarms"][i] += np.sum(mask.astype(int))
mask = np.logical_and(P_f < p, X_o < ROC["X_min"])
ROC["corr_neg"][i] += np.sum(mask.astype(int))
forecast_yes = P_f >= p
forecast_no = ~forecast_yes

ROC["hits"][i] += np.sum(np.logical_and(forecast_yes, obs_yes))
ROC["misses"][i] += np.sum(np.logical_and(forecast_no, obs_yes))
ROC["false_alarms"][i] += np.sum(np.logical_and(forecast_yes, obs_no))
ROC["corr_neg"][i] += np.sum(np.logical_and(forecast_no, obs_no))


def ROC_curve_compute(ROC, compute_area=False):
Expand Down Expand Up @@ -421,3 +427,146 @@ def ROC_curve_compute(ROC, compute_area=False):
return POFD_vals, POD_vals, area
else:
return POFD_vals, POD_vals


def PR_curve(P_f, X_o, X_min, n_prob_thrs=10, compute_area=False):
"""
Compute the Precision–Recall (PR) curve and optionally its area.

Parameters
----------
P_f : array_like
Forecasted probabilities for exceeding the threshold X_min.
Non-finite values are ignored.
X_o : array_like
Observed values. Non-finite values are ignored.
X_min : float
Precipitation intensity threshold for yes/no prediction.
n_prob_thrs : int, optional
Number of probability thresholds to evaluate.
The interval [0, 1] is divided into n_prob_thrs evenly spaced values.
compute_area : bool, optional
If True, compute the area under the PR curve using trapezoidal integration.

Returns
-------
out : tuple
(precision_vals, recall_vals) for each probability threshold.
If compute_area is True, return (precision_vals, recall_vals, area),
where area is the trapezoidal estimate of the PR curve area.
"""
P_f = P_f.copy()
X_o = X_o.copy()
pr = PR_curve_init(X_min, n_prob_thrs)
PR_curve_accum(pr, P_f, X_o)
return PR_curve_compute(pr, compute_area)


def PR_curve_init(X_min, n_prob_thrs=10):
"""
Initialize a Precision–Recall curve object.

Parameters
----------
X_min : float
Precipitation intensity threshold for yes/no prediction.
n_prob_thrs : int, optional
Number of probability thresholds to evaluate.

Returns
-------
PR : dict
Dictionary containing counters for hits, misses, false alarms,
correct negatives, and the probability thresholds.
Keys:
- "X_min" : threshold value
- "hits", "misses", "false_alarms", "corr_neg" : arrays of counts
- "prob_thrs" : array of evenly spaced thresholds in [0, 1]
"""
PR = {}
PR["X_min"] = X_min
PR["hits"] = np.zeros(n_prob_thrs, dtype=int)
PR["misses"] = np.zeros(n_prob_thrs, dtype=int)
PR["false_alarms"] = np.zeros(n_prob_thrs, dtype=int)
PR["corr_neg"] = np.zeros(n_prob_thrs, dtype=int)
PR["prob_thrs"] = np.linspace(0.0, 1.0, int(n_prob_thrs))
return PR


def PR_curve_accum(PR, P_f, X_o):
"""
Accumulate forecast–observation pairs into the PR object.

Parameters
----------
PR : dict
A PR curve object created with PR_curve_init.
P_f : array_like
Forecasted probabilities for exceeding X_min.
X_o : array_like
Observed values.
"""
mask = np.logical_and(np.isfinite(P_f), np.isfinite(X_o))
P_f = P_f[mask]
X_o = X_o[mask]
obs_yes = X_o >= PR["X_min"]
obs_no = ~obs_yes

for i, p in enumerate(PR["prob_thrs"]):
forecast_yes = P_f >= p
forecast_no = ~forecast_yes

PR["hits"][i] += np.sum(np.logical_and(forecast_yes, obs_yes))
PR["misses"][i] += np.sum(np.logical_and(forecast_no, obs_yes))
PR["false_alarms"][i] += np.sum(np.logical_and(forecast_yes, obs_no))
PR["corr_neg"][i] += np.sum(np.logical_and(forecast_no, obs_no))


def PR_curve_compute(PR, compute_area=False):
"""
Compute precision and recall values from the PR object.

Parameters
----------
PR : dict
A PR curve object created with PR_curve_init.
compute_area : bool, optional
If True, compute the area under the PR curve.

Returns
-------
out : tuple
(precision_vals, recall_vals) or (precision_vals, recall_vals, area)

Notes
-----
- Precision uses the formula hits / (hits + false alarms). When a
threshold produces no predicted positives, the denominator becomes zero
and precision is undefined. The standard convention is to set precision
to 1.0. It represents a classifier that predicts nothing positive and
keeps the curve anchored at high precision when recall is zero.
- Recall uses the formula hits / (hits + misses). When the dataset
contains no actual positives, recall cannot be computed, so it is set
to 0.0 for all thresholds. In this situation the PR curve does not
carry meaningful information, but the value is kept consistent.
"""
precision_vals = []
recall_vals = []

for i in range(len(PR["prob_thrs"])):
hits = PR["hits"][i]
misses = PR["misses"][i]
false_alarms = PR["false_alarms"][i]

recall = hits / (hits + misses) if (hits + misses) > 0 else 0.0
precision = hits / (hits + false_alarms) if (hits + false_alarms) > 0 else 1.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the convention using if (hits + false_alarms) > 0 else 1.0 should be documented in the docstrings as it can lead to some surprising results when thresholds yield zero predicted positives

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you comment on how this is handled in scikit-learn for example?


recall_vals.append(recall)
precision_vals.append(precision)

if compute_area:
recall_sorted, precision_sorted = zip(*sorted(zip(recall_vals, precision_vals)))
area = np.trapz(precision_sorted, recall_sorted)
return precision_vals, recall_vals, area
else:
return precision_vals, recall_vals