forked from COSIMA/ocean-ic
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ic_stability_metric.py
executable file
·159 lines (121 loc) · 4.51 KB
/
ic_stability_metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/env python
from __future__ import print_function
import sys, os
import argparse
import subprocess as sp
import netCDF4 as nc
import numpy as np
from seawater import eos80
from scipy import ndimage as nd
from regridder import regrid
"""
Calculate a 'stability metric' for the IC.
This counts the fraction of cells in a column which need to move to make it
completely stable.
"""
def level_of_first_masked(array):
assert len(array.shape) == 1
i = 0
for i in xrange(len(array)):
if array.mask[i]:
break
return i
def calc_density(temp, salt, levels):
assert len(temp.shape) == 3
assert len(salt.shape) == 3
assert len(levels.shape) == 1
num_levs = levels.shape[0]
lats = salt.shape[1]
lons = salt.shape[2]
depth = np.vstack(([levels]*lats*lons)).T.reshape(num_levs, lats, lons)
# Pressure in dbar
pressure = depth*0.1
density = eos80.dens(salt, temp, pressure)
return density
def calc_stability_index(temp, salt, levels):
density = calc_density(temp, salt, levels)
lats = salt.shape[1]
lons = salt.shape[2]
si_ret = np.zeros((lats, lons))
# The score for each column is the sum of the difference between the
# current and sorted column.
for lat in range(lats):
for lon in range(lons):
if hasattr(density, 'mask'):
lev = level_of_first_masked(density[:, lat, lon])
if lev == 0:
continue
else:
lev = density.shape[0]
si = np.count_nonzero(np.sort(density[:lev, lat, lon]) - density[:lev, lat, lon])
si = si / float(lev)
si_ret[lat, lon] = si
return si_ret
def make_more_stable_ic(ic_file, temp_var, salt_var):
"""
"""
sigma = (2, 3, 3)
with nc.Dataset(ic_file, 'r+') as f:
temp = f.variables[temp_var]
temp[0, :, :, :] = nd.filters.gaussian_filter(temp[0, :, :, :], sigma)
salt = f.variables[salt_var]
salt[0, :, :, :] = nd.filters.gaussian_filter(salt[0, :, :, :], sigma)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('temp_ic', help="The initial condition file containing temp")
parser.add_argument('salt_ic', help="The initial condition file containing salt")
parser.add_argument('--output_more_stable', action='store_true',
default=False, help="Output a more stable version of the IC.")
args = parser.parse_args()
salt_var_names = ['vosaline', 'salt', 'SALT']
temp_var_names = ['votemper', 'temp', 'TEMP', 'pottmp']
depth_var_names = ['depth', 'zt', 'ZT', 'AZ_50', 'level']
with nc.Dataset(args.salt_ic) as f:
for salt_var in salt_var_names:
if f.variables.has_key(salt_var):
salt = f.variables[salt_var][0, :, :, :]
try:
if f.variables[salt_var].units == "kg/kg":
salt *= 1000
except AttributeError, e:
pass
break
else:
raise KeyError(salt_var)
with nc.Dataset(args.temp_ic) as f:
for temp_var in temp_var_names:
if f.variables.has_key(temp_var):
temp = f.variables[temp_var][0, :, :, :]
try:
if f.variables[temp_var].units == "K":
temp -= 273.15
except AttributeError, e:
pass
break
else:
raise KeyError(temp_var)
for d in depth_var_names:
if f.variables.has_key(d):
depth = f.variables[d][:]
break
else:
raise KeyError(d)
si = calc_stability_index(temp, salt, depth)
lats = salt.shape[1]
lons = salt.shape[2]
if args.output_more_stable:
ret = sp.call(['nccopy', '-v', temp_var, args.temp_ic, './more_stable_ic.nc'])
assert ret == 0
ret = sp.call(['nccopy', '-v', salt_var, args.salt_ic, './more_stable_ic.nc'])
assert ret == 0
make_more_stable_ic('./more_stable_ic.nc', temp_var, salt_var)
with nc.Dataset('./stability_index.nc', 'w') as f:
f.createDimension('x', lons)
f.createDimension('y', lats)
si_nc = f.createVariable('stability', 'f8', ('y', 'x'))
si_nc[:] = si[:]
# Total score is sum of all columns divided by total columns
print('Average stability metric (high is bad) {}'.format(np.sum(si) / (lats*lons)))
return 0
if __name__ == '__main__':
sys.exit(main())