remap_evag.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. import netCDF4 as nc
  2. import ctypes as ct
  3. import numpy as np
  4. import os
  5. import sys
  6. import math
  7. from mpi4py import MPI
  8. import time
  9. remap = ct.cdll.LoadLibrary(os.path.realpath('libmapper.so'))
  10. def from_mpas(filename):
  11. # construct vortex bounds from Mpas grid structure
  12. f = nc.Dataset(filename)
  13. # in this case it is must faster to first read the whole file into memory
  14. # before converting the data structure
  15. print "read"
  16. stime = time.time()
  17. lon_vert = np.array(f.variables["lonVertex"])
  18. lat_vert = np.array(f.variables["latVertex"])
  19. vert_cell = np.array(f.variables["verticesOnCell"])
  20. nvert_cell = np.array(f.variables["nEdgesOnCell"])
  21. ncell, nvert = vert_cell.shape
  22. assert(max(nvert_cell) <= nvert)
  23. lon = np.zeros(vert_cell.shape)
  24. lat = np.zeros(vert_cell.shape)
  25. etime = time.time()
  26. print "finished read, now convert", etime-stime
  27. scal = 180.0/math.pi
  28. for c in range(ncell):
  29. lat[c,:] = lat_vert[vert_cell[c,:]-1]*scal
  30. lon[c,:] = lon_vert[vert_cell[c,:]-1]*scal
  31. # signal "last vertex" by netCDF convetion
  32. lon[c,nvert_cell[c]] = lon[c,0]
  33. lat[c,nvert_cell[c]] = lat[c,0]
  34. print "convert end", time.time() - etime
  35. return lon, lat
  36. grid_types = {
  37. "dynamico:mesh": {
  38. "lon_name": "bounds_lon_i",
  39. "lat_name": "bounds_lat_i",
  40. "pole": [0,0,0]
  41. },
  42. "dynamico:vort": {
  43. "lon_name": "bounds_lon_v",
  44. "lat_name": "bounds_lat_v",
  45. "pole": [0,0,0]
  46. },
  47. "dynamico:restart": {
  48. "lon_name": "lon_i_vertices",
  49. "lat_name": "lat_i_vertices",
  50. "pole": [0,0,0]
  51. },
  52. "test:polygon": {
  53. "lon_name": "bounds_lon",
  54. "lat_name": "bounds_lat",
  55. "pole": [0,0,0]
  56. },
  57. "test:latlon": {
  58. "lon_name": "bounds_lon",
  59. "lat_name": "bounds_lat",
  60. "pole": [0,0,1]
  61. },
  62. "mpas": {
  63. "reader": from_mpas,
  64. "pole": [0,0,0]
  65. }
  66. }
  67. interp_types = {
  68. "FV1": 1,
  69. "FV2": 2
  70. }
  71. usage = """
  72. Usage: python remap.py interp srctype srcfile dsttype dstfile mode outfile
  73. interp: type of interpolation
  74. choices:
  75. FV1: first order conservative Finite Volume
  76. FV2: second order conservative Finite Volume
  77. srctype, dsttype: grid type of source and destination
  78. choices: """ + " ".join(grid_types.keys()) + """
  79. srcfile, dstfile: grid file names, should mostly be netCDF file
  80. mode: modus of operation
  81. choices:
  82. weights: computes weight and stores them in outfile
  83. remap: computes the interpolated values on destination grid and stores them in outfile
  84. outfile: output filename
  85. """
  86. # parse command line arguments
  87. if not len(sys.argv) == 8:
  88. print usage
  89. sys.exit(2)
  90. interp = sys.argv[1]
  91. try:
  92. srctype = grid_types[sys.argv[2]]
  93. except KeyError:
  94. print "Error: srctype needs to be one of the following: " + " ".join(grid_types.keys()) + "."
  95. exit(2)
  96. srcfile = sys.argv[3]
  97. try:
  98. dsttype = grid_types[sys.argv[4]]
  99. except KeyError:
  100. print "Error: srctype needs to be one of the following: " + " ".join(grid_types.keys()) + "."
  101. exit(2)
  102. dstfile = sys.argv[5]
  103. mode = sys.argv[6]
  104. outfile = sys.argv[7]
  105. if not mode in ("weights", "remap"):
  106. print "Error: mode must be of of the following: weights remap."
  107. exit(2)
  108. remap.mpi_init()
  109. rank = remap.mpi_rank()
  110. size = remap.mpi_size()
  111. print rank, "/", size
  112. print "Reading grids from netCDF files."
  113. if "reader" in srctype:
  114. src_lon, src_lat = srctype["reader"](srcfile)
  115. else:
  116. src = nc.Dataset(srcfile)
  117. # the following two lines do not perform the actual read
  118. # the file is read later when assigning to the ctypes array
  119. # -> no unnecessary array copying in memory
  120. src_lon = src.variables[srctype["lon_name"]]
  121. src_lat = src.variables[srctype["lat_name"]]
  122. if "reader" in dsttype:
  123. dst_lon, dst_lat = dsttype["reader"](dstfile)
  124. else:
  125. dst = nc.Dataset(dstfile)
  126. dst_lon = dst.variables[dsttype["lon_name"]]
  127. dst_lat = dst.variables[dsttype["lat_name"]]
  128. src_ncell, src_nvert = src_lon.shape
  129. dst_ncell, dst_nvert = dst_lon.shape
  130. def compute_distribution(ncell):
  131. "Returns the local number and starting position in global array."
  132. if rank < ncell % size:
  133. return ncell//size + 1, \
  134. (ncell//size + 1)*rank
  135. else:
  136. return ncell//size, \
  137. (ncell//size + 1)*(ncell%size) + (ncell//size)*(rank - ncell%size)
  138. src_ncell_loc, src_loc_start = compute_distribution(src_ncell)
  139. dst_ncell_loc, dst_loc_start = compute_distribution(dst_ncell)
  140. print "src", src_ncell_loc, src_loc_start
  141. print "dst", dst_ncell_loc, dst_loc_start
  142. c_src_lon = (ct.c_double * (src_ncell_loc*src_nvert))()
  143. c_src_lat = (ct.c_double * (src_ncell_loc*src_nvert))()
  144. c_dst_lon = (ct.c_double * (dst_ncell_loc*dst_nvert))()
  145. c_dst_lat = (ct.c_double * (dst_ncell_loc*dst_nvert))()
  146. c_src_lon[:] = nc.numpy.reshape(src_lon[src_loc_start:src_loc_start+src_ncell_loc,:], (len(c_src_lon),1))
  147. c_src_lat[:] = nc.numpy.reshape(src_lat[src_loc_start:src_loc_start+src_ncell_loc,:], (len(c_src_lon),1))
  148. c_dst_lon[:] = nc.numpy.reshape(dst_lon[dst_loc_start:dst_loc_start+dst_ncell_loc,:], (len(c_dst_lon),1))
  149. c_dst_lat[:] = nc.numpy.reshape(dst_lat[dst_loc_start:dst_loc_start+dst_ncell_loc,:], (len(c_dst_lon),1))
  150. print "Calling remap library to compute weights."
  151. srcpole = (ct.c_double * (3))()
  152. dstpole = (ct.c_double * (3))()
  153. srcpole[:] = srctype["pole"]
  154. dstpole[:] = dsttype["pole"]
  155. c_src_ncell = ct.c_int(src_ncell_loc)
  156. c_src_nvert = ct.c_int(src_nvert)
  157. c_dst_ncell = ct.c_int(dst_ncell_loc)
  158. c_dst_nvert = ct.c_int(dst_nvert)
  159. order = ct.c_int(interp_types[interp])
  160. c_nweight = ct.c_int()
  161. print "src:", src_ncell, src_nvert
  162. print "dst:", dst_ncell, dst_nvert
  163. remap.remap_get_num_weights(c_src_lon, c_src_lat, c_src_nvert, c_src_ncell, srcpole,
  164. c_dst_lon, c_dst_lat, c_dst_nvert, c_dst_ncell, dstpole,
  165. order, ct.byref(c_nweight))
  166. nwgt = c_nweight.value
  167. c_weights = (ct.c_double * nwgt)()
  168. c_dst_idx = (ct.c_int * nwgt)()
  169. c_src_idx = (ct.c_int * nwgt)()
  170. remap.remap_get_weights(c_weights, c_src_idx, c_dst_idx)
  171. wgt_glo = MPI.COMM_WORLD.gather(c_weights[:])
  172. src_idx_glo = MPI.COMM_WORLD.gather(c_src_idx[:])
  173. dst_idx_glo = MPI.COMM_WORLD.gather(c_dst_idx[:])
  174. if rank == 0 and mode == 'weights':
  175. nwgt_glo = sum(len(wgt) for wgt in wgt_glo)
  176. print "Writing", nwgt_glo, "weights to netCDF-file '" + outfile + "'."
  177. f = nc.Dataset(outfile,'w')
  178. f.createDimension('n_src', src_ncell)
  179. f.createDimension('n_dst', dst_ncell)
  180. f.createDimension('n_weight', nwgt_glo)
  181. var = f.createVariable('src_idx', 'i', ('n_weight'))
  182. var[:] = np.hstack(src_idx_glo) + 1 # make indices start from 1
  183. var = f.createVariable('dst_idx', 'i', ('n_weight'))
  184. var[:] = np.hstack(dst_idx_glo) + 1 # make indices start from 1
  185. var = f.createVariable('weight', 'd', ('n_weight'))
  186. var[:] = np.hstack(wgt_glo)
  187. f.close()
  188. def test_fun(x, y, z):
  189. return (1-x**2)*(1-y**2)*z
  190. def test_fun_ll(lat, lon):
  191. #return np.cos(lat*math.pi/180)*np.cos(lon*math.pi/180)
  192. return 2.0 + np.cos(lat*math.pi/180.)**2 * np.cos(2*lon*math.pi/180.);
  193. #UNUSED
  194. #def sphe2cart(lat, lon):
  195. # phi = math.pi/180*lon[:]
  196. # theta = math.pi/2 - math.pi/180*lat[:]
  197. # return np.sin(theta)*np.cos(phi), np.sin(theta)*np.sin(phi), np.cos(theta)
  198. if mode == 'remap':
  199. c_centre_lon = (ct.c_double * src_ncell_loc)()
  200. c_centre_lat = (ct.c_double * src_ncell_loc)()
  201. c_areas = (ct.c_double * src_ncell_loc)()
  202. remap.remap_get_barycentres_and_areas(c_src_lon, c_src_lat, c_src_nvert, c_src_ncell, srcpole,
  203. c_centre_lon, c_centre_lat, c_areas)
  204. # src_val_loc = test_fun_ll(np.array(c_centre_lat[:]), np.array(c_centre_lon[:]))
  205. # src_val_loc = src.variables["ps"]
  206. # src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
  207. # src_val_glo = src_val_loc
  208. c_centre_lon = (ct.c_double * dst_ncell_loc)()
  209. c_centre_lat = (ct.c_double * dst_ncell_loc)()
  210. c_areas = (ct.c_double * dst_ncell_loc)()
  211. remap.remap_get_barycentres_and_areas(c_dst_lon, c_dst_lat, c_dst_nvert, c_dst_ncell, dstpole,
  212. c_centre_lon, c_centre_lat, c_areas)
  213. # dst_val_loc = test_fun_ll(np.array(c_centre_lat[:]), np.array(c_centre_lon[:]))
  214. # dst_val_glo = MPI.COMM_WORLD.gather(dst_val_loc)
  215. dst_areas_glo = MPI.COMM_WORLD.gather(np.array(c_areas[:]))
  216. dst_centre_lon_glo = MPI.COMM_WORLD.gather(np.array(c_centre_lon[:]))
  217. dst_centre_lat_glo = MPI.COMM_WORLD.gather(np.array(c_centre_lat[:]))
  218. if rank == 0 and mode == 'remap':
  219. from scipy import sparse
  220. A = sparse.csr_matrix(sparse.coo_matrix((np.hstack(wgt_glo),(np.hstack(dst_idx_glo),np.hstack(src_idx_glo)))))
  221. # src_val = np.hstack(src_val_glo)
  222. # dst_ref = np.hstack(dst_val_glo)
  223. dst_areas = np.hstack(dst_areas_glo)
  224. dst_centre_lon = np.hstack(dst_centre_lon_glo)
  225. dst_centre_lat = np.hstack(dst_centre_lat_glo)
  226. # print "source:", src_val.shape
  227. # print "destin:", dst_ref.shape
  228. # dst_val = A*src_val
  229. # err = dst_val - dst_ref
  230. # print "absolute maximum error, maximum value:", np.max(np.abs(err)), np.max(np.abs(dst_ref))
  231. # print "relative maximum error, normalized L2 error, average target cell size (edgelength of same-area square):"
  232. # print np.max(np.abs(err))/np.max(np.abs(dst_ref)), np.linalg.norm(err)/np.linalg.norm(dst_ref), np.mean(np.sqrt(dst_areas))
  233. lev=src.dimensions['lev']
  234. nq=src.dimensions['nq']
  235. f = nc.Dataset(outfile,'w')
  236. f.createDimension('nvert', dst_nvert)
  237. f.createDimension('cell', dst_ncell)
  238. f.createDimension('lev', len(lev))
  239. f.createDimension('nq', len(nq))
  240. var = f.createVariable('lat', 'd', ('cell'))
  241. var.setncattr("long_name", "latitude")
  242. var.setncattr("units", "degrees_north")
  243. var.setncattr("bounds", "bounds_lat")
  244. var[:] = dst_centre_lat
  245. var = f.createVariable('lon', 'd', ('cell'))
  246. var.setncattr("long_name", "longitude")
  247. var.setncattr("units", "degrees_east")
  248. var.setncattr("bounds", "bounds_lon")
  249. var[:] = dst_centre_lon
  250. var = f.createVariable('bounds_lon', 'd', ('cell','nvert'))
  251. var[:] = dst_lon
  252. var = f.createVariable('bounds_lat', 'd', ('cell','nvert'))
  253. var[:] = dst_lat
  254. var = f.createVariable('lev', 'd', ('lev'))
  255. var[:] = src.variables['lev']
  256. var.setncattr('axis', 'Z')
  257. var.setncattr('units', 'Pa')
  258. var.setncattr('positive', 'down')
  259. var[:] = src.variables['lev']
  260. ps = f.createVariable('ps', 'd', ('cell'))
  261. ps.setncattr("coordinates", "lon lat")
  262. phis = f.createVariable('phis', 'd', ('cell'))
  263. phis.setncattr("coordinates", "lon lat")
  264. theta_rhodz = f.createVariable('theta_rhodz', 'd', ('lev','cell'))
  265. theta_rhodz.setncattr("coordinates", "lev lon lat")
  266. ulon = f.createVariable('ulon', 'd', ('lev','cell'))
  267. ulon.setncattr("coordinates", "lev lon lat")
  268. ulat = f.createVariable('ulat', 'd', ('lev','cell'))
  269. ulat.setncattr("coordinates", "lev lon lat")
  270. q = f.createVariable('q', 'd', ('nq','lev','cell'))
  271. q.setncattr("coordinates", "nq lev lon lat")
  272. #for ps
  273. if mode == 'remap':
  274. src_val_loc = src.variables['ps']
  275. src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
  276. if rank == 0 and mode == 'remap':
  277. # print(src_val_glo)
  278. src_val = np.hstack(src_val_glo)
  279. # print src_val
  280. print A.shape
  281. print src_val.shape
  282. dst_val = A*src_val
  283. ps[:] = dst_val
  284. #for phis
  285. if mode == 'remap':
  286. src_val_loc = src.variables['phis']
  287. src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[:]))
  288. if rank == 0 and mode == 'remap':
  289. src_val = np.hstack(src_val_glo)
  290. dst_val = A*src_val
  291. phis[:] = dst_val
  292. #for theta_rhodz
  293. if mode == 'remap':
  294. src_val_loc = src.variables['theta_rhodz']
  295. for l in range(0, len(lev)):
  296. if mode == 'remap':
  297. src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[l,:]))
  298. if rank == 0 and mode == 'remap':
  299. src_val = np.hstack(src_val_glo)
  300. dst_val = A*src_val
  301. theta_rhodz[l,:] = dst_val
  302. #for ulon
  303. if mode == 'remap':
  304. src_val_loc = src.variables['ulon']
  305. for l in range(0, len(lev)):
  306. if mode == 'remap':
  307. src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[l,:]))
  308. if rank == 0 and mode == 'remap':
  309. src_val = np.hstack(src_val_glo)
  310. dst_val = A*src_val
  311. ulon[l,:] = dst_val
  312. #for ulat
  313. if mode == 'remap':
  314. src_val_loc = src.variables['ulat']
  315. for l in range(0, len(lev)):
  316. if mode == 'remap':
  317. src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[l,:]))
  318. if rank == 0 and mode == 'remap':
  319. src_val = np.hstack(src_val_glo)
  320. dst_val = A*src_val
  321. ulat[l,:] = dst_val
  322. #for q
  323. if mode == 'remap':
  324. src_val_loc = src.variables['q']
  325. for n in range(0, len(nq)):
  326. for l in range(0, len(lev)):
  327. if mode == 'remap':
  328. src_val_glo = MPI.COMM_WORLD.gather(np.array(src_val_loc[n,l,:]))
  329. if rank == 0 and mode == 'remap':
  330. src_val = np.hstack(src_val_glo)
  331. dst_val = A*src_val
  332. q[n,l,:] = dst_val
  333. if mode == 'remap':
  334. f.close()
  335. if not "reader" in srctype:
  336. src.close()
  337. if not "reader" in dsttype:
  338. dst.close()