Skip to content

Commit 8981913

Browse files
iferencikIoan Ferencik
andauthored
Feat/building filtering (#119)
* building filtyering inside a mask * remove rename col --------- Co-authored-by: Ioan Ferencik <[email protected]>
1 parent b921ecf commit 8981913

File tree

5 files changed

+282
-23
lines changed

5 files changed

+282
-23
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,6 @@ cython_debug/
167167
Pipfile
168168
Pipfile.lock
169169

170-
data/
170+
data/
171+
cbsurge/azure/auth.py
172+
cbsurge/azure/token_cache.json

cbsurge/constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from osgeo import ogr
12
AGE_STRUCTURES_ROOT_URL = "https://hub.worldpop.org/geodata"
23
AZURE_BLOB_CONTAINER_NAME = "stacdata"
3-
AZURE_FILESHARE_NAME = "cbrapida"
4+
AZURE_FILESHARE_NAME = "cbrapida"
5+
ARROWTYPE2OGRTYPE = {'string':ogr.OFTString, 'double':ogr.OFTReal, 'int64':ogr.OFTInteger64, 'int':ogr.OFTInteger}

cbsurge/exposure/builtenv/buildings/fgb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from cbsurge.exposure.builtenv.buildings.fgbgdal import OVERPASS_API_URL, GMOSM_BUILDINGS_ROOT
66
from pyogrio.raw import open_arrow, write_arrow, read
77
from cbsurge.exposure.builtenv.buildings.pmt import WEB_MERCATOR_TMS
8+
from cbsurge.constants import ARROWTYPE2OGRTYPE
89
import morecantile as m
910
from cbsurge import util
1011
import logging
@@ -60,7 +61,7 @@ def render(self, task):
6061
}
6162
logger = logging.getLogger(__name__)
6263

63-
ARROWTYPE2OGRTYPE = {'string':ogr.OFTString, 'double':ogr.OFTReal, 'int64':ogr.OFTInteger64, 'int':ogr.OFTInteger}
64+
6465

