-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_da.py
More file actions
268 lines (235 loc) · 13.5 KB
/
run_da.py
File metadata and controls
268 lines (235 loc) · 13.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import os, sys
from torch.utils.data import IterableDataset, DataLoader
import torch
from datetime import datetime, timedelta
import numpy as np
from src.dv import *
from src.obs_cummulative import ObsDatasetCum, ObsError
from src.var_4d_reformatted import FourDVar
from stormer.models.hub.vit_adaln import ViTAdaLN
from stormer.stormer_utils import StormerWrapper
from stormer.stormer_utils_pangu import StormerWrapperPangu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device_set = False
gpu2use = 0
if len(sys.argv) > 3:
gpu2use = sys.argv[3]
device = torch.device("cuda:{}".format(gpu2use) if torch.cuda.is_available() else "cpu")
device_set = True
print('USING DEVICE :',device)
torch.autograd.set_detect_anomaly(True)
import logging
if __name__ == '__main__':
#da_type = ['var4d','var3d']
da_type = str(sys.argv[1])
save_dir_name = str(sys.argv[2])
start_date = datetime(2014, 1, 1, hour=0)
end_date = datetime(2015, 12, 31, hour=12)
da_window = 12
model_step = 6
obs_freq = 3
exp_dir = '/eagle/MDClimSim/mjp5595/data/stormer/{}'.format(save_dir_name)
if not os.path.exists(exp_dir):
os.makedirs(exp_dir)
save_dir = '/eagle/MDClimSim/mjp5595/data/stormer/{}/data/'.format(save_dir_name)
save_dir = os.path.join(exp_dir,'data')
if not os.path.exists(save_dir):
os.makedirs(save_dir)
obs_filepath = "/eagle/MDClimSim/mjp5595/ml4dvar/obs/igra_141520_stormer_obs_standardized_360_3.hdf5"
means_file = '/eagle/MDClimSim/tungnd/data/wb2/1.40625deg_from_full_res_1_step_6hr_h5df/normalize_mean.npz'
stds_file = '/eagle/MDClimSim/tungnd/data/wb2/1.40625deg_from_full_res_1_step_6hr_h5df/normalize_std.npz'
means = np.load(means_file)
stds = np.load(stds_file)
dv_param_file = '/eagle/MDClimSim/awikner/dv_params_128_256.hdf5'
#background_err_file = '/eagle/MDClimSim/mjp5595/ml4dvar/climaX/background_24hr_diff_sh_coeffs_var_climaxv2_standardized_128_uv.npy' #B (spherical harmonics)
#background_err_hf_file = '/eagle/MDClimSim/mjp5595/ml4dvar/climaX/background_24hr_diff_hf_var_climaxv2_standardized_128_uv.npy' #B (grid space (HF))
#background_err_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/background_24hr_diff_sh_coeffs_var_stormer_standardized_128_uv.npy' #B (spherical harmonics)
#background_err_hf_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/background_24hr_diff_hf_var_stormer_standardized_128_uv.npy' #B (grid space (HF))
background_err_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/sh_24hr_stormer_vs_era5.npy'
background_err_hf_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/hf_24hr_stormer_vs_era5.npy' #B (grid space (HF))
if device_set:
if int(gpu2use) == 0:
#background_err_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/sh_12hr_stormer_vs_era5.npy'
#background_err_hf_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/hf_12hr_stormer_vs_era5.npy' #B (grid space (HF))
#background_err_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/sh_12hr_stormer_norm.npy'
#background_err_hf_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/hf_12hr_stormer_norm.npy' #B (grid space (HF))
background_err_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/sh_12hr_stormer_norm_NegB.npy'
background_err_hf_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/hf_12hr_stormer_norm_NegB.npy' #B (grid space (HF))
if int(gpu2use) == 1:
#background_err_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/sh_24hr_stormer_vs_era5.npy'
#background_err_hf_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/hf_24hr_stormer_vs_era5.npy' #B (grid space (HF))
#background_err_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/sh_24hr_stormer_norm.npy'
#background_err_hf_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/hf_24hr_stormer_norm.npy' #B (grid space (HF))
background_err_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/sh_24hr_stormer_norm_NegB.npy'
background_err_hf_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/hf_24hr_stormer_norm_NegB.npy' #B (grid space (HF))
if int(gpu2use) == 2:
#background_err_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/sh_72hr_stormer_vs_era5.npy'
#background_err_hf_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/hf_72hr_stormer_vs_era5.npy' #B (grid space (HF))
#background_err_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/sh_72hr_stormer_norm.npy'
#background_err_hf_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/hf_72hr_stormer_norm.npy' #B (grid space (HF))
background_err_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/sh_72hr_stormer_norm_NegB.npy'
background_err_hf_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/hf_72hr_stormer_norm_NegB.npy' #B (grid space (HF))
if int(gpu2use) == 3:
#background_err_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/sh_144hr_stormer_vs_era5.npy'
#background_err_hf_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/hf_144hr_stormer_vs_era5.npy' #B (grid space (HF))
#background_err_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/sh_144hr_stormer_norm.npy'
#background_err_hf_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/hf_144hr_stormer_norm.npy' #B (grid space (HF))
background_err_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/sh_144hr_stormer_norm_NegB.npy'
background_err_hf_file = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/hf_144hr_stormer_norm_NegB.npy' #B (grid space (HF))
ckpt_pth = '/eagle/MDClimSim/tungnd/stormer/models/6_12_24_climax_large_2_True_delta_8/checkpoints/epoch_015.ckpt'
b_inflation = 1
if da_type == 'var4d':
b_inflation = 1
####################################################################################################################################
# Get start_idx for observations/analysis/background to start from
####################################################################################################################################
background_file_np = '/eagle/MDClimSim/mjp5595/ml4dvar/stormer/data/background_init_stormer_norm_hr12.npy' # Init with 'random' era5 weather state from 1990
backgrounds = os.listdir(save_dir)
start_idx = 0
if len(backgrounds) > 1:
for background_file in backgrounds:
if 'background' in background_file:
background_num = int(background_file.split('_')[1])
if background_num > start_idx:
start_idx = background_num
background_file_np = os.path.join(save_dir,background_file)
print('Starting with background file : {}'.format(background_file_np))
####################################################################################################################################
log_dir = os.path.join(exp_dir,'logs')
if not os.path.exists(log_dir):
os.mkdir(log_dir)
num_logs = len(os.listdir(log_dir))
logging.basicConfig(filename=os.path.join(log_dir,'{}_{}.log'.format(save_dir_name,num_logs)),
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
filemode='w')
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
logger.info('')
logger.info('')
logger.info('da_type : {}'.format(da_type))
logger.info('save_dir_name : {}'.format(save_dir_name))
logger.info('')
logger.info('b_inflation : {}'.format(b_inflation))
logger.info('Using checkpoint pth : {}'.format(ckpt_pth))
logger.info('obs_filepath : {}'.format(obs_filepath))
logger.info('means_file : {}'.format(means_file))
logger.info('stds_file : {}'.format(stds_file))
logger.info('dv_param file : {}'.format(dv_param_file))
logger.info('')
logger.info('Using background_err_file : {}'.format(background_err_file))
logger.info('Using background_err_hf_file : {}'.format(background_err_hf_file))
logger.info('Starting with background file : {}'.format(background_file_np))
logger.info('')
from stormer.varsStormer import varsStormer
vars_stormer = varsStormer().vars_stormer
var_types = ['geopotential', 'temperature', 'specific_humidity', 'u_component_of_wind', 'v_component_of_wind', 'pressure']
var_obs_err = [100., 1.0, 1e-4, 1.0, 1.0, 100.]
obs_perc_err = [False, False, False, False, False, False]
obs_err = ObsError(vars_stormer, var_types, var_obs_err, obs_perc_err, stds, device)
print('obs_err :',obs_err.obs_err)
if logger:
logger.info('obs_err : {}'.format(obs_err.obs_err))
# from src/dv.py
dv_layer = DivergenceVorticity(vars_stormer, means, stds, dv_param_file, device)
be = np.load(background_err_file)
bef = np.load(background_err_hf_file)
# This breaks it... maybe we have to calculate pred12 - pred36 instead of multiplying by -1 in calc_b.py
#logger.info('TESTING Negative B!!!')
#logger.info('Multiplying B by -1 !!!')
#be = -1 * be
#logger.info('')
## This is for converting climaX B -> stormer B
#be_idxs = np.arange(69)
#be_idxs = np.where(be_idxs>=43,be_idxs+13,be_idxs)
#be = be[be_idxs]
#bef = bef[be_idxs]
background_err = torch.from_numpy(be).float().to(device)
background_err = background_err[torch.concat((dv_layer.nowind_idxs, dv_layer.uwind_idxs, dv_layer.vwind_idxs))]
background_err_hf = torch.from_numpy(bef).float().to(device)
background_err_hf = background_err_hf[
torch.concat((dv_layer.nowind_idxs, dv_layer.uwind_idxs, dv_layer.vwind_idxs))]
## Set B to identity matrix
#print('background_err.shape (0):',background_err.shape)
#a,b = background_err.shape
#background_err = torch.eye(a,b)
#print('background_err.shape (1):',background_err.shape)
#background_err = background_err + 1e-6
# 3d var
# (2014,1,1,1,0)->(2014,1,1,12,0)
# 4d var
# make analysis @ because we need the prev 12 hrs
# (2014,1,1,12,0)
# se_obs0 - (2014,1,1,1,0)->(2014,1,1,6,0)
# se_obs1 - (2014,1,1,7,0)->(2014,1,1,12,0)
# in 3d var se_obs makes the assumption that all the obs happen at the analysis time
# 4d var we optimizing trajectory instead of point in time
use_only_recent_obs = False
if da_type == 'var3d':
use_only_recent_obs = True
obs_steps = 1
if da_type == 'var4d':
obs_steps = da_window // model_step
obs_dataset = ObsDatasetCum(obs_filepath, start_date, end_date, vars_stormer,
obs_freq=obs_freq, da_window=da_window,
obs_start_idx=start_idx, obs_steps=obs_steps,
only_recent_obs=use_only_recent_obs, logger=logger,
device=device)
obs_loader = DataLoader(obs_dataset, batch_size=1, num_workers=0)
###################################################################################################################
###################################################################################################################
net = ViTAdaLN(
in_img_size=(128, 256),
list_variables=vars_stormer,
patch_size=2,
embed_norm=True,
hidden_size=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
)
net.to(device)
net.eval()
stormer_wrapper = StormerWrapperPangu(
root_dir='/eagle/MDClimSim/tungnd/data/wb2/1.40625deg_from_full_res_1_step_6hr_h5df/',
variables=vars_stormer,
net=net,
base_lead_time=6,
possible_lead_times=[24,12,6],
ckpt=ckpt_pth,
device=device,
logger=logger,
)
pytorch_total_params = sum(p.numel() for p in stormer_wrapper.net.parameters())
pytorch_trainable_params = sum(p.numel() for p in stormer_wrapper.net.parameters() if p.requires_grad)
print('Total model parameters : {}'.format(pytorch_total_params))
print('Trainable model parameters : {}'.format(pytorch_trainable_params))
logger.info('Total model parameters : {}'.format(pytorch_total_params))
logger.info('Trainable model parameters : {}'.format(pytorch_trainable_params))
print('background_file_np :',background_file_np)
logger.info('background_file_np : {}'.format(background_file_np))
#if '.npy' in background_file_np:
background_f = np.load(background_file_np, 'r')
#elif '.h5' in background_file_np:
# background_f = h5py.File(background_file_np, 'r')[str(forecast_hrs)][:]
# background_f = np.expand_dims(background_f,axis=0)
if 'rand' in save_dir:
#background_f = np.zeros_like(background_f)
print('using random background')
background_f = np.random.randn(*background_f.shape)
background = torch.from_numpy(background_f.copy())
fourd_da = FourDVar(stormer_wrapper, obs_loader,
background, background_err, background_err_hf,
obs_err, dv_layer,
model_step=model_step,
da_window=da_window,
obs_freq=obs_freq,
da_type=da_type,
vars=vars_stormer,
b_inflation=b_inflation,
max_iter=700,
savedir=save_dir,
device=device,
save_idx=start_idx,
logger=logger,
)
fourd_da.cycleDataAssimilation(forecast=True)