Deep learning model for predicting environmental variables on river systems

public public 1yr ago Version: v0.1 0 bookmarks

Deep Graph Convolutional Neural Network for Predicting Environmental Variables on River Networks

This repository contains code for predicting environmental variables on river networks. The models included are all either temporally or spatiotemporally aware and incorporate information from the river network. The original intent of this repository was to predict stream temperature and streamflow.

This work is being developed by researchers in the Data Science branch of the US. Geological Survey and researchers at the University of Minnesota in Vipin Kumar's lab. Sources for specific models are included as comments within the code.

Running the code

There are functions for facilitating pre-processing and post-processing of the data in addition to running the models themselves. Included within the workflow_examples folder of the repository are a number of example Snakemake workflow that show how to run the entire process with a variety of models and end-goals.

To run the Snakemake workflow locally:

  1. Install the dependencies in the environment.yaml file. With conda you can do this with conda env create -f environment.yaml

  2. Activate your conda environment source activate rdl_torch_tf

  3. Install the local river-dl package by pip install path/to/river-dl/ ( optional )

  4. Edit the river-dl run configuration (including paths for I/O data) in the appropriate config.yml from the workflow_examples folder.

  5. Run Snakemake with snakemake --configfile config.yml -s Snakemake --cores <n>

To run the Snakemake Workflow on TallGrass

  1. Request a GPU allocation and start an interactive shell

     salloc -N 1 -t 2:00:00 -p gpu -A <account> --gres=gpu:1 srun -A <account> --pty bash
    
  2. Load the necessary cuda toolkit module and add paths to the cudnn drivers

     module load cuda11.3/toolkit/11.3.0 
     export LD_LIBRARY_PATH=/cm/shared/apps/nvidia/TensorRT-6.0.1.5/lib:/cm/shared/apps/nvidia/cudnn_8.0.5/lib64:$LD_LIBRARY_PATH
    
  3. Follow steps 1-5 above as you would to run the workflow locally (note, you may need to change tensorflow to tensoflow-gpu in the environment.yml ).

After building your environment, you may want to make sure the recommended versions of PyTorch and CUDA were installed according to the PyTorch documentation . You can see the installed versions by calling conda list within your activated environment.

The data

The data used to run this model currently are specific to the Delaware River Basin but will soon be made more generic.

Disclaimer

This software is in the public domain because it contains materials that originally came from the U.S. Geological Survey, an agency of the United States Department of Interior. For more information, see the official USGS copyright policy

Although this software program has been used by the U.S. Geological Survey (USGS), no warranty, expressed or implied, is made by the USGS or the U.S. Government as to the accuracy and functioning of the program and related program material nor shall the fact of distribution constitute any such warranty, and no responsibility is assumed by the USGS in connection therewith.

This software is provided “AS IS.”

Code Snippets

34
35
run:
    asRunConfig(config, code_dir, output[0])
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
run:
    prep_all_data(
              x_data_file=input[0],
              y_data_file=input[1],
              x_vars=config['x_vars'],
              y_vars_finetune=config['y_vars'],
              spatial_idx_name='segs_test',
              time_idx_name='times_test',
              catch_prop_file=None,
              train_start_date=config['train_start_date'],
              train_end_date=config['train_end_date'],
              val_start_date=config['val_start_date'],
              val_end_date=config['val_end_date'],
              test_start_date=config['test_start_date'],
              test_end_date=config['test_end_date'],
              segs=None,
              out_file=output[0],
              trn_offset = config['trn_offset'],
              tst_val_offset = config['tst_val_offset'])
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
run:
    optimizer = tf.optimizers.Adam(learning_rate=config['finetune_learning_rate']) 
    model.compile(optimizer=optimizer, loss=loss_function)
    data = np.load(input[0])
    train_model(model,
                x_trn = data['x_trn'],
                y_trn = data['y_obs_trn'],
                epochs = config['epochs'],
                batch_size = 2,
                seed = config['seed'],
                x_val = data['x_val'],
                y_val = data['y_obs_val'],
                # I need to add a trailing slash here. Otherwise the wgts
                # get saved in the "outdir"
                weight_dir = output[0] + "/",
                log_file = output[1],
                time_file = output[2],
                early_stop_patience=config['early_stopping'])
