Neural Networks based unified physics parameterization for atmospheric models
Help improve this workflow!
This workflow has been published but could be further improved with some additional meta data:- Keyword(s) in categories input, output, operation, topic
You can help improve this workflow by suggesting the addition or removal of keywords, suggest changes and report issues, or request to become a maintainer of the Workflow .
Machine learning approaches to convective parametrization
Documentation
The documentation is hosted on github pages: https://nbren12.github.io/uwnet/
Setup
Obtaining permission to use SAM
The System for Atmospheric Modeling (SAM) is a key part of the pre-processing pipeline and prognostic evaluation of this machine learning project, but it is not necessary for offline evaluation or training.
If you want access to SAM, please email the author Marat Khairoutdinov (cc'ing me) to ask for permission. Then, I can give you access to the slightly modified version of SAM used for this project.
Once you have arranged this access, the SAM source code can be download to the
path
ext/sam
using
git submodule --init --recursive
Setting up the environment
This project uses two dependency management systems. Docker is needed to run the SAM model and SAM-related preprocessing steps. you do not need this if you are only training a model from pre-processed data (the data in zenodo). Poetry is a simpler pure python solution that should work for most common scenarios.
To use docker, you first need to build the image:
make build_image
If you get an error
make: nvidia-docker: Command not found
, edit the
Makefile to have
DOCKER = docker
instead of
nvidia-docker
. (Assuming
docker is already installed.) Then, the docker environment can be entered by
typing
make enter
This opens a shell variable in a docker container with all the necessary software requirements.
To use poetry, you can install all the needed packages and enter a sandboxed environment by running
poetry install
poetry shell
The instructions below assume you are in one of these environments
Running the workflow
To run train the models, type
snakemake -j <number of parallel jobs>
This will take a long time! To see all the steps and the corresponding commands in this workflow, type
snakemake -n -p
This whole analysis is specified in the Snakefile, which is the first place to look.
To reproduce the plots for the Journal of Atmospheric science paper, run
make jas2020
The figures for this paper requires you to install chromedriver to export to svg format. I did this on my mac with these commands:
# for svg saving from altair
brew install chromedriver
# on mac os to allow unverified developers
xattr -d com.apple.quarantine /usr/local/bin/chromedriver
You also need to install Inkscape to convert the svg to pdf format.
Evaluating performance
Evaluating ML Paramerizations is somewhat different than normal ML scoring.
Some useful metrics which work for xarray data are available in
uwnet.metrics
. In particular
uwnet.metrics.r2_score
computes the ubiquitous
R2 score.
Performing online tests
SAM has been modified to call arbitrary python functions within it's time stepping loop. These python functions accept a dictionary of numpy arrays as inputs, and store output arrays with specific names to this dictionary. Then SAM will pull the output contents of this dictionary back into Fortran and apply any computed tendency.
To extend this, one first needs to write a suitable function, which can be tested using the data stored at
assets/sample_sam_state.pt
. The following steps explore this data
In [5]: state = torch.load("assets/sample_sam_state.pt")
In [6]: state.keys()
Out[6]: dict_keys(['layer_mass', 'p', 'pi', 'caseid', 'case', 'liquid_ice_static_energy', '_DIMS', '_ATTRIBUTES', 'total_water_mixing_ratio', 'air_temperature', 'upward_air_velocity', 'x_wind', 'y_wind', 'tendency_of_total_water_mixing_ratio_due_to_dynamics', 'tendency_of_liquid_ice_static_energy_due_to_dynamics', 'tendency_of_x_wind_due_to_dynamics', 'tendency_of_y_wind_due_to_dynamics', 'latitude', 'longitude', 'sea_surface_temperature', 'surface_air_pressure', 'toa_incoming_shortwave_flux', 'surface_upward_sensible_heat_flux', 'surface_upward_latent_heat_flux', 'air_pressure', 'air_pressure_on_interface_levels', 'dt', 'time', 'day', 'nstep'])
In [7]: qt = state['total_water_mixing_ratio']
In [8]: qt.shape
Out[8]: (34, 64, 128)
In [9]: state['sea_surface_temperature'].shape
Out[9]: (1, 64, 128)
In [10]: state['air_pressure_on_interface_levels'].shape
Out[10]: (35,)
In [11]: state['p'].shape
Out[11]: (34,)
In [12]: state['_ATTRIBUTES']
Out[12]:
{'liquid_ice_static_energy': {'units': 'K'}, 'total_water_mixing_ratio': {'units': 'g/kg'}, 'air_temperature': {'units': 'K'}, 'upward_air_velocity': {'units': 'm/s'}, 'x_wind': {'units': 'm/s'}, 'y_wind': {'units': 'm/s'}, 'tendency_of_total_water_mixing_ratio_due_to_dynamics': {'units': 'm/s'}, 'tendency_of_liquid_ice_static_energy_due_to_dynamics': {'units': 'm/s'}, 'tendency_of_x_wind_due_to_dynamics': {'units': 'm/s'}, 'tendency_of_y_wind_due_to_dynamics': {'units': 'm/s'}, 'latitude': {'units': 'degreeN'}, 'longitude': {'units': 'degreeN'}, 'sea_surface_temperature': {'units': 'K'}, 'surface_air_pressure': {'units': 'mbar'}, 'toa_incoming_shortwave_flux': {'units': 'W m^-2'}, 'surface_upward_sensible_heat_flux': {'units': 'W m^-2'}, 'surface_upward_latent_heat_flux': {'units': 'W m^-2'}, 'air_pressure': {'units': 'hPa'}}
In [13]: # tendence of total water mixing ratio expected units = g/kg/day
In [14]: # tendence of tendency_of_liquid_ice_static_energy expected units =K/day
Configuring SAM to call this function
Write uwnet.sam_interface.call_random_forest
rule sam_run in Snakefile actually runs the SAM model.
parameters as a json file are passed to src.sam.create_case via -p flag.
Example parameters at assets/parameters_sam_neural_network.json.
parameters['python'] configures which python function is called.
Code Snippets
78 | shell: "cd data/raw && curl {DATA_URL} | tar xv" |
83 84 85 86 | shell: """ echo {input} | ncrcat -o {output} """ |
93 | script: "uwnet/data/preprocess.py" |
101 | script: "uwnet/data/preprocess.py" |
109 | script: "uwnet/data/preprocess.py" |
117 | script: "uwnet/data/reshape.py" |
122 | shell: "ncks -d y,24,40 {input} {output}" |
127 128 129 130 131 | run: import xarray as xr ds = xr.open_dataset(input[0]) mean = ds.mean(['x', 'time']) mean.to_netcdf(output[0]) |
141 142 143 144 145 146 147 148 149 150 151 | shell: """ papermill -p run_path $PWD/{params.run} -p training_data $PWD/{TRAINING_DATA} \ -p caseid {wildcards.id} \ -p training_data_mean {TRAINING_MEAN} \ --prepare-only {params.template} {params.ipynb} jupyter nbconvert --ExecutePreprocessor.timeout=600 \ --allow-errors \ --execute {params.ipynb} # clean up the notebook rm -f {params.ipynb} """ |
163 164 165 166 167 168 169 170 171 172 173 174 | shell: """ rm -rf {params.rundir} {sys.executable} -m src.sam.create_case -nn {input} \ -n {params.ngaqua} \ -s {params.sam_src} \ -t {params.step} \ -p {params.sam_params} \ {params.rundir} # run sam cd {params.rundir} sh run.sh >> log 2>> log """ |
186 187 188 189 190 191 192 193 194 195 196 197 | shell: """ rm -rf {params.rundir} {sys.executable} -m src.sam.create_case -sk {input} \ -n {params.ngaqua} \ -s {params.sam_src} \ -p {params.sam_params} \ -t {params.step} \ {params.rundir} # run sam cd {params.rundir} sh run.sh >> log 2>> log """ |
208 209 210 211 212 213 214 215 216 217 218 219 | shell: """ rm -rf {params.rundir} {sys.executable} -m src.sam.create_case \ -n {params.ngaqua} \ -s {params.sam_src} \ -t {params.step} \ -p {params.sam_params} \ {params.rundir} # run sam cd {params.rundir} sh run.sh """ |
226 227 228 229 230 231 232 233 | shell: """ rm -rf {params.rundir} {sys.executable} -m src.sam.create_case \ -t 0 -p assets/parameters_{wildcards.kind}.json {params.rundir} {RUN_SAM_SCRIPT} {params.rundir} >> {log} 2>> {log} exit 0 """ |
238 239 240 241 242 243 244 245 | shell:""" dir=$(mktemp --directory) {sys.executable} -m src.sam.create_case \ -p assets/parameters_save_python_state.json \ $dir mv -f $dir/state.pt assets/ rm -rf $dir """ |
252 253 254 | shell: """ python -m uwnet.train train_pre_post with data={TRAINING_DATA} prepost.path={output} """ |
275 276 277 | shell: """ python -m uwnet.ml_models.nn.train with {input.config} epochs={num_epoch} output_dir={params.dir} > {log} 2> {log} """ |
286 287 288 289 | shell: """ python -m uwnet.ml_models.train_generic_sklearn -op {params.dir} -cf {input.config} \ > {log} 2> {log} """ |
295 | shell: "python uwnet/criticism/evaluate.py {input} > {output}" |
301 | script: "src/models/debias.py" |
310 | shell: "python -m src.visualizations.sam_run {params.run} {output}" |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | import xarray as xr from uwnet.debias import insert_apparent_sources, LassoDebiasedModel import torch o = snakemake.output p = snakemake.params ds = xr.open_dataset(p.data) model = torch.load(p.model) ds = insert_apparent_sources(ds, prognostics=p.prognostics) mapping = [ ('QT', 'QT', 'QQT'), ('SLI', 'SLI', 'QSLI'), ] debias = LassoDebiasedModel(model, mapping).fit(ds) torch.save(debias, o[0]) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | import torch from uwnet.loss import get_input_output from uwnet.utils import mean_other_dims import json from tqdm import tqdm import argparse import uwnet.ml_models.nn.datasets_handler as d def batch_to_residual(model, batch): from uwnet.timestepper import Batch batch = Batch(batch.float(), prognostics=["QT", "SLI"]) with torch.no_grad(): prediction, truth = get_input_output(model, 0.125, batch) return prediction - truth def vertically_resolved_mse_from_residual(residual): return {k: mean_other_dims(residual[k] ** 2, 2).squeeze() for k in residual} def batch_to_mse(model, batch): residual = batch_to_residual(model, batch) return vertically_resolved_mse_from_residual(residual) def _parse_args(): parser = argparse.ArgumentParser() parser.add_argument("data", help="path to zarr reshaped training or test data") parser.add_argument("model", help="path to model") return parser.parse_args() args = _parse_args() model_path = args.model path = args.data prognostics = ["QT", "SLI"] model = torch.load(model_path) train_dataset = d.get_dataset(path, predict_radiation=False) dl = d.get_data_loader(train_dataset, prognostics, batch_size=64) total = {} count = 0 for batch in tqdm(dl): mse = batch_to_mse(model, batch) for key in mse: count += 1 alpha = 1 / count zeros = torch.zeros_like(mse[key]) total[key] = total.get(key, zeros) * (1 - alpha) + mse[key] * alpha total = {"mse_apparent_source_" + k.lower(): total[k].numpy().tolist() for k in total} print(json.dumps(total)) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | import xarray as xr import json import argparse from src.sam.case import get_ngqaua_ic from uwnet.data.blur import blur_dataset from uwnet.thermo import layer_mass from src.data.ngaqua import NGAqua from src.sam.process_ngaqua import run_sam_nsteps from subprocess import run import os import shutil from os.path import abspath class Namespace: def __init__(self, **kwargs): self.__dict__.update(kwargs) def get_args_snakemake(): print("Using snakemake") sigma = snakemake.params.sigma if sigma: sigma = float(sigma) return Namespace( time_step=int(snakemake.wildcards.step), sam_parameters=snakemake.input.sam_parameters, sam=snakemake.config.get('sam_path', '/opt/sam'), ngaqua_root=snakemake.params.ngaqua_root, output=snakemake.output[0], sigma=sigma) def get_args_argparse(): parser = argparse.ArgumentParser( description='Pre-process a single time step') parser.add_argument('-n', '--ngaqua-root', type=str) parser.add_argument('-s', '--sam', type=str, default='/opt/sam') parser.add_argument('-t', '--time-step', type=int, default=0) parser.add_argument('-p', '--sam-parameters', type=str, default=0) parser.add_argument('--sigma', type=float, help='Radius for Gaussian blurring') parser.add_argument('output') return parser.parse_args() def get_args(): """Get the arguments needed to run this script This will function will behave differently if run from snakemake""" try: snakemake except NameError: return get_args_argparse() else: return get_args_snakemake() args = get_args() def get_extra_features(ngaqua: NGAqua, time_step): # compute layer_mass stat = ngaqua.stat rho = stat.RHO.isel(time=0).drop('time') w = layer_mass(rho) # get 2D variables time = ngaqua.times_3d[time_step] d2 = ngaqua.data_2d.sel(time=time) # add variables to three-d d2['RADTOA'] = d2.LWNT - d2.SWNT d2['RADSFC'] = d2.LWNS - d2.SWNS d2['layer_mass'] = w d2['rho'] = rho # 3d variables qrad = ngaqua.data_3d.QRAD.drop(['x', 'y']) d2['QRAD'] = qrad.sel(time=time) return d2 # get initial condition ic = get_ngqaua_ic(args.ngaqua_root, args.time_step) # get data ngaqua = NGAqua(args.ngaqua_root) features = get_extra_features(ngaqua, args.time_step) if args.sigma: print(f"Blurring data with radius {args.sigma}") ic = blur_dataset(ic, args.sigma) features = blur_dataset(features, args.sigma) # compute the forcings by running through sam if args.sam_parameters: with open(args.sam_parameters) as f: prm = json.load(f) path = run_sam_nsteps(ic, prm, sam_src=abspath(args.sam)) files = os.path.join(path, 'OUT_3D', '*.nc') ds = xr.open_mfdataset(files) shutil.rmtree(path) for key in ['QT', 'SLI', 'U', 'V']: forcing_key = 'F' + key src = ds[key].diff('time') / ds.time.diff('time') / 86400 src = src.isel(time=0) ds[forcing_key] = src ds = ds.isel(time=0) # forcing data ic['x'] = ds.x ic['y'] = ds.y for key in ic.data_vars: if key not in ds: ds[key] = ic[key] ds = ds.merge(features).expand_dims('time') ds.attrs['sam_namelist'] = json.dumps(prm) ds.to_netcdf(args.output, unlimited_dims=['time'], engine='h5netcdf') |
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 | import xarray as xr import numpy as np # arguments input_file = snakemake.input[0] output_file = snakemake.output[0] variables = snakemake.params.variables shuffle = snakemake.params.shuffle train_or_test = snakemake.wildcards.train_or_test chunk_size = 2**10 # open data ds = xr.open_dataset(input_file) if train_or_test == "train": ds = ds.isel(x=slice(0, 64)) elif train_or_test == "test": ds = ds.isel(x=slice(64, None)) else: raise NotImplementedError("{test_or_train} is not \"train\" or \"test\"") # perform basic validation assert ds['SLI'].dims == ('time', 'z', 'y', 'x') # stack data variables = list(variables) # needs to be a list for xarray stacked = (ds[variables] .stack(sample=['y', 'x']) .drop('sample')) # add needed variables stacked['layer_mass'] = ds.layer_mass.isel(time=0) # shuffle samples if shuffle: n = len(stacked.sample) indices = np.random.choice(n, n, replace=False) stacked = stacked.isel(sample=indices) chunked = stacked.chunk({'sample': chunk_size}) # save to disk chunked.to_zarr(output_file) |
Support
- Future updates
Related Workflows





