Skip to content

Commit 3c07255

Browse files
committed
Replace stacking gradient search with resample_blocks variant
1 parent 9becb0d commit 3c07255

File tree

4 files changed

+81
-509
lines changed

4 files changed

+81
-509
lines changed

pyresample/gradient/__init__.py

Lines changed: 10 additions & 233 deletions
Original file line numberDiff line numberDiff line change
@@ -53,244 +53,19 @@ def GradientSearchResampler(source_geo_def, target_geo_def):
5353

5454
def create_gradient_search_resampler(source_geo_def, target_geo_def):
5555
"""Create a gradient search resampler."""
56-
if isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition):
56+
if ((isinstance(source_geo_def, AreaDefinition) and isinstance(target_geo_def, AreaDefinition)) or
57+
(isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition))):
5758
return ResampleBlocksGradientSearchResampler(source_geo_def, target_geo_def)
58-
elif isinstance(source_geo_def, SwathDefinition) and isinstance(target_geo_def, AreaDefinition):
59-
return StackingGradientSearchResampler(source_geo_def, target_geo_def)
6059
raise NotImplementedError
6160

6261

6362
@da.as_gufunc(signature='(),()->(),()')
6463
def transform(x_coords, y_coords, src_prj=None, dst_prj=None):
6564
"""Calculate projection coordinates."""
66-
transformer = pyproj.Transformer.from_crs(src_prj, dst_prj)
65+
transformer = pyproj.Transformer.from_crs(src_prj, dst_prj, always_xy=True)
6766
return transformer.transform(x_coords, y_coords)
6867

6968

