Skip to content

Commit dfb2b79

Browse files
committed
run aggregations
1 parent 3d83c93 commit dfb2b79

File tree

4 files changed

+149
-40
lines changed

4 files changed

+149
-40
lines changed

cbsurge/azure/blob_storage.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,22 @@ async def list_blobs(self, prefix=None):
138138
return [blob.name async for blob in self.container_client.list_blobs(name_starts_with=prefix)]
139139

140140

141+
async def copy_file(self, source_blob=None, destination_blob=None):
142+
"""
143+
Copy a file from one blob to another.
144+
Args:
145+
source_blob: (str) The name of the source blob to copy.
146+
destination_blob: (str) The name of the destination blob to copy to.
147+
148+
Returns:
149+
150+
"""
151+
logging.info("Copying blob: %s to %s", source_blob, destination_blob)
152+
source_blob_client = self.container_client.get_blob_client(blob=source_blob)
153+
destination_blob_client = self.container_client.get_blob_client(blob=destination_blob)
154+
await destination_blob_client.start_copy_from_url(source_blob_client.url)
155+
return destination_blob_client.url
156+
141157
async def close(self):
142158
"""
143159
Close the Azure Blob Storage Manager.

cbsurge/azure/fileshare.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,19 @@ async def download_file(self, file_name, download_path):
138138
progress_bar.close()
139139
return file_name
140140

141+
async def copy_file(self, source_file, destination_file):
142+
"""
143+
Copy a file from one location to another in the Azure File Share.
144+
Args:
145+
source_file: The file to copy.
146+
destination_file: The destination file.
147+
148+
Returns:
149+
150+
"""
151+
source_file_client = self.share_client.get_file_client(source_file)
152+
destination_file_client = self.share_client.get_file_client(destination_file)
153+
154+
await destination_file_client.start_copy_from_url(source_file_client.url)
155+
return destination_file
156+

cbsurge/exposure/population/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,5 @@
77
"active": [15, 64],
88
"elderly": [65, 100],
99
}
10+
WORLDPOP_SEXES = ["M", "F"]
11+
DATA_YEAR = 2020

cbsurge/exposure/population/worldpop.py

Lines changed: 115 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@
66
from asyncio import subprocess
77
from html.parser import HTMLParser
88
from concurrent.futures import ThreadPoolExecutor
9+
from typing import List, Optional
910

1011
import aiofiles
1112
import httpx
1213
import rasterio
14+
from rasterio.windows import Window
1315
import numpy as np
1416
from tqdm.asyncio import tqdm_asyncio
1517

1618
from cbsurge.azure.blob_storage import AzureBlobStorageManager
17-
from cbsurge.exposure.population.constants import AZ_ROOT_FILE_PATH, WORLDPOP_AGE_MAPPING
19+
from cbsurge.exposure.population.constants import AZ_ROOT_FILE_PATH, WORLDPOP_AGE_MAPPING, DATA_YEAR
1820

1921
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
2022
logging.getLogger("azure").setLevel(logging.WARNING)
@@ -66,7 +68,7 @@ def chunker_function(iterable, chunk_size=4):
6668
yield iterable[i:i + chunk_size]
6769

6870

69-
async def get_available_data(country_code=None, year="2020"):
71+
async def get_available_data(country_code=None, year=DATA_YEAR):
7072
"""
7173
Args:
7274
country_code: The country code for which to fetch data
@@ -256,7 +258,7 @@ async def get_links_from_table(data_id=None):
256258
return await extract_links_from_table(response.text)
257259

258260

