Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 76 additions & 29 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,11 @@
TInput = TypeVar("TInput")
TOutput = TypeVar("TOutput")


class VersionNotRegisteredException(Exception):
pass


def _log_all_threads(logger: logging.Logger, context: str = ""):
"""Helper function to log all currently active threads for debugging."""
active_threads = threading.enumerate()
Expand Down Expand Up @@ -100,15 +102,23 @@ def __init__(self):
self.latest_versioned_orchestrators_version_name = {}
self.activities = {}

def add_orchestrator(self, fn: task.Orchestrator, version_name: Optional[str] = None, is_latest: bool = False) -> str:
def add_orchestrator(
self, fn: task.Orchestrator, version_name: Optional[str] = None, is_latest: bool = False
) -> str:
if fn is None:
raise ValueError("An orchestrator function argument is required.")

name = task.get_name(fn)
self.add_named_orchestrator(name, fn, version_name, is_latest)
return name

def add_named_orchestrator(self, name: str, fn: task.Orchestrator, version_name: Optional[str] = None, is_latest: bool = False) -> None:
def add_named_orchestrator(
self,
name: str,
fn: task.Orchestrator,
version_name: Optional[str] = None,
is_latest: bool = False,
) -> None:
if not name:
raise ValueError("A non-empty orchestrator name is required.")