70-
class StackingGradientSearchResampler(BaseResampler):
71-
"""Resample using gradient search based bilinear interpolation, using stacking for dask processing."""
72-
73-
def __init__(self, source_geo_def, target_geo_def):
74-
"""Init GradientResampler."""
75-
super().__init__(source_geo_def, target_geo_def)
76-
import warnings
77-
warnings.warn("You are using the Gradient Search Resampler, which is still EXPERIMENTAL.", stacklevel=2)
78-
self.use_input_coords = None
79-
self._src_dst_filtered = False
80-
self.prj = None
81-
self.src_x = None
82-
self.src_y = None
83-
self.src_slices = None
84-
self.dst_x = None
85-
self.dst_y = None
86-
self.dst_slices = None
87-
self.src_gradient_xl = None
88-
self.src_gradient_xp = None
89-
self.src_gradient_yl = None
90-
self.src_gradient_yp = None
91-
self.dst_polys = {}
92-
self.dst_mosaic_locations = None
93-
self.coverage_status = None
94-
95-
def _get_projection_coordinates(self, datachunks):
96-
"""Get projection coordinates."""
97-
if self.use_input_coords is None:
98-
try:
99-
self.src_x, self.src_y = self.source_geo_def.get_proj_coords(
100-
chunks=datachunks)
101-
src_crs = self.source_geo_def.crs
102-
self.use_input_coords = True
103-
except AttributeError:
104-
self.src_x, self.src_y = self.source_geo_def.get_lonlats(
105-
chunks=datachunks)
106-
src_crs = pyproj.CRS.from_string("+proj=longlat")
107-
self.use_input_coords = False
108-
try:
109-
self.dst_x, self.dst_y = self.target_geo_def.get_proj_coords(
110-
chunks=CHUNK_SIZE)
111-
dst_crs = self.target_geo_def.crs
112-
except AttributeError as err:
113-
if self.use_input_coords is False:
114-
raise NotImplementedError('Cannot resample lon/lat to lon/lat with gradient search.') from err
115-
self.dst_x, self.dst_y = self.target_geo_def.get_lonlats(
116-
chunks=CHUNK_SIZE)
117-
dst_crs = pyproj.CRS.from_string("+proj=longlat")
118-
if self.use_input_coords:
119-
self.dst_x, self.dst_y = transform(
120-
self.dst_x, self.dst_y,
121-
src_prj=dst_crs, dst_prj=src_crs)
122-
self.prj = pyproj.Proj(self.source_geo_def.crs)
123-
else:
124-
self.src_x, self.src_y = transform(
125-
self.src_x, self.src_y,
126-
src_prj=src_crs, dst_prj=dst_crs)
127-
self.prj = pyproj.Proj(self.target_geo_def.crs)
128-
129-
def _get_prj_poly(self, geo_def):
130-
# - None if out of Earth Disk
131-
# - False is SwathDefinition
132-
if isinstance(geo_def, SwathDefinition):
133-
return False
134-
try:
135-
poly = get_polygon(self.prj, geo_def)
136-
except (NotImplementedError, ValueError): # out-of-earth disk area or any valid projected boundary coordinates
137-
poly = None
138-
return poly
139-
140-
def _get_src_poly(self, src_y_start, src_y_end, src_x_start, src_x_end):
141-
"""Get bounding polygon for source chunk."""
142-
geo_def = self.source_geo_def[src_y_start:src_y_end,
143-
src_x_start:src_x_end]
144-
return self._get_prj_poly(geo_def)
145-
146-
def _get_dst_poly(self, idx,
147-
dst_x_start, dst_x_end,
148-
dst_y_start, dst_y_end):
149-
"""Get target chunk polygon."""
150-
dst_poly = self.dst_polys.get(idx, None)
151-
if dst_poly is None:
152-
geo_def = self.target_geo_def[dst_y_start:dst_y_end,
153-
dst_x_start:dst_x_end]
154-
dst_poly = self._get_prj_poly(geo_def)
155-
self.dst_polys[idx] = dst_poly
156-
return dst_poly
157-
158-
def get_chunk_mappings(self):
159-
"""Map source and target chunks together if they overlap."""
160-
src_y_chunks, src_x_chunks = self.src_x.chunks
161-
dst_y_chunks, dst_x_chunks = self.dst_x.chunks
162-
163-
coverage_status = []
164-
src_slices, dst_slices = [], []
165-
dst_mosaic_locations = []
166-
167-
src_x_start = 0
168-
for src_x_step in src_x_chunks:
169-
src_x_end = src_x_start + src_x_step
170-
src_y_start = 0
171-
for src_y_step in src_y_chunks:
172-
src_y_end = src_y_start + src_y_step
173-
# Get source chunk polygon
174-
src_poly = self._get_src_poly(src_y_start, src_y_end,
175-
src_x_start, src_x_end)
176-
177-
dst_x_start = 0
178-
for x_step_number, dst_x_step in enumerate(dst_x_chunks):
179-
dst_x_end = dst_x_start + dst_x_step
180-
dst_y_start = 0
181-
for y_step_number, dst_y_step in enumerate(dst_y_chunks):
182-
dst_y_end = dst_y_start + dst_y_step
183-
# Get destination chunk polygon
184-
dst_poly = self._get_dst_poly((x_step_number, y_step_number),
185-
dst_x_start, dst_x_end,
186-
dst_y_start, dst_y_end)
187-
188-
covers = check_overlap(src_poly, dst_poly)
189-
190-
coverage_status.append(covers)
191-
src_slices.append((src_y_start, src_y_end,
192-
src_x_start, src_x_end))
193-
dst_slices.append((dst_y_start, dst_y_end,
194-
dst_x_start, dst_x_end))
195-
dst_mosaic_locations.append((x_step_number, y_step_number))
196-
197-
dst_y_start = dst_y_end
198-
dst_x_start = dst_x_end
199-
src_y_start = src_y_end
200-
src_x_start = src_x_end
201-
202-
self.src_slices = src_slices
203-
self.dst_slices = dst_slices
204-
self.dst_mosaic_locations = dst_mosaic_locations
205-
self.coverage_status = coverage_status
206-
207-
def _filter_data(self, data, is_src=True, add_dim=False):
208-
"""Filter unused chunks from the given array."""
209-
if add_dim:
210-
if data.ndim not in [2, 3]:
211-
raise NotImplementedError('Gradient search resampling only '
212-
'supports 2D or 3D arrays.')
213-
if data.ndim == 2:
214-
data = data[np.newaxis, :, :]
215-
216-
data_out = []
217-
for i, covers in enumerate(self.coverage_status):
218-
if covers:
219-
if is_src:
220-
y_start, y_end, x_start, x_end = self.src_slices[i]
221-
else:
222-
y_start, y_end, x_start, x_end = self.dst_slices[i]
223-
try:
224-
val = data[:, y_start:y_end, x_start:x_end]
225-
except IndexError:
226-
val = data[y_start:y_end, x_start:x_end]
227-
else:
228-
val = None
229-
data_out.append(val)
230-
231-
return data_out
232-
233-
def _get_gradients(self):
234-
"""Get gradients in X and Y directions."""
235-
self.src_gradient_xl, self.src_gradient_xp = np.gradient(
236-
self.src_x, axis=[0, 1])
237-
self.src_gradient_yl, self.src_gradient_yp = np.gradient(
238-
self.src_y, axis=[0, 1])
239-
240-
def _filter_src_dst(self):
241-
"""Filter source and target chunks."""
242-
self.src_x = self._filter_data(self.src_x)
243-
self.src_y = self._filter_data(self.src_y)
244-
self.src_gradient_yl = self._filter_data(self.src_gradient_yl)
245-
self.src_gradient_yp = self._filter_data(self.src_gradient_yp)
246-
self.src_gradient_xl = self._filter_data(self.src_gradient_xl)
247-
self.src_gradient_xp = self._filter_data(self.src_gradient_xp)
248-
self.dst_x = self._filter_data(self.dst_x, is_src=False)
249-
self.dst_y = self._filter_data(self.dst_y, is_src=False)
250-
self._src_dst_filtered = True
251-
252-
def compute(self, data, fill_value=None, **kwargs):
253-
"""Resample the given data using gradient search algorithm."""
254-
if 'bands' in data.dims:
255-
datachunks = data.sel(bands=data.coords['bands'][0]).chunks
256-
else:
257-
datachunks = data.chunks
258-
data_dims = data.dims
259-
data_coords = data.coords
260-
261-
self._get_projection_coordinates(datachunks)
262-
263-
if self.src_gradient_xl is None:
264-
self._get_gradients()
265-
if self.coverage_status is None:
266-
self.get_chunk_mappings()
267-
if not self._src_dst_filtered:
268-
self._filter_src_dst()
269-
270-
data = self._filter_data(data.data, add_dim=True)
271-
272-
res = parallel_gradient_search(data,
273-
self.src_x, self.src_y,
274-
self.dst_x, self.dst_y,
275-
self.src_gradient_xl,
276-
self.src_gradient_xp,
277-
self.src_gradient_yl,
278-
self.src_gradient_yp,
279-
self.dst_mosaic_locations,
280-
self.dst_slices,
281-
**kwargs)
282-
283-
coords = _fill_in_coords(self.target_geo_def, data_coords, data_dims)
284-
285-
if fill_value is not None:
286-
res = da.where(np.isnan(res), fill_value, res)
287-
if res.ndim > len(data_dims):
288-
res = res.squeeze()
289-
290-
res = xr.DataArray(res, dims=data_dims, coords=coords)
291-
return res
292-
293-
29469
def check_overlap(src_poly, dst_poly):
29570
"""Check if the two polygons overlap."""
29671
# swath definition case
@@ -491,8 +266,10 @@ def __init__(self, source_geo_def, target_geo_def):
491266
"""Init GradientResampler."""
492267
if isinstance(target_geo_def, SwathDefinition):
493268
raise NotImplementedError("Cannot resample to a SwathDefinition.")
269+
if isinstance(source_geo_def, SwathDefinition):
270+
source_geo_def.lons = source_geo_def.lons.persist()
271+
source_geo_def.lats = source_geo_def.lats.persist()
494272
super().__init__(source_geo_def, target_geo_def)
495-
logger.debug("/!\\ Instantiating an experimental GradientSearch resampler /!\\")
496273
self.indices_xy = None
497274

