remap_orig.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. import netCDF4 as nc
  2. import ctypes as ct
  3. import numpy as np
  4. import os
  5. import sys
  6. import math
  7. from mpi4py import MPI
  8. import time
  9. remap = ct.cdll.LoadLibrary(os.path.realpath('libmapper.so'))
  10. def from_mpas(filename):
  11. # construct vortex bounds from Mpas grid structure
  12. f = nc.Dataset(filename)
  13. # in this case it is must faster to first read the whole file into memory
  14. # before converting the data structure
  15. print "read"
  16. stime = time.time()
  17. lon_vert = np.array(f.variables["lonVertex"])
  18. lat_vert = np.array(f.variables["latVertex"])
  19. vert_cell = np.array(f.variables["verticesOnCell"])
  20. nvert_cell = np.array(f.variables["nEdgesOnCell"])
  21. ncell, nvert = vert_cell.shape
  22. assert(max(nvert_cell) <= nvert)
  23. lon = np.zeros(vert_cell.shape)
  24. lat = np.zeros(vert_cell.shape)
  25. etime = time.time()
  26. print "finished read, now convert", etime-stime
  27. scal = 180.0/math.pi
  28. for c in range(ncell):
  29. lat[c,:] = lat_vert[vert_cell[c,:]-1]*scal
  30. lon[c,:] = lon_vert[vert_cell[c,:]-1]*scal
  31. # signal "last vertex" by netCDF convetion
  32. lon[c,nvert_cell[c]] = lon[c,0]
  33. lat[c,nvert_cell[c]] = lat[c,0]
  34. print "convert end", time.time() - etime
  35. return lon, lat
  36. grid_types = {
  37. "dynamico:mesh": {
  38. "lon_name": "bounds_lon_i",
  39. "lat_name": "bounds_lat_i",
  40. "pole": [0,0,0]
  41. },
  42. "dynamico:vort": {
  43. "lon_name": "bounds_lon_v",
  44. "lat_name": "bounds_lat_v",
  45. "pole": [0,0,0]
  46. },
  47. "dynamico:restart": {
  48. "lon_name": "lon_i_vertices",
  49. "lat_name": "lat_i_vertices",
  50. "pole": [0,0,0]
  51. },
  52. "test:polygon": {
  53. "lon_name": "bounds_lon",
  54. "lat_name": "bounds_lat",
  55. "pole": [0,0,0]
  56. },
  57. "test:latlon": {
  58. "lon_name": "bounds_lon",
  59. "lat_name": "bounds_lat",
  60. "pole": [0,0,1]
  61. },
  62. "mpas": {
  63. "reader": from_mpas,
  64. "pole": [0,0,0]
  65. }
  66. }
  67. interp_types = {
  68. "FV1": 1,
  69. "FV2": 2
  70. }
  71. usage = """
  72. Usage: python remap.py interp srctype srcfile dsttype dstfile mode outfile
  73. interp: type of interpolation
  74. choices:
  75. FV1: first order conservative Finite Volume
  76. FV2: second order conservative Finite Volume
  77. srctype, dsttype: grid type of source and destination
  78. choices: """ + " ".join(grid_types.keys()) + """
  79. srcfile, dstfile: grid file names, should mostly be netCDF file
  80. mode: modus of operation
  81. choices:
  82. weights: computes weight and stores them in outfile
  83. remap: computes the interpolated values on destination grid and stores them in outfile
  84. outfile: output filename
  85. """
  86. # parse command line arguments
  87. if not len(sys.argv) == 8:
  88. print usage
  89. sys.exit(2)
  90. interp = sys.argv[1]
  91. try:
  92. srctype = grid_types[sys.argv[2]]
  93. except KeyError:
  94. print "Error: srctype needs to be one of the following: " + " ".join(grid_types.keys()) + "."
  95. exit(2)
  96. srcfile = sys.argv[3]
  97. try:
  98. dsttype = grid_types[sys.argv[4]]
  99. except KeyError:
  100. print "Error: srctype needs to be one of the following: " + " ".join(grid_types.keys()) + "."
  101. exit(2)
  102. dstfile = sys.argv[5]
  103. mode = sys.argv[6]
  104. outfile = sys.argv[7]
  105. if not mode in ("weights", "remap"):
  106. print "Error: mode must be of of the following: weights remap."
  107. exit(2)
  108. remap.mpi_init()
  109. rank = remap.mpi_rank()
  110. size = remap.mpi_size()
  111. print rank, "/", size
  112. print "Reading grids from netCDF files."
  113. if "reader" in srctype:
  114. src_lon, src_lat = srctype["reader"](srcfile)
  115. else:
  116. src = nc.Dataset(srcfile)
  117. # the following two lines do not perform the actual read
  118. # the file is read later when assigning to the ctypes array
  119. # -> no unnecessary array copying in memory
  120. src_lon = src.variables[srctype["lon_name"]]
  121. src_lat = src.variables[srctype["lat_name"]]
  122. if "reader" in dsttype:
  123. dst_lon, dst_lat = dsttype["reader"](dstfile)
  124. else:
  125. dst = nc.Dataset(dstfile)
  126. dst_lon = dst.variables[dsttype["lon_name"]]
  127. dst_lat = dst.variables[dsttype["lat_name"]]
  128. src_ncell, src_nvert = src_lon.shape
  129. dst_ncell, dst_nvert = dst_lon.shape
  130. def compute_distribution(ncell):
  131. "Returns the local number and starting position in global array."
  132. if rank < ncell % size:
  133. return ncell//size + 1, \
  134. (ncell//size + 1)*rank
  135. else:
  136. return ncell//size, \
  137. (ncell//size + 1)*(ncell%size) + (ncell//size)*(rank - ncell%size)
  138. src_ncell_loc, src_loc_start = compute_distribution(src_ncell)
  139. dst_ncell_loc, dst_loc_start = compute_distribution(dst_ncell)
  140. print "src", src_ncell_loc, src_loc_start
  141. print "dst", dst_ncell_loc, dst_loc_start
  142. c_src_lon = (ct.c_double * (src_ncell_loc*src_nvert))()
  143. c_src_lat = (ct.c_double * (src_ncell_loc*src_nvert))()
  144. c_dst_lon = (ct.c_double * (dst_ncell_loc*dst_nvert))()
  145. c_dst_lat = (ct.c_double * (dst_ncell_loc*dst_nvert))()
  146. c_src_lon[:] = nc.numpy.reshape(src_lon[src_loc_start:src_loc_start+src_ncell_loc,:], (len(c_src_lon),1))
  147. c_src_lat[:] = nc.numpy.reshape(src_lat[src_loc_start:src_loc_start+src_ncell_loc,:], (len(c_src_lon),1))
  148. c_dst_lon[:] = nc.numpy.reshape(dst_lon[dst_loc_start:dst_loc_start+dst_ncell_loc,:], (len(c_dst_lon),1))
  149. c_dst_lat[:] = nc.numpy.reshape(dst_lat[dst_loc_start:dst_loc_start+dst_ncell_loc,:], (len(c_dst_lon),1))
  150. print "Calling remap library to compute weights."
  151. srcpole = (ct.c_double * (3))()
  152. dstpole = (ct.c_double * (3))()
  153. srcpole[:] = srctype["pole"]
  154. dstpole[:] = dsttype["pole"]
  155. c_src_ncell = ct.c_int(src_ncell_loc)
  156. c_src_nvert = ct.c_int(src_nvert)
  157. c_dst_ncell = ct.c_int(dst_ncell_loc)
  158. c_dst_nvert = ct.c_int(dst_nvert)
  159. order = ct.c_int(interp_types[interp])
  160. c_nweight = ct.c_int()
  161. print "src:", src_ncell, src_nvert
  162. print "dst:", dst_ncell, dst_nvert
  163. remap.remap_get_num_weights(c_src_lon, c_src_lat, c_src_nvert, c_src_ncell, srcpole,
  164. c_dst_lon, c_dst_lat, c_dst_nvert, c_dst_ncell, dstpole,
  165. order, ct.byref(c_nweight))
  166. nwgt = c_nweight.value
  167. c_weights = (ct.c_double * nwgt)()
  168. c_dst_idx = (ct.c_int * nwgt)()
  169. c_src_idx = (ct.c_int * nwgt)()
  170. remap.remap_get_weights(c_weights, c_src_idx, c_dst_idx)
  171. wgt_glo = MPI.COMM_WORLD.gather(c_weights[:])
  172. src_idx_glo = MPI.COMM_WORLD.gather(c_src_idx[:])
  173. dst_idx_glo = MPI.COMM_WORLD.gather(c_dst_idx[:])
  174. if rank == 0 and mode == 'weights':
  175. nwgt_glo = sum(len(wgt) for wgt in wgt_glo)
  176. print "Writing", nwgt_glo, "weights to netCDF-file '" + outfile + "'."
  177. f = nc.Dataset(outfile,'w')
  178. f.createDimension('n_src', src_ncell)
  179. f.createDimension('n_dst', dst_ncell)
  180. f.createDimension('n_weight', nwgt_glo)
  181. var = f.createVariable('src_idx', 'i', ('n_weight'))
  182. var[:] = np.hstack(src_idx_glo) + 1 # make indices start from 1
  183. var = f.createVariable('dst_idx', 'i', ('n_weight'))
  184. var[:] = np.hstack(dst_idx_glo) + 1 # make indices start from 1
  185. var = f.createVariable('weight', 'd', ('n_weight'))
  186. var[:] = np.hstack(wgt_glo)
  187. f.close()
  188. def test_fun(x, y, z):
  189. return (1-x**2)*(1-y**2)*z
  190. def test_fun_ll(lat, lon):
  191. #return np.cos(lat*math.pi/180)*np.cos(lon*math.pi/180)
  192. return 2.0 + np.cos(lat*math.pi/180.)**2 * np.cos(2*lon*math.pi/180.);
  193. #UNUSED
  194. #def sphe2cart(lat, lon):
  195. # phi = math.pi/180*lon[:]
  196. # theta = math.pi/2 - math.pi/180*lat[:]
  197. # return np.sin(theta)*np.cos(phi), np.sin(theta)*np.sin(phi), np.cos(theta)
  198. if mode == 'remap':
  199. c_centre_lon = (ct.c_double * src_ncell_loc)()
  200. c_centre_lat = (ct.c_double * src_ncell_loc)()
  201. c_areas = (ct.c_double * src_ncell_loc)()
  202. remap.remap_get_barycentres_and_areas(c_src_lon, c_src_lat, c_src_nvert, c_src_ncell, srcpole,
  203. c_centre_lon, c_centre_lat, c_areas)
  204. src_val_loc = test_fun_ll(np.array(c_centre_lat[:]), np.array(c_centre_lon[:]))
  205. src_val_glo = MPI.COMM_WORLD.gather(src_val_loc)
  206. c_centre_lon = (ct.c_double * dst_ncell_loc)()
  207. c_centre_lat = (ct.c_double * dst_ncell_loc)()
  208. c_areas = (ct.c_double * dst_ncell_loc)()
  209. remap.remap_get_barycentres_and_areas(c_dst_lon, c_dst_lat, c_dst_nvert, c_dst_ncell, dstpole,
  210. c_centre_lon, c_centre_lat, c_areas)
  211. dst_val_loc = test_fun_ll(np.array(c_centre_lat[:]), np.array(c_centre_lon[:]))
  212. dst_val_glo = MPI.COMM_WORLD.gather(dst_val_loc)
  213. dst_areas_glo = MPI.COMM_WORLD.gather(np.array(c_areas[:]))
  214. dst_centre_lon_glo = MPI.COMM_WORLD.gather(np.array(c_centre_lon[:]))
  215. dst_centre_lat_glo = MPI.COMM_WORLD.gather(np.array(c_centre_lat[:]))
  216. if rank == 0 and mode == 'remap':
  217. from scipy import sparse
  218. A = sparse.csr_matrix(sparse.coo_matrix((np.hstack(wgt_glo),(np.hstack(dst_idx_glo),np.hstack(src_idx_glo)))))
  219. src_val = np.hstack(src_val_glo)
  220. dst_ref = np.hstack(dst_val_glo)
  221. dst_areas = np.hstack(dst_areas_glo)
  222. dst_centre_lon = np.hstack(dst_centre_lon_glo)
  223. dst_centre_lat = np.hstack(dst_centre_lat_glo)
  224. print "source:", src_val.shape
  225. print "destin:", dst_ref.shape
  226. dst_val = A*src_val
  227. err = dst_val - dst_ref
  228. print "absolute maximum error, maximum value:", np.max(np.abs(err)), np.max(np.abs(dst_ref))
  229. print "relative maximum error, normalized L2 error, average target cell size (edgelength of same-area square):"
  230. print np.max(np.abs(err))/np.max(np.abs(dst_ref)), np.linalg.norm(err)/np.linalg.norm(dst_ref), np.mean(np.sqrt(dst_areas))
  231. f = nc.Dataset(outfile,'w')
  232. f.createDimension('n_vert', dst_nvert)
  233. f.createDimension('n_cell', dst_ncell)
  234. var = f.createVariable('lat', 'd', ('n_cell'))
  235. var.setncattr("long_name", "latitude")
  236. var.setncattr("units", "degrees_north")
  237. var.setncattr("bounds", "bounds_lat")
  238. var[:] = dst_centre_lat
  239. var = f.createVariable('lon', 'd', ('n_cell'))
  240. var.setncattr("long_name", "longitude")
  241. var.setncattr("units", "degrees_east")
  242. var.setncattr("bounds", "bounds_lon")
  243. var[:] = dst_centre_lon
  244. var = f.createVariable('bounds_lon', 'd', ('n_cell','n_vert'))
  245. var[:] = dst_lon
  246. var = f.createVariable('bounds_lat', 'd', ('n_cell','n_vert'))
  247. var[:] = dst_lat
  248. var = f.createVariable('val', 'd', ('n_cell'))
  249. var.setncattr("coordinates", "lon lat")
  250. var[:] = dst_val
  251. var = f.createVariable('val_ref', 'd', ('n_cell'))
  252. var.setncattr("coordinates", "lon lat")
  253. var[:] = dst_ref
  254. var = f.createVariable('err', 'd', ('n_cell'))
  255. var.setncattr("coordinates", "lon lat")
  256. var[:] = err
  257. var = f.createVariable('area', 'd', ('n_cell'))
  258. var.setncattr("coordinates", "lon lat")
  259. var[:] = dst_areas[:] # dest
  260. f.close()
  261. if not "reader" in srctype:
  262. src.close()
  263. if not "reader" in dsttype:
  264. dst.close()