88from logging import getLogger
99from pathlib import Path
1010from typing import Dict , List , Optional , Set
11+ import shutil
1112
1213from anyio import Path as AsyncPath , create_task_group
1314import sqlalchemy
@@ -98,7 +99,7 @@ def __init__(
9899 self ,
99100 * ,
100101 _sql_engine : sqlalchemy .engine .Engine ,
101- _sources_by_id : Dict [str , ProtocolSource ],
102+ _sources_by_id : dict [str , ProtocolSource | _BadProtocolSource ],
102103 ) -> None :
103104 """Do not call directly.
104105
@@ -117,8 +118,7 @@ def create_empty(
117118 Params:
118119 sql_engine: A reference to the database that this ProtocolStore should
119120 use as its backing storage.
120- This is expected to already have the proper tables set up;
121- see `add_tables_to_db()`.
121+ This is expected to already have the proper tables set up.
122122 This should have no protocol data currently stored.
123123 If there is data, use `rehydrate()` instead.
124124 """
@@ -141,8 +141,7 @@ async def rehydrate(
141141 Params:
142142 sql_engine: A reference to the database that this ProtocolStore should
143143 use as its backing storage.
144- This is expected to already have the proper tables set up;
145- see `add_tables_to_db()`.
144+ This is expected to already have the proper tables set up.
146145 protocols_directory: Where to look for protocol files while rehydrating.
147146 This is expected to have one subdirectory per protocol,
148147 named after its protocol ID.
@@ -157,7 +156,7 @@ async def rehydrate(
157156
158157 sources_by_id = await _compute_protocol_sources (
159158 expected_protocol_ids = expected_ids ,
160- protocols_directory = AsyncPath ( protocols_directory ) ,
159+ protocols_directory = protocols_directory ,
161160 protocol_reader = protocol_reader ,
162161 )
163162
@@ -171,16 +170,18 @@ def insert(self, resource: ProtocolResource) -> None:
171170
172171 The resource must have a unique ID.
173172 """
174- self ._sql_insert (
175- resource = _DBProtocolResource (
176- protocol_id = resource .protocol_id ,
177- created_at = resource .created_at ,
178- protocol_key = resource .protocol_key ,
179- protocol_kind = _http_protocol_kind_to_sql (resource .protocol_kind ),
173+ try :
174+ self ._sql_insert (
175+ resource = _DBProtocolResource (
176+ protocol_id = resource .protocol_id ,
177+ created_at = resource .created_at ,
178+ protocol_key = resource .protocol_key ,
179+ protocol_kind = _http_protocol_kind_to_sql (resource .protocol_kind ),
180+ )
180181 )
181- )
182- self . _sources_by_id [ resource . protocol_id ] = resource . source
183- self ._clear_caches ()
182+ self . _sources_by_id [ resource . protocol_id ] = resource . source
183+ finally :
184+ self ._clear_caches ()
184185
185186 @lru_cache (maxsize = _CACHE_ENTRIES )
186187 def get (self , protocol_id : str ) -> ProtocolResource :
@@ -190,30 +191,48 @@ def get(self, protocol_id: str) -> ProtocolResource:
190191 ProtocolNotFoundError
191192 """
192193 sql_resource = self ._sql_get (protocol_id = protocol_id )
193- return ProtocolResource (
194- protocol_id = sql_resource .protocol_id ,
195- created_at = sql_resource .created_at ,
196- protocol_key = sql_resource .protocol_key ,
197- protocol_kind = _sql_protocol_kind_to_http (sql_resource .protocol_kind ),
198- source = self ._sources_by_id [sql_resource .protocol_id ],
199- )
194+ protocol_source = self ._sources_by_id [sql_resource .protocol_id ]
195+ match protocol_source :
196+ case ProtocolSource () as protocol_source :
197+ return ProtocolResource (
198+ protocol_id = sql_resource .protocol_id ,
199+ created_at = sql_resource .created_at ,
200+ protocol_key = sql_resource .protocol_key ,
201+ protocol_kind = _sql_protocol_kind_to_http (
202+ sql_resource .protocol_kind
203+ ),
204+ source = protocol_source ,
205+ )
206+ case _BadProtocolSource (reason = reason ):
207+ raise reason
200208
201209 @lru_cache (maxsize = _CACHE_ENTRIES )
202210 def get_all (self ) -> List [ProtocolResource ]:
203211 """Get all protocols currently saved in this store.
204212
205213 Results are ordered from first-added to last-added.
214+
215+ If there was an error processing a protocol, it's excluded from the returned
216+ list. This can happen, for example, if a software downgrade left the robot with
217+ protocol files that are too new for the software that it's running now.
206218 """
207219 all_sql_resources = self ._sql_get_all ()
220+ all_sql_resources_and_protocol_sources = (
221+ (r , self ._sources_by_id [r .protocol_id ]) for r in all_sql_resources
222+ )
208223 return [
209224 ProtocolResource (
210- protocol_id = r .protocol_id ,
211- created_at = r .created_at ,
212- protocol_key = r .protocol_key ,
213- protocol_kind = _sql_protocol_kind_to_http (r .protocol_kind ),
214- source = self . _sources_by_id [ r . protocol_id ] ,
225+ protocol_id = sql_resource .protocol_id ,
226+ created_at = sql_resource .created_at ,
227+ protocol_key = sql_resource .protocol_key ,
228+ protocol_kind = _sql_protocol_kind_to_http (sql_resource .protocol_kind ),
229+ source = protocol_source ,
215230 )
216- for r in all_sql_resources
231+ for (
232+ sql_resource ,
233+ protocol_source ,
234+ ) in all_sql_resources_and_protocol_sources
235+ if not isinstance (protocol_source , _BadProtocolSource )
217236 ]
218237
219238 @lru_cache (maxsize = _CACHE_ENTRIES )
@@ -258,17 +277,20 @@ def remove(self, protocol_id: str) -> None:
258277 ProtocolUsedByRunError: the protocol could not be deleted because
259278 there is a run currently referencing the protocol.
260279 """
261- self ._sql_remove (protocol_id = protocol_id )
262-
263- deleted_source = self ._sources_by_id .pop (protocol_id )
264- protocol_dir = deleted_source .directory
265-
266- for source_file in deleted_source .files :
267- source_file .path .unlink ()
268- if protocol_dir :
269- protocol_dir .rmdir ()
270-
271- self ._clear_caches ()
280+ try :
281+ self ._sql_remove (protocol_id = protocol_id )
282+
283+ deleted_source = self ._sources_by_id .pop (protocol_id )
284+ match deleted_source :
285+ case ProtocolSource (directory = directory , files = files ):
286+ for source_file in files :
287+ source_file .path .unlink ()
288+ if directory :
289+ directory .rmdir ()
290+ case _BadProtocolSource (directory = directory ):
291+ shutil .rmtree (directory , ignore_errors = True )
292+ finally :
293+ self ._clear_caches ()
272294
273295 # Note that this is NOT cached like the other getters because we would need
274296 # to invalidate the cache whenever the runs table changes, which is not something
@@ -448,18 +470,11 @@ def _clear_caches(self) -> None:
448470 self .has .cache_clear ()
449471
450472
451- # TODO(mm, 2022-04-18):
452- # Restructure to degrade gracefully in the face of ProtocolReader failures.
453- #
454- # * ProtocolStore.get_all() should omit protocols for which it failed to compute
455- # a ProtocolSource.
456- # * ProtocolStore.get(id) should continue to raise an exception if it failed to compute
457- # that protocol's ProtocolSource.
458473async def _compute_protocol_sources (
459474 expected_protocol_ids : Set [str ],
460- protocols_directory : AsyncPath ,
475+ protocols_directory : Path ,
461476 protocol_reader : ProtocolReader ,
462- ) -> Dict [str , ProtocolSource ]:
477+ ) -> dict [str , ProtocolSource | _BadProtocolSource ]:
463478 """Compute `ProtocolSource` objects from protocol source files.
464479
465480 We don't store these `ProtocolSource` objects in the SQL database because
@@ -475,19 +490,19 @@ async def _compute_protocol_sources(
475490 protocol_reader: An interface to use to compute `ProtocolSource`s.
476491
477492 Returns:
478- A map from protocol ID to computed `ProtocolSource`.
493+ A map from protocol ID to computed `ProtocolSource`, or an `Exception` if
494+ there was a problem processing that particular protocol.
479495
480496 Raises:
481497 Exception: This is not expected to raise anything,
482498 but it might if a software update makes ProtocolReader reject files
483499 that it formerly accepted.
484500 """
485- sources_by_id : Dict [str , ProtocolSource ] = {}
501+ sources_by_id : dict [str , ProtocolSource | _BadProtocolSource ] = {}
486502
487- directory_members = [m async for m in protocols_directory .iterdir ()]
503+ directory_members = [m async for m in AsyncPath ( protocols_directory ) .iterdir ()]
488504 directory_member_names = set (m .name for m in directory_members )
489505 extra_members = directory_member_names - expected_protocol_ids
490- missing_members = expected_protocol_ids - directory_member_names
491506
492507 if extra_members :
493508 # Extra members may be left over from prior interrupted writes
@@ -498,38 +513,48 @@ async def _compute_protocol_sources(
498513 f" Ignoring them."
499514 )
500515
501- if missing_members :
502- raise SubdirectoryMissingError (
503- f"Missing subdirectories for protocols: { missing_members } "
504- )
505-
506516 async def compute_source (
507- protocol_id : str , protocol_subdirectory : AsyncPath
517+ protocol_subdirectory : Path ,
518+ ) -> ProtocolSource | _BadProtocolSource :
519+ try :
520+ # Given that the expected protocol subdirectory exists,
521+ # we trust that the files in it are correct.
522+ # No extra files, and no files missing.
523+ #
524+ # This is a safe assumption as long as:
525+ # * Nobody has tampered with file the storage.
526+ # * We don't try to compute the source of any protocol whose insertion
527+ # failed halfway through and left files behind.
528+ protocol_files = [
529+ Path (f ) async for f in AsyncPath (protocol_subdirectory ).iterdir ()
530+ ]
531+ protocol_source = await protocol_reader .read_saved (
532+ files = protocol_files ,
533+ directory = Path (protocol_subdirectory ),
534+ files_are_prevalidated = True ,
535+ python_parse_mode = PythonParseMode .ALLOW_LEGACY_METADATA_AND_REQUIREMENTS ,
536+ )
537+ return protocol_source
538+ except Exception as exception :
539+ # e.g. if a software downgrade left the robot with some protocol files that
540+ # are too new for the software version that it's running now.
541+ _log .exception (f"Error reading protocol in { protocol_subdirectory } ." )
542+ return _BadProtocolSource (directory = protocol_subdirectory , reason = exception )
543+
544+ async def compute_source_and_store_in_result_dict (
545+ protocol_id : str , protocol_subdirectory : Path
508546 ) -> None :
509- # Given that the expected protocol subdirectory exists,
510- # we trust that the files in it are correct.
511- # No extra files, and no files missing.
512- #
513- # This is a safe assumption as long as:
514- # * Nobody has tampered with file the storage.
515- # * We don't try to compute the source of any protocol whose insertion
516- # failed halfway through and left files behind.
517- protocol_files = [Path (f ) async for f in protocol_subdirectory .iterdir ()]
518- protocol_source = await protocol_reader .read_saved (
519- files = protocol_files ,
520- directory = Path (protocol_subdirectory ),
521- files_are_prevalidated = True ,
522- python_parse_mode = PythonParseMode .ALLOW_LEGACY_METADATA_AND_REQUIREMENTS ,
523- )
524- sources_by_id [protocol_id ] = protocol_source
547+ result = await compute_source (protocol_subdirectory )
548+ sources_by_id [protocol_id ] = result
525549
526550 async with create_task_group () as task_group :
527- # Use a TaskGroup instead of asyncio.gather() so,
528- # if any task raises an unexpected exception,
529- # it cancels every other task and raises an exception to signal the bug.
530551 for protocol_id in expected_protocol_ids :
531552 protocol_subdirectory = protocols_directory / protocol_id
532- task_group .start_soon (compute_source , protocol_id , protocol_subdirectory )
553+ task_group .start_soon (
554+ compute_source_and_store_in_result_dict ,
555+ protocol_id ,
556+ protocol_subdirectory ,
557+ )
533558
534559 for id in expected_protocol_ids :
535560 assert id in sources_by_id
@@ -547,6 +572,14 @@ class _DBProtocolResource:
547572 protocol_kind : ProtocolKindSQLEnum
548573
549574
575+ @dataclass (frozen = True )
576+ class _BadProtocolSource :
577+ """Information about files that we failed to process into a ProtocolSource."""
578+
579+ directory : Path
580+ reason : Exception
581+
582+
550583def _convert_sql_row_to_dataclass (
551584 sql_row : sqlalchemy .engine .Row ,
552585) -> _DBProtocolResource :
0 commit comments