107
108
109
110
111
112
113
114
115
116
117
run:
    weight_dir = input[0] + '/'
    model.load_weights(weight_dir)
    predict_from_io_data(model=model, 
                         io_data=input[1],
                         partition=wildcards.partition,
                         outfile=output[0],
                         trn_offset = config['trn_offset'],
                         spatial_idx_name='segs_test',
                         time_idx_name='times_test',
                         tst_val_offset = config['tst_val_offset'])
141
142
143
144
145
146
147
148
run:
    combined_metrics(obs_file=input[0],
                     pred_trn=input[1],
                     pred_val=input[2],
                     spatial_idx_name='segs_test',
                     time_idx_name='times_test',
                     group=params.grp_arg,
                     outfile=output[0])
156
157
158
run:
    plot_obs(input[0], wildcards.variable, output[0],
             partition=wildcards.partition)
34
35
run:
    asRunConfig(config,code_dir,output[0])
41
42
43
44
shell:
    """
    scp Snakefile {output[0]}
    """
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
run:
    prep_all_data(
        x_data_file=input[0],
        pretrain_file=input[0],
        y_data_file=input[1],
        distfile=input[2],
        x_vars=config['x_vars'],
        y_vars_pretrain=config['y_vars_pretrain'],
        y_vars_finetune=config['y_vars_finetune'],
        catch_prop_file=None,
        train_start_date=config['train_start_date'],
        train_end_date=config['train_end_date'],
        val_start_date=config['val_start_date'],
        val_end_date=config['val_end_date'],
        test_start_date=config['test_start_date'],
        test_end_date=config['test_end_date'],
        segs=None,
        out_file=output[0],
        trn_offset=config['trn_offset'],
        tst_val_offset=config['tst_val_offset'])
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
run:
    data = np.load(input[0])
    data = reshape_for_gwn(data,keep_portion=config['trn_offset'])
    adj_mx = data['dist_matrix']
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    supports = [torch.tensor(adj_mx).to(device).float()]
    in_dim = len(data['x_vars'])
    out_dim = data['y_obs_trn'].shape[3]
    num_nodes = adj_mx.shape[0]
    lrate = 0.001
    wdecay = 0.0001
    model = gwnet(device,num_nodes,supports=supports,aptinit=supports[
        0],in_dim=in_dim,out_dim=out_dim,layers=5,kernel_size=7,blocks=2)
    opt = optim.Adam(model.parameters(),lr=lrate,weight_decay=wdecay)

    train_torch(model,
                loss_function = rmse_masked,
                optimizer= opt,
                x_train= data['x_trn'],
                y_train = data['y_pre_trn'],
                max_epochs = config['pt_epochs'],
                early_stopping_patience=config['early_stopping'],
                batch_size = config['batch_size'],
                weights_file = output[0],
                log_file = output[1],
                device=device)
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
run:
    data = np.load(input[0])
    data = reshape_for_gwn(data,keep_portion=config['trn_offset'])
    adj_mx = data['dist_matrix']
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    supports = [torch.tensor(adj_mx).to(device).float()]
    in_dim = len(data['x_vars'])
    out_dim = data['y_obs_trn'].shape[3]
    num_nodes = adj_mx.shape[0]
    lrate = 0.001
    wdecay = 0.0001
    model = gwnet(device,num_nodes,supports=supports,aptinit=supports[
        0],in_dim=in_dim,out_dim=out_dim,layers=5,kernel_size=7,blocks=2)
    opt = optim.Adam(model.parameters(),lr=lrate,weight_decay=wdecay)

    model.load_state_dict(torch.load(input[1]))

    train_torch(model,
        loss_function=rmse_masked,
        optimizer=opt,
        x_train=data['x_trn'],
        y_train=data['y_obs_trn'],
        x_val=data['x_val'],
        y_val=data['y_obs_val'],
        max_epochs=config['ft_epochs'],
        early_stopping_patience=config['early_stopping'],
        batch_size = config['batch_size'],
        weights_file=output[0],
        log_file=output[1],
        device=device)
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
run:
    data = np.load(input[1])
    data = reshape_for_gwn(data,keep_portion=config['trn_offset'])
    adj_mx = data['dist_matrix']
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    supports = [torch.tensor(adj_mx).to(device).float()]
    in_dim = len(data['x_vars'])
    out_dim = data['y_obs_trn'].shape[3]
    num_nodes = adj_mx.shape[0]
    lrate = 0.001
    wdecay = 0.0001
    model = gwnet(device,num_nodes,supports=supports,aptinit=supports[
        0],in_dim=in_dim,out_dim=out_dim,layers=5,kernel_size=7,blocks=2)
    opt = optim.Adam(model.parameters(),lr=lrate,weight_decay=wdecay)

    model.load_state_dict(torch.load(input[0]))

    predict_from_io_data(model,
        data,
        wildcards.partition,
        outfile=output[0],
        trn_offset=config['trn_offset'],
        tst_val_offset=config['tst_val_offset'],
        torch_model=True,
        )
