# This function extrapolates horizontally each level of a (x,y,z)
# field using the nearest neighbour method.
#
# Usage : python extrap.py
# <2d longitude name in grid file>
# <2d latitude name in grid file>
#
#
# History : Virginie Guemas - Initial version - 2012
# Virginie Guemas - Masking the outputs - March 2014
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
import cdms2 as cdms
import sys,string
from cdms2 import MV
import numpy as N
from numpy import ma
#
# 1. Input arguments
# ===================
#
# Input file and var names :
# ---------------------------
fileIN=sys.argv[1]
varIN=sys.argv[2]
#
# Meshmask file :
# ----------------
fileM=sys.argv[3]
varlon=sys.argv[4]
varlat=sys.argv[5]
#
# Location of points to fill :
# ----------------------------
fileF=sys.argv[6]
varF=sys.argv[7]
#
# Output file name :
# -------------------
fileOUT=sys.argv[8]
#
# 2. Get the input files
# =======================
#
f=cdms.open(fileIN)
var0_a=f(varIN,squeeze=1)
f.close()
mask3d=var0_a.mask
(lz1,ly0,lx0)=var0_a.shape
#
f=cdms.open(fileM)
lon=f(varlon,squeeze=1)
lat=f(varlat,squeeze=1)
f.close()
#
f=cdms.open(fileF)
Pfill=f(varF,squeeze=1)
f.close()
#
var4=N.zeros((lz1,ly0,lx0))
var4=var4.astype('d')
var4=N.where(mask3d==False,var0_a,var4)
pi=N.pi
coslat=N.cos(lat*pi/180)
coslon=N.cos(lon*pi/180)
sinlat=N.sin(lat*pi/180)
sinlon=N.sin(lon*pi/180)
indexes1=N.where(N.sum(Pfill[:,:,:],0)>=1)
for ind in N.arange(indexes1[0].shape[0]) :
jy=indexes1[0][ind]
jx=indexes1[1][ind]
distance=MV.arccos(coslat[jy,jx]*coslat*(coslon[jy,jx]*coslon+sinlon[jy,jx]*sinlon)+sinlat[jy,jx]*sinlat)
indexes2=N.where(Pfill[:,jy,jx]>=1)
for ind2 in N.arange(indexes2[0].shape[0]) :
jz=indexes2[0][ind2]
if mask3d[jz,:,:].mean() < 1 :
distance=N.where(mask3d[jz,:,:]==False,distance,1e20)
dismin=distance.min()
test=cdms.createVariable(var0_a[jz,:,:],mask=MV.where(distance==dismin,False,True))
var4[jz,jy,jx]=test.mean()
maskout=MV.where(Pfill>0.5,False,mask3d)
var4=cdms.createVariable(var4,id=var0_a.id)
var4=MV.where(maskout<0.5,var4,1e20)
var4.getAxis(0).id='z'
var4.getAxis(0)[:]=var0_a.getAxis(0)[:]
var4.getAxis(1).id='y'
var4.getAxis(2).id='x'
var4.id=var0_a.id
h=cdms.open(fileOUT,'w')
h.write(var4)
h.close()