Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ We built Fantasy to power [Crush](https://github.com/charmbracelet/crush), a hot
- Image models
- Audio models
- PDF uploads
- Provider tools (e.g. web_search)
- Provider tools (partial: Anthropic computer use supported; e.g. web_search not yet)

For things you’d like to see supported, PRs are welcome.

Expand Down
121 changes: 88 additions & 33 deletions agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ type agentSettings struct {
userAgent string
providerOptions ProviderOptions

// TODO: add support for provider tools
tools []AgentTool
maxRetries *int
providerDefinedTools []ProviderDefinedTool
tools []AgentTool
maxRetries *int

model LanguageModel

Expand Down Expand Up @@ -429,7 +429,7 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
}
}

preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)

retryOptions := DefaultRetryOptions()
if opts.MaxRetries != nil {
Expand Down Expand Up @@ -464,7 +464,12 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err
if !ok {
continue
}

// Provider-executed tool calls (e.g. web search) are
// handled by the provider and should not be validated
// or executed by the agent.
if toolCall.ProviderExecuted {
continue
}
// Validate and potentially repair the tool call
validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, stepSystemPrompt, stepInputMessages, a.settings.repairToolCall)
stepToolCalls = append(stepToolCalls, validatedToolCall)
Expand All @@ -473,22 +478,26 @@ func (a *agent) Generate(ctx context.Context, opts AgentCall) (*AgentResult, err

toolResults, err := a.executeTools(ctx, stepTools, stepToolCalls, nil)

// Build step content with validated tool calls and tool results
// Build step content with validated tool calls and tool results.
// Provider-executed tool calls are kept as-is.
stepContent := []Content{}
toolCallIndex := 0
for _, content := range result.Content {
if content.GetType() == ContentTypeToolCall {
// Replace with validated tool call
tc, ok := AsContentType[ToolCallContent](content)
if ok && tc.ProviderExecuted {
stepContent = append(stepContent, content)
continue
}
// Replace with validated tool call.
if toolCallIndex < len(stepToolCalls) {
stepContent = append(stepContent, stepToolCalls[toolCallIndex])
toolCallIndex++
}
} else {
// Keep other content as-is
stepContent = append(stepContent, content)
}
}
// Add tool results
} // Add tool results
for _, result := range toolResults {
stepContent = append(stepContent, result)
}
Expand Down Expand Up @@ -602,9 +611,10 @@ func toResponseMessages(content []Content) []Message {
continue
}
toolParts = append(toolParts, ToolResultPart{
ToolCallID: result.ToolCallID,
Output: result.Result,
ProviderOptions: ProviderOptions(result.ProviderMetadata),
ToolCallID: result.ToolCallID,
Output: result.Result,
ProviderExecuted: result.ProviderExecuted,
ProviderOptions: ProviderOptions(result.ProviderMetadata),
})
}
}
Expand Down Expand Up @@ -813,7 +823,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
}
}

preparedTools := a.prepareTools(stepTools, stepActiveTools, disableAllTools)
preparedTools := a.prepareTools(stepTools, a.settings.providerDefinedTools, stepActiveTools, disableAllTools)

// Start step stream
if opts.OnStepStart != nil {
Expand Down Expand Up @@ -902,7 +912,7 @@ func (a *agent) Stream(ctx context.Context, opts AgentStreamCall) (*AgentResult,
return agentResult, nil
}

func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAllTools bool) []Tool {
func (a *agent) prepareTools(tools []AgentTool, providerDefinedTools []ProviderDefinedTool, activeTools []string, disableAllTools bool) []Tool {
preparedTools := make([]Tool, 0, len(tools))

// If explicitly disabling all tools, return no tools
Expand Down Expand Up @@ -930,6 +940,9 @@ func (a *agent) prepareTools(tools []AgentTool, activeTools []string, disableAll
ProviderOptions: tool.ProviderOptions(),
})
}
for _, tool := range providerDefinedTools {
preparedTools = append(preparedTools, tool)
}
return preparedTools
}

