"""Data classes for gcmprocpy data containers."""
from dataclasses import dataclass
import numpy as np
import xarray as xr
# Model-specific default variable names and configurations.
MODEL_DEFAULTS = {
'TIE-GCM': {
'wind_u': 'UN',
'wind_v': 'VN',
'wind_w': 'WN',
'temperature': 'TN',
'electron_density': 'NE',
'density': {
'vars': ['NE', 'DEN', 'O2', 'O1', 'N2', 'NO', 'N4S', 'HE', 'OP',
'NMF2', 'TEC'],
'cmap': 'viridis',
'line_color': 'white',
},
'temperature_type': {
'vars': ['TN', 'TE', 'TI', 'QJOULE'],
'cmap': 'inferno',
'line_color': 'white',
},
'wind': {
'vars': ['UN', 'VN', 'WN', 'UI_ExB', 'VI_ExB', 'WI_ExB'],
'cmap': 'bwr',
'line_color': 'black',
},
'electric': {
'vars': ['POTEN'],
'cmap': 'bwr',
'line_color': 'black',
},
'wind_scale': 0.01, # cm/s → m/s
'species': {
'temp': 'TN', 'o': 'O1', 'o2': 'O2', 'n2': 'N2',
'no': 'NO', 'co2': 'CO2', 'h': 'H', 'o3': 'O3', 'ho2': 'HO2',
},
},
'WACCM-X': {
'wind_u': 'U',
'wind_v': 'V',
'wind_w': 'OMEGA',
'temperature': 'T',
'electron_density': 'EDens',
'density': {
'vars': ['EDens', 'OpDens', 'O2p', 'NOp', 'N2p', 'Op',
'ElecColDens', 'O3', 'NO', 'NO2', 'N2O', 'CO', 'CO2',
'CH4', 'H2O', 'HE', 'O', 'O2', 'N2', 'HNO3', 'NOY',
'CLOY', 'BROY'],
'cmap': 'viridis',
'line_color': 'white',
},
'temperature_type': {
'vars': ['T', 'TREFHT', 'THETA'],
'cmap': 'inferno',
'line_color': 'white',
},
'wind': {
'vars': ['U', 'V', 'OMEGA', 'UTGW_TOTAL', 'VTGW_TOTAL'],
'cmap': 'bwr',
'line_color': 'black',
},
'electric': {
'vars': ['ED1', 'ED2', 'POTEN'],
'cmap': 'bwr',
'line_color': 'black',
},
'radiation': {
'vars': ['FSDS', 'FSNS', 'FSNT', 'FLDS', 'FLNS', 'FLNT', 'FLUT',
'QRL_TOT', 'QRS_TOT', 'QRS_EUV', 'QRS_AUR', 'QTHERMAL',
'SWCF', 'LWCF'],
'cmap': 'plasma',
'line_color': 'white',
},
'wind_scale': 1.0, # already m/s
'species': {
'temp': 'T', 'o': 'O', 'o2': 'O2', 'n2': 'N2',
'no': 'NO', 'co2': 'CO2', 'h': 'H', 'o3': 'O3', 'ho2': 'HO2',
},
},
}
@dataclass
class ModelDataset:
"""A loaded NetCDF dataset with its metadata.
Attributes:
ds: The opened xarray Dataset.
filename: The source filename (e.g. 'decsol_smin_2.5x0.25_sech_001.nc').
model: The model type ('TIE-GCM' or 'WACCM-X').
_time_set: Cached set of time values for fast lookup.
"""
ds: xr.Dataset
filename: str
model: str
_time_set: set = None
_time_values: np.ndarray = None
def __post_init__(self):
self._time_values = self.ds['time'].values
self._time_set = set(self._time_values)
def has_time(self, time):
"""Fast check whether a timestamp exists in this dataset."""
return time in self._time_set
[docs]
@dataclass
class PlotData:
"""Container for data returned by arr_* functions when plot_mode=True.
Attributes:
values: The extracted variable values (numpy array).
variable_unit: The unit string after any conversion.
variable_long_name: The long descriptive name of the variable.
model: The model type ('TIE-GCM' or 'WACCM-X').
filename: The source dataset filename.
levs: Level/ilevel coordinate array (if applicable).
lats: Latitude coordinate array (if applicable).
lons: Longitude coordinate array (if applicable).
mtime: Single model time as [day, hour, min, sec] (for single-time plots).
mtime_values: List of model times (for multi-time plots like lev_time, lat_time).
selected_lat: The latitude value used for selection (if applicable).
selected_lon: The longitude value used for selection (if applicable).
selected_lev: The level value used for selection (if applicable).
"""
values: np.ndarray
variable_unit: str
variable_long_name: str
model: str
filename: str
levs: np.ndarray = None
lats: np.ndarray = None
lons: np.ndarray = None
mtime: list = None
mtime_values: list = None
selected_lat: float = None
selected_lon: float = None
selected_lev: float = None
[docs]
def get_species_names(model):
"""Return species name mapping for a model type.
Uses ``MODEL_DEFAULTS`` as the single source of truth for
mapping canonical role names to dataset variable names.
Args:
model (str): Model type (``'TIE-GCM'`` or ``'WACCM-X'``).
Returns:
dict: Mapping from canonical names (e.g. ``'temp'``, ``'o'``,
``'o2'``) to dataset variable names (e.g. ``'TN'``, ``'O1'``,
``'O2'``).
Raises:
ValueError: If *model* is not recognized.
"""
if model not in MODEL_DEFAULTS:
raise ValueError(
f"Unknown model '{model}'. Known: {list(MODEL_DEFAULTS)}"
)
return MODEL_DEFAULTS[model]['species']
# ---------------------------------------------------------------------------
# Derived-variable registry
# ---------------------------------------------------------------------------
DERIVED_VARIABLES = {}
# Bounded LRU cache for data extraction + derived-variable computations.
# Scrubbing the timeline or re-clicking Plot with the same settings repeatedly
# calls the same (datasets, variable, time, level, ...) tuple — caching turns
# the second hit onward into an O(1) dict lookup. Used for arr_* data
# extraction functions in data_parse.py and derived-variable handlers.
from collections import OrderedDict as _OrderedDict
_DATA_CACHE_MAX = 128
_data_cache = _OrderedDict()
def _make_cache_key(fn_name, datasets, args, kwargs):
# Normalize: convert any list kwargs/args to tuples so keys hash.
norm_args = tuple(tuple(a) if isinstance(a, list) else a for a in args)
norm_kwargs = tuple(sorted(
(k, tuple(v) if isinstance(v, list) else v) for k, v in kwargs.items()
))
return (fn_name, id(datasets), norm_args, norm_kwargs)
def _cached_call(fn, datasets, *args, **kwargs):
try:
key = _make_cache_key(fn.__name__, datasets, args, kwargs)
hash(key)
except TypeError:
return fn(datasets, *args, **kwargs)
cached = _data_cache.get(key)
if cached is not None:
_data_cache.move_to_end(key)
return cached
result = fn(datasets, *args, **kwargs)
_data_cache[key] = result
if len(_data_cache) > _DATA_CACHE_MAX:
_data_cache.popitem(last=False)
return result
[docs]
def clear_data_cache():
"""Drop all cached results. Call on dataset reload."""
_data_cache.clear()
# Backwards-compat alias (GUI imports this name)
clear_derived_cache = clear_data_cache
def cache_data_fn(fn):
"""Decorator to memoize an arr_* data extraction function.
Keys on (fn name, id(datasets), positional args, kwargs). Skips caching
if any arg is unhashable (e.g. raw numpy arrays in arr_sat_track).
"""
def wrapped(datasets, *args, **kwargs):
return _cached_call(fn, datasets, *args, **kwargs)
wrapped.__wrapped__ = fn
wrapped.__name__ = fn.__name__
return wrapped
def _wrap_cached(handler):
"""Wrap a derived-variable handler with data caching."""
return cache_data_fn(handler)
[docs]
def register_derived(name, handler, plot_types=None):
"""Register a derived variable computation handler.
Args:
name (str): Variable name (e.g. ``'NO53'``) or a glob-style
pattern ending with ``*`` (e.g. ``'OH_*'``).
handler (callable): Function with signature
``(datasets, variable_name, time, **kwargs) -> PlotData``.
plot_types (set, optional): Plot types this variable supports
(e.g. ``{'lat_lon', 'lev_lat'}``). *None* means all.
"""
DERIVED_VARIABLES[name] = {
'handler': handler,
'plot_types': plot_types,
}
[docs]
def resolve_derived(variable_name):
"""Look up the handler for a derived variable name.
Checks exact matches first, then pattern matches (keys ending
with ``*``).
Args:
variable_name (str): The variable name to look up.
Returns:
tuple: ``(handler, True)`` if found, ``(None, False)`` otherwise.
"""
# Exact match
if variable_name in DERIVED_VARIABLES:
return _wrap_cached(DERIVED_VARIABLES[variable_name]['handler']), True
# Also check upper-case form
vn_upper = variable_name.upper()
if vn_upper in DERIVED_VARIABLES:
return _wrap_cached(DERIVED_VARIABLES[vn_upper]['handler']), True
# Pattern match (e.g. 'OH_*')
for key, entry in DERIVED_VARIABLES.items():
if key.endswith('*') and vn_upper.startswith(key[:-1]):
return _wrap_cached(entry['handler']), True
return None, False