-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
134 lines (115 loc) · 4.24 KB
/
utils.py
File metadata and controls
134 lines (115 loc) · 4.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
'''
Author: xiaoniu
Date: 2026-01-10 22:21:43
LastEditors: xiaoniu
LastEditTime: 2026-01-10 22:31:05
Description: 文件用途描述
'''
import torch
def get_earth_position_index(window_size, ndim=3):
"""
Revise from WeatherLearn https://github.com/lizhuoq/WeatherLearn
This function construct the position index to reuse symmetrical parameters of the position bias.
implementation from: https://github.com/198808xc/Pangu-Weather/blob/main/pseudocode.py
Args:
window_size (tuple[int]): [pressure levels, latitude, longitude] or [latitude, longitude]
ndim (int): dimension of tensor, 3 or 2
Returns:
position_index (torch.Tensor): [win_pl * win_lat * win_lon, win_pl * win_lat * win_lon] or [win_lat * win_lon, win_lat * win_lon]
"""
if ndim == 3:
win_pl, win_lat, win_lon = window_size
elif ndim == 2:
win_lat, win_lon = window_size
if ndim == 3:
# Index in the pressure level of query matrix
coords_zi = torch.arange(win_pl)
# Index in the pressure level of key matrix
coords_zj = -torch.arange(win_pl) * win_pl
# Index in the latitude of query matrix
coords_hi = torch.arange(win_lat)
# Index in the latitude of key matrix
coords_hj = -torch.arange(win_lat) * win_lat
# Index in the longitude of the key-value pair
coords_w = torch.arange(win_lon)
# Change the order of the index to calculate the index in total
if ndim == 3:
coords_1 = torch.stack(torch.meshgrid([coords_zi, coords_hi, coords_w]))
coords_2 = torch.stack(torch.meshgrid([coords_zj, coords_hj, coords_w]))
elif ndim == 2:
coords_1 = torch.stack(torch.meshgrid([coords_hi, coords_w]))
coords_2 = torch.stack(torch.meshgrid([coords_hj, coords_w]))
coords_flatten_1 = torch.flatten(coords_1, 1)
coords_flatten_2 = torch.flatten(coords_2, 1)
coords = coords_flatten_1[:, :, None] - coords_flatten_2[:, None, :]
coords = coords.permute(1, 2, 0).contiguous()
# Shift the index for each dimension to start from 0
if ndim == 3:
coords[:, :, 2] += win_lon - 1
coords[:, :, 1] *= 2 * win_lon - 1
coords[:, :, 0] *= (2 * win_lon - 1) * win_lat * win_lat
elif ndim == 2:
coords[:, :, 1] += win_lon - 1
coords[:, :, 0] *= 2 * win_lon - 1
# Sum up the indexes in two/three dimensions
position_index = coords.sum(-1)
return position_index
def get_pad3d(input_resolution, window_size):
"""
Args:
input_resolution (tuple[int]): (Pl, Lat, Lon)
window_size (tuple[int]): (Pl, Lat, Lon)
Returns:
padding (tuple[int]): (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
"""
Pl, Lat, Lon = input_resolution
win_pl, win_lat, win_lon = window_size
padding_left = padding_right = padding_top = padding_bottom = padding_front = (
padding_back
) = 0
pl_remainder = Pl % win_pl
lat_remainder = Lat % win_lat
lon_remainder = Lon % win_lon
if pl_remainder:
pl_pad = win_pl - pl_remainder
padding_front = pl_pad // 2
padding_back = pl_pad - padding_front
if lat_remainder:
lat_pad = win_lat - lat_remainder
padding_top = lat_pad // 2
padding_bottom = lat_pad - padding_top
if lon_remainder:
lon_pad = win_lon - lon_remainder
padding_left = lon_pad // 2
padding_right = lon_pad - padding_left
return (
padding_left,
padding_right,
padding_top,
padding_bottom,
padding_front,
padding_back,
)
def crop3d(x: torch.Tensor, resolution):
"""
Args:
x (torch.Tensor): B, C, Pl, Lat, Lon
resolution (tuple[int]): Pl, Lat, Lon
"""
_, _, Pl, Lat, Lon = x.shape
pl_pad = Pl - resolution[0]
lat_pad = Lat - resolution[1]
lon_pad = Lon - resolution[2]
padding_front = pl_pad // 2
padding_back = pl_pad - padding_front
padding_top = lat_pad // 2
padding_bottom = lat_pad - padding_top
padding_left = lon_pad // 2
padding_right = lon_pad - padding_left
return x[
:,
:,
padding_front : Pl - padding_back,
padding_top : Lat - padding_bottom,
padding_left : Lon - padding_right,
]