Skip to content

Commit 07d4876

Browse files
committed
Add support for new event types
1 parent b4086fd commit 07d4876

File tree

1 file changed

+80
-23
lines changed

1 file changed

+80
-23
lines changed

durabletask/worker.py

Lines changed: 80 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -828,6 +828,7 @@ def __init__(self, instance_id: str, registry: _Registry):
828828
self._pending_tasks: dict[int, task.CompletableTask] = {}
829829
# Maps entity ID to task ID
830830
self._entity_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
831+
self._entity_lock_task_id_map: dict[str, tuple[EntityInstanceId, int]] = {}
831832
# Maps criticalSectionId to task ID
832833
self._entity_lock_id_map: dict[str, int] = {}
833834
self._sequence_number = 0
@@ -1590,33 +1591,70 @@ def process_event(
15901591
else:
15911592
raise TypeError("Unexpected sub-orchestration task type")
15921593
elif event.HasField("eventRaised"):
1593-
# event names are case-insensitive
1594-
event_name = event.eventRaised.name.casefold()
1595-
if not ctx.is_replaying:
1596-
self._logger.info(f"{ctx.instance_id} Event raised: {event_name}")
1597-
task_list = ctx._pending_events.get(event_name, None)
1598-
decoded_result: Optional[Any] = None
1599-
if task_list:
1600-
event_task = task_list.pop(0)
1594+
if event.eventRaised.name in ctx._entity_task_id_map:
1595+
# This eventRaised represents the result of an entity operation after being translated to the old
1596+
# entity protocol by the Durable WebJobs extension
1597+
entity_id, task_id = ctx._entity_task_id_map.get(event.eventRaised.name, (None, None))
1598+
if entity_id is None:
1599+
raise RuntimeError(f"Could not retrieve entity ID for entity-related eventRaised with ID '{event.eventId}'")
1600+
if task_id is None:
1601+
raise RuntimeError(f"Could not retrieve task ID for entity-related eventRaised with ID '{event.eventId}'")
1602+
entity_task = ctx._pending_tasks.pop(task_id, None)
1603+
if not entity_task:
1604+
raise RuntimeError(f"Could not retrieve entity task for entity-related eventRaised with ID '{event.eventId}'")
1605+
result = None
16011606
if not ph.is_empty(event.eventRaised.input):
1602-
decoded_result = shared.from_json(event.eventRaised.input.value)
1603-
event_task.complete(decoded_result)
1604-
if not task_list:
1605-
del ctx._pending_events[event_name]
1607+
# TODO: Investigate why the event result is wrapped in a dict with "result" key
1608+
result = shared.from_json(event.eventRaised.input.value)["result"]
1609+
ctx._entity_context.recover_lock_after_call(entity_id)
1610+
entity_task.complete(result)
16061611
ctx.resume()
1607-
else:
1608-
# buffer the event
1609-
event_list = ctx._received_events.get(event_name, None)
1610-
if not event_list:
1611-
event_list = []
1612-
ctx._received_events[event_name] = event_list
1612+
elif event.eventRaised.name in ctx._entity_lock_task_id_map:
1613+
# This eventRaised represents the result of an entity operation after being translated to the old
1614+
# entity protocol by the Durable WebJobs extension
1615+
entity_id, task_id = ctx._entity_lock_task_id_map.get(event.eventRaised.name, (None, None))
1616+
if entity_id is None:
1617+
raise RuntimeError(f"Could not retrieve entity ID for entity-related eventRaised with ID '{event.eventId}'")
1618+
if task_id is None:
1619+
raise RuntimeError(f"Could not retrieve task ID for entity-related eventRaised with ID '{event.eventId}'")
1620+
entity_task = ctx._pending_tasks.pop(task_id, None)
1621+
if not entity_task:
1622+
raise RuntimeError(f"Could not retrieve entity task for entity-related eventRaised with ID '{event.eventId}'")
1623+
result = None
16131624
if not ph.is_empty(event.eventRaised.input):
1614-
decoded_result = shared.from_json(event.eventRaised.input.value)
1615-
event_list.append(decoded_result)
1625+
# TODO: Investigate why the event result is wrapped in a dict with "result" key
1626+
result = shared.from_json(event.eventRaised.input.value)["result"]
1627+
ctx._entity_context.complete_acquire(event.eventRaised.name)
1628+
entity_task.complete(EntityLock(ctx))
1629+
ctx.resume()
1630+
else:
1631+
# event names are case-insensitive
1632+
event_name = event.eventRaised.name.casefold()
16161633
if not ctx.is_replaying:
1617-
self._logger.info(
1618-
f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it."
1619-
)
1634+
self._logger.info(f"{ctx.instance_id} Event raised: {event_name}")
1635+
task_list = ctx._pending_events.get(event_name, None)
1636+
decoded_result: Optional[Any] = None
1637+
if task_list:
1638+
event_task = task_list.pop(0)
1639+
if not ph.is_empty(event.eventRaised.input):
1640+
decoded_result = shared.from_json(event.eventRaised.input.value)
1641+
event_task.complete(decoded_result)
1642+
if not task_list:
1643+
del ctx._pending_events[event_name]
1644+
ctx.resume()
1645+
else:
1646+
# buffer the event
1647+
event_list = ctx._received_events.get(event_name, None)
1648+
if not event_list:
1649+
event_list = []
1650+
ctx._received_events[event_name] = event_list
1651+
if not ph.is_empty(event.eventRaised.input):
1652+
decoded_result = shared.from_json(event.eventRaised.input.value)
1653+
event_list.append(decoded_result)
1654+
if not ctx.is_replaying:
1655+
self._logger.info(
1656+
f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it."
1657+
)
16201658
elif event.HasField("executionSuspended"):
16211659
if not self._is_suspended and not ctx.is_replaying:
16221660
self._logger.info(f"{ctx.instance_id}: Execution suspended.")
@@ -1743,6 +1781,25 @@ def process_event(
17431781
self._logger.info(f"{ctx.instance_id}: Entity operation failed.")
17441782
self._logger.info(f"Data: {json.dumps(event.entityOperationFailed)}")
17451783
pass
1784+
elif event.HasField("orchestratorCompleted"):
1785+
# Added in Functions only (for some reason) and does not affect orchestrator flow
1786+
pass
1787+
elif event.HasField("eventSent"):
1788+
# Check if this eventSent corresponds to an entity operation call after being translated to the old
1789+
# entity protocol by the Durable WebJobs extension. If so, treat this message similarly to
1790+
# entityOperationCalled and remove the pending action. Also store the entity id and event id for later
1791+
action = ctx._pending_actions.pop(event.eventId, None)
1792+
if action and action.HasField("sendEntityMessage"):
1793+
if action.sendEntityMessage.HasField("entityOperationCalled"):
1794+
entity_id = EntityInstanceId.parse(event.eventSent.instanceId)
1795+
event_id = json.loads(event.eventSent.input.value)["id"]
1796+
ctx._entity_task_id_map[event_id] = (entity_id, event.eventId)
1797+
elif action.sendEntityMessage.HasField("entityLockRequested"):
1798+
entity_id = EntityInstanceId.parse(event.eventSent.instanceId)
1799+
event_id = json.loads(event.eventSent.input.value)["id"]
1800+
ctx._entity_lock_task_id_map[event_id] = (entity_id, event.eventId)
1801+
else:
1802+
return
17461803
else:
17471804
eventType = event.WhichOneof("eventType")
17481805
raise task.OrchestrationStateError(

0 commit comments

Comments
 (0)