nn_interp.py 640 B

12345678910111213141516171819202122232425262728
  1. import numpy as np
  2. def nn_interp(input_arr):
  3. shape = np.shape(input_arr)
  4. n_shape = np.size(shape)
  5. if(n_shape < 2):
  6. print('error. cant do nn_interp on arrays with rank < 2')
  7. sys.exit()
  8. else:
  9. iter_arr = np.ma.copy(input_arr)
  10. while( np.count_nonzero(iter_arr.mask) > 0 ):
  11. for shift in (-1,1):
  12. for axis in (-1,-2):
  13. arr_shifted=np.roll(iter_arr,shift=shift,axis=axis)
  14. idx=~arr_shifted.mask * iter_arr.mask
  15. iter_arr[idx]=arr_shifted[idx]
  16. return iter_arr