208
209
210
211
212
213
214
run:
    combined_metrics(obs_file=input[0],
        pred_trn=input[1],
        pred_val=input[2],
        pred_tst=input[3],
        group=params.grp_arg,
        outfile=output[0])
222
223
224
run:
    plot_obs(input[0],wildcards.variable,output[0],
        partition=wildcards.partition)
59
60
61
62
63
64
65
66
67
68
69
70
71
72
run:
    prep_annual_signal_data(input[0], input[1], input[2],
              train_start_date=config['train_start_date'],
              train_end_date=config['train_end_date'],
              val_start_date=config['val_start_date'],
              val_end_date=config['val_end_date'],
              test_start_date=config['test_start_date'],
              test_end_date=config['test_end_date'], 
              out_file=output[0],
              extraResSegments = config['extraResSegments'],
              reach_file= config['reach_attr_file'],
              gw_loss_type=config['gw_loss_type'],
              trn_offset = config['trn_offset'],
              tst_val_offset = config['tst_val_offset'],
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
run:
    data = np.load(input[0])
    temp_air_index = np.where(data['x_vars'] == 'seg_tave_air')[0]
    air_unscaled = data['x_trn'][:, :, temp_air_index] * data['x_std'][temp_air_index] + \
                   data['x_mean'][temp_air_index]
    y_trn_obs = np.concatenate(
        [data["y_obs_trn"], data["GW_trn_reshape"], air_unscaled], axis=2
    )
    air_val = data['x_val'][:, :, temp_air_index] * data['x_std'][temp_air_index] + data['x_mean'][
        temp_air_index]
    y_val_obs = np.concatenate(
        [data["y_obs_val"], data["GW_val_reshape"], air_val], axis=2
    )
    num_segs = len(np.unique(data['ids_trn']))
    adj_mx = data['dist_matrix']
    in_dim = len(data['x_vars'])
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = RGCN_v1(in_dim,config['hidden_size'],adj_mx,device=device, seed=config['seed'])
    opt = optim.Adam(model.parameters(),lr=config['finetune_learning_rate'])
    model.load_state_dict(torch.load(input[1]))
    train_torch(model,
        loss_function=get_gw_loss(input[0]),
        optimizer=opt,
        x_train=data['x_trn'],
        y_train=y_trn_obs,
        x_val=data['x_val'],
        y_val=y_val_obs,
        max_epochs=config['ft_epochs'],
        early_stopping_patience=config['early_stopping'],
        batch_size = num_segs,
        weights_file=output[0],
        log_file=output[1],
        device=device)
144
145
run: 
    calc_pred_ann_temp(input[0],input[1],input[2], input[3], output[0], output[1], output[2])
156
157
run:
    calc_gw_metrics(input[0],input[1],input[2],output[0], output[1], output[2])
59
60
61
62
63
64
65
66
67
68
69
70
71
72
run:
    prep_annual_signal_data(input[0], input[1], input[2],
              train_start_date=config['train_start_date'],
              train_end_date=config['train_end_date'],
              val_start_date=config['val_start_date'],
              val_end_date=config['val_end_date'],
              test_start_date=config['test_start_date'],
              test_end_date=config['test_end_date'], 
              out_file=output[0],
              extraResSegments = config['extraResSegments'],
              reach_file= config['reach_attr_file'],
              gw_loss_type=config['gw_loss_type'],
              trn_offset = config['trn_offset'],
              tst_val_offset = config['tst_val_offset'],
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
run:
    data = np.load(input[0])
    temp_air_index = np.where(data['x_vars'] == 'seg_tave_air')[0]
    air_unscaled = data['x_trn'][:, :, temp_air_index] * data['x_std'][temp_air_index] + \
                   data['x_mean'][temp_air_index]
    y_trn_obs = np.concatenate(
        [data["y_obs_trn"], data["GW_trn_reshape"], air_unscaled], axis=2
    )
    air_val = data['x_val'][:, :, temp_air_index] * data['x_std'][temp_air_index] + data['x_mean'][
        temp_air_index]
    y_val_obs = np.concatenate(
        [data["y_obs_val"], data["GW_val_reshape"], air_val], axis=2
    )
    optimizer = tf.optimizers.Adam(learning_rate=config['finetune_learning_rate'])
    num_segs = len(np.unique(data['ids_trn']))
    model = RGCNModel(
        config['hidden_size'],
        recurrent_dropout=config['recurrent_dropout'],
        dropout=config['dropout'],
        num_tasks=len(config['y_vars_pretrain']),
        A= data["dist_matrix"]
    )

    model.compile(optimizer=optimizer, loss=get_gw_loss(input[0]))
    model.load_weights(input[1] + "/")

        # Run the finetuning within the training engine on CPU for the GW loss function
    train_model(model,
                x_trn = data['x_trn'],
                y_trn = y_trn_obs,
                epochs = config['pt_epochs'],
                seed = config['seed'],
                batch_size = num_segs,
                x_val = data['x_val'],
                y_val = y_val_obs,
                # I need to add a trailing slash here. Otherwise the wgts
                # get saved in the "outdir"
                weight_dir = output[0] + "/",
                best_val_weight_dir = output[1] + "/",
                log_file = output[2],
                time_file = output[3],
                early_stop_patience=config['early_stopping'],
                use_cpu = True)
153
154
run: 
    calc_pred_ann_temp(input[0],input[1],input[2], input[3], output[0], output[1], output[2])
165
166
run:
    calc_gw_metrics(input[0],input[1],input[2],output[0], output[1], output[2])
35
36
run:
    asRunConfig(config, code_dir, output[0])
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
run:
    prep_all_data(
              x_data_file=input[0],
              pretrain_file=input[0],
              y_data_file=input[1],
              distfile=input[2],
              x_vars=config['x_vars'],
              y_vars_pretrain=config['y_vars_pretrain'],
              y_vars_finetune=config['y_vars_finetune'],
              spatial_idx_name='segs_test',
              time_idx_name='times_test',
              catch_prop_file=None,
              train_start_date=config['train_start_date'],
              train_end_date=config['train_end_date'],
              val_start_date=config['val_start_date'],
              val_end_date=config['val_end_date'],
              test_start_date=config['test_start_date'],
              test_end_date=config['test_end_date'],
              segs=None,
              out_file=output[0],
              trn_offset = config['trn_offset'],
              tst_val_offset = config['tst_val_offset'])
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
run:
    data = np.load(input[0])

    optimizer = tf.optimizers.Adam(learning_rate=config['pretrain_learning_rate']) 

    model = LSTMModel(
        config['hidden_size'],
        recurrent_dropout=config['recurrent_dropout'],
        dropout=config['dropout'],
        num_tasks=len(config['y_vars_pretrain']),
    )

    model.compile(optimizer=optimizer, loss=loss_function)
    train_model(model,
                x_trn = data['x_pre_full'],
                y_trn = data['y_pre_full'],
                epochs = config['pt_epochs'],
                batch_size = 2,
                seed=config['seed'],
                # I need to add a trailing slash here. Otherwise the wgts
                # get saved in the "outdir"
                weight_dir = output[0] + "/",
                log_file = output[1],
                time_file = output[2])
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
run:
    data = np.load(input[0])
    optimizer = tf.optimizers.Adam(learning_rate=config['finetune_learning_rate']) 

    model = LSTMModel(
        config['hidden_size'],
        recurrent_dropout=config['recurrent_dropout'],
        dropout=config['dropout'],
        num_tasks=len(config['y_vars_pretrain']),
    )

    model.compile(optimizer=optimizer, loss=loss_function)
    model.load_weights(input[1] + "/")
    train_model(model,
                x_trn = data['x_trn'],
                y_trn = data['y_obs_trn'],
                epochs = config['pt_epochs'],
                batch_size = 2,
                seed=config['seed'],
                x_val = data['x_val'],
                y_val = data['y_obs_val'],
                # I need to add a trailing slash here. Otherwise the wgts
                # get saved in the "outdir"
                weight_dir = output[0] + "/",
                best_val_weight_dir = output[1] + "/",
                log_file = output[2],
                time_file = output[3],
                early_stop_patience=config['early_stopping'])
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
run:
    model = LSTMModel(
        config['hidden_size'],
        recurrent_dropout=config['recurrent_dropout'],
        dropout=config['dropout'],
        num_tasks=len(config['y_vars_pretrain']),
    )
    weight_dir = input[0] + '/'
    model.load_weights(weight_dir)
    predict_from_io_data(model=model, 
                         io_data=input[1],
                         partition=wildcards.partition,
                         outfile=output[0],
                         trn_offset = config['trn_offset'],
                         spatial_idx_name='segs_test',
                         time_idx_name='times_test',
                         tst_val_offset = config['tst_val_offset'])
194
195
196
197
198
199
200
201
run:
    combined_metrics(obs_file=input[0],
                     pred_trn=input[1],
                     pred_val=input[2],
                     spatial_idx_name='segs_test',
                     time_idx_name='times_test',
                     group=params.grp_arg,
                     outfile=output[0])
209
210
211
run:
    plot_obs(input[0], wildcards.variable, output[0],
             partition=wildcards.partition)
38
39
run:
    asRunConfig(config,code_dir,output[0])
45
46
47
48
shell:
    """
    scp Snakefile {output[0]}
    """
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
run:
    prep_all_data(
        x_data_file=input[0],
        pretrain_file=input[0],
        y_data_file=input[1],
        distfile=input[2],
        x_vars=config['x_vars'],
        y_vars_pretrain=config['y_vars_pretrain'],
        y_vars_finetune=config['y_vars_finetune'],
        catch_prop_file=None,
        train_start_date=config['train_start_date'],
        train_end_date=config['train_end_date'],
        val_start_date=config['val_start_date'],
        val_end_date=config['val_end_date'],
        test_start_date=config['test_start_date'],
        test_end_date=config['test_end_date'],
        segs=None,
        out_file=output[0],
        trn_offset= float(wildcards.offset),
        tst_val_offset= float(wildcards.offset),
        seq_len=int(wildcards.seq_length),
    )
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
run:
    os.system("module load analytics cuda11.3/toolkit/11.3.0")
    os.system("export LD_LIBRARY_PATH=/cm/shared/apps/nvidia/TensorRT-6.0.1.5/lib:/cm/shared/apps/nvidia/cudnn_8.0.5/lib64:$LD_LIBRARY_PATH")

    data = np.load(input[0])
    num_segs = len(np.unique(data['ids_trn']))
    adj_mx = data['dist_matrix']
    in_dim = len(data['x_vars'])
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = RGCN_v1(in_dim,config['hidden_size'],adj_mx,device=device)
    opt = optim.Adam(model.parameters(),lr=config['pretrain_learning_rate'])

    train_torch(model,
        loss_function=rmse_masked,
        optimizer=opt,
        x_train=data['x_pre_full'],
        y_train=data['y_pre_full'],
        max_epochs=config['pt_epochs'],
        batch_size=num_segs,
        weights_file=output[0],
        log_file=output[1],
        device=device,
        keep_portion=float(wildcards.offset))
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
run:
    os.system("module load analytics cuda11.3/toolkit/11.3.0")
    os.system("export LD_LIBRARY_PATH=/cm/shared/apps/nvidia/TensorRT-6.0.1.5/lib:/cm/shared/apps/nvidia/cudnn_8.0.5/lib64:$LD_LIBRARY_PATH")

    data = np.load(input[0])
    num_segs = len(np.unique(data['ids_trn']))
    adj_mx = data['dist_matrix']
    in_dim = len(data['x_vars'])
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = RGCN_v1(in_dim,config['hidden_size'],adj_mx,device=device)
    opt = optim.Adam(model.parameters(),lr=config['finetune_learning_rate'])
    scheduler = optim.lr_scheduler.LambdaLR(opt,lr_lambda=lambda epoch: 0.97 ** epoch)
    model.load_state_dict(torch.load(input[1]))
    train_torch(model,
        loss_function=rmse_masked,
        optimizer=opt,
        x_train=data['x_trn'],
        y_train=data['y_obs_trn'],
        x_val=data['x_val'],
        y_val=data['y_obs_val'],
        max_epochs=config['ft_epochs'],
        early_stopping_patience=config['early_stopping'],
        batch_size=num_segs,
        weights_file=output[0],
        log_file=output[1],
        device=device,
        keep_portion=float(wildcards.offset))
166
167
168
169
170
171
172
173
174
175
176
177
178
run:
    data = np.load(input[1])
    adj_mx = data['dist_matrix']
    in_dim = len(data['x_vars'])
    model = RGCN_v1(in_dim,config['hidden_size'],adj_mx)
    opt = optim.Adam(model.parameters(),lr=config['finetune_learning_rate'])
    model.load_state_dict(torch.load(input[0]))
    predict_from_io_data(model=model,
        io_data=input[1],
        partition=wildcards.partition,
        outfile=output[0],
        trn_offset=float(wildcards.offset),
        tst_val_offset=float(wildcards.offset))
204
205
206
207
208
209
210
run:
    combined_metrics(obs_file=input[0],
        pred_trn=input[1],
        pred_val=input[2],
        pred_tst=input[3],
        group=params.grp_arg,
        outfile=output[0])
35
36
run:
    asRunConfig(config, code_dir, output[0])
42
43
44
45
shell:
    """
    scp Snakefile_rgcn_pytorch.smk {output[0]}
    """
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
run:
    prep_all_data(
              x_data_file=input[0],
              pretrain_file=input[0],
              y_data_file=input[1],
              distfile=input[2],
              x_vars=config['x_vars'],
              y_vars_pretrain=config['y_vars_pretrain'],
              y_vars_finetune=config['y_vars_finetune'],
              catch_prop_file=None,
              train_start_date=config['train_start_date'],
              train_end_date=config['train_end_date'],
              val_start_date=config['val_start_date'],
              val_end_date=config['val_end_date'],
              test_start_date=config['test_start_date'],
              test_end_date=config['test_end_date'],
              segs=None,
              out_file=output[0],
              trn_offset = config['trn_offset'],
              tst_val_offset = config['tst_val_offset'])
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
run:
    data = np.load(input[0])
    num_segs = len(np.unique(data['ids_trn']))
    adj_mx = data['dist_matrix']
    in_dim = len(data['x_vars'])
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = RGCN_v1(in_dim, config['hidden_size'], adj_mx,device=device, seed=config['seed'])
    opt = optim.Adam(model.parameters(),lr=config['pretrain_learning_rate'])
    train_torch(model,
                loss_function = rmse_masked,
                optimizer= opt,
                x_train= data['x_trn'],
                y_train = data['y_pre_trn'],
                max_epochs = config['pt_epochs'],
                early_stopping_patience=config['early_stopping'],
                batch_size = num_segs,
                weights_file = output[0],
                log_file = output[1],
                device=device)
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
run:
    data = np.load(input[0])
    num_segs = len(np.unique(data['ids_trn']))
    adj_mx = data['dist_matrix']
    in_dim = len(data['x_vars'])
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    model = RGCN_v1(in_dim,config['hidden_size'],adj_mx,device=device, seed=config['seed'])
    opt = optim.Adam(model.parameters(),lr=config['finetune_learning_rate'])
    scheduler = optim.lr_scheduler.LambdaLR(opt,lr_lambda=lambda epoch: 0.97 ** epoch)
    model.load_state_dict(torch.load(input[1]))
    train_torch(model,
        loss_function=rmse_masked,
        optimizer=opt,
        x_train=data['x_trn'],
        y_train=data['y_obs_trn'],
        x_val=data['x_val'],
        y_val=data['y_obs_val'],
        max_epochs=config['ft_epochs'],
        early_stopping_patience=config['early_stopping'],
        batch_size = num_segs,
        weights_file=output[0],
        log_file=output[1],
        device=device)
150
151
152
153
154
155
156
157
158
159
160
161
162
run:
    data = np.load(input[1])
    adj_mx = data['dist_matrix']
    in_dim = len(data['x_vars'])
    model = RGCN_v1(in_dim,config['hidden_size'],adj_mx)
    opt = optim.Adam(model.parameters(),lr=config['finetune_learning_rate'])
    model.load_state_dict(torch.load(input[0]))
    predict_from_io_data(model=model, 
                         io_data=input[1],
                         partition=wildcards.partition,
                         outfile=output[0],
                         trn_offset = config['trn_offset'],
                         tst_val_offset = config['tst_val_offset'])
186
187
188
189
190
191
run:
    combined_metrics(obs_file=input[0],
                     pred_trn=input[1],
                     pred_val=input[2],
                     group=params.grp_arg,
                     outfile=output[0])
35
36
run:
    asRunConfig(config, code_dir, output[0])
43
44
45
46
shell:
    """
    scp Snakefile_rgcn.smk {output[0]}
    """
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
run:
    prep_all_data(
              x_data_file=input[0],
              pretrain_file=input[0],
              y_data_file=input[1],
              distfile=input[2],
              x_vars=config['x_vars'],
              y_vars_pretrain=config['y_vars_pretrain'],
              y_vars_finetune=config['y_vars_finetune'],
              catch_prop_file=None,
              train_start_date=config['train_start_date'],
              train_end_date=config['train_end_date'],
              val_start_date=config['val_start_date'],
              val_end_date=config['val_end_date'],
              test_start_date=config['test_start_date'],
              test_end_date=config['test_end_date'],
              segs=None,
              out_file=output[0],
              trn_offset = config['trn_offset'],
              tst_val_offset = config['tst_val_offset'])
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
run:
    data = np.load(input[0])

    optimizer = tf.optimizers.Adam(learning_rate=config['pretrain_learning_rate'])
    num_segs = len(np.unique(data['ids_trn']))
    model = RGCNModel(
        config['hidden_size'],
        recurrent_dropout=config['recurrent_dropout'],
        dropout=config['dropout'],
        num_tasks=len(config['y_vars_pretrain']),
        A= data["dist_matrix"]
    )

    model.compile(optimizer=optimizer, loss=loss_function)
    train_model(model,
                x_trn = data['x_pre_full'],
                y_trn = data['y_pre_full'],
                epochs = config['pt_epochs'],
                batch_size = num_segs,
                seed=config['seed'],
                # I need to add a trailing slash here. Otherwise the wgts
                # get saved in the "outdir"
                weight_dir = output[0] + "/",
                log_file = output[1],
                time_file = output[2])
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
run:
    data = np.load(input[0])
    optimizer = tf.optimizers.Adam(learning_rate=config['finetune_learning_rate'])
    num_segs = len(np.unique(data['ids_trn']))
    model = RGCNModel(
        config['hidden_size'],
        recurrent_dropout=config['recurrent_dropout'],
        dropout=config['dropout'],
        num_tasks=len(config['y_vars_pretrain']),
        A= data["dist_matrix"]
    )

    model.compile(optimizer=optimizer, loss=loss_function)
    model.load_weights(input[1] + "/")
    train_model(model,
                x_trn = data['x_trn'],
                y_trn = data['y_obs_trn'],
                epochs = config['pt_epochs'],
                batch_size = num_segs,
                seed=config['seed'],
                x_val = data['x_val'],
                y_val = data['y_obs_val'],
                # I need to add a trailing slash here. Otherwise the wgts
                # get saved in the "outdir"
                weight_dir = output[0] + "/",
                best_val_weight_dir = output[1] + "/",
                log_file = output[2],
                time_file = output[3],
                early_stop_patience=config['early_stopping'])
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
run:
    data = np.load(input[1])
    model = RGCNModel(
        config['hidden_size'],
        recurrent_dropout=config['recurrent_dropout'],
        dropout=config['dropout'],
        num_tasks=len(config['y_vars_pretrain']),
        A= data["dist_matrix"]
    )
    weight_dir = input[0] + '/'
    model.load_weights(weight_dir)
    predict_from_io_data(model=model, 
                         io_data=input[1],
                         partition=wildcards.partition,
                         outfile=output[0],
                         trn_offset = config['trn_offset'],
                         tst_val_offset = config['tst_val_offset'])
205
206
207
208
209
210
run:
    combined_metrics(obs_file=input[0],
                     pred_trn=input[1],
                     pred_val=input[2],
                     group=params.grp_arg,
                     outfile=output[0])
218
219
220
run:
    plot_obs(input[0], wildcards.variable, output[0],
             partition=wildcards.partition)
ShowHide 50 more snippets with no or duplicated tags.

Login to post a comment if you would like to share your experience with this workflow.

Do you know this workflow well? If so, you can request seller status , and start supporting this workflow.

Free

Created: 1yr ago
Updated: 1yr ago
Maitainers: public
URL: https://github.com/jsadler2/river-dl
Name: river-dl
Version: v0.1
Badge:
workflow icon

Insert copied code into your website to add a link to this workflow.

Downloaded: 0
Copyright: Public Domain
License: Creative Commons Zero v1.0 Universal
  • Future updates

Related Workflows

cellranger-snakemake-gke
snakemake workflow to run cellranger on a given bucket using gke.
A Snakemake workflow for running cellranger on a given bucket using Google Kubernetes Engine. The usage of this workflow ...