66from asyncio import subprocess
77from html .parser import HTMLParser
88from concurrent .futures import ThreadPoolExecutor
9+ from typing import List , Optional
910
1011import aiofiles
1112import httpx
1213import rasterio
14+ from rasterio .windows import Window
1315import numpy as np
1416from tqdm .asyncio import tqdm_asyncio
1517
1618from cbsurge .azure .blob_storage import AzureBlobStorageManager
17- from cbsurge .exposure .population .constants import AZ_ROOT_FILE_PATH , WORLDPOP_AGE_MAPPING
19+ from cbsurge .exposure .population .constants import AZ_ROOT_FILE_PATH , WORLDPOP_AGE_MAPPING , DATA_YEAR
1820
1921logging .basicConfig (level = logging .INFO , format = "%(asctime)s - %(levelname)s - %(message)s" )
2022logging .getLogger ("azure" ).setLevel (logging .WARNING )
@@ -66,7 +68,7 @@ def chunker_function(iterable, chunk_size=4):
6668 yield iterable [i :i + chunk_size ]
6769
6870
69- async def get_available_data (country_code = None , year = "2020" ):
71+ async def get_available_data (country_code = None , year = DATA_YEAR ):
7072 """
7173 Args:
7274 country_code: The country code for which to fetch data
@@ -256,7 +258,7 @@ async def get_links_from_table(data_id=None):
256258 return await extract_links_from_table (response .text )
257259
258260
259- async def download_data (country_code = None , year = "2020" , force_reprocessing = False , download_path = None ):
261+ async def download_data (country_code = None , year = DATA_YEAR , force_reprocessing = False , download_path = None ):
260262 """
261263 Download all available data for a given country and year.
262264 Args:
@@ -299,90 +301,163 @@ async def download_data(country_code=None, year="2020", force_reprocessing=False
299301 await storage_manager .close ()
300302
301303
302-
303- def create_sum (input_file_paths , output_file_path ):
304+ def create_sum (input_file_paths , output_file_path , block_size = (256 , 256 )):
304305 """
305306 Sum multiple raster files and save the result to an output file, processing in blocks.
307+ If the data is not blocked, create blocks within the function.
306308
307309 Args:
308310 input_file_paths (list of str): Paths to input raster files.
309311 output_file_path (str): Path to save the summed raster file.
312+ block_size (tuple): Tuple representing the block size (rows, cols). Default is (256, 256).
310313
311314 Returns:
312315 None
313316 """
314- datasets = [rasterio .open (file_path ) for file_path in input_file_paths ]
317+ logging .info ("Starting create_sum function" )
318+ logging .info ("Input files: %s" , input_file_paths )
319+ logging .info ("Output file: %s" , output_file_path )
320+ logging .info ("Block size: %s" , block_size )
321+
322+ # Open all input files
323+ try :
324+ datasets = [rasterio .open (file_path ) for file_path in input_file_paths ]
325+ except Exception as e :
326+ logging .error ("Error opening input files: %s" , e )
327+ return
328+
329+ logging .info ("Successfully opened input raster files" )
315330
316331 # Use the first dataset as reference for metadata
317332 ref_meta = datasets [0 ].meta .copy ()
318333 ref_meta .update (dtype = "float32" , count = 1 , nodata = 0 )
319334
320- # Create the output dataset
321- with rasterio .open (output_file_path , "w" , ** ref_meta ) as dst :
322- for ji , window in dst .block_windows (1 ):
323- output_data = np .zeros ((window .height , window .width ), dtype = np .float32 )
324- for src in datasets :
325- input_data = src .read (1 , window = window )
326- input_data = np .where (input_data == src .nodata , 0 , input_data )
327- output_data += input_data
335+ rows , cols = datasets [0 ].shape
336+ logging .info ("Raster dimensions: %d rows x %d cols" , rows , cols )
328337
329- # Write the summed block to the output file
330- dst .write (output_data , window = window , indexes = 1 )
338+ # Create the output file
339+ try :
340+ with rasterio .open (output_file_path , "w" , ** ref_meta ) as dst :
341+ logging .info ("Output file created successfully" )
342+
343+ # Process raster in blocks
344+ for i in range (0 , rows , block_size [0 ]):
345+ for j in range (0 , cols , block_size [1 ]):
346+ window = Window (j , i , min (block_size [1 ], cols - j ), min (block_size [0 ], rows - i ))
347+ logging .info ("Processing block: row %d to %d, col %d to %d" , i , i + block_size [0 ], j ,
348+ j + block_size [1 ])
349+
350+ output_data = np .zeros ((window .height , window .width ), dtype = np .float32 )
351+
352+ for idx , src in enumerate (datasets ):
353+ try :
354+ input_data = src .read (1 , window = window )
355+ input_data = np .where (input_data == src .nodata , 0 , input_data )
356+ output_data += input_data
357+ logging .debug ("Added data from raster %d" , idx + 1 )
358+ except Exception as e :
359+ logging .error ("Error reading block from raster %d: %s" , idx + 1 , e )
360+
361+ dst .write (output_data , window = window , indexes = 1 )
362+
363+ logging .info ("Finished processing all blocks" )
364+ except Exception as e :
365+ logging .error ("Error creating or writing to output file: %s" , e )
366+ finally :
367+ # Close all input datasets
368+ for src in datasets :
369+ src .close ()
370+ logging .info ("Closed all input raster files" )
331371
332- for src in datasets :
333- src .close ()
372+ logging .info ("create_sum function completed successfully" )
334373
335374
336- async def process_aggregates (country_code = None , age_group = None , sex = None ):
375+ async def process_aggregates (country_code : str , age_group : Optional [ str ] = None , sex : Optional [ str ] = None ):
337376 """
338- Process the aggregate files.
377+ Process the aggregate files based on sex and age group.
378+
339379 Args:
340- country_code: The country code for which to process the data.
341- age_group: The age group to process the data .
342- sex: The sex to process the data for .
380+ country_code (str) : The country code to process the data for .
381+ age_group (Optional[str]) : The age group to process (child, active, elderly) .
382+ sex (Optional[str]) : The sex to process (M, F) .
343383 """
344384 assert country_code , "Country code must be provided"
345385 assert sex or age_group , "Either age or sex must be provided"
346386
347387 async with AzureBlobStorageManager (conn_str = os .getenv ("AZURE_STORAGE_CONNECTION_STRING" )) as storage_manager :
348388 logging .info ("Processing aggregate files for country: %s" , country_code )
349389
350- async def process_group (sex : str , age_group : str = None ):
390+ async def process_group (sexes : List [ str ] , age_group : Optional [ str ] = None , output_blob_path : Optional [ str ] = None ):
351391 """
352- Process the files for a specific sex (M/F) and optionally an age group.
392+ Processes a group of files for a specific sex and/or age group.
393+
394+ Args:
395+ sexes (List[str]): List of sexes to process (e.g., ['M', 'F']).
396+ age_group (Optional[str]): The age group to process (child, active, elderly).
397+ output_blob_path (Optional[str]): Path to store the final output file.
353398 """
354- path = f"{ AZ_ROOT_FILE_PATH } /2020/{ country_code } /{ sex } /"
399+ # Construct paths for input blobs
400+ paths = [f"{ AZ_ROOT_FILE_PATH } /{ DATA_YEAR } /{ country_code } /{ sex_group } /" for sex_group in sexes ]
355401 if age_group :
356402 assert age_group in WORLDPOP_AGE_MAPPING , "Invalid age group provided"
357- path += f"{ age_group } "
403+ paths = [f"{ path } { age_group } /" for path in paths ]
404+
405+ # Fetch blobs for all paths
406+ blobs = []
407+ for path in paths :
408+ blobs += await storage_manager .list_blobs (path )
358409
359- blobs = await storage_manager .list_blobs (path )
360410 if not blobs :
361- logging .warning ("No blobs found for path : %s" , path )
411+ logging .warning ("No blobs found for paths : %s" , paths )
362412 return
363413
414+ # Download and process blobs
364415 dataset_files = []
365416 with tempfile .TemporaryDirectory (delete = False ) as temp_dir :
366417 for blob in blobs :
418+ local_file = os .path .join (temp_dir , os .path .basename (blob ))
367419 await storage_manager .download_blob (blob_name = blob , local_directory = temp_dir )
368- dataset_files .append (f" { temp_dir } / { blob . split ( '/' )[ - 1 ] } " )
420+ dataset_files .append (local_file )
369421
370- output_file = f"{ temp_dir } /{ country_code } _{ age_group or 'ALL' } _{ sex } .tif"
371- with ThreadPoolExecutor () as executor :
372- executor .submit (create_sum , dataset_files , output_file )
373- # TODO: Upload the output file to Azure Blob Storage `aggregate` folder
374- shutil .move (output_file , f"data/{ country_code } _{ age_group or 'ALL' } _{ sex } .tif" )
422+ # Prepare output directory
423+ os .makedirs (f"data/{ output_blob_path } " , exist_ok = True )
424+ output_file = f"{ temp_dir } /{ country_code } _{ age_group or 'ALL' } _{ '_' .join (sexes )} .tif"
375425
426+ # Perform summation using create_sum
427+ with ThreadPoolExecutor () as executor :
428+ executor .submit (create_sum , input_file_paths = dataset_files , output_file_path = output_file , block_size = (512 , 512 ))
429+
430+ # Save final output
431+ final_output_path = f"data/{ output_blob_path } /{ os .path .basename (output_file )} "
432+ await storage_manager .upload_blob (file_path = output_file , blob_name = f"{ output_blob_path } /{ os .path .basename (output_file )} " )
433+ # shutil.copy2(output_file, final_output_path)
434+ # logging.info("Output saved to: %s", final_output_path)
435+ logging .info ("Output saved to: %s" , f"{ output_blob_path } /{ os .path .basename (output_file )} " )
436+ # Processing logic for combinations of sex and age group
376437 if sex and age_group :
377- await process_group (sex , age_group )
438+ logging .info ("Processing for sex '%s' and age group '%s'" , sex , age_group )
439+ output_blob_path = f"{ AZ_ROOT_FILE_PATH } /{ DATA_YEAR } /{ country_code } /aggregate/{ sex } /{ age_group } "
440+ await process_group ([sex ], age_group , output_blob_path )
378441 elif age_group :
379- await process_group ('M' , age_group )
380- await process_group ('F' , age_group )
442+ logging .info ("Processing for age group '%s' (both sexes)" , age_group )
443+ output_blob_path = f"{ AZ_ROOT_FILE_PATH } /{ DATA_YEAR } /{ country_code } /aggregate/{ age_group } "
444+ await process_group (['M' , 'F' ], age_group , output_blob_path )
381445 elif sex :
382- await process_group (sex )
446+ logging .info ("Processing for sex '%s' (all age groups)" , sex )
447+ output_blob_path = f"{ AZ_ROOT_FILE_PATH } /{ DATA_YEAR } /{ country_code } /aggregate/{ sex } "
448+ await process_group ([sex ], None , output_blob_path )
449+ # else:
450+ # # Process all
451+ # logging.info("Processing all data")
452+ # output_blob_path = f"{AZ_ROOT_FILE_PATH}/{DATA_YEAR}/{country_code}/aggregate"
453+ # await process_group(['M', 'F'], None, output_blob_path)
454+ # await process_group(['M', 'F'], 'child', output_blob_path)
455+ # await process_group(['M', 'F'], 'active', output_blob_path)
456+ # await process_group(['M', 'F'], 'elderly', output_blob_path)
457+ logging .info ("Processing complete" )
383458
384459
385460if __name__ == "__main__" :
386461 # asyncio.run(download_data(force_reprocessing=False))
387462
388- create_sum (["/media/thuha/Data/worldpop_data/m/0-12/BDI_m_0_2020_constrained.tif" , "/media/thuha/Data/worldpop_data/f/0-12/BDI_f_0_2020_constrained.tif" ], "data/BDI_0-12.tif" )
463+ create_sum (["/media/thuha/Data/worldpop_data/m/0-12/BDI_m_0_2020_constrained.tif" , "/media/thuha/Data/worldpop_data/f/0-12/BDI_f_0_2020_constrained.tif" ], "data/BDI_0-12.tif" )
0 commit comments