test-main.cpp 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. #include "mpi.hpp"
  2. #include <netcdf.h>
  3. #include "errhandle.hpp"
  4. #include "libmapper.hpp"
  5. #include "mapper.hpp"
  6. #include "node.hpp"
  7. using namespace sphereRemap ;
  8. int ncread(double *lon, double *lat, double* val, int nElt, const char* filename)
  9. {
  10. int ncid, blonid, blatid, valid, neltid;
  11. size_t nEltCheck;
  12. exit_on_failure(nc_open(filename, NC_NOWRITE, &ncid), std::string("Cannot open netCDF file ") + filename);
  13. exit_on_failure(nc_inq_dimid(ncid, "elt", &neltid), std::string("No dimension elt in file ") + filename);
  14. nc_inq_dimlen(ncid, neltid, &nEltCheck);
  15. exit_on_failure(nElt != nEltCheck, std::string("Array sizes do not match!"));
  16. nc_inq_varid(ncid, "bounds_lon", &blonid);
  17. nc_inq_varid(ncid, "bounds_lat", &blatid);
  18. nc_inq_varid(ncid, "val", &valid);
  19. nc_get_var_double(ncid, blonid, lon);
  20. nc_get_var_double(ncid, blatid, lat);
  21. nc_get_var_double(ncid, valid, val);
  22. nc_close(ncid);
  23. }
  24. int ncwriteValue(double *val, const char* filename)
  25. {
  26. int ncid, valid ;
  27. nc_open(filename, NC_WRITE, &ncid);
  28. nc_inq_varid(ncid, "val", &valid);
  29. nc_put_var_double(ncid, valid, val);
  30. nc_close(ncid);
  31. }
  32. void compute_distribution(int nGlobalElts, int &start, int &nLocalElts)
  33. {
  34. int mpiSize, mpiRank;
  35. MPI_Comm_size(MPI_COMM_WORLD, &mpiSize);
  36. MPI_Comm_rank(MPI_COMM_WORLD, &mpiRank);
  37. start = 0;
  38. nLocalElts = 0;
  39. for (int i = 0; i <= mpiRank; i++)
  40. {
  41. start += nLocalElts;
  42. nLocalElts = nGlobalElts/mpiSize;
  43. if (i < nGlobalElts % mpiSize) nLocalElts++;
  44. }
  45. }
  46. int main()
  47. {
  48. int interpOrder = 2;
  49. /* low resolution */
  50. /*
  51. char srcFile[] = "h14.nc";
  52. char dstFile[] = "r180x90.nc";
  53. double srcPole[] = {0, 0, 0};
  54. double dstPole[] = {0, 0, 1};
  55. int nSrcElt = 13661;
  56. int nDstElt = 16200;
  57. */
  58. /* high resolution */
  59. char srcFile[] = "t740.nc";
  60. char dstFile[] = "r1440x720.nc";
  61. double srcPole[] = {0, 0, 0};
  62. double dstPole[] = {0, 0, 1};
  63. int nSrcElt = 741034;
  64. int nDstElt = 1036800;
  65. int nVert = 10;
  66. int nSrcLocal, nDstLocal, startSrc, startDst;
  67. int nWeight;
  68. int mpi_rank, mpi_size ;
  69. MPI_Init(NULL, NULL);
  70. MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank) ;
  71. MPI_Comm_size(MPI_COMM_WORLD, &mpi_size) ;
  72. double *srcLon = (double *) malloc(nSrcElt*nVert*sizeof(double));
  73. double *srcLat = (double *) malloc(nSrcElt*nVert*sizeof(double));
  74. double *srcVal = (double *) malloc(nSrcElt*sizeof(double));
  75. double *dstLon = (double *) malloc(nDstElt*nVert*sizeof(double));
  76. double *dstLat = (double *) malloc(nDstElt*nVert*sizeof(double));
  77. double *dstVal = (double *) malloc(nDstElt*sizeof(double));
  78. double *globalVal = (double *) malloc(nDstElt*sizeof(double));
  79. ncread(srcLon, srcLat, srcVal, nSrcElt, srcFile);
  80. ncread(dstLon, dstLat, dstVal, nDstElt, dstFile);
  81. compute_distribution(nSrcElt, startSrc, nSrcLocal);
  82. compute_distribution(nDstElt, startDst, nDstLocal);
  83. Mapper mapper(MPI_COMM_WORLD);
  84. mapper.setVerbosity(PROGRESS) ;
  85. mapper.setSourceMesh(srcLon + startSrc*nVert, srcLat + startSrc*nVert, nVert, nSrcLocal, srcPole ) ;
  86. mapper.setTargetMesh(dstLon + startDst*nVert, dstLat + startDst*nVert, nVert, nDstLocal, dstPole ) ;
  87. for(int i=0;i<nDstElt;i++) dstVal[i]=0 ;
  88. mapper.setSourceValue(srcVal+ startSrc) ;
  89. vector<double> timings = mapper.computeWeights(interpOrder);
  90. for(int i=0;i<mapper.nWeights;i++) dstVal[mapper.dstAddress[i]]=dstVal[mapper.dstAddress[i]]+mapper.remapMatrix[i]*srcVal[mapper.sourceWeightId[i]];
  91. /*
  92. remap_get_num_weights(srcLon + startSrc*nVert, srcLat + startSrc*nVert, nVert, nSrcLocal, srcPole,
  93. dstLon + startDst*nVert, dstLat + startDst*nVert, nVert, nDstLocal, dstPole,
  94. interpOrder, &nWeight);
  95. double *weights = (double *) malloc(nWeight*sizeof(double));
  96. int *srcIdx = (int *) malloc(nWeight*sizeof(int));
  97. int *srcRank = (int *) malloc(nWeight*sizeof(int));
  98. int *dstIdx = (int *) malloc(nWeight*sizeof(int));
  99. remap_get_weights(weights, srcIdx, dstIdx);
  100. free(srcLon); free(srcLat);
  101. free(dstLon); free(dstLat);
  102. free(srcIdx); free(dstIdx);
  103. free(srcRank);
  104. free(weights);
  105. #ifdef DEBUG
  106. memory_report();
  107. #endif
  108. */
  109. int* displs=new int[mpi_size] ;
  110. int* recvCount=new int[mpi_size] ;
  111. MPI_Gather(&startDst,1,MPI_INT,displs,1,MPI_INT,0,MPI_COMM_WORLD) ;
  112. MPI_Gather(&nDstLocal,1,MPI_INT,recvCount,1,MPI_INT,0,MPI_COMM_WORLD) ;
  113. MPI_Gatherv(dstVal,nDstLocal,MPI_DOUBLE,globalVal,recvCount,displs,MPI_DOUBLE,0,MPI_COMM_WORLD) ;
  114. if (mpi_rank==0) ncwriteValue(globalVal, dstFile);
  115. MPI_Finalize();
  116. return 0;
  117. }