259-
async def download_data(country_code=None, year="2020", force_reprocessing=False, download_path=None):
261+
async def download_data(country_code=None, year=DATA_YEAR, force_reprocessing=False, download_path=None):
260262
"""
261263
Download all available data for a given country and year.
262264
Args:
@@ -299,90 +301,163 @@ async def download_data(country_code=None, year="2020", force_reprocessing=False
299301
await storage_manager.close()
300302

301303

302-
303-
def create_sum(input_file_paths, output_file_path):
304+
def create_sum(input_file_paths, output_file_path, block_size=(256, 256)):
304305
"""
305306
Sum multiple raster files and save the result to an output file, processing in blocks.
307+
If the data is not blocked, create blocks within the function.
306308
307309
Args:
308310
input_file_paths (list of str): Paths to input raster files.
309311
output_file_path (str): Path to save the summed raster file.
312+
block_size (tuple): Tuple representing the block size (rows, cols). Default is (256, 256).
310313
311314
Returns:
312315
None
313316
"""
314-
datasets = [rasterio.open(file_path) for file_path in input_file_paths]
317+
logging.info("Starting create_sum function")
318+
logging.info("Input files: %s", input_file_paths)
319+
logging.info("Output file: %s", output_file_path)
320+
logging.info("Block size: %s", block_size)
321+
322+
# Open all input files
323+
try:
324+
datasets = [rasterio.open(file_path) for file_path in input_file_paths]
325+
except Exception as e:
326+
logging.error("Error opening input files: %s", e)
327+
return
328+
329+
logging.info("Successfully opened input raster files")
315330

316331
# Use the first dataset as reference for metadata
317332
ref_meta = datasets[0].meta.copy()
318333
ref_meta.update(dtype="float32", count=1, nodata=0)
319334

320-
# Create the output dataset
321-
with rasterio.open(output_file_path, "w", **ref_meta) as dst:
322-
for ji, window in dst.block_windows(1):
323-
output_data = np.zeros((window.height, window.width), dtype=np.float32)
324-
for src in datasets:
325-
input_data = src.read(1, window=window)
326-
input_data = np.where(input_data == src.nodata, 0, input_data)
327-
output_data += input_data
335+
rows, cols = datasets[0].shape
336+
logging.info("Raster dimensions: %d rows x %d cols", rows, cols)
328337

329-
# Write the summed block to the output file
330-
dst.write(output_data, window=window, indexes=1)
338+
# Create the output file
339+
try:
340+
with rasterio.open(output_file_path, "w", **ref_meta) as dst:
341+
logging.info("Output file created successfully")
342+
343+
# Process raster in blocks
344+
for i in range(0, rows, block_size[0]):
345+
for j in range(0, cols, block_size[1]):
346+
window = Window(j, i, min(block_size[1], cols - j), min(block_size[0], rows - i))
347+
logging.info("Processing block: row %d to %d, col %d to %d", i, i + block_size[0], j,
348+
j + block_size[1])
349+
350+
output_data = np.zeros((window.height, window.width), dtype=np.float32)
351+
352+
for idx, src in enumerate(datasets):
353+
try:
354+
input_data = src.read(1, window=window)
355+
input_data = np.where(input_data == src.nodata, 0, input_data)
356+
output_data += input_data
357+
logging.debug("Added data from raster %d", idx + 1)
358+
except Exception as e:
359+
logging.error("Error reading block from raster %d: %s", idx + 1, e)
360+
361+
dst.write(output_data, window=window, indexes=1)
362+
363+
logging.info("Finished processing all blocks")
364+
except Exception as e:
365+
logging.error("Error creating or writing to output file: %s", e)
366+
finally:
367+
# Close all input datasets
368+
for src in datasets:
369+
src.close()
370+
logging.info("Closed all input raster files")
331371

332-
for src in datasets:
333-
src.close()
372+
logging.info("create_sum function completed successfully")
334373

335374

336-
async def process_aggregates(country_code=None, age_group=None, sex=None):
375+
async def process_aggregates(country_code: str, age_group: Optional[str] = None, sex: Optional[str] = None):
337376
"""
338-
Process the aggregate files.
377+
Process the aggregate files based on sex and age group.
378+
339379
Args:
340-
country_code: The country code for which to process the data.
341-
age_group: The age group to process the data.
342-
sex: The sex to process the data for.
380+
country_code (str): The country code to process the data for.
381+
age_group (Optional[str]): The age group to process (child, active, elderly).
382+
sex (Optional[str]): The sex to process (M, F).
343383
"""
344384
assert country_code, "Country code must be provided"
345385
assert sex or age_group, "Either age or sex must be provided"
346386

347387
async with AzureBlobStorageManager(conn_str=os.getenv("AZURE_STORAGE_CONNECTION_STRING")) as storage_manager:
348388
logging.info("Processing aggregate files for country: %s", country_code)
349389

350-
async def process_group(sex: str, age_group: str = None):
390+
async def process_group(sexes: List[str], age_group: Optional[str] = None, output_blob_path: Optional[str] = None):
351391
"""
352-
Process the files for a specific sex (M/F) and optionally an age group.
392+
Processes a group of files for a specific sex and/or age group.
393+
394+
Args:
395+
sexes (List[str]): List of sexes to process (e.g., ['M', 'F']).
396+
age_group (Optional[str]): The age group to process (child, active, elderly).
397+
output_blob_path (Optional[str]): Path to store the final output file.
353398
"""
354-
path = f"{AZ_ROOT_FILE_PATH}/2020/{country_code}/{sex}/"
399+
# Construct paths for input blobs
400+
paths = [f"{AZ_ROOT_FILE_PATH}/{DATA_YEAR}/{country_code}/{sex_group}/" for sex_group in sexes]
355401
if age_group:
356402
assert age_group in WORLDPOP_AGE_MAPPING, "Invalid age group provided"
357-
path += f"{age_group}"
403+
paths = [f"{path}{age_group}/" for path in paths]
404+
405+
# Fetch blobs for all paths
406+
blobs = []
407+
for path in paths:
408+
blobs += await storage_manager.list_blobs(path)
358409

359-
blobs = await storage_manager.list_blobs(path)
360410
if not blobs:
361-
logging.warning("No blobs found for path: %s", path)
411+
logging.warning("No blobs found for paths: %s", paths)
362412
return
363413

414+
# Download and process blobs
364415
dataset_files = []
365416
with tempfile.TemporaryDirectory(delete=False) as temp_dir:
366417
for blob in blobs:
418+
local_file = os.path.join(temp_dir, os.path.basename(blob))
367419
await storage_manager.download_blob(blob_name=blob, local_directory=temp_dir)
368-
dataset_files.append(f"{temp_dir}/{blob.split('/')[-1]}")
420+
dataset_files.append(local_file)
369421

370-
output_file = f"{temp_dir}/{country_code}_{age_group or 'ALL'}_{sex}.tif"
371-
with ThreadPoolExecutor() as executor:
372-
executor.submit(create_sum, dataset_files, output_file)
373-
# TODO: Upload the output file to Azure Blob Storage `aggregate` folder
374-
shutil.move(output_file, f"data/{country_code}_{age_group or 'ALL'}_{sex}.tif")
422+
# Prepare output directory
423+
os.makedirs(f"data/{output_blob_path}", exist_ok=True)
424+
output_file = f"{temp_dir}/{country_code}_{age_group or 'ALL'}_{'_'.join(sexes)}.tif"
375425

426+
# Perform summation using create_sum
427+
with ThreadPoolExecutor() as executor:
428+
executor.submit(create_sum, input_file_paths=dataset_files, output_file_path=output_file, block_size=(512, 512))
429+
430+
# Save final output
431+
final_output_path = f"data/{output_blob_path}/{os.path.basename(output_file)}"
432+
await storage_manager.upload_blob(file_path=output_file, blob_name=f"{output_blob_path}/{os.path.basename(output_file)}")
433+
# shutil.copy2(output_file, final_output_path)
434+
# logging.info("Output saved to: %s", final_output_path)
435+
logging.info("Output saved to: %s", f"{output_blob_path}/{os.path.basename(output_file)}")
436+
# Processing logic for combinations of sex and age group
376437
if sex and age_group:
377-
await process_group(sex, age_group)
438+
logging.info("Processing for sex '%s' and age group '%s'", sex, age_group)
439+
output_blob_path = f"{AZ_ROOT_FILE_PATH}/{DATA_YEAR}/{country_code}/aggregate/{sex}/{age_group}"
440+
await process_group([sex], age_group, output_blob_path)
378441
elif age_group:
379-
await process_group('M', age_group)
380-
await process_group('F', age_group)
442+
logging.info("Processing for age group '%s' (both sexes)", age_group)
443+
output_blob_path = f"{AZ_ROOT_FILE_PATH}/{DATA_YEAR}/{country_code}/aggregate/{age_group}"
444+
await process_group(['M', 'F'], age_group, output_blob_path)
381445
elif sex:
382-
await process_group(sex)
446+
logging.info("Processing for sex '%s' (all age groups)", sex)
447+
output_blob_path = f"{AZ_ROOT_FILE_PATH}/{DATA_YEAR}/{country_code}/aggregate/{sex}"
448+
await process_group([sex], None, output_blob_path)
449+
# else:
450+
# # Process all
451+
# logging.info("Processing all data")
452+
# output_blob_path = f"{AZ_ROOT_FILE_PATH}/{DATA_YEAR}/{country_code}/aggregate"
453+
# await process_group(['M', 'F'], None, output_blob_path)
454+
# await process_group(['M', 'F'], 'child', output_blob_path)
455+
# await process_group(['M', 'F'], 'active', output_blob_path)
456+
# await process_group(['M', 'F'], 'elderly', output_blob_path)
457+
logging.info("Processing complete")
383458

384459

385460
if __name__ == "__main__":
386461
# asyncio.run(download_data(force_reprocessing=False))
387462

388-
create_sum(["/media/thuha/Data/worldpop_data/m/0-12/BDI_m_0_2020_constrained.tif", "/media/thuha/Data/worldpop_data/f/0-12/BDI_f_0_2020_constrained.tif"], "data/BDI_0-12.tif")
463+
create_sum(["/media/thuha/Data/worldpop_data/m/0-12/BDI_m_0_2020_constrained.tif", "/media/thuha/Data/worldpop_data/f/0-12/BDI_f_0_2020_constrained.tif"], "data/BDI_0-12.tif")

0 commit comments

Comments
 (0)