498275
def precompute(self, **kwargs):
@@ -590,11 +367,11 @@ def gradient_resampler_indices(source_area, target_area, block_info=None, **kwar
590367
def _get_coordinates_in_same_projection(source_area, target_area):
591368
try:
592369
src_x, src_y = source_area.get_proj_coords()
593-
transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True)
594370
except AttributeError as err:
595-
raise NotImplementedError("Cannot resample from Swath for now.") from err
596-
371+
lons, lats = source_area.get_lonlats()
372+
src_x, src_y = da.compute(lons, lats)
597373
try:
374+
transformer = pyproj.Transformer.from_crs(target_area.crs, source_area.crs, always_xy=True)
598375
dst_x, dst_y = transformer.transform(*target_area.get_proj_coords())
599376
except AttributeError as err:
600377
raise NotImplementedError("Cannot resample to Swath for now.") from err
@@ -618,7 +395,7 @@ def block_bilinear_interpolator(data, indices_xy, fill_value=np.nan, block_info=
618395
res = ((1 - weight_l) * (1 - weight_p) * data[..., l_start, p_start] +
619396
(1 - weight_l) * weight_p * data[..., l_start, p_end] +
620397
weight_l * (1 - weight_p) * data[..., l_end, p_start] +
621-
weight_l * weight_p * data[..., l_end, p_end])
398+
weight_l * weight_p * data[..., l_end, p_end]).astype(data.dtype)
622399
res = np.where(mask, fill_value, res)
623400
return res
624401

pyresample/gradient/_gradient_search.pyx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ cdef inline void bil(const data_type[:, :, :] data, int l0, int p0, float_index
8080
p_b = min(p0 + 1, pmax)
8181
w_p = dp
8282
for i in range(z_size):
83-
res[i] = ((1 - w_l) * (1 - w_p) * data[i, l_a, p_a] +
84-
(1 - w_l) * w_p * data[i, l_a, p_b] +
85-
w_l * (1 - w_p) * data[i, l_b, p_a] +
86-
w_l * w_p * data[i, l_b, p_b])
83+
res[i] = <data_type>((1 - w_l) * (1 - w_p) * data[i, l_a, p_a] +
84+
(1 - w_l) * w_p * data[i, l_a, p_b] +
85+
w_l * (1 - w_p) * data[i, l_b, p_a] +
86+
w_l * w_p * data[i, l_b, p_b])
8787

8888

8989
@cython.boundscheck(False)

pyresample/resampler.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ def resample_blocks(func, src_area, src_arrays, dst_area,
214214
fill_value: Desired value for any invalid values in the output array
215215
kwargs: any other keyword arguments that will be passed on to func.
216216
217+
Returns:
218+
A dask array, chunked as dst_area, containing the resampled data.
219+
217220
218221
Principle of operations:
219222
Resample_blocks works by iterating over chunks on the dst_area domain. For each chunk, the corresponding slice
@@ -235,10 +238,6 @@ def resample_blocks(func, src_area, src_arrays, dst_area,
235238
236239
237240
"""
238-
if dst_area == src_area:
239-
raise ValueError("Source and destination areas are identical."
240-
" Should you be running `map_blocks` instead of `resample_blocks`?")
241-
242241
name = _create_dask_name(name, func,
243242
src_area, src_arrays,
244243
dst_area, dst_arrays,

0 commit comments

Comments
 (0)