nn_interp.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. import numpy as np
  2. def nn_interp(input_arr):
  3. shape = np.shape(input_arr)
  4. n_shape = np.size(shape)
  5. if(n_shape < 2):
  6. print('error. cant do nn_interp on arrays with rank < 2')
  7. sys.exit()
  8. else:
  9. iter_arr = np.ma.copy(input_arr)
  10. while( np.count_nonzero(iter_arr.mask) > 0 ):
  11. for shift in (-1,1):
  12. for axis in (-1,-2):
  13. arr_shifted=np.roll(iter_arr,shift=shift,axis=axis)
  14. idx=~arr_shifted.mask * iter_arr.mask
  15. iter_arr[idx]=arr_shifted[idx]
  16. return iter_arr
  17. def nn_interp_1d(input_arr):
  18. shape = np.shape(input_arr)
  19. n_shape = np.size(shape)
  20. iter_arr = np.ma.copy(input_arr)
  21. while( np.count_nonzero(iter_arr.mask) > 0 ):
  22. for shift in (-1,1):
  23. arr_shifted=np.roll(iter_arr,shift=shift,axis=-1)
  24. idx=~arr_shifted.mask * iter_arr.mask
  25. iter_arr[idx]=arr_shifted[idx]
  26. return iter_arr