diff --git a/pysteps/tests/test_verification_probscores.py b/pysteps/tests/test_verification_probscores.py index c7f9990b8..b13ff091c 100644 --- a/pysteps/tests/test_verification_probscores.py +++ b/pysteps/tests/test_verification_probscores.py @@ -46,3 +46,14 @@ def test_ROC_curve_area(X_f, X_o, X_min, n_prob_thrs, compute_area, expected): assert_array_almost_equal( probscores.ROC_curve(P_f, X_o, X_min, n_prob_thrs, compute_area)[2], expected ) + + +@pytest.mark.parametrize( + "X_f, X_o, X_min, n_prob_thrs, compute_area, expected", test_data +) +def test_PR_curve_area(X_f, X_o, X_min, n_prob_thrs, compute_area, expected): + """Test the PR_curve.""" + P_f = excprob(X_f, X_min, ignore_nan=False) + assert_array_almost_equal( + probscores.PR_curve(P_f, X_o, X_min, n_prob_thrs, compute_area)[2], expected + ) diff --git a/pysteps/verification/plots.py b/pysteps/verification/plots.py index 21974b013..6fa9e9664 100755 --- a/pysteps/verification/plots.py +++ b/pysteps/verification/plots.py @@ -220,3 +220,56 @@ def plot_ROC(ROC, ax=None, opt_prob_thr=False): for p_thr_, x, y in zip(p_thr, POFD, POD): ax.text(x + 0.02, y - 0.02, "%.2f" % p_thr_, fontsize=7) + + +def plot_PR(PR, ax=None, opt_prob_thr=False): + """ + Plot a Precision-Recall (PR) curve. + + Parameters + ---------- + PR: dict + A PR curve object created by probscores.PR_curve_init. + ax: axis handle, optional + Axis handle for the figure. If set to None, the handle is taken from + the current figure (matplotlib.pylab.gca()). + opt_prob_thr: bool, optional + If set to True, plot the optimal probability threshold that maximizes + the F1 score (harmonic mean of precision and recall). + """ + if ax is None: + ax = plt.gca() + + precision, recall, area = probscores.PR_curve_compute(PR, compute_area=True) + p_thr = PR["prob_thrs"] + + total_pos = PR["hits"][0] + PR["misses"][0] + total_neg = PR["false_alarms"][0] + PR["corr_neg"][0] + prevalence = total_pos / (total_pos + total_neg) if (total_pos + total_neg) > 0 else 0 + + ax.plot([0, 1], [prevalence, prevalence], "k--") + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_xlabel("Recall (POD)") + ax.set_ylabel("Precision") + ax.grid(True, ls=":") + + ax.plot(recall, precision, "kD-") + + if opt_prob_thr: + p = np.array(precision) + r = np.array(recall) + f1 = np.divide(2 * p * r, p + r, out=np.zeros_like(p), where=(p + r) != 0) + + opt_idx = np.argmax(f1) + ax.scatter( + [recall[opt_idx]], + [precision[opt_idx]], + c="r", + s=150, + facecolors="none", + edgecolors="r", + ) + + for p_thr_, x, y in zip(p_thr, recall, precision): + ax.text(x + 0.02, y + 0.02, "%.2f" % p_thr_, fontsize=7) diff --git a/pysteps/verification/probscores.py b/pysteps/verification/probscores.py index f1eb398ba..dd6376965 100644 --- a/pysteps/verification/probscores.py +++ b/pysteps/verification/probscores.py @@ -20,6 +20,10 @@ ROC_curve_init ROC_curve_accum ROC_curve_compute + PR_curve + PR_curve_init + PR_curve_accum + PR_curve_compute """ import numpy as np @@ -350,7 +354,8 @@ def ROC_curve_init(X_min, n_prob_thrs=10): def ROC_curve_accum(ROC, P_f, X_o): - """Accumulate the given probability-observation pairs into the given ROC + """ + Accumulate the given probability-observation pairs into the given ROC object. Parameters @@ -367,16 +372,17 @@ def ROC_curve_accum(ROC, P_f, X_o): P_f = P_f[mask] X_o = X_o[mask] - + obs_yes = X_o >= ROC["X_min"] + obs_no = ~obs_yes + for i, p in enumerate(ROC["prob_thrs"]): - mask = np.logical_and(P_f >= p, X_o >= ROC["X_min"]) - ROC["hits"][i] += np.sum(mask.astype(int)) - mask = np.logical_and(P_f < p, X_o >= ROC["X_min"]) - ROC["misses"][i] += np.sum(mask.astype(int)) - mask = np.logical_and(P_f >= p, X_o < ROC["X_min"]) - ROC["false_alarms"][i] += np.sum(mask.astype(int)) - mask = np.logical_and(P_f < p, X_o < ROC["X_min"]) - ROC["corr_neg"][i] += np.sum(mask.astype(int)) + forecast_yes = P_f >= p + forecast_no = ~forecast_yes + + ROC["hits"][i] += np.sum(np.logical_and(forecast_yes, obs_yes)) + ROC["misses"][i] += np.sum(np.logical_and(forecast_no, obs_yes)) + ROC["false_alarms"][i] += np.sum(np.logical_and(forecast_yes, obs_no)) + ROC["corr_neg"][i] += np.sum(np.logical_and(forecast_no, obs_no)) def ROC_curve_compute(ROC, compute_area=False): @@ -421,3 +427,146 @@ def ROC_curve_compute(ROC, compute_area=False): return POFD_vals, POD_vals, area else: return POFD_vals, POD_vals + + +def PR_curve(P_f, X_o, X_min, n_prob_thrs=10, compute_area=False): + """ + Compute the Precision–Recall (PR) curve and optionally its area. + + Parameters + ---------- + P_f : array_like + Forecasted probabilities for exceeding the threshold X_min. + Non-finite values are ignored. + X_o : array_like + Observed values. Non-finite values are ignored. + X_min : float + Precipitation intensity threshold for yes/no prediction. + n_prob_thrs : int, optional + Number of probability thresholds to evaluate. + The interval [0, 1] is divided into n_prob_thrs evenly spaced values. + compute_area : bool, optional + If True, compute the area under the PR curve using trapezoidal integration. + + Returns + ------- + out : tuple + (precision_vals, recall_vals) for each probability threshold. + If compute_area is True, return (precision_vals, recall_vals, area), + where area is the trapezoidal estimate of the PR curve area. + """ + P_f = P_f.copy() + X_o = X_o.copy() + pr = PR_curve_init(X_min, n_prob_thrs) + PR_curve_accum(pr, P_f, X_o) + return PR_curve_compute(pr, compute_area) + + +def PR_curve_init(X_min, n_prob_thrs=10): + """ + Initialize a Precision–Recall curve object. + + Parameters + ---------- + X_min : float + Precipitation intensity threshold for yes/no prediction. + n_prob_thrs : int, optional + Number of probability thresholds to evaluate. + + Returns + ------- + PR : dict + Dictionary containing counters for hits, misses, false alarms, + correct negatives, and the probability thresholds. + Keys: + - "X_min" : threshold value + - "hits", "misses", "false_alarms", "corr_neg" : arrays of counts + - "prob_thrs" : array of evenly spaced thresholds in [0, 1] + """ + PR = {} + PR["X_min"] = X_min + PR["hits"] = np.zeros(n_prob_thrs, dtype=int) + PR["misses"] = np.zeros(n_prob_thrs, dtype=int) + PR["false_alarms"] = np.zeros(n_prob_thrs, dtype=int) + PR["corr_neg"] = np.zeros(n_prob_thrs, dtype=int) + PR["prob_thrs"] = np.linspace(0.0, 1.0, int(n_prob_thrs)) + return PR + + +def PR_curve_accum(PR, P_f, X_o): + """ + Accumulate forecast–observation pairs into the PR object. + + Parameters + ---------- + PR : dict + A PR curve object created with PR_curve_init. + P_f : array_like + Forecasted probabilities for exceeding X_min. + X_o : array_like + Observed values. + """ + mask = np.logical_and(np.isfinite(P_f), np.isfinite(X_o)) + P_f = P_f[mask] + X_o = X_o[mask] + obs_yes = X_o >= PR["X_min"] + obs_no = ~obs_yes + + for i, p in enumerate(PR["prob_thrs"]): + forecast_yes = P_f >= p + forecast_no = ~forecast_yes + + PR["hits"][i] += np.sum(np.logical_and(forecast_yes, obs_yes)) + PR["misses"][i] += np.sum(np.logical_and(forecast_no, obs_yes)) + PR["false_alarms"][i] += np.sum(np.logical_and(forecast_yes, obs_no)) + PR["corr_neg"][i] += np.sum(np.logical_and(forecast_no, obs_no)) + + +def PR_curve_compute(PR, compute_area=False): + """ + Compute precision and recall values from the PR object. + + Parameters + ---------- + PR : dict + A PR curve object created with PR_curve_init. + compute_area : bool, optional + If True, compute the area under the PR curve. + + Returns + ------- + out : tuple + (precision_vals, recall_vals) or (precision_vals, recall_vals, area) + + Notes + ----- + - Precision uses the formula hits / (hits + false alarms). When a + threshold produces no predicted positives, the denominator becomes zero + and precision is undefined. The standard convention is to set precision + to 1.0. It represents a classifier that predicts nothing positive and + keeps the curve anchored at high precision when recall is zero. + - Recall uses the formula hits / (hits + misses). When the dataset + contains no actual positives, recall cannot be computed, so it is set + to 0.0 for all thresholds. In this situation the PR curve does not + carry meaningful information, but the value is kept consistent. + """ + precision_vals = [] + recall_vals = [] + + for i in range(len(PR["prob_thrs"])): + hits = PR["hits"][i] + misses = PR["misses"][i] + false_alarms = PR["false_alarms"][i] + + recall = hits / (hits + misses) if (hits + misses) > 0 else 0.0 + precision = hits / (hits + false_alarms) if (hits + false_alarms) > 0 else 1.0 + + recall_vals.append(recall) + precision_vals.append(precision) + + if compute_area: + recall_sorted, precision_sorted = zip(*sorted(zip(recall_vals, precision_vals))) + area = np.trapz(precision_sorted, recall_sorted) + return precision_vals, recall_vals, area + else: + return precision_vals, recall_vals \ No newline at end of file