diff --git a/graph.py b/graph.py index 2b1dbb6..2147f67 100644 --- a/graph.py +++ b/graph.py @@ -13,7 +13,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -from typing import Optional, List, TypedDict, Annotated, Literal, TypeVar, Type, Protocol, cast, Any, Tuple, NotRequired +from typing import Optional, List, TypedDict, Annotated, Literal, TypeVar, Type, Protocol, cast, Any, Tuple, NotRequired, Iterable from langchain_core.messages import ToolMessage, AnyMessage, SystemMessage, HumanMessage, BaseMessage, AIMessage, RemoveMessage from langchain_core.tools import InjectedToolCallId, BaseTool from langchain_core.language_models.base import LanguageModelInput @@ -90,6 +90,19 @@ def tool_return( } ) +def tool_state_update( + tool_call_id: str, + content: str, + **state_diff +) -> Command: + update = { + "messages": [ + ToolMessage(tool_call_id=tool_call_id, content=content) + ], + **state_diff + } + return Command(update=update) + class FlowInput(TypedDict): """ Upper bound on any type used as an input to a workflow. @@ -258,10 +271,12 @@ def to_return(state: InputState) -> StateT: BoundLLM = Runnable[LanguageModelInput, BaseMessage] +SplitTool = tuple[dict[str, Any], BaseTool] + def build_workflow( state_class: Type[StateT], input_type: Type[InputState], - tools_list: List[BaseTool], + tools_list: Iterable[BaseTool | SplitTool], sys_prompt: str, initial_prompt: str, output_key: str, @@ -305,18 +320,33 @@ def should_end(state: StateT) -> Literal["__end__", "tool_result"]: if state.get(output_key, None) is not None: return "__end__" return "tool_result" - - if isinstance(unbound_llm, ChatAnthropic) and (beta_attr := getattr(unbound_llm, "betas", [])) is not None and "context-management-2025-06-27" in beta_attr: - llm = unbound_llm.bind_tools([{ - "type": "memory_20250818", - "name": "memory" - } if t.name == "memory" else t for t in tools_list]) - else: - llm = unbound_llm.bind_tools(tools_list) + + tool_schemas : list[BaseTool | dict] = [] + tool_impls : list[BaseTool] = [] + + supports_memory = isinstance(unbound_llm, ChatAnthropic) and \ + (beta_attr := getattr(unbound_llm, "betas", [])) is not None and \ + "context-management-2025-06-27" in beta_attr + + for t in tools_list: + if isinstance(t, tuple): + tool_schemas.append(t[0]) + tool_impls.append(t[1]) + elif t.name == "memory" and supports_memory: + tool_schemas.append({ + "type": "memory_20250818", + "name": "memory" + }) + tool_impls.append(t) + else: + tool_schemas.append(t) + tool_impls.append(t) + + llm = unbound_llm.bind_tools(tool_schemas) # Create initial node and tool node with curried LLM init_node = initial_node(input_type, state_class, sys_prompt=sys_prompt, initial_prompt=initial_prompt, llm=llm) - tool_node = ToolNode(tools_list, handle_tool_errors=False) + tool_node = ToolNode(tool_impls, handle_tool_errors=False) tool_result_node = tool_result_generator(state_class, llm) # Build the graph with fixed input schema, no context diff --git a/summary.py b/summary.py index ac357f2..4bd0b7c 100644 --- a/summary.py +++ b/summary.py @@ -34,7 +34,7 @@ def get_summarization_prompt(self, state: StateT) -> str: Your summary should include: 1. Your current progress on the task, what is done and what remains to be done -2. Any lessons you have learned from invoking the tools; what CVL syntax you've learned, and what lessons you've learned from verification failures (if any) +2. Any lessons you have learned from invoking the tools; e.g., what CVL syntax you've learned, and what lessons you've learned from verification failures (if any) 3. Any lessons you've learned about invoking the various tools; e.g., "using solc8.2 doesn't work because ...". Include any workarounds or advice when invoking the tools IMPORTANT: your summary must accurately capture the current state of the task. Do NOT include commentary describing past failures, **unless** it is diff --git a/tools/human.py b/tools/human.py new file mode 100644 index 0000000..f131374 --- /dev/null +++ b/tools/human.py @@ -0,0 +1,98 @@ +from typing import TypeVar, Callable, Literal, Annotated, get_args, get_origin, cast, Any + +from pydantic import create_model, Field, BaseModel +from pydantic.fields import FieldInfo + +from langchain_core.tools import BaseTool, InjectedToolCallId, tool +from langchain_core.messages import ToolMessage +from langgraph.prebuilt import InjectedState +from langgraph.types import Command, interrupt + +M = TypeVar("M") +S = TypeVar("S") + +_injected_state_name = "graphcore_injected_state" + +def _process_model_tydict( + t: type[dict] +) -> tuple[dict[str, Any], str | None]: + disc : str | None = None + fields : dict[str, Any] = {} + for (k, v) in t.__annotations__.items(): + if get_origin(v) is Literal: + if k != "type": + raise RuntimeError(f"Illegal type annotation: {v} on {k}") + disc = get_args(v)[0] + continue + elif get_origin(v) is Annotated: + a = get_args(v) + if len(a) != 2 or not isinstance(a[1], str): + raise RuntimeError(f"Illegal type annotation: {v} for {k}") + fields[k] = (a[0], Field(description=a[1])) + else: + raise RuntimeError(f"Illegal type annotation: {v} for {k}") + return (fields, disc) + +def _process_model_basem( + t: type[BaseModel] +) -> tuple[dict[str, Any], str | None]: + disc : str | None = None + fields : dict[str, Any] = {} + for (k, v) in t.model_fields.items(): + assert v.annotation is not None + ty = v.annotation + if get_origin(ty) is Literal: + if k != "type": + raise RuntimeError(f"Illegal type annotation: {v} on {k}") + disc = get_args(v)[0] + continue + fields[k] = (v.annotation, Field(description=v.description)) + return (fields, disc) + + +def human_interaction_tool( + t: type[M], + state: type[S], + name: str, + state_updater: Callable[[S, M, str], dict] = lambda x, y, z: {} +) -> BaseTool: + assert issubclass(t, BaseModel) or issubclass(t, dict) + fields : dict[str, Any] + disc : str | None + if issubclass(t, BaseModel): + (fields, disc) = _process_model_basem(t) + else: + (fields, disc) = _process_model_tydict(t) + + fields[_injected_state_name] = (Annotated[state, InjectedState], Field()) + fields["tool_call_id"] = (Annotated[str, InjectedToolCallId], Field()) + + model = create_model( + t.__name__, + __doc__ = t.__doc__, + **cast(dict[str, Any], fields) + ) + @tool(name, args_schema=model) + def interaction_tool( + **kwargs + ) -> Command: + dict_args = { + k: v for (k, v) in kwargs.items() if k != "tool_call_id" and k != _injected_state_name + } + if disc is not None: + dict_args["type"] = disc + payload : Any + if issubclass(t, BaseModel): + payload = t.model_validate(dict_args) + else: + payload = t(**dict_args) + response = interrupt(payload) + state_update = state_updater(kwargs[_injected_state_name], payload, response) + response_update = { + "messages": [ + ToolMessage(content=response, tool_call_id=kwargs["tool_call_id"]) + ] + } + response_update.update(state_update) + return Command(update=response_update) + return interaction_tool