diff --git a/scripts/plot_gsi_stat_exp.py b/scripts/plot_gsi_stat_exp.py index be8894e..b280add 100644 --- a/scripts/plot_gsi_stat_exp.py +++ b/scripts/plot_gsi_stat_exp.py @@ -1,218 +1,245 @@ -#!/usr/bin/env python -import argparse -import datetime -import matplotlib -from matplotlib import pyplot as plt -from matplotlib import rcParams, ticker -from matplotlib import gridspec as gspec -import os -import numpy as np -from pyGSI.gsi_stat import GSIstat +#!/usr/bin/env python3 +""" +Plot O-F statistics from GSI stat files for multiple experiments. -it = 'it == 1' -plottype = 'mean' -obtypes = [120, 220] -obvars = ['t', 'uv', 'q'] -levels = [1000, 900, 800, 600, 400, 300, 250, 200, 150, 100, 50, 0] +Modernized for Python 3.11. +""" +from __future__ import annotations -def gen_figure(datadict, datatypestr, stattype, labels, sdate, edate, save, plotdir): - # Line/marker colors for experiments ('k' is the first) - mc = ['k', 'r', 'g', 'b', 'm', 'c', 'y'] +import argparse +from dataclasses import dataclass +from datetime import datetime, timedelta +from pathlib import Path +from typing import Dict, List, Tuple - # set figure params one time only. - rcParams['figure.subplot.left'] = 0.1 - rcParams['figure.subplot.top'] = 0.85 - rcParams['legend.fontsize'] = 12 - rcParams['axes.grid'] = True +import numpy as np +from matplotlib import rcParams, ticker, pyplot as plt +from matplotlib import gridspec as gspec - fig1 = plt.figure(figsize=(10, 8)) +from pyGSI.gsi_stat import GSIstat + + +# ---- Configurable constants ---- +IT_QUERY = "it == 1" +PLOT_TYPE = "mean" +OB_TYPES = [120, 220] +OB_VARS = ["t", "uv", "q"] +LEVELS = np.array([1000, 900, 800, 600, 400, 300, 250, 200, 150, 100, 50, 0]) + + +@dataclass(frozen=True) +class DateRange: + """Represents a start/end date range and produces 6-hourly cycle strings.""" + start: datetime + end: datetime + + def cycles_6h(self) -> Tuple[List[str], List[str]]: + """Return lists of expected gsistat filenames and cycle timestamps.""" + statfiles, cycles = [], [] + cur = self.start + while cur <= self.end: + cyc = cur.strftime("%Y%m%d%H") + statfiles.append(f"gsistat.gdas.{cyc}") + cycles.append(cyc) + cur += timedelta(hours=6) + return statfiles, cycles + + +def _validate_inputs(gsistat_dirs: List[Path], labels: List[str]) -> None: + """Ensure each experiment directory exists and labels align with them.""" + if len(gsistat_dirs) != len(labels): + raise ValueError( + f"Number of --gsistats ({len(gsistat_dirs)}) must match number of --label ({len(labels)})." + ) + missing = [p for p in gsistat_dirs if not p.is_dir()] + if missing: + raise FileNotFoundError(f"The following directories do not exist: {', '.join(map(str, missing))}") + + +def _set_matplotlib_defaults() -> None: + """Apply consistent plot aesthetics for all figures.""" + rcParams["figure.subplot.left"] = 0.1 + rcParams["figure.subplot.top"] = 0.85 + rcParams["legend.fontsize"] = 12 + rcParams["axes.grid"] = True + + +def gen_figure( + datadict: Dict[str, Dict[str, Dict[str, np.ndarray]]], + datatypestr: str, + stattype: str, + labels: List[str], + sdate: datetime, + edate: datetime, + save: bool, + plotdir: Path, +) -> None: + """ + Generate a 1x3 figure of t / uv / q vs pressure (log scale). + + Args: + datadict: Nested dictionary of experiment, stat type, variable arrays. + datatypestr: Label for plot title (e.g., "RMSE", "Bias"). + stattype: Which statistic aggregation to plot ("mean", "aggr", or "sum"). + labels: List of experiment identifiers for the legend. + sdate: Start datetime for annotation. + edate: End datetime for annotation. + save: If True, saves plots as PDF/PNG instead of showing interactively. + plotdir: Directory path where figures are saved if `save` is True. + """ + _set_matplotlib_defaults() + colors = ["k", "r", "g", "b", "m", "c", "y"] + fig = plt.figure(figsize=(10, 8)) plt.subplots_adjust(hspace=0.3) gs = gspec.GridSpec(1, 3) + y_levels = LEVELS[:-1] - for v, var in enumerate(obvars): - xmin = 999 - xmax = 0 + for v, var in enumerate(OB_VARS): ax = plt.subplot(gs[v]) + xmin, xmax = np.inf, -np.inf + for e, expid in enumerate(labels): profile = datadict[expid][stattype][var][:-1] - ax.plot(profile, levels[:-1], marker='o', color=mc[e], - mfc=mc[e], mec=mc[e], label=labels[e]) - if (var in ['q']): - xmin_, xmax_ = np.min(profile[:-1]), np.max(profile[:-1]) - else: - xmin_, xmax_ = np.min(profile), np.max(profile) - if (xmin_ < xmin): - xmin = xmin_ - if (xmax_ > xmax): - xmax = xmax_ - if (v in [0]): - plt.legend(loc=0, numpoints=1) - if (v in [0]): - plt.ylabel('pressure (hPa)') - - if (var == 'uv'): - var_unit = 'm/s' - var_name = 'Winds' - elif (var == 't'): - var_unit = 'K' - var_name = 'Temperature' - elif (var == 'q'): - var_unit = '%' - var_name = 'Relative Humidity' - - if (stattype == 'sum'): - plt.xlabel('count') - else: - plt.xlabel('magnitude (%s)' % var_unit) - - plt.title(var_name, fontsize=14) - plt.ylim(1020, 50) - ax.set_yscale('log') - ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, - subs=np.arange(1, 10))) + c = colors[e % len(colors)] + ax.plot(profile, y_levels, marker="o", color=c, mfc=c, mec=c, label=expid) + + # Track min/max for consistent x-limits + valid = profile[:-1] if var == "q" else profile + pmin, pmax = np.nanmin(valid), np.nanmax(valid) + xmin, xmax = min(xmin, pmin), max(xmax, pmax) + + # Axis labels and titles + if v == 0: + ax.legend(loc=0, numpoints=1) + ax.set_ylabel("pressure (hPa)") + + var_labels = {"uv": ("m/s", "Winds"), "t": ("K", "Temperature"), "q": ("%", "Relative Humidity")} + var_unit, var_name = var_labels.get(var, ("", var)) + + ax.set_title(var_name, fontsize=14) + ax.set_yscale("log") + ax.set_ylim(1020, 50) + ax.yaxis.set_major_locator(ticker.LogLocator(base=10.0, subs=np.arange(1, 10))) ax.yaxis.set_major_formatter(ticker.FormatStrFormatter("%g")) - xmin = xmin - (xmax-xmin)*0.1 - xmax = xmax + (xmax-xmin)*0.1 - plt.xlim(xmin, xmax) - - sdatestr = sdate.strftime('%Y%m%d%H') - edatestr = edate.strftime('%Y%m%d%H') - plt.figtext(0.5, 0.93, '%s O-F (%s-%s)' - % (datatypestr, sdatestr, edatestr), - horizontalalignment='center', fontsize=18) - - if (save_figure): - fname = 'gsistat_uvtq' - plt.savefig(plotdir+'/%s_%s.pdf' % (fname, datatypestr)) - plt.savefig(plotdir+'/%s_%s.png' % (fname, datatypestr)) + ax.set_xlabel("count" if stattype == "sum" else f"magnitude ({var_unit})") + + # Add horizontal padding to x-limits + if np.isfinite(xmin) and np.isfinite(xmax) and xmax > xmin: + pad = (xmax - xmin) * 0.1 + ax.set_xlim(xmin - pad, xmax + pad) + + # Annotate figure header + sdatestr, edatestr = sdate.strftime("%Y%m%d%H"), edate.strftime("%Y%m%d%H") + plt.figtext(0.5, 0.93, f"{datatypestr} O-F ({sdatestr}-{edatestr})", ha="center", fontsize=18) + + # Save or show + fname = f"gsistat_uvtq_{datatypestr}" + if save: + plotdir.mkdir(parents=True, exist_ok=True) + plt.savefig(plotdir / f"{fname}.pdf", bbox_inches="tight") + plt.savefig(plotdir / f"{fname}.png", bbox_inches="tight", dpi=150) + plt.close(fig) else: plt.show() -def get_gsistat_list(startdate, enddate): - statfiles = [] - cycles = [] - mydate = startdate - while mydate <= enddate: - cycle = mydate.strftime('%Y%m%d%H') - fname = 'gsistat.gdas.' + cycle - statfiles.append(fname) - cycles.append(cycle) - mydate = mydate + datetime.timedelta(hours=6) - return statfiles, cycles - - -if __name__ == '__main__': - # get command line arguments - parser = argparse.ArgumentParser(description=('Plots a comparison of GSI ', - 'statistics for ', - 'experiments compared to a ', - 'reference control run.')) - parser.add_argument('-d', '--gsistats', - help='list of directories containing GSI stat files', - nargs='+', required=True) - parser.add_argument('-l', '--label', - help='list of labels for experiment IDs', - nargs='+', required=False) - parser.add_argument('-f', '--save_figure', - help='save figures as png and pdf', - action='store_true', required=False) - parser.add_argument('-p', '--plotdir', - help='path to where to save figures', - default='./', required=False) - parser.add_argument('-s', '--start_date', help='starting date', - type=str, metavar='YYYYMMDDHH', required=True) - parser.add_argument('-e', '--end_date', help='ending date', - type=str, metavar='YYYYMMDDHH', required=True) +def main() -> None: + """Main CLI entry point: read data, compute aggregates, and plot.""" + # 1. Parse command-line args + parser = argparse.ArgumentParser( + description="Plot comparison of GSI O-F statistics (RMSE, Bias, Count) for multiple experiments." + ) + parser.add_argument("-d", "--gsistats", nargs="+", required=True, + help="Directories containing GSI stat files (one per experiment).") + parser.add_argument("-l", "--label", nargs="+", required=False, + help="Labels for experiments (must match --gsistats order).") + parser.add_argument("-f", "--save-figure", action="store_true", dest="save_figure", + help="Save figures as PNG/PDF instead of showing interactively.") + parser.add_argument("-p", "--plotdir", default="./", help="Output directory for saved figures.") + parser.add_argument("-s", "--start-date", required=True, metavar="YYYYMMDDHH", + help="Start date/time (e.g., 2025010100).", dest="start_date") + parser.add_argument("-e", "--end-date", required=True, metavar="YYYYMMDDHH", + help="End date/time (e.g., 2025010700).", dest="end_date") args = parser.parse_args() + # 2. Prepare paths and dates save_figure = args.save_figure - if (save_figure): - matplotlib.use('Agg') - - sdate = datetime.datetime.strptime(args.start_date, '%Y%m%d%H') - edate = datetime.datetime.strptime(args.end_date, '%Y%m%d%H') - - statfiles, cycles = get_gsistat_list(sdate, edate) - - if args.label: - labels = args.label - else: - labels = [g.rstrip('/').split('/')[-1] for g in args.gsistats] - - # loop through all files and variables and grab statistics - rmses = {} - counts = {} - biases = {} - for exp, gsistats in zip(labels, args.gsistats): - rmses[exp] = {} - counts[exp] = {} - biases[exp] = {} + plotdir = Path(args.plotdir).expanduser().resolve() + sdate = datetime.strptime(args.start_date, "%Y%m%d%H") + edate = datetime.strptime(args.end_date, "%Y%m%d%H") + date_range = DateRange(sdate, edate) + statfiles, cycles = date_range.cycles_6h() + + gsistat_dirs = [Path(p).expanduser().resolve() for p in args.gsistats] + labels = args.label if args.label else [p.name for p in gsistat_dirs] + _validate_inputs(gsistat_dirs, labels) + + # 3. Read gsistat data + rmses: Dict[str, dict] = {} + counts: Dict[str, dict] = {} + biases: Dict[str, dict] = {} + + for exp, gsistats_dir in zip(labels, gsistat_dirs): + rmses[exp], counts[exp], biases[exp] = {}, {}, {} for gsistat, cycle in zip(statfiles, cycles): - rmses[exp][cycle] = {} - counts[exp][cycle] = {} - biases[exp][cycle] = {} - inputfile = os.path.join(gsistats, gsistat) - try: - gdas = GSIstat(inputfile, cycle) - except FileNotFoundError: - raise FileNotFoundError( - f'Unable to find {inputfile} for cycle {cycle}') - # now loop through variables - for var in obvars: - stat = gdas.extract(var) # t, uv, q, etc. - stat = stat.query(it) # ges (1) or anl (3) ? - tmpstat = stat[stat.index.isin(obtypes, level='typ')] - tmpstat = tmpstat[tmpstat.index.isin(['asm'], level='use')] - rmses[exp][cycle][var] = tmpstat[tmpstat.index.isin( - ['rms'], - level='stat')] - counts[exp][cycle][var] = tmpstat[tmpstat.index.isin( - ['count'], - level='stat')] - biases[exp][cycle][var] = tmpstat[tmpstat.index.isin( - ['bias'], - level='stat')] - - # now aggregate stats + rmses[exp][cycle], counts[exp][cycle], biases[exp][cycle] = {}, {}, {} + inputfile = gsistats_dir / gsistat + if not inputfile.is_file(): + raise FileNotFoundError(f"Unable to find {inputfile} for cycle {cycle}") + + # Read gsistat file using pyGSI + gdas = GSIstat(str(inputfile), cycle) + + # Extract desired variables and QC subsets + for var in OB_VARS: + stat = gdas.extract(var).query(IT_QUERY) + tmp = stat[stat.index.isin(OB_TYPES, level="typ")] + tmp = tmp[tmp.index.isin(["asm"], level="use")] + rmses[exp][cycle][var] = tmp[tmp.index.isin(["rms"], level="stat")] + counts[exp][cycle][var] = tmp[tmp.index.isin(["count"], level="stat")] + biases[exp][cycle][var] = tmp[tmp.index.isin(["bias"], level="stat")] + + # 4. Aggregate across cycles + n_cycles = len(cycles) + n_levels = len(LEVELS) for exp in labels: - rmses[exp]['mean'] = {} - rmses[exp]['aggr'] = {} - biases[exp]['mean'] = {} - biases[exp]['aggr'] = {} - counts[exp]['sum'] = {} - for var in obvars: - rmse_var = np.empty([len(cycles), len(levels)]) - bias_var = np.empty([len(cycles), len(levels)]) - counts_var = np.empty([len(cycles), len(levels)], dtype=int) - for i, cycle in enumerate(cycles): - rmse_var[i, :] = rmses[exp][cycle][var].values[0] - bias_var[i, :] = biases[exp][cycle][var].values[0] - counts_var[i, :] = counts[exp][cycle][var].values[0] - # Compute mean rms, bias - rmses[exp]['mean'][var] = rmse_var.mean(axis=0) - biases[exp]['mean'][var] = bias_var.mean(axis=0) - # Compute aggregate rms, bias - ar = np.asarray([]) - ab = np.asarray([]) - for j in range(np.ma.size(rmse_var, axis=1)): - r = rmse_var[:, j].squeeze() - b = bias_var[:, j].squeeze() - c = counts_var[:, j].squeeze() - if (np.sum(c) > 0): - ar = np.append(ar, np.sqrt(np.sum(np.multiply(c, r**2.))/np.sum(c))) - ab = np.append(ab, np.sum(np.multiply(c, b))/np.sum(c)) - else: - ar = np.append(ar, np.nan) - ab = np.append(ab, np.nan) - rmses[exp]['aggr'][var] = ar - biases[exp]['aggr'][var] = ab - # Compute summed counts - counts[exp]['sum'][var] = counts_var.sum(axis=0) - - # make figures - gen_figure(rmses, 'RMSE', plottype, labels, sdate, edate, save_figure, args.plotdir) - gen_figure(biases, 'Bias', plottype, labels, sdate, edate, save_figure, args.plotdir) - gen_figure(counts, 'Count', 'sum', labels, sdate, edate, - save_figure, args.plotdir) + rmses[exp]["mean"], rmses[exp]["aggr"] = {}, {} + biases[exp]["mean"], biases[exp]["aggr"] = {}, {} + counts[exp]["sum"] = {} + + for var in OB_VARS: + # Build matrices across all cycles + rmse_mat = np.empty((n_cycles, n_levels), dtype=float) + bias_mat = np.empty((n_cycles, n_levels), dtype=float) + cnt_mat = np.empty((n_cycles, n_levels), dtype=float) + + for i, cyc in enumerate(cycles): + rmse_mat[i, :] = rmses[exp][cyc][var].values[0] + bias_mat[i, :] = biases[exp][cyc][var].values[0] + cnt_mat[i, :] = counts[exp][cyc][var].values[0] + + # Simple means across time + rmses[exp]["mean"][var] = np.nanmean(rmse_mat, axis=0) + biases[exp]["mean"][var] = np.nanmean(bias_mat, axis=0) + + # Weighted aggregates across time (using counts as weights) + sum_c = np.nansum(cnt_mat, axis=0) + with np.errstate(invalid="ignore", divide="ignore"): + w_rmse = np.sqrt(np.nansum(cnt_mat * (rmse_mat ** 2.0), axis=0) / sum_c) + w_bias = np.nansum(cnt_mat * bias_mat, axis=0) / sum_c + w_rmse[sum_c <= 0] = np.nan + w_bias[sum_c <= 0] = np.nan + + rmses[exp]["aggr"][var] = w_rmse + biases[exp]["aggr"][var] = w_bias + counts[exp]["sum"][var] = np.nansum(cnt_mat, axis=0) + + # 5. Plot results + gen_figure(rmses, "RMSE", PLOT_TYPE, labels, sdate, edate, save_figure, plotdir) + gen_figure(biases, "Bias", PLOT_TYPE, labels, sdate, edate, save_figure, plotdir) + gen_figure(counts, "Count", "sum", labels, sdate, edate, save_figure, plotdir) + + +if __name__ == "__main__": + main() diff --git a/src/pyGSI/gsi_stat.py b/src/pyGSI/gsi_stat.py index 3f60a01..8226115 100644 --- a/src/pyGSI/gsi_stat.py +++ b/src/pyGSI/gsi_stat.py @@ -153,9 +153,9 @@ def extract_instrument(self, obtype, instrument): df = pd.DataFrame(data=tmp, columns=columns) df.drop(['col1', 'col2', 'col3'], inplace=True, axis=1) df[['channel', 'nassim', 'nrej']] = df[[ - 'channel', 'nassim', 'nrej']].astype(np.int) + 'channel', 'nassim', 'nrej']].astype(int) df[['oberr', 'OmF_bc', 'OmF_wobc']] = df[[ - 'oberr', 'OmF_bc', 'OmF_wobc']].astype(np.float) + 'oberr', 'OmF_bc', 'OmF_wobc']].astype(float) # Since iteration number is not readily available, make one lendf = len(df) @@ -230,9 +230,9 @@ def _get_ps_tpw(self, name): columns = header.split() df = pd.DataFrame(data=tmp, columns=columns) - df[['it', 'typ', 'count']] = df[['it', 'typ', 'count']].astype(np.int) + df[['it', 'typ', 'count']] = df[['it', 'typ', 'count']].astype(int) df[['bias', 'rms', 'cpen', 'qcpen']] = df[[ - 'bias', 'rms', 'cpen', 'qcpen']].astype(np.float) + 'bias', 'rms', 'cpen', 'qcpen']].astype(float) df.set_index(columns[:5], inplace=True) return df @@ -251,7 +251,7 @@ def _get_conv(self, name): if re.search(pattern, line): if re.search(' '+name+' ', self._lines[i+2]): header = line.strip() - ptops = np.array(header.split()[2:-1], dtype=np.float) + ptops = np.array(header.split()[2:-1], dtype=float) break if ptops is []: print(f'No matching ptop for {name}') @@ -267,7 +267,7 @@ def _get_conv(self, name): header = line.strip() header = re.sub('pbot', 'stat', header) header = re.sub('0.200E' + r'\+04', 'column', header) - pbots = np.array(header.split()[7:-1], dtype=np.float) + pbots = np.array(header.split()[7:-1], dtype=float) break if pbots is []: print(f'No matching pbot for {name}') @@ -291,9 +291,9 @@ def _get_conv(self, name): columns = header.split() df = pd.DataFrame(data=tmp, columns=columns) - df[['it', 'typ']] = df[['it', 'typ']].astype(np.int) + df[['it', 'typ']] = df[['it', 'typ']].astype(int) df.set_index(columns[:7], inplace=True) - df = df.astype(np.float) + df = df.astype(float) return df @@ -328,9 +328,9 @@ def _get_ozone(self): columns = header.split() df = pd.DataFrame(data=tmp, columns=columns) df[['it', 'read', 'keep', 'assim']] = df[[ - 'it', 'read', 'keep', 'assim']].astype(np.int) + 'it', 'read', 'keep', 'assim']].astype(int) df[['penalty', 'cpen', 'qcpen', 'qcfail']] = df[[ - 'penalty', 'cpen', 'qcpen', 'qcfail']].astype(np.float) + 'penalty', 'cpen', 'qcpen', 'qcfail']].astype(float) df.set_index(columns[:4], inplace=True) df = df.swaplevel('sat', 'inst') df.index.rename(['satellite', 'instrument'], level=[ @@ -369,9 +369,9 @@ def _get_radiance(self): columns = header.split() df = pd.DataFrame(data=tmp, columns=columns) df[['it', 'read', 'keep', 'assim']] = df[[ - 'it', 'read', 'keep', 'assim']].astype(np.int) + 'it', 'read', 'keep', 'assim']].astype(int) df[['penalty', 'qcpnlty', 'cpen', 'qccpen']] = df[[ - 'penalty', 'qcpnlty', 'cpen', 'qccpen']].astype(np.float) + 'penalty', 'qcpnlty', 'cpen', 'qccpen']].astype(float) df.set_index(columns[:4], inplace=True) df = df.swaplevel('satellite', 'instrument') @@ -405,8 +405,8 @@ def _get_cost(self): columns = ['Outer', 'Inner', 'J', 'gJ'] df = pd.DataFrame(data=tmp, columns=columns) - df[['Outer', 'Inner', ]] = df[['Outer', 'Inner']].astype(np.int) + df[['Outer', 'Inner', ]] = df[['Outer', 'Inner']].astype(int) df.set_index(columns[:2], inplace=True) - df = df.astype(np.float) + df = df.astype(float) return df