@@ -53,244 +53,19 @@ def GradientSearchResampler(source_geo_def, target_geo_def):
5353
5454def create_gradient_search_resampler (source_geo_def , target_geo_def ):
5555 """Create a gradient search resampler."""
56- if isinstance (source_geo_def , AreaDefinition ) and isinstance (target_geo_def , AreaDefinition ):
56+ if ((isinstance (source_geo_def , AreaDefinition ) and isinstance (target_geo_def , AreaDefinition )) or
57+ (isinstance (source_geo_def , SwathDefinition ) and isinstance (target_geo_def , AreaDefinition ))):
5758 return ResampleBlocksGradientSearchResampler (source_geo_def , target_geo_def )
58- elif isinstance (source_geo_def , SwathDefinition ) and isinstance (target_geo_def , AreaDefinition ):
59- return StackingGradientSearchResampler (source_geo_def , target_geo_def )
6059 raise NotImplementedError
6160
6261
6362@da .as_gufunc (signature = '(),()->(),()' )
6463def transform (x_coords , y_coords , src_prj = None , dst_prj = None ):
6564 """Calculate projection coordinates."""
66- transformer = pyproj .Transformer .from_crs (src_prj , dst_prj )
65+ transformer = pyproj .Transformer .from_crs (src_prj , dst_prj , always_xy = True )
6766 return transformer .transform (x_coords , y_coords )
6867
6968
70- class StackingGradientSearchResampler (BaseResampler ):
71- """Resample using gradient search based bilinear interpolation, using stacking for dask processing."""
72-
73- def __init__ (self , source_geo_def , target_geo_def ):
74- """Init GradientResampler."""
75- super ().__init__ (source_geo_def , target_geo_def )
76- import warnings
77- warnings .warn ("You are using the Gradient Search Resampler, which is still EXPERIMENTAL." , stacklevel = 2 )
78- self .use_input_coords = None
79- self ._src_dst_filtered = False
80- self .prj = None
81- self .src_x = None
82- self .src_y = None
83- self .src_slices = None
84- self .dst_x = None
85- self .dst_y = None
86- self .dst_slices = None
87- self .src_gradient_xl = None
88- self .src_gradient_xp = None
89- self .src_gradient_yl = None
90- self .src_gradient_yp = None
91- self .dst_polys = {}
92- self .dst_mosaic_locations = None
93- self .coverage_status = None
94-
95- def _get_projection_coordinates (self , datachunks ):
96- """Get projection coordinates."""
97- if self .use_input_coords is None :
98- try :
99- self .src_x , self .src_y = self .source_geo_def .get_proj_coords (
100- chunks = datachunks )
101- src_crs = self .source_geo_def .crs
102- self .use_input_coords = True
103- except AttributeError :
104- self .src_x , self .src_y = self .source_geo_def .get_lonlats (
105- chunks = datachunks )
106- src_crs = pyproj .CRS .from_string ("+proj=longlat" )
107- self .use_input_coords = False
108- try :
109- self .dst_x , self .dst_y = self .target_geo_def .get_proj_coords (
110- chunks = CHUNK_SIZE )
111- dst_crs = self .target_geo_def .crs
112- except AttributeError as err :
113- if self .use_input_coords is False :
114- raise NotImplementedError ('Cannot resample lon/lat to lon/lat with gradient search.' ) from err
115- self .dst_x , self .dst_y = self .target_geo_def .get_lonlats (
116- chunks = CHUNK_SIZE )
117- dst_crs = pyproj .CRS .from_string ("+proj=longlat" )
118- if self .use_input_coords :
119- self .dst_x , self .dst_y = transform (
120- self .dst_x , self .dst_y ,
121- src_prj = dst_crs , dst_prj = src_crs )
122- self .prj = pyproj .Proj (self .source_geo_def .crs )
123- else :
124- self .src_x , self .src_y = transform (
125- self .src_x , self .src_y ,
126- src_prj = src_crs , dst_prj = dst_crs )
127- self .prj = pyproj .Proj (self .target_geo_def .crs )
128-
129- def _get_prj_poly (self , geo_def ):
130- # - None if out of Earth Disk
131- # - False is SwathDefinition
132- if isinstance (geo_def , SwathDefinition ):
133- return False
134- try :
135- poly = get_polygon (self .prj , geo_def )
136- except (NotImplementedError , ValueError ): # out-of-earth disk area or any valid projected boundary coordinates
137- poly = None
138- return poly
139-
140- def _get_src_poly (self , src_y_start , src_y_end , src_x_start , src_x_end ):
141- """Get bounding polygon for source chunk."""
142- geo_def = self .source_geo_def [src_y_start :src_y_end ,
143- src_x_start :src_x_end ]
144- return self ._get_prj_poly (geo_def )
145-
146- def _get_dst_poly (self , idx ,
147- dst_x_start , dst_x_end ,
148- dst_y_start , dst_y_end ):
149- """Get target chunk polygon."""
150- dst_poly = self .dst_polys .get (idx , None )
151- if dst_poly is None :
152- geo_def = self .target_geo_def [dst_y_start :dst_y_end ,
153- dst_x_start :dst_x_end ]
154- dst_poly = self ._get_prj_poly (geo_def )
155- self .dst_polys [idx ] = dst_poly
156- return dst_poly
157-
158- def get_chunk_mappings (self ):
159- """Map source and target chunks together if they overlap."""
160- src_y_chunks , src_x_chunks = self .src_x .chunks
161- dst_y_chunks , dst_x_chunks = self .dst_x .chunks
162-
163- coverage_status = []
164- src_slices , dst_slices = [], []
165- dst_mosaic_locations = []
166-
167- src_x_start = 0
168- for src_x_step in src_x_chunks :
169- src_x_end = src_x_start + src_x_step
170- src_y_start = 0
171- for src_y_step in src_y_chunks :
172- src_y_end = src_y_start + src_y_step
173- # Get source chunk polygon
174- src_poly = self ._get_src_poly (src_y_start , src_y_end ,
175- src_x_start , src_x_end )
176-
177- dst_x_start = 0
178- for x_step_number , dst_x_step in enumerate (dst_x_chunks ):
179- dst_x_end = dst_x_start + dst_x_step
180- dst_y_start = 0
181- for y_step_number , dst_y_step in enumerate (dst_y_chunks ):
182- dst_y_end = dst_y_start + dst_y_step
183- # Get destination chunk polygon
184- dst_poly = self ._get_dst_poly ((x_step_number , y_step_number ),
185- dst_x_start , dst_x_end ,
186- dst_y_start , dst_y_end )
187-
188- covers = check_overlap (src_poly , dst_poly )
189-
190- coverage_status .append (covers )
191- src_slices .append ((src_y_start , src_y_end ,
192- src_x_start , src_x_end ))
193- dst_slices .append ((dst_y_start , dst_y_end ,
194- dst_x_start , dst_x_end ))
195- dst_mosaic_locations .append ((x_step_number , y_step_number ))
196-
197- dst_y_start = dst_y_end
198- dst_x_start = dst_x_end
199- src_y_start = src_y_end
200- src_x_start = src_x_end
201-
202- self .src_slices = src_slices
203- self .dst_slices = dst_slices
204- self .dst_mosaic_locations = dst_mosaic_locations
205- self .coverage_status = coverage_status
206-
207- def _filter_data (self , data , is_src = True , add_dim = False ):
208- """Filter unused chunks from the given array."""
209- if add_dim :
210- if data .ndim not in [2 , 3 ]:
211- raise NotImplementedError ('Gradient search resampling only '
212- 'supports 2D or 3D arrays.' )
213- if data .ndim == 2 :
214- data = data [np .newaxis , :, :]
215-
216- data_out = []
217- for i , covers in enumerate (self .coverage_status ):
218- if covers :
219- if is_src :
220- y_start , y_end , x_start , x_end = self .src_slices [i ]
221- else :
222- y_start , y_end , x_start , x_end = self .dst_slices [i ]
223- try :
224- val = data [:, y_start :y_end , x_start :x_end ]
225- except IndexError :
226- val = data [y_start :y_end , x_start :x_end ]
227- else :
228- val = None
229- data_out .append (val )
230-
231- return data_out
232-
233- def _get_gradients (self ):
234- """Get gradients in X and Y directions."""
235- self .src_gradient_xl , self .src_gradient_xp = np .gradient (
236- self .src_x , axis = [0 , 1 ])
237- self .src_gradient_yl , self .src_gradient_yp = np .gradient (
238- self .src_y , axis = [0 , 1 ])
239-
240- def _filter_src_dst (self ):
241- """Filter source and target chunks."""
242- self .src_x = self ._filter_data (self .src_x )
243- self .src_y = self ._filter_data (self .src_y )
244- self .src_gradient_yl = self ._filter_data (self .src_gradient_yl )
245- self .src_gradient_yp = self ._filter_data (self .src_gradient_yp )
246- self .src_gradient_xl = self ._filter_data (self .src_gradient_xl )
247- self .src_gradient_xp = self ._filter_data (self .src_gradient_xp )
248- self .dst_x = self ._filter_data (self .dst_x , is_src = False )
249- self .dst_y = self ._filter_data (self .dst_y , is_src = False )
250- self ._src_dst_filtered = True
251-
252- def compute (self , data , fill_value = None , ** kwargs ):
253- """Resample the given data using gradient search algorithm."""
254- if 'bands' in data .dims :
255- datachunks = data .sel (bands = data .coords ['bands' ][0 ]).chunks
256- else :
257- datachunks = data .chunks
258- data_dims = data .dims
259- data_coords = data .coords
260-
261- self ._get_projection_coordinates (datachunks )
262-
263- if self .src_gradient_xl is None :
264- self ._get_gradients ()
265- if self .coverage_status is None :
266- self .get_chunk_mappings ()
267- if not self ._src_dst_filtered :
268- self ._filter_src_dst ()
269-
270- data = self ._filter_data (data .data , add_dim = True )
271-
272- res = parallel_gradient_search (data ,
273- self .src_x , self .src_y ,
274- self .dst_x , self .dst_y ,
275- self .src_gradient_xl ,
276- self .src_gradient_xp ,
277- self .src_gradient_yl ,
278- self .src_gradient_yp ,
279- self .dst_mosaic_locations ,
280- self .dst_slices ,
281- ** kwargs )
282-
283- coords = _fill_in_coords (self .target_geo_def , data_coords , data_dims )
284-
285- if fill_value is not None :
286- res = da .where (np .isnan (res ), fill_value , res )
287- if res .ndim > len (data_dims ):
288- res = res .squeeze ()
289-
290- res = xr .DataArray (res , dims = data_dims , coords = coords )
291- return res
292-
293-
29469def check_overlap (src_poly , dst_poly ):
29570 """Check if the two polygons overlap."""
29671 # swath definition case
@@ -491,8 +266,10 @@ def __init__(self, source_geo_def, target_geo_def):
491266 """Init GradientResampler."""
492267 if isinstance (target_geo_def , SwathDefinition ):
493268 raise NotImplementedError ("Cannot resample to a SwathDefinition." )
269+ if isinstance (source_geo_def , SwathDefinition ):
270+ source_geo_def .lons = source_geo_def .lons .persist ()
271+ source_geo_def .lats = source_geo_def .lats .persist ()
494272 super ().__init__ (source_geo_def , target_geo_def )
495- logger .debug ("/!\\ Instantiating an experimental GradientSearch resampler /!\\ " )
496273 self .indices_xy = None
497274
498275 def precompute (self , ** kwargs ):
@@ -590,11 +367,11 @@ def gradient_resampler_indices(source_area, target_area, block_info=None, **kwar
590367def _get_coordinates_in_same_projection (source_area , target_area ):
591368 try :
592369 src_x , src_y = source_area .get_proj_coords ()
593- transformer = pyproj .Transformer .from_crs (target_area .crs , source_area .crs , always_xy = True )
594370 except AttributeError as err :
595- raise NotImplementedError ( "Cannot resample from Swath for now." ) from err
596-
371+ lons , lats = source_area . get_lonlats ()
372+ src_x , src_y = da . compute ( lons , lats )
597373 try :
374+ transformer = pyproj .Transformer .from_crs (target_area .crs , source_area .crs , always_xy = True )
598375 dst_x , dst_y = transformer .transform (* target_area .get_proj_coords ())
599376 except AttributeError as err :
600377 raise NotImplementedError ("Cannot resample to Swath for now." ) from err
@@ -618,7 +395,7 @@ def block_bilinear_interpolator(data, indices_xy, fill_value=np.nan, block_info=
618395 res = ((1 - weight_l ) * (1 - weight_p ) * data [..., l_start , p_start ] +
619396 (1 - weight_l ) * weight_p * data [..., l_start , p_end ] +
620397 weight_l * (1 - weight_p ) * data [..., l_end , p_start ] +
621- weight_l * weight_p * data [..., l_end , p_end ])
398+ weight_l * weight_p * data [..., l_end , p_end ]). astype ( data . dtype )
622399 res = np .where (mask , fill_value , res )
623400 return res
624401
0 commit comments