Deep learning model for predicting environmental variables on river systems
Help improve this workflow!
This workflow has been published but could be further improved with some additional meta data:- Keyword(s) in categories input, output, operation, topic
You can help improve this workflow by suggesting the addition or removal of keywords, suggest changes and report issues, or request to become a maintainer of the Workflow .
Deep Graph Convolutional Neural Network for Predicting Environmental Variables on River Networks
This repository contains code for predicting environmental variables on river networks. The models included are all either temporally or spatiotemporally aware and incorporate information from the river network. The original intent of this repository was to predict stream temperature and streamflow.
This work is being developed by researchers in the Data Science branch of the US. Geological Survey and researchers at the University of Minnesota in Vipin Kumar's lab. Sources for specific models are included as comments within the code.
Running the code
There are functions for facilitating pre-processing and post-processing of the data in addition to running the models themselves. Included within the workflow_examples folder of the repository are a number of example Snakemake workflow that show how to run the entire process with a variety of models and end-goals.
To run the Snakemake workflow locally:
-
Install the dependencies in the
environment.yaml
file. With conda you can do this withconda env create -f environment.yaml
-
Activate your conda environment
source activate rdl_torch_tf
-
Install the local
river-dl
package bypip install path/to/river-dl/
( optional ) -
Edit the river-dl run configuration (including paths for I/O data) in the appropriate
config.yml
from the workflow_examples folder. -
Run Snakemake with
snakemake --configfile config.yml -s Snakemake --cores <n>
To run the Snakemake Workflow on TallGrass
-
Request a GPU allocation and start an interactive shell
salloc -N 1 -t 2:00:00 -p gpu -A <account> --gres=gpu:1 srun -A <account> --pty bash
-
Load the necessary cuda toolkit module and add paths to the cudnn drivers
module load cuda11.3/toolkit/11.3.0 export LD_LIBRARY_PATH=/cm/shared/apps/nvidia/TensorRT-6.0.1.5/lib:/cm/shared/apps/nvidia/cudnn_8.0.5/lib64:$LD_LIBRARY_PATH
-
Follow steps 1-5 above as you would to run the workflow locally (note, you may need to change
tensorflow
totensoflow-gpu
in theenvironment.yml
).
After building your environment, you may want to make sure the recommended versions of PyTorch and CUDA were installed
according to the
PyTorch documentation
. You can see the installed versions
by calling
conda list
within your activated environment.
The data
The data used to run this model currently are specific to the Delaware River Basin but will soon be made more generic.
Disclaimer
This software is in the public domain because it contains materials that originally came from the U.S. Geological Survey, an agency of the United States Department of Interior. For more information, see the official USGS copyright policy
Although this software program has been used by the U.S. Geological Survey (USGS), no warranty, expressed or implied, is made by the USGS or the U.S. Government as to the accuracy and functioning of the program and related program material nor shall the fact of distribution constitute any such warranty, and no responsibility is assumed by the USGS in connection therewith.
This software is provided “AS IS.”
Code Snippets
34 35 | run: asRunConfig(config, code_dir, output[0]) |
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | run: prep_all_data( x_data_file=input[0], y_data_file=input[1], x_vars=config['x_vars'], y_vars_finetune=config['y_vars'], spatial_idx_name='segs_test', time_idx_name='times_test', catch_prop_file=None, train_start_date=config['train_start_date'], train_end_date=config['train_end_date'], val_start_date=config['val_start_date'], val_end_date=config['val_end_date'], test_start_date=config['test_start_date'], test_end_date=config['test_end_date'], segs=None, out_file=output[0], trn_offset = config['trn_offset'], tst_val_offset = config['tst_val_offset']) |
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | run: optimizer = tf.optimizers.Adam(learning_rate=config['finetune_learning_rate']) model.compile(optimizer=optimizer, loss=loss_function) data = np.load(input[0]) train_model(model, x_trn = data['x_trn'], y_trn = data['y_obs_trn'], epochs = config['epochs'], batch_size = 2, seed = config['seed'], x_val = data['x_val'], y_val = data['y_obs_val'], # I need to add a trailing slash here. Otherwise the wgts # get saved in the "outdir" weight_dir = output[0] + "/", log_file = output[1], time_file = output[2], early_stop_patience=config['early_stopping']) |
107 108 109 110 111 112 113 114 115 116 117 | run: weight_dir = input[0] + '/' model.load_weights(weight_dir) predict_from_io_data(model=model, io_data=input[1], partition=wildcards.partition, outfile=output[0], trn_offset = config['trn_offset'], spatial_idx_name='segs_test', time_idx_name='times_test', tst_val_offset = config['tst_val_offset']) |
141 142 143 144 145 146 147 148 | run: combined_metrics(obs_file=input[0], pred_trn=input[1], pred_val=input[2], spatial_idx_name='segs_test', time_idx_name='times_test', group=params.grp_arg, outfile=output[0]) |
156 157 158 | run: plot_obs(input[0], wildcards.variable, output[0], partition=wildcards.partition) |
34 35 | run: asRunConfig(config,code_dir,output[0]) |
41 42 43 44 | shell: """ scp Snakefile {output[0]} """ |
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | run: prep_all_data( x_data_file=input[0], pretrain_file=input[0], y_data_file=input[1], distfile=input[2], x_vars=config['x_vars'], y_vars_pretrain=config['y_vars_pretrain'], y_vars_finetune=config['y_vars_finetune'], catch_prop_file=None, train_start_date=config['train_start_date'], train_end_date=config['train_end_date'], val_start_date=config['val_start_date'], val_end_date=config['val_end_date'], test_start_date=config['test_start_date'], test_end_date=config['test_end_date'], segs=None, out_file=output[0], trn_offset=config['trn_offset'], tst_val_offset=config['tst_val_offset']) |
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | run: data = np.load(input[0]) data = reshape_for_gwn(data,keep_portion=config['trn_offset']) adj_mx = data['dist_matrix'] device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') supports = [torch.tensor(adj_mx).to(device).float()] in_dim = len(data['x_vars']) out_dim = data['y_obs_trn'].shape[3] num_nodes = adj_mx.shape[0] lrate = 0.001 wdecay = 0.0001 model = gwnet(device,num_nodes,supports=supports,aptinit=supports[ 0],in_dim=in_dim,out_dim=out_dim,layers=5,kernel_size=7,blocks=2) opt = optim.Adam(model.parameters(),lr=lrate,weight_decay=wdecay) train_torch(model, loss_function = rmse_masked, optimizer= opt, x_train= data['x_trn'], y_train = data['y_pre_trn'], max_epochs = config['pt_epochs'], early_stopping_patience=config['early_stopping'], batch_size = config['batch_size'], weights_file = output[0], log_file = output[1], device=device) |
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | run: data = np.load(input[0]) data = reshape_for_gwn(data,keep_portion=config['trn_offset']) adj_mx = data['dist_matrix'] device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') supports = [torch.tensor(adj_mx).to(device).float()] in_dim = len(data['x_vars']) out_dim = data['y_obs_trn'].shape[3] num_nodes = adj_mx.shape[0] lrate = 0.001 wdecay = 0.0001 model = gwnet(device,num_nodes,supports=supports,aptinit=supports[ 0],in_dim=in_dim,out_dim=out_dim,layers=5,kernel_size=7,blocks=2) opt = optim.Adam(model.parameters(),lr=lrate,weight_decay=wdecay) model.load_state_dict(torch.load(input[1])) train_torch(model, loss_function=rmse_masked, optimizer=opt, x_train=data['x_trn'], y_train=data['y_obs_trn'], x_val=data['x_val'], y_val=data['y_obs_val'], max_epochs=config['ft_epochs'], early_stopping_patience=config['early_stopping'], batch_size = config['batch_size'], weights_file=output[0], log_file=output[1], device=device) |
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | run: data = np.load(input[1]) data = reshape_for_gwn(data,keep_portion=config['trn_offset']) adj_mx = data['dist_matrix'] device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') supports = [torch.tensor(adj_mx).to(device).float()] in_dim = len(data['x_vars']) out_dim = data['y_obs_trn'].shape[3] num_nodes = adj_mx.shape[0] lrate = 0.001 wdecay = 0.0001 model = gwnet(device,num_nodes,supports=supports,aptinit=supports[ 0],in_dim=in_dim,out_dim=out_dim,layers=5,kernel_size=7,blocks=2) opt = optim.Adam(model.parameters(),lr=lrate,weight_decay=wdecay) model.load_state_dict(torch.load(input[0])) predict_from_io_data(model, data, wildcards.partition, outfile=output[0], trn_offset=config['trn_offset'], tst_val_offset=config['tst_val_offset'], torch_model=True, ) |
208 209 210 211 212 213 214 | run: combined_metrics(obs_file=input[0], pred_trn=input[1], pred_val=input[2], pred_tst=input[3], group=params.grp_arg, outfile=output[0]) |
222 223 224 | run: plot_obs(input[0],wildcards.variable,output[0], partition=wildcards.partition) |
59 60 61 62 63 64 65 66 67 68 69 70 71 72 | run: prep_annual_signal_data(input[0], input[1], input[2], train_start_date=config['train_start_date'], train_end_date=config['train_end_date'], val_start_date=config['val_start_date'], val_end_date=config['val_end_date'], test_start_date=config['test_start_date'], test_end_date=config['test_end_date'], out_file=output[0], extraResSegments = config['extraResSegments'], reach_file= config['reach_attr_file'], gw_loss_type=config['gw_loss_type'], trn_offset = config['trn_offset'], tst_val_offset = config['tst_val_offset'], |
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | run: data = np.load(input[0]) temp_air_index = np.where(data['x_vars'] == 'seg_tave_air')[0] air_unscaled = data['x_trn'][:, :, temp_air_index] * data['x_std'][temp_air_index] + \ data['x_mean'][temp_air_index] y_trn_obs = np.concatenate( [data["y_obs_trn"], data["GW_trn_reshape"], air_unscaled], axis=2 ) air_val = data['x_val'][:, :, temp_air_index] * data['x_std'][temp_air_index] + data['x_mean'][ temp_air_index] y_val_obs = np.concatenate( [data["y_obs_val"], data["GW_val_reshape"], air_val], axis=2 ) num_segs = len(np.unique(data['ids_trn'])) adj_mx = data['dist_matrix'] in_dim = len(data['x_vars']) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = RGCN_v1(in_dim,config['hidden_size'],adj_mx,device=device, seed=config['seed']) opt = optim.Adam(model.parameters(),lr=config['finetune_learning_rate']) model.load_state_dict(torch.load(input[1])) train_torch(model, loss_function=get_gw_loss(input[0]), optimizer=opt, x_train=data['x_trn'], y_train=y_trn_obs, x_val=data['x_val'], y_val=y_val_obs, max_epochs=config['ft_epochs'], early_stopping_patience=config['early_stopping'], batch_size = num_segs, weights_file=output[0], log_file=output[1], device=device) |
144 145 | run: calc_pred_ann_temp(input[0],input[1],input[2], input[3], output[0], output[1], output[2]) |
156 157 | run: calc_gw_metrics(input[0],input[1],input[2],output[0], output[1], output[2]) |
59 60 61 62 63 64 65 66 67 68 69 70 71 72 | run: prep_annual_signal_data(input[0], input[1], input[2], train_start_date=config['train_start_date'], train_end_date=config['train_end_date'], val_start_date=config['val_start_date'], val_end_date=config['val_end_date'], test_start_date=config['test_start_date'], test_end_date=config['test_end_date'], out_file=output[0], extraResSegments = config['extraResSegments'], reach_file= config['reach_attr_file'], gw_loss_type=config['gw_loss_type'], trn_offset = config['trn_offset'], tst_val_offset = config['tst_val_offset'], |
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | run: data = np.load(input[0]) temp_air_index = np.where(data['x_vars'] == 'seg_tave_air')[0] air_unscaled = data['x_trn'][:, :, temp_air_index] * data['x_std'][temp_air_index] + \ data['x_mean'][temp_air_index] y_trn_obs = np.concatenate( [data["y_obs_trn"], data["GW_trn_reshape"], air_unscaled], axis=2 ) air_val = data['x_val'][:, :, temp_air_index] * data['x_std'][temp_air_index] + data['x_mean'][ temp_air_index] y_val_obs = np.concatenate( [data["y_obs_val"], data["GW_val_reshape"], air_val], axis=2 ) optimizer = tf.optimizers.Adam(learning_rate=config['finetune_learning_rate']) num_segs = len(np.unique(data['ids_trn'])) model = RGCNModel( config['hidden_size'], recurrent_dropout=config['recurrent_dropout'], dropout=config['dropout'], num_tasks=len(config['y_vars_pretrain']), A= data["dist_matrix"] ) model.compile(optimizer=optimizer, loss=get_gw_loss(input[0])) model.load_weights(input[1] + "/") # Run the finetuning within the training engine on CPU for the GW loss function train_model(model, x_trn = data['x_trn'], y_trn = y_trn_obs, epochs = config['pt_epochs'], seed = config['seed'], batch_size = num_segs, x_val = data['x_val'], y_val = y_val_obs, # I need to add a trailing slash here. Otherwise the wgts # get saved in the "outdir" weight_dir = output[0] + "/", best_val_weight_dir = output[1] + "/", log_file = output[2], time_file = output[3], early_stop_patience=config['early_stopping'], use_cpu = True) |
153 154 | run: calc_pred_ann_temp(input[0],input[1],input[2], input[3], output[0], output[1], output[2]) |
165 166 | run: calc_gw_metrics(input[0],input[1],input[2],output[0], output[1], output[2]) |
35 36 | run: asRunConfig(config, code_dir, output[0]) |
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | run: prep_all_data( x_data_file=input[0], pretrain_file=input[0], y_data_file=input[1], distfile=input[2], x_vars=config['x_vars'], y_vars_pretrain=config['y_vars_pretrain'], y_vars_finetune=config['y_vars_finetune'], spatial_idx_name='segs_test', time_idx_name='times_test', catch_prop_file=None, train_start_date=config['train_start_date'], train_end_date=config['train_end_date'], val_start_date=config['val_start_date'], val_end_date=config['val_end_date'], test_start_date=config['test_start_date'], test_end_date=config['test_end_date'], segs=None, out_file=output[0], trn_offset = config['trn_offset'], tst_val_offset = config['tst_val_offset']) |
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 | run: data = np.load(input[0]) optimizer = tf.optimizers.Adam(learning_rate=config['pretrain_learning_rate']) model = LSTMModel( config['hidden_size'], recurrent_dropout=config['recurrent_dropout'], dropout=config['dropout'], num_tasks=len(config['y_vars_pretrain']), ) model.compile(optimizer=optimizer, loss=loss_function) train_model(model, x_trn = data['x_pre_full'], y_trn = data['y_pre_full'], epochs = config['pt_epochs'], batch_size = 2, seed=config['seed'], # I need to add a trailing slash here. Otherwise the wgts # get saved in the "outdir" weight_dir = output[0] + "/", log_file = output[1], time_file = output[2]) |
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | run: data = np.load(input[0]) optimizer = tf.optimizers.Adam(learning_rate=config['finetune_learning_rate']) model = LSTMModel( config['hidden_size'], recurrent_dropout=config['recurrent_dropout'], dropout=config['dropout'], num_tasks=len(config['y_vars_pretrain']), ) model.compile(optimizer=optimizer, loss=loss_function) model.load_weights(input[1] + "/") train_model(model, x_trn = data['x_trn'], y_trn = data['y_obs_trn'], epochs = config['pt_epochs'], batch_size = 2, seed=config['seed'], x_val = data['x_val'], y_val = data['y_obs_val'], # I need to add a trailing slash here. Otherwise the wgts # get saved in the "outdir" weight_dir = output[0] + "/", best_val_weight_dir = output[1] + "/", log_file = output[2], time_file = output[3], early_stop_patience=config['early_stopping']) |
154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | run: model = LSTMModel( config['hidden_size'], recurrent_dropout=config['recurrent_dropout'], dropout=config['dropout'], num_tasks=len(config['y_vars_pretrain']), ) weight_dir = input[0] + '/' model.load_weights(weight_dir) predict_from_io_data(model=model, io_data=input[1], partition=wildcards.partition, outfile=output[0], trn_offset = config['trn_offset'], spatial_idx_name='segs_test', time_idx_name='times_test', tst_val_offset = config['tst_val_offset']) |
194 195 196 197 198 199 200 201 | run: combined_metrics(obs_file=input[0], pred_trn=input[1], pred_val=input[2], spatial_idx_name='segs_test', time_idx_name='times_test', group=params.grp_arg, outfile=output[0]) |
209 210 211 | run: plot_obs(input[0], wildcards.variable, output[0], partition=wildcards.partition) |
38 39 | run: asRunConfig(config,code_dir,output[0]) |
45 46 47 48 | shell: """ scp Snakefile {output[0]} """ |
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | run: prep_all_data( x_data_file=input[0], pretrain_file=input[0], y_data_file=input[1], distfile=input[2], x_vars=config['x_vars'], y_vars_pretrain=config['y_vars_pretrain'], y_vars_finetune=config['y_vars_finetune'], catch_prop_file=None, train_start_date=config['train_start_date'], train_end_date=config['train_end_date'], val_start_date=config['val_start_date'], val_end_date=config['val_end_date'], test_start_date=config['test_start_date'], test_end_date=config['test_end_date'], segs=None, out_file=output[0], trn_offset= float(wildcards.offset), tst_val_offset= float(wildcards.offset), seq_len=int(wildcards.seq_length), ) |
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | run: os.system("module load analytics cuda11.3/toolkit/11.3.0") os.system("export LD_LIBRARY_PATH=/cm/shared/apps/nvidia/TensorRT-6.0.1.5/lib:/cm/shared/apps/nvidia/cudnn_8.0.5/lib64:$LD_LIBRARY_PATH") data = np.load(input[0]) num_segs = len(np.unique(data['ids_trn'])) adj_mx = data['dist_matrix'] in_dim = len(data['x_vars']) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = RGCN_v1(in_dim,config['hidden_size'],adj_mx,device=device) opt = optim.Adam(model.parameters(),lr=config['pretrain_learning_rate']) train_torch(model, loss_function=rmse_masked, optimizer=opt, x_train=data['x_pre_full'], y_train=data['y_pre_full'], max_epochs=config['pt_epochs'], batch_size=num_segs, weights_file=output[0], log_file=output[1], device=device, keep_portion=float(wildcards.offset)) |
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | run: os.system("module load analytics cuda11.3/toolkit/11.3.0") os.system("export LD_LIBRARY_PATH=/cm/shared/apps/nvidia/TensorRT-6.0.1.5/lib:/cm/shared/apps/nvidia/cudnn_8.0.5/lib64:$LD_LIBRARY_PATH") data = np.load(input[0]) num_segs = len(np.unique(data['ids_trn'])) adj_mx = data['dist_matrix'] in_dim = len(data['x_vars']) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = RGCN_v1(in_dim,config['hidden_size'],adj_mx,device=device) opt = optim.Adam(model.parameters(),lr=config['finetune_learning_rate']) scheduler = optim.lr_scheduler.LambdaLR(opt,lr_lambda=lambda epoch: 0.97 ** epoch) model.load_state_dict(torch.load(input[1])) train_torch(model, loss_function=rmse_masked, optimizer=opt, x_train=data['x_trn'], y_train=data['y_obs_trn'], x_val=data['x_val'], y_val=data['y_obs_val'], max_epochs=config['ft_epochs'], early_stopping_patience=config['early_stopping'], batch_size=num_segs, weights_file=output[0], log_file=output[1], device=device, keep_portion=float(wildcards.offset)) |
166 167 168 169 170 171 172 173 174 175 176 177 178 | run: data = np.load(input[1]) adj_mx = data['dist_matrix'] in_dim = len(data['x_vars']) model = RGCN_v1(in_dim,config['hidden_size'],adj_mx) opt = optim.Adam(model.parameters(),lr=config['finetune_learning_rate']) model.load_state_dict(torch.load(input[0])) predict_from_io_data(model=model, io_data=input[1], partition=wildcards.partition, outfile=output[0], trn_offset=float(wildcards.offset), tst_val_offset=float(wildcards.offset)) |
204 205 206 207 208 209 210 | run: combined_metrics(obs_file=input[0], pred_trn=input[1], pred_val=input[2], pred_tst=input[3], group=params.grp_arg, outfile=output[0]) |
35 36 | run: asRunConfig(config, code_dir, output[0]) |
42 43 44 45 | shell: """ scp Snakefile_rgcn_pytorch.smk {output[0]} """ |
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | run: prep_all_data( x_data_file=input[0], pretrain_file=input[0], y_data_file=input[1], distfile=input[2], x_vars=config['x_vars'], y_vars_pretrain=config['y_vars_pretrain'], y_vars_finetune=config['y_vars_finetune'], catch_prop_file=None, train_start_date=config['train_start_date'], train_end_date=config['train_end_date'], val_start_date=config['val_start_date'], val_end_date=config['val_end_date'], test_start_date=config['test_start_date'], test_end_date=config['test_end_date'], segs=None, out_file=output[0], trn_offset = config['trn_offset'], tst_val_offset = config['tst_val_offset']) |
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | run: data = np.load(input[0]) num_segs = len(np.unique(data['ids_trn'])) adj_mx = data['dist_matrix'] in_dim = len(data['x_vars']) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = RGCN_v1(in_dim, config['hidden_size'], adj_mx,device=device, seed=config['seed']) opt = optim.Adam(model.parameters(),lr=config['pretrain_learning_rate']) train_torch(model, loss_function = rmse_masked, optimizer= opt, x_train= data['x_trn'], y_train = data['y_pre_trn'], max_epochs = config['pt_epochs'], early_stopping_patience=config['early_stopping'], batch_size = num_segs, weights_file = output[0], log_file = output[1], device=device) |
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | run: data = np.load(input[0]) num_segs = len(np.unique(data['ids_trn'])) adj_mx = data['dist_matrix'] in_dim = len(data['x_vars']) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = RGCN_v1(in_dim,config['hidden_size'],adj_mx,device=device, seed=config['seed']) opt = optim.Adam(model.parameters(),lr=config['finetune_learning_rate']) scheduler = optim.lr_scheduler.LambdaLR(opt,lr_lambda=lambda epoch: 0.97 ** epoch) model.load_state_dict(torch.load(input[1])) train_torch(model, loss_function=rmse_masked, optimizer=opt, x_train=data['x_trn'], y_train=data['y_obs_trn'], x_val=data['x_val'], y_val=data['y_obs_val'], max_epochs=config['ft_epochs'], early_stopping_patience=config['early_stopping'], batch_size = num_segs, weights_file=output[0], log_file=output[1], device=device) |
150 151 152 153 154 155 156 157 158 159 160 161 162 | run: data = np.load(input[1]) adj_mx = data['dist_matrix'] in_dim = len(data['x_vars']) model = RGCN_v1(in_dim,config['hidden_size'],adj_mx) opt = optim.Adam(model.parameters(),lr=config['finetune_learning_rate']) model.load_state_dict(torch.load(input[0])) predict_from_io_data(model=model, io_data=input[1], partition=wildcards.partition, outfile=output[0], trn_offset = config['trn_offset'], tst_val_offset = config['tst_val_offset']) |
186 187 188 189 190 191 | run: combined_metrics(obs_file=input[0], pred_trn=input[1], pred_val=input[2], group=params.grp_arg, outfile=output[0]) |
35 36 | run: asRunConfig(config, code_dir, output[0]) |
43 44 45 46 | shell: """ scp Snakefile_rgcn.smk {output[0]} """ |
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | run: prep_all_data( x_data_file=input[0], pretrain_file=input[0], y_data_file=input[1], distfile=input[2], x_vars=config['x_vars'], y_vars_pretrain=config['y_vars_pretrain'], y_vars_finetune=config['y_vars_finetune'], catch_prop_file=None, train_start_date=config['train_start_date'], train_end_date=config['train_end_date'], val_start_date=config['val_start_date'], val_end_date=config['val_end_date'], test_start_date=config['test_start_date'], test_end_date=config['test_end_date'], segs=None, out_file=output[0], trn_offset = config['trn_offset'], tst_val_offset = config['tst_val_offset']) |
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | run: data = np.load(input[0]) optimizer = tf.optimizers.Adam(learning_rate=config['pretrain_learning_rate']) num_segs = len(np.unique(data['ids_trn'])) model = RGCNModel( config['hidden_size'], recurrent_dropout=config['recurrent_dropout'], dropout=config['dropout'], num_tasks=len(config['y_vars_pretrain']), A= data["dist_matrix"] ) model.compile(optimizer=optimizer, loss=loss_function) train_model(model, x_trn = data['x_pre_full'], y_trn = data['y_pre_full'], epochs = config['pt_epochs'], batch_size = num_segs, seed=config['seed'], # I need to add a trailing slash here. Otherwise the wgts # get saved in the "outdir" weight_dir = output[0] + "/", log_file = output[1], time_file = output[2]) |
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | run: data = np.load(input[0]) optimizer = tf.optimizers.Adam(learning_rate=config['finetune_learning_rate']) num_segs = len(np.unique(data['ids_trn'])) model = RGCNModel( config['hidden_size'], recurrent_dropout=config['recurrent_dropout'], dropout=config['dropout'], num_tasks=len(config['y_vars_pretrain']), A= data["dist_matrix"] ) model.compile(optimizer=optimizer, loss=loss_function) model.load_weights(input[1] + "/") train_model(model, x_trn = data['x_trn'], y_trn = data['y_obs_trn'], epochs = config['pt_epochs'], batch_size = num_segs, seed=config['seed'], x_val = data['x_val'], y_val = data['y_obs_val'], # I need to add a trailing slash here. Otherwise the wgts # get saved in the "outdir" weight_dir = output[0] + "/", best_val_weight_dir = output[1] + "/", log_file = output[2], time_file = output[3], early_stop_patience=config['early_stopping']) |
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | run: data = np.load(input[1]) model = RGCNModel( config['hidden_size'], recurrent_dropout=config['recurrent_dropout'], dropout=config['dropout'], num_tasks=len(config['y_vars_pretrain']), A= data["dist_matrix"] ) weight_dir = input[0] + '/' model.load_weights(weight_dir) predict_from_io_data(model=model, io_data=input[1], partition=wildcards.partition, outfile=output[0], trn_offset = config['trn_offset'], tst_val_offset = config['tst_val_offset']) |
205 206 207 208 209 210 | run: combined_metrics(obs_file=input[0], pred_trn=input[1], pred_val=input[2], group=params.grp_arg, outfile=output[0]) |
218 219 220 | run: plot_obs(input[0], wildcards.variable, output[0], partition=wildcards.partition) |
Support
- Future updates