Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from typing import Optional, List, TypedDict, Annotated, Literal, TypeVar, Type, Protocol, cast, Any, Tuple, NotRequired
from typing import Optional, List, TypedDict, Annotated, Literal, TypeVar, Type, Protocol, cast, Any, Tuple, NotRequired, Iterable
from langchain_core.messages import ToolMessage, AnyMessage, SystemMessage, HumanMessage, BaseMessage, AIMessage, RemoveMessage
from langchain_core.tools import InjectedToolCallId, BaseTool
from langchain_core.language_models.base import LanguageModelInput
Expand Down Expand Up @@ -90,6 +90,19 @@ def tool_return(
}
)

def tool_state_update(
tool_call_id: str,
content: str,
**state_diff
) -> Command:
update = {
"messages": [
ToolMessage(tool_call_id=tool_call_id, content=content)
],
**state_diff
}
return Command(update=update)

class FlowInput(TypedDict):
"""
Upper bound on any type used as an input to a workflow.
Expand Down Expand Up @@ -258,10 +271,12 @@ def to_return(state: InputState) -> StateT:

BoundLLM = Runnable[LanguageModelInput, BaseMessage]

SplitTool = tuple[dict[str, Any], BaseTool]

def build_workflow(
state_class: Type[StateT],
input_type: Type[InputState],
tools_list: List[BaseTool],
tools_list: Iterable[BaseTool | SplitTool],
sys_prompt: str,
initial_prompt: str,
output_key: str,
Expand Down Expand Up @@ -305,18 +320,33 @@ def should_end(state: StateT) -> Literal["__end__", "tool_result"]:
if state.get(output_key, None) is not None:
return "__end__"
return "tool_result"

if isinstance(unbound_llm, ChatAnthropic) and (beta_attr := getattr(unbound_llm, "betas", [])) is not None and "context-management-2025-06-27" in beta_attr:
llm = unbound_llm.bind_tools([{
"type": "memory_20250818",
"name": "memory"
} if t.name == "memory" else t for t in tools_list])
else:
llm = unbound_llm.bind_tools(tools_list)

tool_schemas : list[BaseTool | dict] = []
tool_impls : list[BaseTool] = []

supports_memory = isinstance(unbound_llm, ChatAnthropic) and \
(beta_attr := getattr(unbound_llm, "betas", [])) is not None and \
"context-management-2025-06-27" in beta_attr

for t in tools_list:
if isinstance(t, tuple):
tool_schemas.append(t[0])
tool_impls.append(t[1])
elif t.name == "memory" and supports_memory:
tool_schemas.append({
"type": "memory_20250818",
"name": "memory"
})
tool_impls.append(t)
else:
tool_schemas.append(t)
tool_impls.append(t)

llm = unbound_llm.bind_tools(tool_schemas)

# Create initial node and tool node with curried LLM
init_node = initial_node(input_type, state_class, sys_prompt=sys_prompt, initial_prompt=initial_prompt, llm=llm)
tool_node = ToolNode(tools_list, handle_tool_errors=False)
tool_node = ToolNode(tool_impls, handle_tool_errors=False)
tool_result_node = tool_result_generator(state_class, llm)

# Build the graph with fixed input schema, no context
Expand Down
2 changes: 1 addition & 1 deletion summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_summarization_prompt(self, state: StateT) -> str:

Your summary should include:
1. Your current progress on the task, what is done and what remains to be done
2. Any lessons you have learned from invoking the tools; what CVL syntax you've learned, and what lessons you've learned from verification failures (if any)
2. Any lessons you have learned from invoking the tools; e.g., what CVL syntax you've learned, and what lessons you've learned from verification failures (if any)
3. Any lessons you've learned about invoking the various tools; e.g., "using solc8.2 doesn't work because ...". Include any workarounds or advice when invoking the tools

IMPORTANT: your summary must accurately capture the current state of the task. Do NOT include commentary describing past failures, **unless** it is
Expand Down
98 changes: 98 additions & 0 deletions tools/human.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import TypeVar, Callable, Literal, Annotated, get_args, get_origin, cast, Any

from pydantic import create_model, Field, BaseModel
from pydantic.fields import FieldInfo

from langchain_core.tools import BaseTool, InjectedToolCallId, tool
from langchain_core.messages import ToolMessage
from langgraph.prebuilt import InjectedState
from langgraph.types import Command, interrupt

M = TypeVar("M")
S = TypeVar("S")

_injected_state_name = "graphcore_injected_state"

def _process_model_tydict(
t: type[dict]
) -> tuple[dict[str, Any], str | None]:
disc : str | None = None
fields : dict[str, Any] = {}
for (k, v) in t.__annotations__.items():
if get_origin(v) is Literal:
if k != "type":
raise RuntimeError(f"Illegal type annotation: {v} on {k}")
disc = get_args(v)[0]
continue
elif get_origin(v) is Annotated:
a = get_args(v)
if len(a) != 2 or not isinstance(a[1], str):
raise RuntimeError(f"Illegal type annotation: {v} for {k}")
fields[k] = (a[0], Field(description=a[1]))
else:
raise RuntimeError(f"Illegal type annotation: {v} for {k}")
return (fields, disc)

def _process_model_basem(
t: type[BaseModel]
) -> tuple[dict[str, Any], str | None]:
disc : str | None = None
fields : dict[str, Any] = {}
for (k, v) in t.model_fields.items():
assert v.annotation is not None
ty = v.annotation
if get_origin(ty) is Literal:
if k != "type":
raise RuntimeError(f"Illegal type annotation: {v} on {k}")
disc = get_args(v)[0]
continue
fields[k] = (v.annotation, Field(description=v.description))
return (fields, disc)


def human_interaction_tool(
t: type[M],
state: type[S],
name: str,
state_updater: Callable[[S, M, str], dict] = lambda x, y, z: {}
) -> BaseTool:
assert issubclass(t, BaseModel) or issubclass(t, dict)
fields : dict[str, Any]
disc : str | None
if issubclass(t, BaseModel):
(fields, disc) = _process_model_basem(t)
else:
(fields, disc) = _process_model_tydict(t)

fields[_injected_state_name] = (Annotated[state, InjectedState], Field())
fields["tool_call_id"] = (Annotated[str, InjectedToolCallId], Field())

model = create_model(
t.__name__,
__doc__ = t.__doc__,
**cast(dict[str, Any], fields)
)
@tool(name, args_schema=model)
def interaction_tool(
**kwargs
) -> Command:
dict_args = {
k: v for (k, v) in kwargs.items() if k != "tool_call_id" and k != _injected_state_name
}
if disc is not None:
dict_args["type"] = disc
payload : Any
if issubclass(t, BaseModel):
payload = t.model_validate(dict_args)
else:
payload = t(**dict_args)
response = interrupt(payload)
state_update = state_updater(kwargs[_injected_state_name], payload, response)
response_update = {
"messages": [
ToolMessage(content=response, tool_call_id=kwargs["tool_call_id"])
]
}
response_update.update(state_update)
return Command(update=response_update)
return interaction_tool