|
1 | 1 | """Tool execution functionality for the event loop.""" |
2 | 2 |
|
3 | 3 | import logging |
| 4 | +import queue |
| 5 | +import threading |
4 | 6 | import time |
5 | | -from concurrent.futures import TimeoutError |
6 | | -from typing import Any, Callable, List, Optional, Tuple |
| 7 | +from typing import Any, Callable, Generator, Optional, cast |
7 | 8 |
|
8 | 9 | from opentelemetry import trace |
9 | 10 |
|
|
19 | 20 |
|
20 | 21 | def run_tools( |
21 | 22 | handler: Callable[[ToolUse], ToolResult], |
22 | | - tool_uses: List[ToolUse], |
| 23 | + tool_uses: list[ToolUse], |
23 | 24 | event_loop_metrics: EventLoopMetrics, |
24 | | - request_state: Any, |
25 | | - invalid_tool_use_ids: List[str], |
26 | | - tool_results: List[ToolResult], |
| 25 | + invalid_tool_use_ids: list[str], |
| 26 | + tool_results: list[ToolResult], |
27 | 27 | cycle_trace: Trace, |
28 | 28 | parent_span: Optional[trace.Span] = None, |
29 | 29 | parallel_tool_executor: Optional[ParallelToolExecutorInterface] = None, |
30 | | -) -> bool: |
| 30 | +) -> Generator[dict[str, Any], None, None]: |
31 | 31 | """Execute tools either in parallel or sequentially. |
32 | 32 |
|
33 | 33 | Args: |
34 | 34 | handler: Tool handler processing function. |
35 | 35 | tool_uses: List of tool uses to execute. |
36 | 36 | event_loop_metrics: Metrics collection object. |
37 | | - request_state: Current request state. |
38 | 37 | invalid_tool_use_ids: List of invalid tool use IDs. |
39 | 38 | tool_results: List to populate with tool results. |
40 | 39 | cycle_trace: Parent trace for the current cycle. |
41 | 40 | parent_span: Parent span for the current cycle. |
42 | 41 | parallel_tool_executor: Optional executor for parallel processing. |
43 | 42 |
|
44 | | - Returns: |
45 | | - bool: True if any tool failed, False otherwise. |
| 43 | + Yields: |
| 44 | + Events of the tool invocations. Tool results are appended to `tool_results`. |
46 | 45 | """ |
47 | 46 |
|
48 | | - def _handle_tool_execution(tool: ToolUse) -> Tuple[bool, Optional[ToolResult]]: |
49 | | - result = None |
50 | | - tool_succeeded = False |
51 | | - |
| 47 | + def handle(tool: ToolUse) -> Generator[dict[str, Any], None, ToolResult]: |
52 | 48 | tracer = get_tracer() |
53 | 49 | tool_call_span = tracer.start_tool_call_span(tool, parent_span) |
54 | 50 |
|
55 | | - try: |
56 | | - if "toolUseId" not in tool or tool["toolUseId"] not in invalid_tool_use_ids: |
57 | | - tool_name = tool["name"] |
58 | | - tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) |
59 | | - tool_start_time = time.time() |
60 | | - result = handler(tool) |
61 | | - tool_success = result.get("status") == "success" |
62 | | - if tool_success: |
63 | | - tool_succeeded = True |
64 | | - |
65 | | - tool_duration = time.time() - tool_start_time |
66 | | - message = Message(role="user", content=[{"toolResult": result}]) |
67 | | - event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message) |
68 | | - cycle_trace.add_child(tool_trace) |
69 | | - |
70 | | - if tool_call_span: |
71 | | - tracer.end_tool_call_span(tool_call_span, result) |
72 | | - except Exception as e: |
73 | | - if tool_call_span: |
74 | | - tracer.end_span_with_error(tool_call_span, str(e), e) |
75 | | - |
76 | | - return tool_succeeded, result |
77 | | - |
78 | | - any_tool_failed = False |
| 51 | + tool_name = tool["name"] |
| 52 | + tool_trace = Trace(f"Tool: {tool_name}", parent_id=cycle_trace.id, raw_name=tool_name) |
| 53 | + tool_start_time = time.time() |
| 54 | + |
| 55 | + result = handler(tool) |
| 56 | + yield {"result": result} # Placeholder until handler becomes a generator from which we can yield from |
| 57 | + |
| 58 | + tool_success = result.get("status") == "success" |
| 59 | + tool_duration = time.time() - tool_start_time |
| 60 | + message = Message(role="user", content=[{"toolResult": result}]) |
| 61 | + event_loop_metrics.add_tool_usage(tool, tool_duration, tool_trace, tool_success, message) |
| 62 | + cycle_trace.add_child(tool_trace) |
| 63 | + |
| 64 | + if tool_call_span: |
| 65 | + tracer.end_tool_call_span(tool_call_span, result) |
| 66 | + |
| 67 | + return result |
| 68 | + |
| 69 | + def work( |
| 70 | + tool: ToolUse, |
| 71 | + worker_id: int, |
| 72 | + worker_queue: queue.Queue, |
| 73 | + worker_event: threading.Event, |
| 74 | + ) -> ToolResult: |
| 75 | + events = handle(tool) |
| 76 | + |
| 77 | + while True: |
| 78 | + try: |
| 79 | + event = next(events) |
| 80 | + worker_queue.put((worker_id, event)) |
| 81 | + worker_event.wait() |
| 82 | + |
| 83 | + except StopIteration as stop: |
| 84 | + return cast(ToolResult, stop.value) |
| 85 | + |
| 86 | + tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids] |
| 87 | + |
79 | 88 | if parallel_tool_executor: |
80 | 89 | logger.debug( |
81 | 90 | "tool_count=<%s>, tool_executor=<%s> | executing tools in parallel", |
82 | 91 | len(tool_uses), |
83 | 92 | type(parallel_tool_executor).__name__, |
84 | 93 | ) |
85 | | - # Submit all tasks with their associated tools |
86 | | - future_to_tool = { |
87 | | - parallel_tool_executor.submit(_handle_tool_execution, tool_use): tool_use for tool_use in tool_uses |
88 | | - } |
| 94 | + |
| 95 | + worker_queue: queue.Queue[tuple[int, dict[str, Any]]] = queue.Queue() |
| 96 | + worker_events = [threading.Event() for _ in range(len(tool_uses))] |
| 97 | + |
| 98 | + workers = [ |
| 99 | + parallel_tool_executor.submit(work, tool_use, worker_id, worker_queue, worker_events[worker_id]) |
| 100 | + for worker_id, tool_use in enumerate(tool_uses) |
| 101 | + ] |
89 | 102 | logger.debug("tool_count=<%s> | submitted tasks to parallel executor", len(tool_uses)) |
90 | 103 |
|
91 | | - # Collect results truly in parallel using the provided executor's as_completed method |
92 | | - completed_results = [] |
93 | | - try: |
94 | | - for future in parallel_tool_executor.as_completed(future_to_tool): |
95 | | - try: |
96 | | - succeeded, result = future.result() |
97 | | - if result is not None: |
98 | | - completed_results.append(result) |
99 | | - if not succeeded: |
100 | | - any_tool_failed = True |
101 | | - except Exception as e: |
102 | | - tool = future_to_tool[future] |
103 | | - logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], e) |
104 | | - any_tool_failed = True |
105 | | - except TimeoutError: |
106 | | - logger.error("timeout_seconds=<%s> | parallel tool execution timed out", parallel_tool_executor.timeout) |
107 | | - # Process any completed tasks |
108 | | - for future in future_to_tool: |
109 | | - if future.done(): # type: ignore |
110 | | - try: |
111 | | - succeeded, result = future.result(timeout=0) |
112 | | - if result is not None: |
113 | | - completed_results.append(result) |
114 | | - except Exception as tool_e: |
115 | | - tool = future_to_tool[future] |
116 | | - logger.debug("tool_name=<%s> | tool execution failed | %s", tool["name"], tool_e) |
117 | | - else: |
118 | | - # This future didn't complete within the timeout |
119 | | - tool = future_to_tool[future] |
120 | | - logger.debug("tool_name=<%s> | tool execution timed out", tool["name"]) |
121 | | - |
122 | | - any_tool_failed = True |
123 | | - |
124 | | - # Add completed results to tool_results |
125 | | - tool_results.extend(completed_results) |
| 104 | + while not all(worker.done() for worker in workers): |
| 105 | + if not worker_queue.empty(): |
| 106 | + worker_id, event = worker_queue.get() |
| 107 | + yield event |
| 108 | + worker_events[worker_id].set() |
| 109 | + |
| 110 | + tool_results.extend([worker.result() for worker in workers]) |
| 111 | + |
126 | 112 | else: |
127 | 113 | # Sequential execution fallback |
128 | 114 | for tool_use in tool_uses: |
129 | | - succeeded, result = _handle_tool_execution(tool_use) |
130 | | - if result is not None: |
131 | | - tool_results.append(result) |
132 | | - if not succeeded: |
133 | | - any_tool_failed = True |
134 | | - |
135 | | - return any_tool_failed |
| 115 | + result = yield from handle(tool_use) |
| 116 | + tool_results.append(result) |
136 | 117 |
|
137 | 118 |
|
138 | 119 | def validate_and_prepare_tools( |
139 | 120 | message: Message, |
140 | | - tool_uses: List[ToolUse], |
141 | | - tool_results: List[ToolResult], |
142 | | - invalid_tool_use_ids: List[str], |
| 121 | + tool_uses: list[ToolUse], |
| 122 | + tool_results: list[ToolResult], |
| 123 | + invalid_tool_use_ids: list[str], |
143 | 124 | ) -> None: |
144 | 125 | """Validate tool uses and prepare them for execution. |
145 | 126 |
|
|
0 commit comments