Skip to content

Commit fce544d

Browse files
authored
Merge pull request #69 from grasp-technologies/release-0.6.2
routing cleanup
2 parents 2d7086e + 62faf2f commit fce544d

File tree

7 files changed

+85
-58
lines changed

7 files changed

+85
-58
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "grasp_agents"
3-
version = "0.6.1"
3+
version = "0.6.2"
44
description = "Grasp Agents Library"
55
readme = "README.md"
66
requires-python = ">=3.11.4,<4"

src/grasp_agents/llm_agent.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __call__(
4141
self,
4242
memory: "LLMAgentMemory",
4343
*,
44+
instructions: LLMPrompt | None = None,
4445
in_args: Any | None,
4546
ctx: RunContext[Any],
4647
call_id: str,
@@ -90,6 +91,7 @@ def __init__(
9091
react_mode: bool = False,
9192
final_answer_as_tool_call: bool = False,
9293
# Agent memory management
94+
memory: LLMAgentMemory | None = None,
9395
reset_memory_on_run: bool = False,
9496
# Agent run retries
9597
max_retries: int = 0,
@@ -100,7 +102,7 @@ def __init__(
100102

101103
# Agent memory
102104

103-
self._memory: LLMAgentMemory = LLMAgentMemory()
105+
self._memory: LLMAgentMemory = memory or LLMAgentMemory()
104106
self._reset_memory_on_run = reset_memory_on_run
105107

106108
# LLM policy executor
@@ -174,13 +176,18 @@ def prepare_memory(
174176
self,
175177
memory: LLMAgentMemory,
176178
*,
179+
instructions: LLMPrompt | None = None,
177180
in_args: InT | None = None,
178181
ctx: RunContext[Any],
179182
call_id: str,
180183
) -> None:
181184
if is_method_overridden("prepare_memory_impl", self, LLMAgent[Any, Any, Any]):
182185
return self.prepare_memory_impl(
183-
memory=memory, in_args=in_args, ctx=ctx, call_id=call_id
186+
memory=memory,
187+
instructions=instructions,
188+
in_args=in_args,
189+
ctx=ctx,
190+
call_id=call_id,
184191
)
185192

186193
def _memorize_inputs(
@@ -203,7 +210,11 @@ def _memorize_inputs(
203210
system_message = cast("SystemMessage", memory.messages[0])
204211
else:
205212
self.prepare_memory(
206-
memory=memory, in_args=in_args, ctx=ctx, call_id=call_id
213+
memory=memory,
214+
instructions=formatted_sys_prompt,
215+
in_args=in_args,
216+
ctx=ctx,
217+
call_id=call_id,
207218
)
208219

209220
input_message = self._prompt_builder.build_input_message(
@@ -332,6 +343,7 @@ def prepare_memory_impl(
332343
self,
333344
memory: LLMAgentMemory,
334345
*,
346+
instructions: LLMPrompt | None = None,
335347
in_args: InT | None = None,
336348
ctx: RunContext[Any],
337349
call_id: str,

src/grasp_agents/packet.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Sequence
2-
from typing import Generic, Self, TypeVar
2+
from typing import Any, Generic, Self, TypeVar
33
from uuid import uuid4
44

55
from pydantic import BaseModel, ConfigDict, Field, model_validator
@@ -33,6 +33,19 @@ class Packet(BaseModel, Generic[_PayloadT_co]):
3333
def uniform_routing(self) -> Sequence[ProcName] | None:
3434
return is_uniform_routing(self.routing)
3535

36+
@model_validator(mode="before")
37+
@classmethod
38+
def _normalize_routing(cls, data: dict[str, Any]) -> dict[str, Any]:
39+
routing = data.get("routing")
40+
if (
41+
routing
42+
and isinstance(routing, (list, tuple))
43+
and all(isinstance(r, str) for r in routing) # type: ignore[misc]
44+
):
45+
payloads = data.get("payloads", [])
46+
data["routing"] = [routing for _ in range(len(payloads))]
47+
return data
48+
3649
@model_validator(mode="after")
3750
def _validate_routing(self) -> Self:
3851
if self.routing is not None and len(self.payloads) != len(self.routing):
@@ -45,9 +58,11 @@ def split_per_payload(self) -> Sequence["Packet[_PayloadT_co]"] | None:
4558
if self.routing is None:
4659
return None
4760

61+
single_payload = len(self.payloads) == 1
62+
4863
return [
4964
Packet(
50-
id=f"{self.id}/{i}",
65+
id=f"{self.id}/{i}" if not single_payload else self.id,
5166
payloads=[payload],
5267
routing=[recipients],
5368
sender=self.sender,
@@ -70,9 +85,10 @@ def split_by_recipient(
7085
recipient_to_payloads[recipient] = []
7186
recipient_to_payloads[recipient].append(payload)
7287

88+
single_recipient = len(recipient_to_payloads) == 1
7389
return [
7490
Packet(
75-
id=f"{self.id}/{recipient}",
91+
id=f"{self.id}/{recipient}" if not single_recipient else self.id,
7692
payloads=payloads,
7793
routing=[[recipient] for _ in range(len(payloads))],
7894
sender=self.sender,

src/grasp_agents/processors/base_processor.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
class RecipientSelector(Protocol[_OutT_contra, CtxT]):
4646
def __call__(
47-
self, output: _OutT_contra, *, ctx: RunContext[CtxT]
47+
self, output: _OutT_contra, *, ctx: RunContext[CtxT], call_id: str
4848
) -> Sequence[ProcName] | None: ...
4949

5050

@@ -253,7 +253,7 @@ def _validate_recipients(
253253
)
254254

255255
def select_recipients_impl(
256-
self, output: OutT, *, ctx: RunContext[CtxT]
256+
self, output: OutT, *, ctx: RunContext[CtxT], call_id: str
257257
) -> Sequence[ProcName] | None:
258258
raise NotImplementedError
259259

@@ -266,11 +266,15 @@ def add_recipient_selector(
266266

267267
@final
268268
def select_recipients(
269-
self, output: OutT, ctx: RunContext[CtxT]
269+
self, output: OutT, ctx: RunContext[CtxT], call_id: str
270270
) -> Sequence[ProcName] | None:
271271
base_cls = BaseProcessor[Any, Any, Any, Any]
272272
if is_method_overridden("select_recipients_impl", self, base_cls):
273-
return self.select_recipients_impl(output=output, ctx=ctx)
273+
recipients = self.select_recipients_impl(
274+
output=output, ctx=ctx, call_id=call_id
275+
)
276+
self._validate_recipients(recipients, call_id=call_id)
277+
return recipients
274278

275279
return self.recipients
276280

src/grasp_agents/processors/mapping_processor.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,23 +83,32 @@ def _preprocess(
8383

8484
return val_in_args, memory
8585

86+
def _join_routings(
87+
self, routings: list[Sequence[ProcName] | None]
88+
) -> Sequence[Sequence[ProcName]] | None:
89+
if all(r is None for r in routings):
90+
joined_routing = None
91+
else:
92+
joined_routing = [r or [] for r in routings]
93+
return joined_routing
94+
8695
def _postprocess(
8796
self, outputs: list[OutT], call_id: str, ctx: RunContext[CtxT]
8897
) -> Packet[OutT]:
8998
payloads: list[OutT] = []
90-
routing: list[Sequence[ProcName]] | None = []
99+
routings: list[Sequence[ProcName] | None] = []
91100
for output in outputs:
92101
val_output = self._validate_output(output, call_id=call_id)
93102
payloads.append(val_output)
94103

95-
selected_recipients = self.select_recipients(output=val_output, ctx=ctx)
96-
self._validate_recipients(selected_recipients, call_id=call_id)
97-
routing.append(selected_recipients or [])
104+
selected_recipients = self.select_recipients(
105+
output=val_output, ctx=ctx, call_id=call_id
106+
)
107+
routings.append(selected_recipients)
98108

99-
if all(len(r) == 0 for r in routing):
100-
routing = None
109+
routing = self._join_routings(routings)
101110

102-
return Packet(payloads=payloads, sender=self.name, routing=routing)
111+
return Packet(sender=self.name, payloads=payloads, routing=routing)
103112

104113
@agent(name="processor") # type: ignore
105114
@with_retry

src/grasp_agents/processors/parallel_processor.py

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from grasp_agents.tracing_decorators import agent
77

8-
from ..errors import PacketRoutingError
98
from ..memory import MemT
109
from ..packet import Packet
1110
from ..run_context import CtxT, RunContext
@@ -20,6 +19,8 @@
2019
class ParallelProcessor(
2120
BaseProcessor[InT, OutT, MemT, CtxT], Generic[InT, OutT, MemT, CtxT]
2221
):
22+
"""Processor that runs multiple inputs in parallel, each producing one output."""
23+
2324
_generic_arg_to_instance_attr_map: ClassVar[dict[int, str]] = {
2425
0: "_in_type",
2526
1: "_out_type",
@@ -34,6 +35,7 @@ async def _process(
3435
call_id: str,
3536
ctx: RunContext[CtxT],
3637
) -> OutT:
38+
"""Process a single input and return a single output."""
3739
return cast("OutT", in_args)
3840

3941
async def _process_stream(
@@ -48,28 +50,6 @@ async def _process_stream(
4850
output = cast("OutT", in_args)
4951
yield ProcPayloadOutputEvent(data=output, proc_name=self.name, call_id=call_id)
5052

51-
def _validate_parallel_recipients(
52-
self, out_packets: Sequence[Packet[OutT]], call_id: str
53-
) -> None:
54-
if not out_packets:
55-
return
56-
57-
first_packet = out_packets[0]
58-
first_recipients_set = set((first_packet.routing or [[]])[0])
59-
err_kwargs = dict(proc_name=self.name, call_id=call_id)
60-
61-
for p in out_packets[:1]:
62-
recipients_set = set((p.routing or [[]])[0])
63-
if recipients_set != first_recipients_set:
64-
raise PacketRoutingError(
65-
message=(
66-
"Parallel processor requires all parallel outputs to "
67-
"have the same recipients "
68-
f"[proc_name={self.name}; call_id={call_id}]"
69-
),
70-
**err_kwargs, # type: ignore
71-
)
72-
7353
@with_retry
7454
async def _run_single(
7555
self,
@@ -91,15 +71,30 @@ async def _run_single(
9171
)
9272
val_output = self._validate_output(output, call_id=call_id)
9373

94-
recipients = self.select_recipients(output=val_output, ctx=ctx)
95-
self._validate_recipients(recipients, call_id=call_id)
74+
recipients = self.select_recipients(output=val_output, ctx=ctx, call_id=call_id)
9675

9776
return Packet(
9877
payloads=[val_output],
9978
sender=self.name,
10079
routing=[recipients] if recipients is not None else None,
10180
)
10281

82+
def _join_payloads(self, packets: Sequence[Packet[OutT]]) -> list[OutT]:
83+
return [p.payloads[0] for p in packets]
84+
85+
def _join_routings(
86+
self, packets: Sequence[Packet[OutT]]
87+
) -> Sequence[Sequence[str]] | None:
88+
if not packets:
89+
return None
90+
routings = [p.routing[0] if p.routing is not None else None for p in packets]
91+
if all(r is None for r in routings):
92+
joined_routing = None
93+
else:
94+
joined_routing = [r or [] for r in routings]
95+
96+
return joined_routing
97+
10398
async def _run_parallel(
10499
self, in_args: list[InT], call_id: str, ctx: RunContext[CtxT]
105100
) -> Packet[OutT]:
@@ -110,12 +105,11 @@ async def _run_parallel(
110105
for idx, inp in enumerate(in_args)
111106
]
112107
out_packets = await asyncio.gather(*tasks)
113-
self._validate_parallel_recipients(out_packets, call_id=call_id)
114108

115109
return Packet(
116-
payloads=[out_packet.payloads[0] for out_packet in out_packets],
117110
sender=self.name,
118-
routing=out_packets[0].routing,
111+
payloads=self._join_payloads(out_packets),
112+
routing=self._join_routings(out_packets),
119113
)
120114

121115
@agent(name="processor") # type: ignore
@@ -180,8 +174,7 @@ async def _run_single_stream(
180174

181175
val_output = self._validate_output(output, call_id=call_id)
182176

183-
recipients = self.select_recipients(output=val_output, ctx=ctx)
184-
self._validate_recipients(recipients, call_id=call_id)
177+
recipients = self.select_recipients(output=val_output, ctx=ctx, call_id=call_id)
185178

186179
out_packet = Packet[OutT](
187180
payloads=[val_output],
@@ -213,17 +206,10 @@ async def _run_parallel_stream(
213206
else:
214207
yield event
215208

216-
self._validate_parallel_recipients(
217-
out_packets=list(out_packets_map.values()), call_id=call_id
218-
)
219-
220209
out_packet = Packet(
221-
payloads=[
222-
out_packet.payloads[0]
223-
for _, out_packet in sorted(out_packets_map.items())
224-
],
225210
sender=self.name,
226-
routing=out_packets_map[0].routing,
211+
payloads=self._join_payloads(list(out_packets_map.values())),
212+
routing=self._join_routings(list(out_packets_map.values())),
227213
)
228214

229215
yield ProcPacketOutputEvent(

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)