6566
def country_info(bbox=None, overpass_url=OVERPASS_API_URL):
6667
"""
Lines changed: 273 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,34 @@
11
import logging
2-
logger = logging.getLogger(__name__)
3-
from osgeo import gdal, ogr
4-
gdal.UseExceptions()
2+
import multiprocessing
3+
import random
4+
from multiprocessing import sharedctypes
5+
import threading
6+
import numpy as np
7+
import pyarrow as pa
8+
import pyproj
9+
import shapely
10+
from osgeo import gdal, ogr, osr
511

612

13+
from cbsurge import util
14+
from collections import deque
15+
from rich.progress import Progress
16+
import concurrent
17+
from pyogrio import read_info, read_arrow, read_dataframe
18+
from rasterio import warp
19+
# from pyogrio.raw import open_arrow, read_arrow
20+
import rasterio
21+
from rasterio.windows import Window
22+
from pyarrow import compute as pc
23+
from cbsurge.constants import ARROWTYPE2OGRTYPE
24+
logger = logging.getLogger(__name__)
25+
gdal.UseExceptions()
726

8-
'''
9-
ogr2ogr -sql "SELECT ST_PointOnSurface(geometry) as geometry, ogc_fid, area_in_meters, confidence FROM buildings" -dialect sqlite /tmp/bldgs1c.fgb /tmp/bldgs1.fgb
10-
'''
27+
def cb(complete, message, stop_event):
28+
#logger.info(f'{complete * 100:.2f}%')
29+
if stop_event and stop_event.is_set():
30+
logger.info(f'GDAL was signalled to stop')
31+
return 0
1132

1233
def create_centroid(src_path=None, src_layer_name=None, dst_path=None, dst_srs=None):
1334
"""
@@ -27,6 +48,10 @@ def create_centroid(src_path=None, src_layer_name=None, dst_path=None, dst_srs=N
2748
geometryType='POINT'
2849

2950
)
51+
52+
if not ogr.GetDriverByName(options['format']).TestCapability(ogr.OLCFastFeatureCount):
53+
logger.debug('No progress bar is available in create_centroid')
54+
3055
if dst_srs is not None:
3156
options['dstSRS'] = dst_srs
3257
options['reproject'] = True
@@ -39,42 +64,270 @@ def create_centroid(src_path=None, src_layer_name=None, dst_path=None, dst_srs=N
3964
del ds
4065

4166

67+
def geoarrow_schema_adapter(schema: pa.Schema) -> pa.Schema:
68+
"""
69+
Convert a geoarrow-compatible schema to a proper geoarrow schema
70+
71+
This assumes there is a single "geometry" column with WKB formatting
72+
73+
Parameters
74+
----------
75+
schema: pa.Schema
76+
77+
Returns
78+
-------
79+
pa.Schema
80+
A copy of the input schema with the geometry field replaced with
81+
a new one with the proper geoarrow ARROW:extension metadata
82+
83+
"""
84+
geometry_field_index = schema.get_field_index("geometry")
85+
geometry_field = schema.field(geometry_field_index)
86+
geoarrow_geometry_field = geometry_field.with_metadata(
87+
{b"ARROW:extension:name": b"geoarrow.wkb"}
88+
)
89+
90+
geoarrow_schema = schema.set(geometry_field_index, geoarrow_geometry_field)
4291

92+
return geoarrow_schema
4393

44-
def buildings_in_mask_ogrio(buildings_centroid_path=None, mask_path=None, mask_pixel_value=None,
45-
horizontal_chunks=None, vertical_chunks=None):
94+
def proj_are_equal(src_srs: osr.SpatialReference = None, dst_srs: osr.SpatialReference = None):
95+
"""
96+
Decides if two projections are equal
97+
@param src_srs: the source projection
98+
@param dst_srs: the dst projection
99+
@return: bool, True if the source is different then dst else false
100+
If the src is ESPG:4326 or EPSG:3857 returns False
101+
"""
102+
auth_code_func_name = ".".join(
103+
[osr.SpatialReference.GetAuthorityCode.__module__, osr.SpatialReference.GetAuthorityCode.__name__])
104+
is_same_func_name = ".".join([osr.SpatialReference.IsSame.__module__, osr.SpatialReference.IsSame.__name__])
105+
try:
106+
proj_are_equal = int(src_srs.GetAuthorityCode(None)) == int(dst_srs.GetAuthorityCode(None))
107+
except Exception as evpe:
108+
logger.error(
109+
f'Failed to compare src and dst projections using {auth_code_func_name}. Trying using {is_same_func_name}')
110+
try:
111+
proj_are_equal = bool(src_srs.IsSame(dst_srs))
112+
except Exception as evpe1:
113+
logger.error(
114+
f'Failed to compare src and dst projections using {is_same_func_name}. Error is \n {evpe1}')
115+
raise evpe1
116+
117+
return proj_are_equal
118+
119+
def filter_buildings_in_block(buildings_ds_path=None, mask_ds=None, block=None, block_id=None, band=1):
120+
121+
try:
122+
123+
window = Window(*block)
124+
col_start, row_start, col_size, row_size = block
125+
m = mask_ds.read(band, window=window)
126+
127+
bbox = rasterio.windows.bounds(window=window, transform=mask_ds.transform)
128+
129+
ds = read_dataframe(buildings_ds_path, bbox=bbox, read_geometry=True)
130+
131+
if len(ds) == 0:
132+
return block_id, None, None
133+
pcoords = ds.centroid.get_coordinates()
134+
pcols, prows = ~mask_ds.transform * (pcoords.x, pcoords.y)
135+
prows, pcols = np.floor(prows).astype('i4'), np.floor(pcols).astype('i4')
136+
prows -= row_start
137+
pcols -= col_start
138+
rowmask = (prows >= 0) & (prows < row_size)
139+
colmask = (pcols >= 0) & (pcols < col_size)
140+
rcmask = rowmask & colmask
141+
if rcmask[rcmask].size == 0:
142+
return block_id, None, None
143+
prows = prows[rcmask]
144+
pcols = pcols[rcmask]
145+
ds = ds[rcmask]
146+
rm = m[prows, pcols] == True
147+
mds = ds[rm]
148+
ao = mds.to_arrow(index=False)
149+
table = pa.table(ao)
150+
# schema = geoarrow_schema_adapter(table.schema)
151+
# table = pa.table(table, schema=schema)
152+
table = table.rename_columns(names={'geometry':'wkb_geometry'})
153+
154+
if mds.size == 0:
155+
return block_id, None, None
156+
out_srs = osr.SpatialReference()
157+
out_srs.SetFromUserInput(':'.join(mds.crs.to_authority()))
158+
return block_id, table, out_srs
159+
except Exception as e:
160+
161+
return block_id, None, None
162+
163+
164+
165+
def worker(work=None, result=None, finished=None):
166+
167+
168+
logger.debug(f'starting building filter thread {threading.current_thread().name}')
169+
while True:
170+
171+
job = None
172+
try:
173+
job = work.pop()
174+
except IndexError as ie:
175+
pass
176+
if job is None:
177+
if finished.is_set():
178+
logger.debug(f'worker is finishing in {threading.current_thread().name}')
179+
break
180+
continue
181+
182+
if finished.is_set():
183+
break
184+
logger.debug(f'Starting job in block {job["block_id"]}')
185+
186+
result.append(filter_buildings_in_block(**job))
187+
188+
189+
190+
def filter_buildings(buildings_path=None, mask_path=None, mask_pixel_value=None,
191+
horizontal_chunks=None, vertical_chunks=None, nworkers=1,
192+
out_path=None):
46193
"""
47194
Select buildings whose centroid is inside the masked pixels
48-
:param buildings_centroid_path:
195+
:param buildings_path:
49196
:param mask_path:
50197
:param mask_pixel_value:
51198
:param horizontal_chunks:
52199
:param vertical_chunks:
53200
:return:
54201
"""
55202

56-
with gdal.Open(mask_path, gdal.OF_READONLY) as mds:
57-
print(mds)
203+
nfiltered = 0
204+
failed = []
205+
206+
207+
with ogr.GetDriverByName('FlatGeobuf').CreateDataSource(out_path) as dst_ds:
208+
with rasterio.open(mask_path) as mask_ds:
209+
msr = osr.SpatialReference()
210+
msr.SetFromUserInput(str(mask_ds.crs))
211+
# should_reproj = not proj_are_equal(src_srs=bsr, dst_srs=msr)
212+
# assert should_reproj is False, f'{buildings_path} and {mask_path} need to be in the same projection'
213+
width = mask_ds.width
214+
height = mask_ds.height
215+
assert mask_ds.count == 1, f'The mask dataset {mask_path} contains more than one band'
216+
#mband = mds.GetRasterBand(1)
217+
#m = mband.ReadAsArray()
218+
#ctypes_shared_mem = sharedctypes.RawArray(np.ctypeslib.as_ctypes_type(m.dtype), m.ravel())
219+
# width = mds.RasterXSize
220+
# height = mds.RasterYSize
221+
block_xsize = width // horizontal_chunks
222+
block_ysize = height // vertical_chunks
223+
blocks = util.gen_blocks(blockxsize=block_xsize, blockysize=block_ysize, width=width, height=height)
224+
nblocks, blocks = util.generator_length(blocks)
225+
stop_event = threading.Event()
226+
jobs = deque()
227+
results = deque()
228+
nworkers = nblocks if nworkers > nblocks else nworkers
229+
print(nworkers)
230+
with concurrent.futures.ThreadPoolExecutor(max_workers=nworkers) as executor:
231+
[executor.submit(worker, jobs, results, stop_event) for i in range(nworkers)]
232+
with Progress() as progress:
233+
total_task = progress.add_task(
234+
description=f'[red]Going to filter buildings from {nblocks} blocks/chunks',
235+
total=nblocks)
236+
for block_id, block in enumerate(blocks):
237+
job = dict(
238+
buildings_ds_path=buildings_path,
239+
mask_ds=mask_ds,
240+
block=block,
241+
block_id=block_id,
242+
243+
)
244+
jobs.append(job)
245+
58246

247+
while True:
248+
try:
249+
try:
250+
block_id, table, dst_srs = results.pop()
59251

252+
if table is None:
253+
nfiltered += 1
254+
progress.update(total_task,
255+
description=f'[red]Filtered buildings from {nfiltered} out of {nblocks} blocks',
256+
advance=1)
257+
continue
258+
259+
if dst_ds.GetLayerCount() == 0:
260+
261+
dst_lyr = dst_ds.CreateLayer('buildings_filtered', geom_type=ogr.wkbPolygon, srs=dst_srs,
262+
)
263+
for name in table.schema.names:
264+
if 'wkb' in name or 'geometry' in name: continue
265+
266+
field = table.schema.field(name)
267+
field_type = ARROWTYPE2OGRTYPE[field.type]
268+
logger.debug(f'Creating field {name}: {field.type}: {field_type}')
269+
dst_lyr.CreateField(ogr.FieldDefn(name, field_type))
270+
271+
272+
try:
273+
274+
dst_lyr.WritePyArrow(table)
275+
except Exception as e:
276+
logger.error(
277+
f'Failed to write {table.num_rows} features/rows in block id {block_id} because {e}. Skipping')
278+
279+
dst_lyr.SyncToDisk()
280+
nfiltered += 1
281+
progress.update(total_task,
282+
description=f'[red]Filtered buildings from {nfiltered} out of {nblocks} blocks',
283+
advance=1)
284+
logger.debug(f'{block_id} was processed')
285+
except IndexError as ie:
286+
if not jobs and progress.finished:
287+
stop_event.set()
288+
break
289+
s = random.random() # this one is necessary for ^C/KeyboardInterrupt
290+
time.sleep(s)
291+
292+
continue
293+
except Exception as e:
294+
failed.append(f'Error in block_id {block_id} failed: {e.__class__.__name__}("{e}")')
295+
progress.update(total_task,
296+
description=f'[red]Filtered buildings from {nfiltered} out of {nblocks} blocks',
297+
advance=1)
298+
nfiltered+=1
299+
300+
except KeyboardInterrupt:
301+
logger.info(f'Cancelling jobs. Please wait/allow for a graceful shutdown')
302+
stop_event.set()
303+
break
304+
logger.info(f'{dst_lyr.GetFeatureCount()} feature were written to {out_path} ')
305+
if failed:
306+
for msg in failed:
307+
logger.error(msg)
60308

61309
if __name__ == '__main__':
62310
import time
63-
httpx_logger = logging.getLogger('httpx')
64-
httpx_logger.setLevel(100)
65-
logging.basicConfig()
66-
logger.setLevel(logging.INFO)
67311

312+
logger = util.setup_logger(name='rapida', level=logging.INFO, make_root=True)
68313

69-
src_path = '/tmp/bldgs1.fgb'
70-
dst_path = '/tmp/bldgs1c.fgb'
71-
mask = '/data/surge/surge/stats/floods_mask.tif'
72314

315+
src_path = '/data/surge/buildings_eqar.fgb'
316+
dst_path = '/data/surge/bldgs1c.fgb'
317+
mask = '/data/surge/surge/stats/floods_mask.tif'
318+
filtered_buildings_path = '/data/surge/bldgs1_filtered.fgb'
73319

74320
start = time.time()
75321

76-
77-
create_centroid(src_path=src_path, src_layer_name='buildings', dst_path=dst_path, dst_srs='EPSG:3857')
322+
#create_centroid(src_path=src_path, src_layer_name='buildings', dst_path=dst_path, dst_srs='ESRI:54034')
323+
filter_buildings(
324+
buildings_path=src_path,
325+
mask_path=mask,
326+
mask_pixel_value=1,
327+
horizontal_chunks=10,
328+
vertical_chunks=20,
329+
out_path=filtered_buildings_path
330+
)
78331

79332
end = time.time()
80333
print((end-start))

cbsurge/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def setup_logger(name=None, make_root=True, level=logging.INFO):
183183

184184
if make_root:
185185
logger = logging.getLogger()
186+
186187
else:
187188
logger = logging.getLogger(name)
188189
formatter = logging.Formatter(

0 commit comments

Comments
 (0)