Expand All @@ -120,12 +130,16 @@ def add_named_orchestrator(self, name: str, fn: task.Orchestrator, version_name:
if name not in self.versioned_orchestrators:
self.versioned_orchestrators[name] = {}
if version_name in self.versioned_orchestrators[name]:
raise ValueError(f"The version '{version_name}' of '{name}' orchestrator already exists.")
raise ValueError(
f"The version '{version_name}' of '{name}' orchestrator already exists."
)
self.versioned_orchestrators[name][version_name] = fn
if is_latest:
self.latest_versioned_orchestrators_version_name[name] = version_name

def get_orchestrator(self, name: str, version_name: Optional[str] = None) -> Optional[tuple[task.Orchestrator, str]]:
def get_orchestrator(
self, name: str, version_name: Optional[str] = None
) -> Optional[tuple[task.Orchestrator, str]]:
if name in self.orchestrators:
return self.orchestrators.get(name), None

Expand Down Expand Up @@ -282,7 +296,7 @@ def __init__(
self._channel_options = channel_options
self._stop_timeout = stop_timeout
self._current_channel: Optional[grpc.Channel] = None # Store channel reference for cleanup

self._stream_ready = threading.Event()
# Use provided concurrency options or create default ones
self._concurrency_options = (
concurrency_options if concurrency_options is not None else ConcurrencyOptions()
Expand All @@ -298,7 +312,7 @@ def __init__(
else:
self._interceptors = None

self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options)
self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options, self._logger)

@property
def concurrency_options(self) -> ConcurrencyOptions:
Expand All @@ -323,6 +337,9 @@ def add_activity(self, fn: task.Activity) -> str:
raise RuntimeError("Activities cannot be added while the worker is running.")
return self._registry.add_activity(fn)

def is_worker_ready(self) -> bool:
return self._stream_ready.is_set() and self._is_running

def start(self):
"""Starts the worker on a background thread and begins listening for work items."""
if self._is_running:
Expand All @@ -336,6 +353,8 @@ def run_loop():
self._logger.info(f"Starting gRPC worker that connects to {self._host_address}")
self._runLoop = Thread(target=run_loop, name="WorkerRunLoop")
self._runLoop.start()
if not self._stream_ready.wait(timeout=10):
raise RuntimeError("Failed to establish work item stream connection within 10 seconds")
self._is_running = True

# TODO: refactor this to be more readable and maintainable.
Expand Down Expand Up @@ -446,10 +465,13 @@ def should_invalidate_connection(rpc_error):
maxConcurrentOrchestrationWorkItems=self._concurrency_options.maximum_concurrent_orchestration_work_items,
maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items,
)
self._response_stream = stub.GetWorkItems(get_work_items_request)
self._logger.info(
f"Successfully connected to {self._host_address}. Waiting for work items..."
)
try:
self._response_stream = stub.GetWorkItems(get_work_items_request)
self._logger.info(
f"Successfully connected to {self._host_address}. Waiting for work items..."
)
except Exception:
raise

# Use a thread to read from the blocking gRPC stream and forward to asyncio
import queue
Expand All @@ -460,12 +482,15 @@ def should_invalidate_connection(rpc_error):
# NOTE: This is equivalent to the Durabletask Go goroutine calling stream.Recv() in worker_grpc.go StartWorkItemListener()
def stream_reader():
try:
if self._response_stream is None:
return
stream = self._response_stream

# Use next() to allow shutdown check between items
# This matches Go's pattern: check ctx.Err() after each stream.Recv()
while True:
if self._shutdown.is_set():
self._logger.debug("Stream reader: shutdown detected, exiting loop")
break

try:
Expand Down Expand Up @@ -502,15 +527,26 @@ def stream_reader():
self._logger.debug(
f"Stream reader: exception during shutdown: {type(stream_error).__name__}: {stream_error}"
)
break
# Other stream errors - put in queue for async loop to handle
self._logger.warning(
f"Stream reader: unexpected error: {stream_error}"
self._logger.error(
f"Stream reader: unexpected error: {type(stream_error).__name__}: {stream_error}",
exc_info=True,
)
raise

except Exception as e:
self._logger.exception(
f"Stream reader: fatal exception in stream_reader: {type(e).__name__}: {e}",
exc_info=True,
)
if not self._shutdown.is_set():
work_item_queue.put(e)
try:
work_item_queue.put(e)
except Exception as queue_error:
self._logger.error(
f"Stream reader: failed to put exception in queue: {queue_error}"
)
finally:
# signal that the stream reader is done (ie matching Go's context cancellation)
try:
Expand All @@ -519,16 +555,20 @@ def stream_reader():
# queue might be closed so ignore this
pass

import threading

# Use non-daemon thread (daemon=False) to ensure proper resource cleanup.
# Daemon threads exit immediately when the main program exits, which prevents
# cleanup of gRPC channel resources and OTel interceptors. Non-daemon threads
# block shutdown until they complete, ensuring all resources are properly closed.
current_reader_thread = threading.Thread(
target=stream_reader, daemon=False, name="StreamReader"
)
current_reader_thread.start()

try:
current_reader_thread.start()
self._stream_ready.set()
except Exception:
raise

loop = asyncio.get_running_loop()

# NOTE: This is a blocking call that will wait for a work item to become available or the shutdown sentinel
Expand Down Expand Up @@ -760,7 +800,6 @@ def _execute_orchestrator(
version = version or pb.OrchestrationVersion()
version.patches.extend(result.patches)


res = pb.OrchestratorResponse(
instanceId=req.instanceId,
actions=result.actions,
Expand Down Expand Up @@ -932,14 +971,12 @@ def set_failed(self, ex: Exception):
)
self._pending_actions[action.id] = action


def set_version_not_registered(self):
self._pending_actions.clear()
self._completion_status = pb.ORCHESTRATION_STATUS_STALLED
action = ph.new_orchestrator_version_not_available_action(self.next_sequence_number())
self._pending_actions[action.id] = action


def set_continued_as_new(self, new_input: Any, save_events: bool):
if self._is_complete:
return
Expand Down Expand Up @@ -1150,7 +1187,6 @@ def continue_as_new(self, new_input, *, save_events: bool = False) -> None:

self.set_continued_as_new(new_input, save_events)


def is_patched(self, patch_name: str) -> bool:
is_patched = self._is_patched(patch_name)
if is_patched:
Expand Down Expand Up @@ -1178,7 +1214,13 @@ class ExecutionResults:
version_name: Optional[str]
patches: Optional[list[str]]

def __init__(self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str], version_name: Optional[str] = None, patches: Optional[list[str]] = None):
def __init__(
self,
actions: list[pb.OrchestratorAction],
encoded_custom_status: Optional[str],
version_name: Optional[str] = None,
patches: Optional[list[str]] = None,
):
self.actions = actions
self.encoded_custom_status = encoded_custom_status
self.version_name = version_name
Expand Down Expand Up @@ -1254,8 +1296,8 @@ def execute(
return ExecutionResults(
actions=actions,
encoded_custom_status=ctx._encoded_custom_status,
version_name=getattr(ctx, '_version_name', None),
patches=ctx._encountered_patches
version_name=getattr(ctx, "_version_name", None),
patches=ctx._encountered_patches,
)

def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent) -> None:
Expand Down Expand Up @@ -1283,9 +1325,10 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven
if ctx._orchestrator_version_name:
version_name = ctx._orchestrator_version_name


# TODO: Check if we already started the orchestration
fn, version_used = self._registry.get_orchestrator(event.executionStarted.name, version_name=version_name)
fn, version_used = self._registry.get_orchestrator(
event.executionStarted.name, version_name=version_name
)

if fn is None:
raise OrchestratorNotRegisteredError(
Expand Down Expand Up @@ -1693,7 +1736,7 @@ def _is_suspendable(event: pb.HistoryEvent) -> bool:


class _AsyncWorkerManager:
def __init__(self, concurrency_options: ConcurrencyOptions):
def __init__(self, concurrency_options: ConcurrencyOptions, logger: logging.Logger):
self.concurrency_options = concurrency_options
self.activity_semaphore = None
self.orchestration_semaphore = None
Expand All @@ -1709,14 +1752,16 @@ def __init__(self, concurrency_options: ConcurrencyOptions):
thread_name_prefix="DurableTask",
)
self._shutdown = False
self._logger = logger

def _ensure_queues_for_current_loop(self):
"""Ensure queues are bound to the current event loop."""
try:
current_loop = asyncio.get_running_loop()
if current_loop.is_closed():
return
except RuntimeError:
except RuntimeError as e:
self._logger.exception(f"Failed to get event loop {e}")
# No event loop running, can't create queues
return

Expand All @@ -1735,14 +1780,16 @@ def _ensure_queues_for_current_loop(self):
try:
while not self.activity_queue.empty():
existing_activity_items.append(self.activity_queue.get_nowait())
except Exception:
except Exception as e:
self._logger.debug(f"Failed to append to the activity queue {e}")
pass

if self.orchestration_queue is not None:
try:
while not self.orchestration_queue.empty():
existing_orchestration_items.append(self.orchestration_queue.get_nowait())
except Exception:
except Exception as e:
self._logger.debug(f"Failed to append to the orchestration queue {e}")
pass

# Create fresh queues for the current event loop
Expand Down
Loading