remap_ECDYN.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. import netCDF4 as nc
  2. import ctypes as ct
  3. import numpy as np
  4. import os
  5. import sys
  6. import math
  7. from mpi4py import MPI
  8. import time
  9. remap = ct.cdll.LoadLibrary(os.path.realpath('libmapper.so'))
  10. def from_mpas(filename):
  11. # construct vortex bounds from Mpas grid structure
  12. f = nc.Dataset(filename)
  13. # in this case it is must faster to first read the whole file into memory
  14. # before converting the data structure
  15. print "read"
  16. stime = time.time()
  17. lon_vert = np.array(f.variables["lonVertex"])
  18. lat_vert = np.array(f.variables["latVertex"])
  19. vert_cell = np.array(f.variables["verticesOnCell"])
  20. nvert_cell = np.array(f.variables["nEdgesOnCell"])
  21. ncell, nvert = vert_cell.shape
  22. assert(max(nvert_cell) <= nvert)
  23. lon = np.zeros(vert_cell.shape)
  24. lat = np.zeros(vert_cell.shape)
  25. etime = time.time()
  26. print "finished read, now convert", etime-stime
  27. scal = 180.0/math.pi
  28. for c in range(ncell):
  29. lat[c,:] = lat_vert[vert_cell[c,:]-1]*scal
  30. lon[c,:] = lon_vert[vert_cell[c,:]-1]*scal
  31. # signal "last vertex" by netCDF convetion
  32. lon[c,nvert_cell[c]] = lon[c,0]
  33. lat[c,nvert_cell[c]] = lat[c,0]
  34. print "convert end", time.time() - etime
  35. return lon, lat
  36. grid_types = {
  37. "dynamico:mesh": {
  38. "lon_name": "bounds_lon_i",
  39. "lat_name": "bounds_lat_i",
  40. "pole": [0,0,0]
  41. },
  42. "dynamico:vort": {
  43. "lon_name": "bounds_lon_v",
  44. "lat_name": "bounds_lat_v",
  45. "pole": [0,0,0]
  46. },
  47. "dynamico:restart": {
  48. "lon_name": "lon_i_vertices",
  49. "lat_name": "lat_i_vertices",
  50. "pole": [0,0,0]
  51. },
  52. "test:polygon": {
  53. "lon_name": "bounds_lon",
  54. "lat_name": "bounds_lat",
  55. "pole": [0,0,0]
  56. },
  57. "test:latlon": {
  58. "lon_name": "bounds_lon",
  59. "lat_name": "bounds_lat",
  60. "pole": [0,0,1]
  61. },
  62. "mpas": {
  63. "reader": from_mpas,
  64. "pole": [0,0,0]
  65. }
  66. }
  67. interp_types = {
  68. "FV1": 1,
  69. "FV2": 2
  70. }
  71. usage = """
  72. Usage: python remap.py interp srctype srcfile dsttype dstfile mode outfile
  73. interp: type of interpolation
  74. choices:
  75. FV1: first order conservative Finite Volume
  76. FV2: second order conservative Finite Volume
  77. srctype, dsttype: grid type of source and destination
  78. choices: """ + " ".join(grid_types.keys()) + """
  79. srcfile, dstfile: grid file names, should mostly be netCDF file
  80. mode: modus of operation
  81. choices:
  82. weights: computes weight and stores them in outfile
  83. remap: computes the interpolated values on destination grid and stores them in outfile
  84. outfile: output filename
  85. """
  86. # parse command line arguments
  87. if not len(sys.argv) == 8:
  88. print usage
  89. sys.exit(2)
  90. interp = sys.argv[1]
  91. try:
  92. srctype = grid_types[sys.argv[2]]
  93. except KeyError:
  94. print "Error: srctype needs to be one of the following: " + " ".join(grid_types.keys()) + "."
  95. exit(2)
  96. srcfile = sys.argv[3]
  97. try:
  98. dsttype = grid_types[sys.argv[4]]
  99. except KeyError:
  100. print "Error: srctype needs to be one of the following: " + " ".join(grid_types.keys()) + "."
  101. exit(2)
  102. dstfile = sys.argv[5]
  103. mode = sys.argv[6]
  104. outfile = sys.argv[7]
  105. if not mode in ("weights", "remap"):
  106. print "Error: mode must be of of the following: weights remap."
  107. exit(2)
  108. remap.mpi_init()
  109. rank = remap.mpi_rank()
  110. size = remap.mpi_size()
  111. print rank, "/", size
  112. print "Reading grids from netCDF files."
  113. if "reader" in srctype:
  114. src_lon, src_lat = srctype["reader"](srcfile)
  115. else:
  116. src = nc.Dataset(srcfile)
  117. # the following two lines do not perform the actual read
  118. # the file is read later when assigning to the ctypes array
  119. # -> no unnecessary array copying in memory
  120. src_lon = src.variables[srctype["lon_name"]]
  121. src_lat = src.variables[srctype["lat_name"]]
  122. if "reader" in dsttype:
  123. dst_lon, dst_lat = dsttype["reader"](dstfile)
  124. else:
  125. dst = nc.Dataset(dstfile)
  126. dst_lon = dst.variables[dsttype["lon_name"]]
  127. dst_lat = dst.variables[dsttype["lat_name"]]
  128. src_ncell, src_nvert = src_lon.shape
  129. dst_ncell, dst_nvert = dst_lon.shape
  130. def compute_distribution(ncell):
  131. "Returns the local number and starting position in global array."
  132. if rank < ncell % size:
  133. return ncell//size + 1, \
  134. (ncell//size + 1)*rank
  135. else:
  136. return ncell//size, \
  137. (ncell//size + 1)*(ncell%size) + (ncell//size)*(rank - ncell%size)
  138. src_ncell_loc, src_loc_start = compute_distribution(src_ncell)
  139. dst_ncell_loc, dst_loc_start = compute_distribution(dst_ncell)
  140. print "src", src_ncell_loc, src_loc_start
  141. print "dst", dst_ncell_loc, dst_loc_start
  142. c_src_lon = (ct.c_double * (src_ncell_loc*src_nvert))()
  143. c_src_lat = (ct.c_double * (src_ncell_loc*src_nvert))()
  144. c_dst_lon = (ct.c_double * (dst_ncell_loc*dst_nvert))()
  145. c_dst_lat = (ct.c_double * (dst_ncell_loc*dst_nvert))()
  146. c_src_lon[:] = nc.numpy.reshape(src_lon[src_loc_start:src_loc_start+src_ncell_loc,:], (len(c_src_lon),1))
  147. c_src_lat[:] = nc.numpy.reshape(src_lat[src_loc_start:src_loc_start+src_ncell_loc,:], (len(c_src_lon),1))
  148. c_dst_lon[:] = nc.numpy.reshape(dst_lon[dst_loc_start:dst_loc_start+dst_ncell_loc,:], (len(c_dst_lon),1))
  149. c_dst_lat[:] = nc.numpy.reshape(dst_lat[dst_loc_start:dst_loc_start+dst_ncell_loc,:], (len(c_dst_lon),1))
  150. print "Calling remap library to compute weights."
  151. srcpole = (ct.c_double * (3))()
  152. dstpole = (ct.c_double * (3))()
  153. srcpole[:] = srctype["pole"]
  154. dstpole[:] = dsttype["pole"]
  155. c_src_ncell = ct.c_int(src_ncell_loc)
  156. c_src_nvert = ct.c_int(src_nvert)
  157. c_dst_ncell = ct.c_int(dst_ncell_loc)
  158. c_dst_nvert = ct.c_int(dst_nvert)
  159. order = ct.c_int(interp_types[interp])
  160. c_nweight = ct.c_int()
  161. print "src:", src_ncell, src_nvert
  162. print "dst:", dst_ncell, dst_nvert
  163. remap.remap_get_num_weights(c_src_lon, c_src_lat, c_src_nvert, c_src_ncell, srcpole,
  164. c_dst_lon, c_dst_lat, c_dst_nvert, c_dst_ncell, dstpole,
  165. order, ct.byref(c_nweight))
  166. nwgt = c_nweight.value
  167. c_weights = (ct.c_double * nwgt)()
  168. c_dst_idx = (ct.c_int * nwgt)()
  169. c_src_idx = (ct.c_int * nwgt)()
  170. remap.remap_get_weights(c_weights, c_src_idx, c_dst_idx)
  171. wgt_glo = MPI.COMM_WORLD.gather(c_weights[:])
  172. src_idx_glo = MPI.COMM_WORLD.gather(c_src_idx[:])
  173. dst_idx_glo = MPI.COMM_WORLD.gather(c_dst_idx[:])
  174. if rank == 0 and mode == 'weights':
  175. nwgt_glo = sum(len(wgt) for wgt in wgt_glo)
  176. print "Writing", nwgt_glo, "weights to netCDF-file '" + outfile + "'."
  177. f = nc.Dataset(outfile,'w')
  178. f.createDimension('n_src', src_ncell)
  179. f.createDimension('n_dst', dst_ncell)
  180. f.createDimension('n_weight', nwgt_glo)
  181. var = f.createVariable('src_idx', 'i', ('n_weight'))
  182. var[:] = np.hstack(src_idx_glo) + 1 # make indices start from 1
  183. var = f.createVariable('dst_idx', 'i', ('n_weight'))
  184. var[:] = np.hstack(dst_idx_glo) + 1 # make indices start from 1
  185. var = f.createVariable('weight', 'd', ('n_weight'))
  186. var[:] = np.hstack(wgt_glo)
  187. f.close()
  188. def test_fun(x, y, z):
  189. return (1-x**2)*(1-y**2)*z
  190. def test_fun_ll(lat, lon):
  191. #return np.cos(lat*math.pi/180)*np.cos(lon*math.pi/180)
  192. return 2.0 + np.cos(lat*math.pi/180.)**2 * np.cos(2*lon*math.pi/180.);
  193. #UNUSED
  194. #def sphe2cart(lat, lon):
  195. # phi = math.pi/180*lon[:]
  196. # theta = math.pi/2 - math.pi/180*lat[:]
  197. # return np.sin(theta)*np.cos(phi), np.sin(theta)*np.sin(phi), np.cos(theta)
  198. if mode == 'remap':
  199. c_centre_lon = (ct.c_double * src_ncell_loc)()
  200. c_centre_lat = (ct.c_double * src_ncell_loc)()
  201. c_areas = (ct.c_double * src_ncell_loc)()
  202. remap.remap_get_barycentres_and_areas(c_src_lon, c_src_lat, c_src_nvert, c_src_ncell, srcpole,
  203. c_centre_lon, c_centre_lat, c_areas)
  204. # src_val_loc = test_fun_ll(np.array(c_centre_lat[:]), np.array(c_centre_lon[:]))
  205. # src_val_loc = src.variables["ps"]
  206. # src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
  207. # src_val_glo = src_val_loc
  208. c_centre_lon = (ct.c_double * dst_ncell_loc)()
  209. c_centre_lat = (ct.c_double * dst_ncell_loc)()
  210. c_areas = (ct.c_double * dst_ncell_loc)()
  211. remap.remap_get_barycentres_and_areas(c_dst_lon, c_dst_lat, c_dst_nvert, c_dst_ncell, dstpole,
  212. c_centre_lon, c_centre_lat, c_areas)
  213. # dst_val_loc = test_fun_ll(np.array(c_centre_lat[:]), np.array(c_centre_lon[:]))
  214. # dst_val_glo = MPI.COMM_WORLD.gather(dst_val_loc)
  215. dst_areas_glo = MPI.COMM_WORLD.gather(np.array(c_areas[:]))
  216. dst_centre_lon_glo = MPI.COMM_WORLD.gather(np.array(c_centre_lon[:]))
  217. dst_centre_lat_glo = MPI.COMM_WORLD.gather(np.array(c_centre_lat[:]))
  218. if rank == 0 and mode == 'remap':
  219. from scipy import sparse
  220. A = sparse.csr_matrix(sparse.coo_matrix((np.hstack(wgt_glo),(np.hstack(dst_idx_glo),np.hstack(src_idx_glo)))))
  221. # src_val = np.hstack(src_val_glo)
  222. # dst_ref = np.hstack(dst_val_glo)
  223. dst_areas = np.hstack(dst_areas_glo)
  224. dst_centre_lon = np.hstack(dst_centre_lon_glo)
  225. dst_centre_lat = np.hstack(dst_centre_lat_glo)
  226. # print "source:", src_val.shape
  227. # print "destin:", dst_ref.shape
  228. # dst_val = A*src_val
  229. # err = dst_val - dst_ref
  230. # print "absolute maximum error, maximum value:", np.max(np.abs(err)), np.max(np.abs(dst_ref))
  231. # print "relative maximum error, normalized L2 error, average target cell size (edgelength of same-area square):"
  232. # print np.max(np.abs(err))/np.max(np.abs(dst_ref)), np.linalg.norm(err)/np.linalg.norm(dst_ref), np.mean(np.sqrt(dst_areas))
  233. lev=src.dimensions['lev']
  234. f = nc.Dataset(outfile,'w')
  235. f.createDimension('nvert', dst_nvert)
  236. f.createDimension('cell', dst_ncell)
  237. f.createDimension('lev', len(lev))
  238. var = f.createVariable('lat', 'd', ('cell'))
  239. var.setncattr("long_name", "latitude")
  240. var.setncattr("units", "degrees_north")
  241. var.setncattr("bounds", "bounds_lat")
  242. var[:] = dst_centre_lat
  243. var = f.createVariable('lon', 'd', ('cell'))
  244. var.setncattr("long_name", "longitude")
  245. var.setncattr("units", "degrees_east")
  246. var.setncattr("bounds", "bounds_lon")
  247. var[:] = dst_centre_lon
  248. var = f.createVariable('bounds_lon', 'd', ('cell','nvert'))
  249. var[:] = dst_lon
  250. var = f.createVariable('bounds_lat', 'd', ('cell','nvert'))
  251. var[:] = dst_lat
  252. var = f.createVariable('lev', 'd', ('lev'))
  253. var[:] = src.variables['lev']
  254. var.setncattr('axis', 'Z')
  255. var.setncattr('units', 'Pa')
  256. var.setncattr('positive', 'down')
  257. var[:] = src.variables['lev']
  258. U = f.createVariable('U', 'd', ('lev','cell'))
  259. U.setncattr("coordinates", "lev lon lat")
  260. V = f.createVariable('V', 'd', ('lev','cell'))
  261. V.setncattr("coordinates", "lev lon lat")
  262. TEMP = f.createVariable('TEMP', 'd', ('lev','cell'))
  263. TEMP.setncattr("coordinates", "lev lon lat")
  264. R = f.createVariable('R', 'd', ('lev','cell'))
  265. R.setncattr("coordinates", "lev lon lat")
  266. Z = f.createVariable('Z', 'd', ('cell'))
  267. Z.setncattr("coordinates", "lon lat")
  268. ST = f.createVariable('ST', 'd', ('cell'))
  269. ST.setncattr("coordinates", "lon lat")
  270. CDSW = f.createVariable('CDSW', 'd', ('cell'))
  271. CDSW.setncattr("coordinates", "lon lat")
  272. SP = f.createVariable('SP', 'd', ('cell'))
  273. SP.setncattr("coordinates", "lon lat")
  274. #for U
  275. if mode == 'remap':
  276. src_val_loc = src.variables['U']
  277. for l in range(0, len(lev)):
  278. if mode == 'remap':
  279. src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[l,:]))
  280. if rank == 0 and mode == 'remap':
  281. src_val = np.hstack(src_val_glo)
  282. dst_val = A*src_val
  283. U[l,:] = dst_val
  284. #for V
  285. if mode == 'remap':
  286. src_val_loc = src.variables['V']
  287. for l in range(0, len(lev)):
  288. if mode == 'remap':
  289. src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[l,:]))
  290. if rank == 0 and mode == 'remap':
  291. src_val = np.hstack(src_val_glo)
  292. dst_val = A*src_val
  293. V[l,:] = dst_val
  294. #for TEMP
  295. if mode == 'remap':
  296. src_val_loc = src.variables['TEMP']
  297. for l in range(0, len(lev)):
  298. if mode == 'remap':
  299. src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[l,:]))
  300. if rank == 0 and mode == 'remap':
  301. src_val = np.hstack(src_val_glo)
  302. dst_val = A*src_val
  303. TEMP[l,:] = dst_val
  304. #for R
  305. if mode == 'remap':
  306. src_val_loc = src.variables['R']
  307. for l in range(0, len(lev)):
  308. if mode == 'remap':
  309. src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[l,:]))
  310. if rank == 0 and mode == 'remap':
  311. src_val = np.hstack(src_val_glo)
  312. dst_val = A*src_val
  313. R[l,:] = dst_val
  314. #for Z
  315. if mode == 'remap':
  316. src_val_loc = src.variables['Z']
  317. src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
  318. if rank == 0 and mode == 'remap':
  319. src_val = np.hstack(src_val_glo)
  320. dst_val = A*src_val
  321. Z[:] = dst_val
  322. #for ST
  323. if mode == 'remap':
  324. src_val_loc = src.variables['ST']
  325. src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
  326. if rank == 0 and mode == 'remap':
  327. src_val = np.hstack(src_val_glo)
  328. dst_val = A*src_val
  329. ST[:] = dst_val
  330. #for CDSW
  331. if mode == 'remap':
  332. src_val_loc = src.variables['CDSW']
  333. src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
  334. if rank == 0 and mode == 'remap':
  335. src_val = np.hstack(src_val_glo)
  336. dst_val = A*src_val
  337. CDSW[:] = dst_val
  338. #for SP
  339. if mode == 'remap':
  340. src_val_loc = src.variables['SP']
  341. src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
  342. if rank == 0 and mode == 'remap':
  343. src_val = np.hstack(src_val_glo)
  344. dst_val = A*src_val
  345. SP[:] = dst_val
  346. if mode == 'remap':
  347. f.close()
  348. if not "reader" in srctype:
  349. src.close()
  350. if not "reader" in dsttype:
  351. dst.close()