Expand Down Expand Up @@ -1063,6 +1076,15 @@ func WithTools(tools ...AgentTool) AgentOption {
}
}

// WithProviderDefinedTools sets the provider-defined tools for the agent.
// These tools are executed by the provider (e.g. web search) rather
// than by the client.
func WithProviderDefinedTools(tools ...ProviderDefinedTool) AgentOption {
return func(s *agentSettings) {
s.providerDefinedTools = append(s.providerDefinedTools, tools...)
}
}

// WithStopConditions sets the stop conditions for the agent.
func WithStopConditions(conditions ...StopCondition) AgentOption {
return func(s *agentSettings) {
Expand Down Expand Up @@ -1311,29 +1333,62 @@ func (a *agent) processStepStream(ctx context.Context, stream StreamResponse, op
ProviderMetadata: part.ProviderMetadata,
}

// Validate and potentially repair the tool call
validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
stepToolCalls = append(stepToolCalls, validatedToolCall)
stepContent = append(stepContent, validatedToolCall)
// Provider-executed tool calls are handled by the provider
// and should not be validated or executed by the agent.
if toolCall.ProviderExecuted {
stepContent = append(stepContent, toolCall)
if opts.OnToolCall != nil {
err := opts.OnToolCall(toolCall)
if err != nil {
return stepExecutionResult{}, err
}
}
delete(activeToolCalls, part.ID)
} else {
// Validate and potentially repair the tool call
validatedToolCall := a.validateAndRepairToolCall(ctx, toolCall, stepTools, a.settings.systemPrompt, nil, opts.RepairToolCall)
stepToolCalls = append(stepToolCalls, validatedToolCall)
stepContent = append(stepContent, validatedToolCall)

if opts.OnToolCall != nil {
err := opts.OnToolCall(validatedToolCall)
if err != nil {
return stepExecutionResult{}, err
if opts.OnToolCall != nil {
err := opts.OnToolCall(validatedToolCall)
if err != nil {
return stepExecutionResult{}, err
}
}
}

// Determine if tool can run in parallel
isParallel := false
if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
isParallel = tool.Info().Parallel
}
// Determine if tool can run in parallel
isParallel := false
if tool, exists := toolMap[validatedToolCall.ToolName]; exists {
isParallel = tool.Info().Parallel
}

// Send tool call to execution channel
toolChan <- toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel}
// Send tool call to execution channel
toolChan <- toolExecutionRequest{toolCall: validatedToolCall, parallel: isParallel}

// Clean up active tool call
delete(activeToolCalls, part.ID)
// Clean up active tool call
delete(activeToolCalls, part.ID)
}

case StreamPartTypeToolResult:
// Provider-executed tool results (e.g. web search)
// are emitted by the provider and added directly
// to the step content for multi-turn round-tripping.
if part.ProviderExecuted {
resultContent := ToolResultContent{
ToolCallID: part.ID,
ToolName: part.ToolCallName,
ProviderExecuted: true,
ProviderMetadata: part.ProviderMetadata,
}
stepContent = append(stepContent, resultContent)
if opts.OnToolResult != nil {
err := opts.OnToolResult(resultContent)
if err != nil {
return stepExecutionResult{}, err
}
}
}

case StreamPartTypeSource:
sourceContent := SourceContent{
Expand Down
7 changes: 4 additions & 3 deletions content.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,10 @@ func (t ToolCallPart) Options() ProviderOptions {

// ToolResultPart represents a tool result in a message.
type ToolResultPart struct {
ToolCallID string `json:"tool_call_id"`
Output ToolResultOutputContent `json:"output"`
ProviderOptions ProviderOptions `json:"provider_options"`
ToolCallID string `json:"tool_call_id"`
Output ToolResultOutputContent `json:"output"`
ProviderExecuted bool `json:"provider_executed"`
ProviderOptions ProviderOptions `json:"provider_options"`
}

// GetType returns the type of the tool result part.
Expand Down
22 changes: 13 additions & 9 deletions content_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -711,13 +711,15 @@ func (t *ToolCallPart) UnmarshalJSON(data []byte) error {
// MarshalJSON implements json.Marshaler for ToolResultPart.
func (t ToolResultPart) MarshalJSON() ([]byte, error) {
dataBytes, err := json.Marshal(struct {
ToolCallID string `json:"tool_call_id"`
Output ToolResultOutputContent `json:"output"`
ProviderOptions ProviderOptions `json:"provider_options,omitempty"`
ToolCallID string `json:"tool_call_id"`
Output ToolResultOutputContent `json:"output"`
ProviderExecuted bool `json:"provider_executed"`
ProviderOptions ProviderOptions `json:"provider_options,omitempty"`
}{
ToolCallID: t.ToolCallID,
Output: t.Output,
ProviderOptions: t.ProviderOptions,
ToolCallID: t.ToolCallID,
Output: t.Output,
ProviderExecuted: t.ProviderExecuted,
ProviderOptions: t.ProviderOptions,
})
if err != nil {
return nil, err
Expand All @@ -737,16 +739,18 @@ func (t *ToolResultPart) UnmarshalJSON(data []byte) error {
}

var aux struct {
ToolCallID string `json:"tool_call_id"`
Output json.RawMessage `json:"output"`
ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"`
ToolCallID string `json:"tool_call_id"`
Output json.RawMessage `json:"output"`
ProviderExecuted bool `json:"provider_executed"`
ProviderOptions map[string]json.RawMessage `json:"provider_options,omitempty"`
}

if err := json.Unmarshal(mpj.Data, &aux); err != nil {
return err
}

t.ToolCallID = aux.ToolCallID
t.ProviderExecuted = aux.ProviderExecuted

// Unmarshal the Output field
output, err := UnmarshalToolResultOutputContent(aux.Output)
Expand Down
90 changes: 90 additions & 0 deletions examples/computer-use/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package main

// This is a minimal example showing the API plumbing for Anthropic's computer
// use tool. It demonstrates how to wire up the provider, model, and tool, then
// inspect the tool calls that Claude returns.
//
// In a real implementation the caller would execute each action (screenshot,
// click, type, etc.) inside a sandboxed environment (VM, container, or VNC
// session) and feed the results back. The loop would continue with tool results
// until Claude signals it is done.

import (
"context"
"encoding/json"
"fmt"
"os"

"charm.land/fantasy"
"charm.land/fantasy/providers/anthropic"
)

func main() {
// Set up the Anthropic provider.
provider, err := anthropic.New(anthropic.WithAPIKey(os.Getenv("ANTHROPIC_API_KEY")))
if err != nil {
fmt.Fprintln(os.Stderr, "could not create provider:", err)
os.Exit(1)
}

ctx := context.Background()

// Pick the model.
model, err := provider.LanguageModel(ctx, "claude-opus-4-6")
if err != nil {
fmt.Fprintln(os.Stderr, "could not get language model:", err)
os.Exit(1)
}

// Create a computer use tool. This tells Claude the dimensions of the
// virtual display it will be controlling.
computerTool := anthropic.NewComputerUseTool(anthropic.ComputerUseToolOptions{
DisplayWidthPx: 1920,
DisplayHeightPx: 1080,
ToolVersion: anthropic.ComputerUse20251124,
})

// Build a Call with a simple prompt and the computer use tool.
call := fantasy.Call{
Prompt: fantasy.Prompt{
fantasy.NewUserMessage("Take a screenshot of the desktop"),
},
Tools: []fantasy.Tool{computerTool},
}

// Ask the model to generate a response.
resp, err := model.Generate(ctx, call)
if err != nil {
fmt.Fprintln(os.Stderr, "generate failed:", err)
os.Exit(1)
}

// Inspect the response content. Claude will typically reply with one
// or more tool calls describing the actions it wants to perform.
for _, tc := range resp.Content.ToolCalls() {
fmt.Printf("Tool call: %s (id=%s)\n", tc.ToolName, tc.ToolCallID)

// The Input field is a JSON string describing the requested
// action (e.g. {"action": "screenshot"} or
// {"action": "click", "coordinate": [100, 200]}).
var action map[string]any
if err := json.Unmarshal([]byte(tc.Input), &action); err != nil {
fmt.Fprintln(os.Stderr, "could not parse tool input:", err)
os.Exit(1)
}
fmt.Printf(" Action: %v\n", action)

// In a real agent loop you would:
// 1. Execute the action in a sandboxed environment.
// 2. Capture the result (e.g. a screenshot as a base64 image).
// 3. Build a new Call that includes the tool result and send it
// back to model.Generate.
// 4. Repeat until Claude stops requesting tool calls.
fmt.Println(" -> (stub) would execute action and return screenshot")
}

// Print any text content Claude included alongside the tool calls.
if text := resp.Content.Text(); text != "" {
fmt.Println("\nClaude said:", text)
}
}
6 changes: 3 additions & 3 deletions examples/go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cel.dev/expr v0.25.1 h1:1KrZg61W6TWSxuNZ37Xy49ps13NUovb66QLprthtwi4=
cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4=
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaJeYoNhuNEM9kM8SjgRgdeOJUSE=
cloud.google.com/go v0.123.0 h1:2NAUJwPR47q+E35uaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
cloud.google.com/go v0.123.0/go.mod h1:xBoMV08QcqUGuPW65Qfm1o9Y4zKZBpGS+7bImXLTAZU=
cloud.google.com/go/auth v0.18.2 h1:+Nbt5Ev0xEqxlNjd6c+yYUeosQ5TtEUaNcN/3FozlaM=
cloud.google.com/go/auth v0.18.2/go.mod h1:xD+oY7gcahcu7G2SG2DsBerfFxgPAJz17zz2joOFF3M=
Expand All @@ -22,10 +22,10 @@ cloud.google.com/go/trace v1.11.7 h1:kDNDX8JkaAG3R2nq1lIdkb7FCSi1rCmsEtKVsty7p+U
cloud.google.com/go/trace v1.11.7/go.mod h1:TNn9d5V3fQVf6s4SCveVMIBS2LJUqo73GACmq/Tky0s=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0 h1:DHa2U07rk8syqvCge0QIGMCE1WxGj9njT44GH7zNJLQ=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.31.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 h1:UnDZ/zFfG1JhH/DqxIZYU/1CUAlTUScoXD/LcM2Ykk8=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0 h1:UnDZ/zZfG1JhH/DqxIZYU/1CUAlTUScoXD/LcM2Ykk8=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.55.0/go.mod h1:IA1C1U7jO/ENqm/vhi7V9YYpBsp+IMyqNrEN94N7tVc=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.55.0 h1:7t/qx5Ost0s0wbA/VDrByOooURhp+ikYwv20i9Y07TQ=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.55.0/go.mod h1:vB2GH9GAYYJTO3mEn8oYwzEdhlayZIdQz6zdzgUIRvA=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0.55.0/go.mod h1:vB2GH9GAYYJTO3mEn8oYwzEdhlayZIdQz6zxJH/yteYEYCFa8=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0 h1:0s6TxfCu2KHkkZPnBfsQ2y5qia0jl3MMrmBhu3nCOYk=
github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.55.0/go.mod h1:Mf6O40IAyB9zR/1J8nGDDPirZQQPbYJni8Yisy7NTMc=
github.com/MakeNowJust/heredoc v1.0.0 h1:cXCdzVdstXyiTqTvfqk9SDHpKNjxuom+DOlyEeQ4pzQ=
Expand Down
Loading