diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index f942a27..b1eabd2 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -25,7 +25,8 @@ jobs: run: | python -m pip install --upgrade pip pip install -r requirements.txt + pip install pytest - name: Run tests run: | - python -m unittest discover -s . -p "test_*.py" \ No newline at end of file + python -m pytest tests/ -v diff --git a/.gitignore b/.gitignore index df9adf7..c8b5dc9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,4 +5,5 @@ metrics prometheus.yml myserver.crt myserver.key -sandbox.py \ No newline at end of file +sandbox.py +memory-bank \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json index e065b89..b7911be 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,5 +7,8 @@ "test*.py" ], "python.testing.pytestEnabled": false, - "python.testing.unittestEnabled": true + "python.testing.unittestEnabled": true, + "python-envs.defaultEnvManager": "ms-python.python:conda", + "python-envs.defaultPackageManager": "ms-python.python:conda", + "python-envs.pythonProjects": [] } \ No newline at end of file diff --git a/argo_bridge.py b/argo_bridge.py index 480053d..82950a4 100644 --- a/argo_bridge.py +++ b/argo_bridge.py @@ -10,6 +10,14 @@ import httpx from functools import wraps +# Import tool calling functionality +from tool_calls import handle_tools, ToolInterceptor +from tool_calls.output_handle import tool_calls_to_openai, tool_calls_to_openai_stream +from tool_calls.utils import determine_model_family + +# Import centralized logging +from logging_config import get_logger, log_request_summary, log_response_summary, log_tool_processing, log_data_verbose + app = Flask(__name__) CORS(app, @@ -62,7 +70,8 @@ def after_request(response): 'o3mini': 'gpto3mini', 'gpto3mini': 'gpto3mini', 'gpto4mini': 'gpto4mini', - 'o4-mini' : 'gpto4mini', + 'o4-mini': 'gpto4mini', + 'o4mini': 'gpto4mini', 'gpto1': 'gpto1', 'o1': 'gpto1', @@ -217,29 +226,58 @@ def get_api_url(model, endpoint_type): @app.route('/api/chat/completions', methods=['POST']) @app.route('/v1/chat/completions', methods=['POST']) #LMStudio Compatibility def chat_completions(): - logging.info("Received chat completions request") - + logger = get_logger('chat') + data = request.get_json() - logging.info(f"Request Data: {data}") model_base = data.get("model", DEFAULT_MODEL) is_streaming = data.get("stream", False) temperature = data.get("temperature", 0.1) stop = data.get("stop", []) + # Check if request contains tool-related parameters + has_tools = "tools" in data or "tool_choice" in data + + # Log request summary + log_request_summary("/v1/chat/completions", model_base, has_tools) + log_data_verbose("Request data", data) + # Force non-streaming for specific models. Remove once Argo supports streaming for all models. # TODO: TEMP Fake streaming for the new models until Argo supports it is_fake_stream = False if model_base in NON_STREAMING_MODELS and is_streaming: is_fake_stream = True + logger.debug(f"Using fake streaming for {model_base}") + + # Also force fake streaming for tool calls until we implement streaming tool support + if has_tools and is_streaming: + is_fake_stream = True + logger.debug("Using fake streaming for tool calls") if model_base not in MODEL_MAPPING: + logger.error(f"Unsupported model: {model_base}") return jsonify({"error": { "message": f"Model '{model_base}' not supported." }}), 400 model = MODEL_MAPPING[model_base] - logging.debug(f"Received request: {data}") + # Process tool calls if present + if has_tools: + try: + # Determine if we should use native tools or prompt-based tools + model_family = determine_model_family(model) + use_native_tools = model_family in ["openai", "anthropic"] + + tool_count = len(data.get("tools", [])) + log_tool_processing(model_family, tool_count, use_native_tools) + + data = handle_tools(data, native_tools=use_native_tools) + log_data_verbose("Processed request with tools", data) + except Exception as e: + logger.error(f"Tool processing failed: {e}") + return jsonify({"error": { + "message": f"Tool processing failed: {str(e)}" + }}), 400 # Process multimodal content for Gemini models if model_base.startswith('gemini'): @@ -256,43 +294,70 @@ def chat_completions(): "user": user, "model": model, "messages": data['messages'], - "system": "", + "system": data.get("system", ""), "stop": stop, "temperature": temperature } - logging.debug(f"Argo Request {req_obj}") + # Add tool-related fields if they exist (for native tool calling) + if "tools" in data: + req_obj["tools"] = data["tools"] + if "tool_choice" in data: + req_obj["tool_choice"] = data["tool_choice"] + + log_data_verbose("Argo request", req_obj) if is_fake_stream: - logging.info(req_obj) response = requests.post(get_api_url(model, 'chat'), json=req_obj) if not response.ok: - logging.error(f"Internal API error: {response.status_code} {response.reason}") + logger.error(f"Argo API error: {response.status_code} {response.reason}") + log_response_summary("error", model_base) return jsonify({"error": { "message": f"Internal API error: {response.status_code} {response.reason}" }}), 500 json_response = response.json() text = json_response.get("response", "") - logging.debug(f"Response Text {text}") - return Response(_fake_stream_response(text, model), mimetype='text/event-stream') + log_data_verbose("Response text", text) + + # Process tool calls in response if present + if has_tools: + log_response_summary("success", model_base, "tool_calls") + return Response( + _fake_stream_response_with_tools(json_response, model, model_base), + mimetype='text/event-stream' + ) + else: + log_response_summary("success", model_base, "stop") + return Response(_fake_stream_response(text, model), mimetype='text/event-stream') elif is_streaming: - return Response(_stream_chat_response(model, req_obj), mimetype='text/event-stream') + if has_tools: + return Response(_stream_chat_response_with_tools(model, req_obj, model_base), mimetype='text/event-stream') + else: + return Response(_stream_chat_response(model, req_obj), mimetype='text/event-stream') else: response = requests.post(get_api_url(model, 'chat'), json=req_obj) if not response.ok: - logging.error(f"Internal API error: {response.status_code} {response.reason}") + logger.error(f"Argo API error: {response.status_code} {response.reason}") + log_response_summary("error", model_base) return jsonify({"error": { "message": f"Internal API error: {response.status_code} {response.reason}" }}), 500 json_response = response.json() text = json_response.get("response", "") - logging.debug(f"Response Text {text}") - return jsonify(_static_chat_response(text, model_base)) + log_data_verbose("Response text", text) + + # Process tool calls in response if present + if has_tools: + log_response_summary("success", model_base, "tool_calls") + return jsonify(_static_chat_response_with_tools(text, model_base, json_response)) + else: + log_response_summary("success", model_base, "stop") + return jsonify(_static_chat_response(text, model_base)) def _stream_chat_response(model, req_obj): @@ -457,6 +522,174 @@ def convert_multimodal_to_text(messages, model_base): return processed_messages +def _static_chat_response_with_tools(text, model_base, json_response): + """ + Generate static chat response with tool call processing. + """ + # Initialize tool interceptor + tool_interceptor = ToolInterceptor() + + # Determine model family for processing + model_family = determine_model_family(model_base) + + # Process response to extract tool calls + tool_calls, clean_text = tool_interceptor.process( + json_response.get("response", text), + model_family + ) + + # Determine finish reason + finish_reason = "tool_calls" if tool_calls else "stop" + + # Convert tool calls to OpenAI format if present + openai_tool_calls = None + if tool_calls: + openai_tool_calls = tool_calls_to_openai(tool_calls, api_format="chat_completion") + # Convert Pydantic models to dictionaries for JSON serialization + openai_tool_calls = [tool_call.model_dump() for tool_call in openai_tool_calls] + + return { + "id": "argo", + "object": "chat.completion", + "created": int(datetime.datetime.now().timestamp()), + "model": model_base, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": clean_text, + "tool_calls": openai_tool_calls, + }, + "logprobs": None, + "finish_reason": finish_reason + }] + } + + +def _fake_stream_response_with_tools(json_response, model, model_base): + """ + Generate fake streaming response with tool call processing. + """ + # Initialize tool interceptor + tool_interceptor = ToolInterceptor() + + # Determine model family for processing + model_family = determine_model_family(model_base) + + # Process response to extract tool calls + tool_calls, clean_text = tool_interceptor.process( + json_response, + model_family + ) + + # Start with role chunk + begin_chunk = { + "id": 'abc', + "object": "chat.completion.chunk", + "created": int(datetime.datetime.now().timestamp()), + "model": model, + "choices": [{ + "index": 0, + "delta": {'role': 'assistant', 'content': ''}, + "logprobs": None, + "finish_reason": None + }] + } + yield f"data: {json.dumps(begin_chunk)}\n\n" + + # Send text content if present + if clean_text: + content_chunk = { + "id": 'abc', + "object": "chat.completion.chunk", + "created": int(datetime.datetime.now().timestamp()), + "model": model, + "choices": [{ + "index": 0, + "delta": {'content': clean_text}, + "logprobs": None, + "finish_reason": None + }] + } + yield f"data: {json.dumps(content_chunk)}\n\n" + + # Send tool calls if present + if tool_calls: + for i, tool_call in enumerate(tool_calls): + tool_call_chunk = tool_calls_to_openai_stream( + tool_call, + tc_index=i, + api_format="chat_completion" + ) + chunk = { + "id": 'abc', + "object": "chat.completion.chunk", + "created": int(datetime.datetime.now().timestamp()), + "model": model, + "choices": [{ + "index": 0, + "delta": {'tool_calls': [tool_call_chunk.model_dump()]}, + "logprobs": None, + "finish_reason": None + }] + } + yield f"data: {json.dumps(chunk)}\n\n" + + # Send final chunk + finish_reason = "tool_calls" if tool_calls else "stop" + end_chunk = { + "id": 'argo', + "object": "chat.completion.chunk", + "created": int(datetime.datetime.now().timestamp()), + "model": model, + "system_fingerprint": "fp_44709d6fcb", + "choices": [{ + "index": 0, + "delta": {}, + "logprobs": None, + "finish_reason": finish_reason + }] + } + yield f"data: {json.dumps(end_chunk)}\n\n" + yield "data: [DONE]\n\n" + + +def _stream_chat_response_with_tools(model, req_obj, model_base): + """ + Generate streaming response with tool call processing. + Note: This is a placeholder for future real streaming tool support. + For now, it falls back to fake streaming. + """ + # For now, we'll use the non-streaming endpoint and fake stream the result + # TODO: Implement real streaming tool support when Argo supports it + + response = requests.post(get_api_url(model, 'chat'), json=req_obj) + + if not response.ok: + # Return error in streaming format + error_chunk = { + "id": 'error', + "object": "chat.completion.chunk", + "created": int(datetime.datetime.now().timestamp()), + "model": model, + "choices": [{ + "index": 0, + "delta": {'content': f"Error: {response.status_code} {response.reason}"}, + "logprobs": None, + "finish_reason": "stop" + }] + } + yield f"data: {json.dumps(error_chunk)}\n\n" + yield "data: [DONE]\n\n" + return + + json_response = response.json() + text = json_response.get("response", "") + + # Use fake streaming with tool processing + yield from _fake_stream_response_with_tools(text, model, model_base) + + """ ================================= Completions Endpoint @@ -467,7 +700,8 @@ def convert_multimodal_to_text(messages, model_base): @app.route('/completions', methods=['POST']) @app.route('/v1/completions', methods=['POST', 'OPTIONS']) #LMStudio Compatibility def completions(): - logging.info("Received completions request") + logger = get_logger('completions') + data = request.get_json() prompt = data.get("prompt", "") stop = data.get("stop", []) @@ -475,15 +709,17 @@ def completions(): model_base = data.get("model", DEFAULT_MODEL) is_streaming = data.get("stream", False) + log_request_summary("/v1/completions", model_base) + log_data_verbose("Request data", data) + if model_base not in MODEL_MAPPING: + logger.error(f"Unsupported model: {model_base}") return jsonify({"error": { "message": f"Model '{model_base}' not supported." }}), 400 model = MODEL_MAPPING[model_base] - logging.debug(f"Received request: {data}") - user = get_user_from_auth_header() req_obj = { @@ -495,22 +731,25 @@ def completions(): "temperature": temperature } - logging.debug(f"Argo Request {req_obj}") + log_data_verbose("Argo request", req_obj) response = requests.post(get_api_url(model, 'chat'), json=req_obj) if not response.ok: - logging.error(f"Internal API error: {response.status_code} {response.reason}") + logger.error(f"Argo API error: {response.status_code} {response.reason}") + log_response_summary("error", model_base) return jsonify({"error": { "message": f"Internal API error: {response.status_code} {response.reason}" }}), 500 json_response = response.json() text = json_response.get("response", "") - logging.debug(f"Response Text {text}") + log_data_verbose("Response text", text) if is_streaming: + log_response_summary("success", model_base, "stop") return Response(_stream_completions_response(text, model), mimetype='text/event-stream') else: + log_response_summary("success", model_base, "stop") return jsonify(_static_completions_response(text, model_base)) @@ -552,12 +791,17 @@ def _stream_completions_response(text, model): @app.route('/embeddings', methods=['POST']) @app.route('/v1/embeddings', methods=['POST']) def embeddings(): - logging.info("Recieved embeddings request") + logger = get_logger('embeddings') + data = request.get_json() model_base = data.get("model", "v3small") + log_request_summary("/v1/embeddings", model_base) + log_data_verbose("Request data", data) + # Check if the model is supported if model_base not in EMBEDDING_MODEL_MAPPING: + logger.error(f"Unsupported embedding model: {model_base}") return jsonify({"error": { "message": f"Embedding model '{model_base}' not supported." }}), 400 @@ -568,7 +812,15 @@ def embeddings(): input_data = [input_data] user = get_user_from_auth_header() - embedding_vectors = _get_embeddings_from_argo(input_data, model, user) + + try: + embedding_vectors = _get_embeddings_from_argo(input_data, model, user) + except Exception as e: + logger.error(f"Embedding processing failed: {e}") + log_response_summary("error", model_base) + return jsonify({"error": { + "message": f"Embedding processing failed: {str(e)}" + }}), 500 response_data = { "object": "list", @@ -587,10 +839,12 @@ def embeddings(): "index": i }) + log_response_summary("success", model_base) return jsonify(response_data) def _get_embeddings_from_argo(texts, model, user=BRIDGE_USER): + logger = get_logger('embeddings') BATCH_SIZE = 16 all_embeddings = [] @@ -603,18 +857,20 @@ def _get_embeddings_from_argo(texts, model, user=BRIDGE_USER): "prompt": batch_texts } - logging.debug(f"Sending embedding request for batch {i // BATCH_SIZE + 1}: {payload}") + logger.debug(f"Sending embedding request for batch {i // BATCH_SIZE + 1}") + log_data_verbose(f"Embedding batch {i // BATCH_SIZE + 1} payload", payload) response = requests.post(get_api_url(model, 'embed'), json=payload) if not response.ok: - logging.error(f"Embedding API error: {response.status_code} {response.reason}") + logger.error(f"Argo embedding API error: {response.status_code} {response.reason}") raise Exception(f"Embedding API error: {response.status_code} {response.reason}") embedding_response = response.json() batch_embeddings = embedding_response.get("embedding", []) all_embeddings.extend(batch_embeddings) + logger.debug(f"Successfully processed {len(all_embeddings)} embeddings") return all_embeddings """ diff --git a/grafana/provisioning/datasources/prometheus.yml b/grafana/provisioning/datasources/prometheus.yml deleted file mode 100644 index c4df67d..0000000 --- a/grafana/provisioning/datasources/prometheus.yml +++ /dev/null @@ -1,10 +0,0 @@ -apiVersion: 1 - -datasources: - - name: Prometheus - type: prometheus - access: proxy - url: http://prometheus:9090 - isDefault: true - editable: false - version: 1 diff --git a/logging_config.py b/logging_config.py new file mode 100644 index 0000000..f06285f --- /dev/null +++ b/logging_config.py @@ -0,0 +1,137 @@ +""" +Logging Configuration for Argo Bridge + +This module provides centralized logging configuration with support for: +- Different log levels for console and file output +- Environment variable configuration +- Structured logging with appropriate levels +- Optional verbose mode for debugging +""" + +import logging +import os +import sys +from typing import Optional + + +class ArgoLogger: + """Centralized logger configuration for Argo Bridge""" + _instance = None + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(ArgoLogger, cls).__new__(cls) + return cls._instance + + def __init__(self): + # Prevent re-initialization if already initialized + if hasattr(self, '_initialized') and self._initialized: + return + self.logger = None + self._setup_logging() + self._initialized = True + + def _setup_logging(self): + """Setup logging configuration based on environment variables and defaults""" + + # Get configuration from environment variables + log_level = os.getenv('ARGO_LOG_LEVEL', 'INFO').upper() + console_level = os.getenv('ARGO_CONSOLE_LOG_LEVEL', 'WARNING').upper() + file_level = os.getenv('ARGO_FILE_LOG_LEVEL', log_level).upper() + log_file = os.getenv('ARGO_LOG_FILE', 'log_bridge.log') + verbose_mode = os.getenv('ARGO_VERBOSE', 'false').lower() == 'true' + + # If verbose mode is enabled, make console more verbose + if verbose_mode: + console_level = 'DEBUG' + file_level = 'DEBUG' + + # Create logger + self.logger = logging.getLogger('argo_bridge') + self.logger.setLevel(logging.DEBUG) # Set to lowest level, handlers will filter + + # Clear any existing handlers + self.logger.handlers.clear() + + # Create formatters + detailed_formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(funcName)s:%(lineno)d - %(message)s' + ) + simple_formatter = logging.Formatter( + '%(asctime)s - %(levelname)s - %(message)s' + ) + + # File handler - detailed logging + file_handler = logging.FileHandler(log_file) + file_handler.setLevel(getattr(logging, file_level)) + file_handler.setFormatter(detailed_formatter) + self.logger.addHandler(file_handler) + + # Console handler - less verbose by default + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(getattr(logging, console_level)) + console_handler.setFormatter(simple_formatter) + self.logger.addHandler(console_handler) + + # Suppress noisy third-party loggers + logging.getLogger('watchdog').setLevel(logging.CRITICAL + 10) + logging.getLogger('urllib3').setLevel(logging.WARNING) + logging.getLogger('requests').setLevel(logging.WARNING) + + # Log the configuration + self.logger.info(f"Logging initialized - Console: {console_level}, File: {file_level}") + if verbose_mode: + self.logger.debug("Verbose mode enabled") + + def get_logger(self, name: Optional[str] = None) -> logging.Logger: + """Get a logger instance""" + if name: + return logging.getLogger(f'argo_bridge.{name}') + return self.logger + + def log_request_summary(self, endpoint: str, model: str, has_tools: bool = False): + """Log a summary of incoming requests without full payload""" + tools_info = " (with tools)" if has_tools else "" + self.logger.info(f"Request: {endpoint} - Model: {model}{tools_info}") + + def log_response_summary(self, status: str, model: str, finish_reason: str = None): + """Log a summary of responses without full payload""" + reason_info = f" - {finish_reason}" if finish_reason else "" + self.logger.info(f"Response: {status} - Model: {model}{reason_info}") + + def log_tool_processing(self, model_family: str, tool_count: int, native_tools: bool): + """Log tool processing information""" + tool_type = "native" if native_tools else "prompt-based" + self.logger.info(f"Processing {tool_count} tools for {model_family} model using {tool_type} approach") + + def log_data_verbose(self, label: str, data: any, max_length: int = 500): + """Log data only in verbose mode, with optional truncation""" + if self.logger.isEnabledFor(logging.DEBUG): + data_str = str(data) + if len(data_str) > max_length: + data_str = data_str[:max_length] + "... (truncated)" + self.logger.debug(f"{label}: {data_str}") + + +# Global logger instance +_argo_logger = ArgoLogger() + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """Get the Argo Bridge logger""" + return _argo_logger.get_logger(name) + +def log_request_summary(endpoint: str, model: str, has_tools: bool = False): + """Log a summary of incoming requests""" + _argo_logger.log_request_summary(endpoint, model, has_tools) + +def log_response_summary(status: str, model: str, finish_reason: str = None): + """Log a summary of responses""" + _argo_logger.log_response_summary(status, model, finish_reason) + +def log_tool_processing(model_family: str, tool_count: int, native_tools: bool): + """Log tool processing information""" + _argo_logger.log_tool_processing(model_family, tool_count, native_tools) + +def log_data_verbose(label: str, data: any, max_length: int = 500): + """Log data only in verbose mode""" + _argo_logger.log_data_verbose(label, data, max_length) diff --git a/bridge_prod.py b/prod/bridge_prod.py similarity index 100% rename from bridge_prod.py rename to prod/bridge_prod.py diff --git a/docker-compose.yaml b/prod/docker-compose.yaml similarity index 90% rename from docker-compose.yaml rename to prod/docker-compose.yaml index 5d82a46..0313b62 100644 --- a/docker-compose.yaml +++ b/prod/docker-compose.yaml @@ -1,6 +1,8 @@ services: argo_bridge: - build: . + build: + context: .. + dockerfile: ./dockerfile ports: - "443:443" restart: unless-stopped @@ -13,7 +15,7 @@ services: prometheus: image: prom/prometheus:latest volumes: - - ./prometheus.yml:/etc/prometheus/prometheus.yml + - ./prometheus.yml.template:/etc/prometheus/prometheus.yml - prometheus_data:/prometheus ports: - "127.0.0.1:9090:9090" @@ -34,7 +36,6 @@ services: - grafana_data:/var/lib/grafana - ./grafana/provisioning/datasources:/etc/grafana/provisioning/datasources - ./grafana/provisioning/dashboards:/etc/grafana/provisioning/dashboards - - ./grafana/dashboards:/var/lib/grafana/dashboards ports: - "127.0.0.1:3000:3000" restart: unless-stopped @@ -49,4 +50,4 @@ services: volumes: prometheus_data: grafana_data: - metrics_data: \ No newline at end of file + metrics_data: diff --git a/dockerfile b/prod/dockerfile similarity index 80% rename from dockerfile rename to prod/dockerfile index 5f84d3c..11ec3bf 100644 --- a/dockerfile +++ b/prod/dockerfile @@ -16,4 +16,4 @@ RUN mkdir -p /app/metrics && chmod 777 /app/metrics EXPOSE 80 # Command to run your application -CMD ["gunicorn", "--config", "gunicorn_config.py", "bridge_prod:prod_app"] \ No newline at end of file +CMD ["gunicorn", "--config", "prod/gunicorn_config.py", "prod.bridge_prod:prod_app"] diff --git a/grafana/dashboards/argo-bridge-dashboard.json b/prod/grafana/dashboards/argo-bridge-dashboard.json similarity index 100% rename from grafana/dashboards/argo-bridge-dashboard.json rename to prod/grafana/dashboards/argo-bridge-dashboard.json diff --git a/grafana/provisioning/dashboards/argo-bridge-dashboards.yml b/prod/grafana/provisioning/dashboards/argo-bridge-dashboards.yml similarity index 100% rename from grafana/provisioning/dashboards/argo-bridge-dashboards.yml rename to prod/grafana/provisioning/dashboards/argo-bridge-dashboards.yml diff --git a/gunicorn_config.py b/prod/gunicorn_config.py similarity index 100% rename from gunicorn_config.py rename to prod/gunicorn_config.py diff --git a/prometheus.yml.template b/prod/prometheus.yml.template similarity index 100% rename from prometheus.yml.template rename to prod/prometheus.yml.template diff --git a/readme.md b/readme.md index e54c198..34fba8e 100644 --- a/readme.md +++ b/readme.md @@ -5,7 +5,25 @@ This project provides a compatibility layer that transforms OpenAI-style API req ## Downstream Integration -Several tools have been tested with the bridge, including IDE integrations, web UI's, and python libraries. Setup guides for these tools tools are located in the [downstream_config.md](downstream_config.md). +Several tools have been tested with the bridge, including IDE integrations, web UI's, and python libraries. Setup guides for these tools tools are located in the [downstream_config.md](downstream_config.md). + +## Features + +### Tool Calling + +The bridge supports comprehensive tool calling, including: +- **Native Tool Calling**: For providers like OpenAI, Anthropic, and Google. +- **Prompt-Based Fallback**: For models without native tool support. +- **Streaming and Non-Streaming**: Support for both modes. + +Tool calling is integrated automatically. Simply include `tools` and `tool_choice` in your API requests. + +### Logging + +The server uses a configurable logging system with separate levels for console and file output. +- **Default**: `WARNING` on console, `INFO` to `log_bridge.log`. +- **Verbose Mode**: Set `ARGO_VERBOSE=true` for `DEBUG` level logging. +- **Customization**: Use environment variables like `ARGO_CONSOLE_LOG_LEVEL` and `ARGO_FILE_LOG_LEVEL` to control verbosity. ## Setup diff --git a/requirements.txt b/requirements.txt index d77842e..4e8e0a7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,8 @@ tqdm==4.67.1 flask-cors==5.0.1 httpx==0.28.1 gunicorn==23.0.0 -prometheus-client \ No newline at end of file +prometheus-client +pydantic>=2.0.0 +pytest +pytest-mock +openai \ No newline at end of file diff --git a/test_server.py b/tests/test_server.py similarity index 100% rename from test_server.py rename to tests/test_server.py diff --git a/tests/test_tool_calling.py b/tests/test_tool_calling.py new file mode 100644 index 0000000..25922cc --- /dev/null +++ b/tests/test_tool_calling.py @@ -0,0 +1,267 @@ +import json +import pytest +import requests +from openai import OpenAI + +# Configuration +BRIDGE_URL = "http://localhost:7285" # Default argo_bridge URL +API_KEY = "dummy" # argo_bridge doesn't require real API keys + +# Define tools for testing +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given city", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city name" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "description": "Temperature unit" + } + }, + "required": ["location"] + } + } + }, + { + "type": "function", + "function": { + "name": "calculate", + "description": "Perform basic mathematical calculations", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Mathematical expression to evaluate (e.g., '2 + 3 * 4')" + } + }, + "required": ["expression"] + } + } + } +] + +@pytest.fixture(scope="module") +def openai_client(): + """Fixture to initialize OpenAI client pointing to argo_bridge.""" + return OpenAI( + api_key=API_KEY, + base_url=f"{BRIDGE_URL}/v1" + ) + +@pytest.mark.parametrize("test_case", [ + { + "name": "OpenAI GPT-4o with auto tool choice", + "model": "gpt-4o", + "tool_choice": "auto", + "message": "What's the weather like in Paris?", + "expected_tool": "get_weather" + }, + { + "name": "Claude Sonnet with required tool choice", + "model": "claudesonnet35v2", + "tool_choice": "required", + "message": "Calculate 15 * 23 + 7", + "expected_tool": "calculate" + }, + { + "name": "Gemini with specific tool choice", + "model": "gemini25flash", + "tool_choice": {"type": "function", "function": {"name": "get_weather"}}, + "message": "Tell me about the weather in Tokyo", + "expected_tool": "get_weather" + } +]) +def test_with_requests(test_case, mocker): + """Test tool calling using raw HTTP requests.""" + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + { + "message": { + "tool_calls": [ + { + "function": { + "name": test_case["expected_tool"], + "arguments": "{}" + } + } + ] + }, + "finish_reason": "tool_calls" + } + ] + } + mocker.patch('requests.post', return_value=mock_response) + + payload = { + "model": test_case["model"], + "messages": [ + {"role": "user", "content": test_case["message"]} + ], + "tools": TOOLS, + "tool_choice": test_case["tool_choice"], + "temperature": 0.1 + } + + response = requests.post( + f"{BRIDGE_URL}/v1/chat/completions", + json=payload, + headers={"Content-Type": "application/json"}, + timeout=30 + ) + + assert response.status_code == 200 + result = response.json() + assert "choices" in result + choice = result["choices"][0] + message = choice["message"] + + assert message.get("tool_calls") is not None + tool_call = message["tool_calls"][0] + assert tool_call["function"]["name"] == test_case["expected_tool"] + assert choice["finish_reason"] == "tool_calls" + +def test_conversation_with_tools(openai_client, mocker): + """Test a multi-turn conversation with tool calls.""" + # Mock the first call to create + mock_response1 = mocker.Mock() + mock_tool_call = mocker.Mock() + mock_tool_call.function.name = "get_weather" + mock_tool_call.function.arguments = '{"city": "New York"}' + mock_tool_call.id = "call_123" + mock_response1.choices = [mocker.Mock()] + mock_response1.choices[0].message.tool_calls = [mock_tool_call] + mock_response1.choices[0].message.content = None + + # Mock the second call to create + mock_response2 = mocker.Mock() + mock_response2.choices = [mocker.Mock()] + mock_response2.choices[0].message.content = "The weather in New York is Sunny, 22°C" + + mocker.patch.object(openai_client.chat.completions, 'create', side_effect=[mock_response1, mock_response2]) + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather information for a city", + "parameters": { + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"} + }, + "required": ["city"] + } + } + } + ] + + messages = [ + {"role": "user", "content": "What's the weather like in New York?"} + ] + + # First request + response = openai_client.chat.completions.create( + model="gpt-4o", + messages=messages, + tools=tools, + tool_choice="auto" + ) + + assistant_message = response.choices[0].message + messages.append({ + "role": "assistant", + "content": assistant_message.content, + "tool_calls": [tc.model_dump() for tc in assistant_message.tool_calls] if assistant_message.tool_calls else None + }) + + assert assistant_message.tool_calls is not None + tool_call = assistant_message.tool_calls[0] + assert tool_call.function.name == "get_weather" + + # Simulate tool execution result + tool_result = f"Weather in {json.loads(tool_call.function.arguments)['city']}: Sunny, 22°C" + messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "content": tool_result + }) + + # Follow-up request with tool results + response2 = openai_client.chat.completions.create( + model="gpt-4o", + messages=messages + ) + + final_message = response2.choices[0].message + assert final_message.content is not None + assert "Sunny" in final_message.content + + +def test_streaming_with_text_and_tool_call(openai_client, mocker): + """Test streaming response with both text and a tool call.""" + # Mock the streaming response + mock_stream = mocker.MagicMock() + + # Define the chunks to be returned by the stream + # Create tool call function mock + tool_call_function = mocker.Mock() + tool_call_function.name = "get_weather" + tool_call_function.arguments = '{"location": "Chicago"}' + + chunks = [ + # 1. Role chunk + mocker.Mock(choices=[mocker.Mock(delta=mocker.Mock(role='assistant', content=None, tool_calls=None))]), + # 2. Text content chunk + mocker.Mock(choices=[mocker.Mock(delta=mocker.Mock(content="Of course, I can help with that.", tool_calls=None))]), + # 3. Tool call chunk + mocker.Mock(choices=[mocker.Mock(delta=mocker.Mock(content=None, tool_calls=[ + mocker.Mock( + id="call_456", + function=tool_call_function + ) + ]))]), + # 4. Final empty chunk + mocker.Mock(choices=[mocker.Mock(delta=mocker.Mock(content=None, tool_calls=None), finish_reason="tool_calls")]) + ] + + mock_stream.__iter__.return_value = iter(chunks) + mocker.patch.object(openai_client.chat.completions, 'create', return_value=mock_stream) + + # Make the streaming request + stream = openai_client.chat.completions.create( + model="claudesonnet35v2", + messages=[{"role": "user", "content": "What is the weather in Chicago?"}], + tools=TOOLS, + stream=True, + ) + + # Process the stream and check the order + received_text = None + received_tool_call = None + + for chunk in stream: + if chunk.choices[0].delta.content: + assert received_tool_call is None, "Text chunk received after tool_call chunk" + received_text = chunk.choices[0].delta.content + + if chunk.choices[0].delta.tool_calls: + assert received_text is not None, "Tool_call chunk received before text chunk" + received_tool_call = chunk.choices[0].delta.tool_calls[0] + + # Final assertions + assert received_text == "Of course, I can help with that." + assert received_tool_call.function.name == "get_weather" + assert "Chicago" in received_tool_call.function.arguments diff --git a/tool_calls/__init__.py b/tool_calls/__init__.py new file mode 100644 index 0000000..79cbf69 --- /dev/null +++ b/tool_calls/__init__.py @@ -0,0 +1,46 @@ +""" +Tool Calling Module +=================== + +This module provides a comprehensive toolkit for handling tool calls in Large Language Models (LLMs). +It offers a suite of utilities for converting tool calls, definitions, and choices between different +API formats, including OpenAI, Anthropic, and Google Gemini. + +Core functionalities include: +- Universal middleware for seamless conversion of tool-related data structures. +- Robust input and output handling for both native and prompt-based tool calling. +- Pydantic-based type definitions for clear, validated data models. + +Key Classes and Functions: +- `ToolCall`: A universal representation of a tool call. +- `Tool`: A universal representation of a tool definition. +- `ToolChoice`: A universal representation of a tool choice strategy. +- `handle_tools`: A function to process and convert incoming tool-related requests. +- `ToolInterceptor`: A class to process and extract tool calls from model responses. + +Usage Example: + from tool_calls import Tool, ToolCall, handle_tools, ToolInterceptor + + # Define a tool + my_tool = Tool(name="get_weather", description="Fetches weather data.", parameters={...}) + + # Process an incoming request + processed_request = handle_tools(request_data) + + # Intercept and process a model's response + interceptor = ToolInterceptor() + tool_calls, text_content = interceptor.process(response_content) +""" + +from .handler import Tool, ToolCall, ToolChoice +from .input_handle import handle_tools +from .output_handle import ToolInterceptor +from .types import * + +__all__ = [ + "Tool", + "ToolCall", + "ToolChoice", + "handle_tools", + "ToolInterceptor", +] diff --git a/tool_calls/handler.py b/tool_calls/handler.py new file mode 100644 index 0000000..e37440e --- /dev/null +++ b/tool_calls/handler.py @@ -0,0 +1,534 @@ +""" +Universal Tool Call Middleware Module + +This module provides universal middleware classes for converting tool calls, tool definitions, +and tool choice data between different API formats. + +Supported API formats include: +- OpenAI Chat Completions API +- OpenAI Responses API +- Anthropic Claude API +- Google Gemini API (partial support) + +Main classes: +- ToolCall: Universal representation of tool call data +- Tool: Universal representation of tool definition data +- ToolChoice: Universal representation of tool choice strategy +- NamedTool: Simple representation of named tools + +Usage example: + # Create tool call from OpenAI format + tool_call = ToolCall.from_entry(openai_data, api_format="openai-chatcompletion") + + # Convert to Anthropic format + anthropic_data = tool_call.to_tool_call("anthropic") + + # Serialize to dictionary + serialized = tool_call.serialize("anthropic") +""" + +import json +from typing import Any, Dict, Literal, Union + +from pydantic import BaseModel + +from .types import ( + ChatCompletionMessageToolCall, + ChatCompletionNamedToolChoiceParam, + ChatCompletionToolParam, + Function, + FunctionDefinition, + FunctionDefinitionCore, + FunctionTool, + ResponseFunctionToolCall, + ToolChoiceAnyParam, + ToolChoiceAutoParam, + ToolChoiceFunctionParam, + ToolChoiceNoneParam, + ToolChoiceToolParam, + ToolParam, + ToolUseBlock, +) +from .utils import API_FORMATS + + +class ToolCall(BaseModel): + """ + Universal tool call middleware class supporting conversion between multiple API formats. + + This class serves as a bridge between different API formats (OpenAI, Anthropic, Google, etc.), + allowing loading tool call data from any supported format and converting to other formats. + + Attributes: + id: Unique identifier for the tool call + name: Name of the function to be called + arguments: Function arguments stored as JSON string format + """ + + id: str + """Unique identifier for the tool call""" + name: str + """Name of the function to be called""" + arguments: str + """Function arguments stored as JSON string format""" + + @classmethod + def from_entry( + cls, + tool_call: Dict[str, Any], + *, + api_format: API_FORMATS = "openai-chatcompletion", + ) -> "ToolCall": + """ + Create a ToolCall instance from dictionary data in the specified API format. + + Args: + tool_call: Dictionary containing tool call information + api_format: API format type, supports openai, openai-response, anthropic, etc. + + Returns: + ToolCall: Created tool call instance + + Raises: + ValueError: When API format is not supported + NotImplementedError: When API format is not yet implemented + """ + if api_format in ["openai", "openai-chatcompletion"]: + origin_tool_call = ChatCompletionMessageToolCall.model_validate(tool_call) + return cls( + id=origin_tool_call.id, + name=origin_tool_call.function.name, + arguments=origin_tool_call.function.arguments, + ) + elif api_format == "openai-response": + origin_tool_call = ResponseFunctionToolCall.model_validate(tool_call) + return cls( + id=origin_tool_call.call_id, + name=origin_tool_call.name, + arguments=origin_tool_call.arguments, + ) + elif api_format == "anthropic": + origin_tool_call = ToolUseBlock.model_validate(tool_call) + arguments_str = ( + json.dumps(origin_tool_call.input) + if not isinstance(origin_tool_call.input, str) + else origin_tool_call.input + ) + return cls( + id=origin_tool_call.id, + name=origin_tool_call.name, + arguments=arguments_str, + ) + elif api_format == "google": + # TODO: Implement Google API format + raise NotImplementedError("Google API format is not supported yet.") + else: + raise ValueError(f"Unsupported API format: {api_format}") + + from_dict = from_entry + + def to_tool_call( + self, api_format: Union[API_FORMATS, Literal["general"]] = "general" + ) -> Union[ + "ToolCall", + ChatCompletionMessageToolCall, + ResponseFunctionToolCall, + ToolUseBlock, + ]: + if api_format in ["openai", "openai-chatcompletion"]: + tool_call = ChatCompletionMessageToolCall( + id=self.id, + function=Function( + name=self.name, + arguments=self.arguments, + ), + ) + + elif api_format == "openai-response": + tool_call = ResponseFunctionToolCall( + call_id=self.id, + name=self.name, + arguments=self.arguments, + ) + + elif api_format == "anthropic": + try: + input_data = ( + json.loads(self.arguments) + if isinstance(self.arguments, str) + else self.arguments + ) + except json.JSONDecodeError: + input_data = self.arguments + + tool_call = ToolUseBlock( + id=self.id, + name=self.name, + input=input_data, + ) + + elif api_format == "google": + raise NotImplementedError("Google API format is not supported yet.") + + elif api_format == "general": + return self + else: + raise ValueError(f"Unsupported API format: {api_format}") + + return tool_call + + def serialize( + self, api_format: Union[API_FORMATS, Literal["general"]] = "general" + ) -> Dict[str, Any]: + return self.to_tool_call(api_format).model_dump() + + def __str__(self) -> str: + return f"ToolCall(id={self.id}, name={self.name}, arguments={self.arguments})" + + def __repr__(self) -> str: + return self.__str__() + + +class Tool(BaseModel): + """ + Universal tool definition middleware class supporting conversion between multiple API formats. + + This class represents tool/function definition information, including name, description, and parameter schema. + It can load tool definitions from different API formats and convert to other formats. + + Attributes: + name: Name of the tool/function + description: Description of the tool/function + parameters: Parameter schema of the tool/function, usually in JSON Schema format + """ + + name: str + """Name of the tool/function""" + description: str + """Description of the tool/function""" + parameters: Dict[str, Any] + """Parameter schema of the tool/function, usually in JSON Schema format""" + + @classmethod + def from_entry( + cls, tool: Dict[str, Any], *, api_format: API_FORMATS = "openai-chatcompletion" + ) -> "Tool": + if api_format in ["openai", "openai-chatcompletion"]: + # For OpenAI format, tool should be ChatCompletionToolParam format + origin_tool = ChatCompletionToolParam.model_validate(tool) + return Tool( + name=origin_tool.function.name, + description=origin_tool.function.description, + parameters=origin_tool.function.parameters, + ) + elif api_format == "openai-response": + origin_tool = FunctionTool.model_validate(tool) + return Tool( + name=origin_tool.name, + description=origin_tool.description, + parameters=origin_tool.parameters, + ) + elif api_format == "anthropic": + origin_tool = ToolParam.model_validate(tool) + # Ensure input_schema is in dictionary format + if hasattr(origin_tool.input_schema, "model_dump"): + parameters = origin_tool.input_schema.model_dump() + elif isinstance(origin_tool.input_schema, dict): + parameters = origin_tool.input_schema + else: + parameters = dict(origin_tool.input_schema) + + return Tool( + name=origin_tool.name, + description=origin_tool.description, + parameters=parameters, + ) + elif api_format == "google": + # TODO: Implement Google tool format + raise NotImplementedError("Google tool format not implemented") + else: + raise ValueError(f"Invalid API format: {api_format}") + + from_dict = from_entry + + def to_tool( + self, api_format: Union[API_FORMATS, Literal["general"]] = "general" + ) -> Union[ + "Tool", + ChatCompletionToolParam, + FunctionTool, + ToolParam, + ]: + if api_format in ["openai", "openai-chatcompletion"]: + tool = ChatCompletionToolParam( + function=FunctionDefinition( + name=self.name, + description=self.description, + parameters=self.parameters, + ) + ) + elif api_format == "openai-response": + tool = FunctionTool( + name=self.name, + description=self.description, + parameters=self.parameters, + strict=False, + ) + elif api_format == "anthropic": + tool = ToolParam( + name=self.name, + description=self.description, + input_schema=self.parameters, + ) + elif api_format == "google": + # TODO: Implement Google tool format + raise NotImplementedError("Google tool format not implemented") + + elif api_format == "general": + tool = self + + else: + raise ValueError(f"Invalid API format: {api_format}") + + return tool + + def serialize( + self, api_format: Union[API_FORMATS, Literal["general"]] = "general" + ) -> Dict[str, Any]: + return self.to_tool(api_format).model_dump() + + def __str__(self) -> str: + return f"Tool(name={self.name}, description={self.description}, parameters={self.parameters})" + + def __repr__(self) -> str: + return self.__str__() + + +class NamedTool(BaseModel): + name: str + + def __str__(self) -> str: + return f"NamedTool(name={self.name})" + + def __repr__(self) -> str: + return self.__str__() + + +class ToolChoice(BaseModel): + """ + Universal tool choice middleware class supporting conversion between multiple API formats. + + This class represents tool choice strategy, which can be string-type choices (like auto, required, none) + or specify a specific tool name. Supports conversion between different API formats. + + Attributes: + choice: Tool choice strategy, can be "optional" (auto), "none" (don't use), + "any" (must use) or NamedTool instance (specific tool) + """ + + choice: Union[Literal["optional", "none", "any"], NamedTool] + """Tool choice strategy""" + + @staticmethod + def _str_triage(data: str) -> "ToolChoice": + if data == "auto": + return ToolChoice(choice="optional") + elif data == "required": + return ToolChoice(choice="any") + elif data == "none": + return ToolChoice(choice="none") + else: + raise ValueError(f"Invalid tool choice: {data}") + + @classmethod + def from_entry( + cls, + data: Union[str, Dict[str, Any]], + *, + api_format: API_FORMATS = "openai-chatcompletion", + ) -> "ToolChoice": + """ + Create a ToolChoice instance from data in the specified API format. + + Args: + data: Tool choice data, can be string or dictionary + api_format: API format type + + Returns: + ToolChoice: Created tool choice instance + + Raises: + ValueError: When data format is invalid or API format is not supported + NotImplementedError: When API format is not yet implemented + """ + if api_format in ["openai", "openai-chatcompletion"]: + return cls._handle_openai_chatcompletion(data) + elif api_format == "openai-response": + return cls._handle_openai_response(data) + elif api_format == "anthropic": + return cls._handle_anthropic(data) + elif api_format == "google": + raise NotImplementedError("Google API format is not supported yet.") + else: + raise ValueError(f"Unsupported API format: {api_format}") + + @classmethod + def _handle_openai_chatcompletion( + cls, data: Union[str, Dict[str, Any]] + ) -> "ToolChoice": + """Handle OpenAI Chat Completions API format tool_choice""" + if isinstance(data, str): + return cls._str_triage(data) + elif isinstance(data, dict): + # ChatCompletionNamedToolChoiceParam format: {"type": "function", "function": {"name": "..."}} + if "function" in data and "name" in data["function"]: + return cls(choice=NamedTool(name=data["function"]["name"])) + else: + raise ValueError( + f"Invalid OpenAI chat completion tool choice format: {data}" + ) + else: + raise ValueError(f"Invalid tool choice data type: {type(data)}") + + @classmethod + def _handle_openai_response(cls, data: Union[str, Dict[str, Any]]) -> "ToolChoice": + """Handle OpenAI Responses API format tool_choice""" + if isinstance(data, str): + return cls._str_triage(data) + elif isinstance(data, dict): + # ToolChoiceFunctionParam format: {"type": "function", "name": "..."} + if "name" in data: + return cls(choice=NamedTool(name=data["name"])) + else: + raise ValueError(f"Invalid OpenAI response tool choice format: {data}") + else: + raise ValueError(f"Invalid tool choice data type: {type(data)}") + + @classmethod + def _handle_anthropic(cls, data: Union[str, Dict[str, Any]]) -> "ToolChoice": + """Handle Anthropic API format tool_choice""" + if isinstance(data, dict): + tool_type = data.get("type") + if tool_type == "auto": + return cls(choice="optional") + elif tool_type == "any": + return cls(choice="any") + elif tool_type == "none": + return cls(choice="none") + elif tool_type == "tool": + if "name" in data: + return cls(choice=NamedTool(name=data["name"])) + else: + raise ValueError( + "Anthropic tool choice with type 'tool' must have 'name' field" + ) + else: + raise ValueError(f"Invalid Anthropic tool choice type: {tool_type}") + else: + raise ValueError( + f"Anthropic tool choice must be a dictionary, got: {type(data)}" + ) + + def to_tool_choice( + self, + api_format: Union[API_FORMATS, Literal["general"]] = "general", + ) -> Union[str, Dict[str, Any], BaseModel, "ToolChoice"]: + """ + Convert ToolChoice instance to data in the specified API format. + + Args: + api_format: Target API format + + Returns: + Converted tool choice data + + Raises: + ValueError: When tool choice is invalid or API format is not supported + NotImplementedError: When API format is not yet implemented + """ + if api_format in ["openai", "openai-chatcompletion"]: + return self._to_openai_chatcompletion() + elif api_format == "openai-response": + return self._to_openai_response() + elif api_format == "anthropic": + return self._to_anthropic() + elif api_format == "google": + raise NotImplementedError("Google API format not implemented yet") + elif api_format == "general": + return self + else: + raise ValueError(f"Invalid API format: {api_format}") + + def _to_openai_chatcompletion( + self, + ) -> Union[str, ChatCompletionNamedToolChoiceParam]: + """Convert to OpenAI Chat Completions API format""" + if isinstance(self.choice, str): + if self.choice == "optional": + return "auto" + elif self.choice == "any": + return "required" + elif self.choice == "none": + return "none" + else: + raise ValueError(f"Invalid tool choice: {self.choice}") + elif isinstance(self.choice, NamedTool): + return ChatCompletionNamedToolChoiceParam( + function=FunctionDefinitionCore(name=self.choice.name) + ) + else: + raise ValueError(f"Invalid tool choice type: {type(self.choice)}") + + def _to_openai_response(self) -> Union[str, ToolChoiceFunctionParam]: + """Convert to OpenAI Responses API format""" + if isinstance(self.choice, str): + if self.choice == "optional": + return "auto" + elif self.choice == "any": + return "required" + elif self.choice == "none": + return "none" + else: + raise ValueError(f"Invalid tool choice: {self.choice}") + elif isinstance(self.choice, NamedTool): + return ToolChoiceFunctionParam(name=self.choice.name) + else: + raise ValueError(f"Invalid tool choice type: {type(self.choice)}") + + def _to_anthropic( + self, + ) -> Union[ + ToolChoiceAutoParam, + ToolChoiceAnyParam, + ToolChoiceNoneParam, + ToolChoiceToolParam, + ]: + """Convert to Anthropic API format""" + if isinstance(self.choice, str): + if self.choice == "optional": + return ToolChoiceAutoParam() + elif self.choice == "any": + return ToolChoiceAnyParam() + elif self.choice == "none": + return ToolChoiceNoneParam() + else: + raise ValueError(f"Invalid tool choice: {self.choice}") + elif isinstance(self.choice, NamedTool): + return ToolChoiceToolParam(name=self.choice.name) + else: + raise ValueError(f"Invalid tool choice type: {type(self.choice)}") + + def serialize( + self, + api_format: Union[API_FORMATS, Literal["general"]] = "general", + ) -> Union[Dict[str, Any], str]: + serialized = self.to_tool_choice(api_format) + return ( + serialized.model_dump() if hasattr(serialized, "model_dump") else serialized + ) + + def __str__(self): + return f"ToolChoice(choice={self.choice})" + + def __repr__(self): + return self.__str__() diff --git a/tool_calls/input_handle.py b/tool_calls/input_handle.py new file mode 100644 index 0000000..1a6032d --- /dev/null +++ b/tool_calls/input_handle.py @@ -0,0 +1,510 @@ +""" +input_handle.py +--------------- + +Tool call input handling module for converting between different LLM provider formats. + +This module provides functionality for: +1. Prompt-based tool handling (for models without native tool support) +2. Native tool format conversion between providers (OpenAI, Anthropic, Google) +3. Validation and error handling + +Usage +===== +>>> from tool_calls.input_handle import handle_tools +>>> processed_data = handle_tools(request_data, native_tools=True) +""" + +import json +import logging +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import ValidationError + +from .utils import determine_model_family +from .tool_prompts import get_prompt_skeleton + +# Get logger for this module +logger = logging.getLogger('argo_bridge.tool_calls.input_handle') + +# ====================================================================== +# TYPE ALIASES +# ====================================================================== + +Tools = List[Dict[str, Any]] +ToolChoice = Union[str, Dict[str, Any], None] + +# ====================================================================== +# PROMPT-BASED TOOL HANDLING +# ====================================================================== + + +def build_tool_prompt( + tools: Tools, + tool_choice: ToolChoice = None, + *, + parallel_tool_calls: bool = False, + json_indent: Optional[int] = None, + model_family: Literal["openai", "anthropic", "google"] = "openai", +) -> str: + """ + Return a system-prompt string embedding `tools`, `tool_choice` + and `parallel_tool_calls`. + + Parameters + ---------- + tools : list[dict] + The exact array you would pass to the OpenAI API. + tool_choice : str | dict | None + "none", "auto", or an object with "name", etc. + parallel_tool_calls : bool + Whether multiple tool calls may be returned in one turn. + json_indent : int | None + Pretty-print indentation for embedded JSON blobs. Defaults to None for most compact output. + + Returns + ------- + str + A fully formatted system prompt. + """ + # Dump JSON with stable key order for readability + tools_json = json.dumps(tools, indent=json_indent, ensure_ascii=False) + tool_choice_json = json.dumps( + tool_choice if tool_choice is not None else "none", + indent=json_indent, + ensure_ascii=False, + ) + parallel_flag = "true" if parallel_tool_calls else "false" + + PROMPT_SKELETON = get_prompt_skeleton(model_family) + return PROMPT_SKELETON.format( + tools_json=tools_json, + tool_choice_json=tool_choice_json, + parallel_flag=parallel_flag, + ) + + +def handle_tools_prompt(data: Dict[str, Any]) -> Dict[str, Any]: + """ + Process input data containing tool calls using prompt-based approach. + + This function will: + 1. Check if input data contains tool-related fields (tools, tool_choice, parallel_tool_calls) + 2. If present, generate tool call system prompt and add it to system messages + 3. Return processed data + + Parameters + ---------- + data : dict + Dictionary containing request data, may include: + - tools: List of tool definitions + - tool_choice: Tool selection preference + - parallel_tool_calls: Whether to allow parallel tool calls + - messages: Message list + - system: System message + + Returns + ------- + dict + Processed data dictionary + """ + # Check if there are tool-related fields + tools = data.get("tools") + if not tools: + return data + + # Get tool call related parameters + tool_choice = data.get("tool_choice") + parallel_tool_calls = data.get("parallel_tool_calls", False) + + # Determine model family for appropriate prompt + model_family = determine_model_family(data.get("model", "gpt-4")) + + # Generate tool call prompt + tool_prompt = build_tool_prompt( + tools=tools, + tool_choice=tool_choice, + parallel_tool_calls=parallel_tool_calls, + model_family=model_family + ) + + # Add tool prompt to system messages + if "messages" in data: + # Handle messages format + messages = data["messages"] + + # Find existing system message + system_msg_found = False + for _, msg in enumerate(messages): + if msg.get("role") == "system": + # Add tool prompt to existing system message + existing_content = msg.get("content", "") + msg["content"] = f"{existing_content}\n\n{tool_prompt}".strip() + system_msg_found = True + break + + # If no system message found, add one at the beginning + if not system_msg_found: + system_message = {"role": "system", "content": tool_prompt} + messages.insert(0, system_message) + + elif "system" in data: + # Handle direct system field + existing_system = data["system"] + if isinstance(existing_system, str): + data["system"] = f"{existing_system}\n\n{tool_prompt}".strip() + elif isinstance(existing_system, list): + data["system"] = existing_system + [tool_prompt] + else: + # If no system message, create one + data["system"] = tool_prompt + + # Remove original tool-related fields as they've been converted to prompts + data.pop("tools", None) + data.pop("tool_choice", None) + data.pop("parallel_tool_calls", None) + + return data + + +# ====================================================================== +# NATIVE TOOL HANDLING +# ====================================================================== + + +def handle_tools_native(data: Dict[str, Any]) -> Dict[str, Any]: + """Handles tool calls by converting them to the appropriate format for the target model. + + Uses middleware classes from handler.py to process tool-related parameters in the request data + and converts them from OpenAI format to the native format required by the target model + (OpenAI, Anthropic, or Google). Also handles tool_calls in messages for different model families. + + Args: + data: Request data dictionary containing model parameters. May include: + - tools: List of tool definitions in OpenAI format + - tool_choice: Tool choice parameter ("auto", "none", "required", or dict) + - parallel_tool_calls: Whether to enable parallel tool calls (removed for now) + - model: Model identifier used to determine the target format + - messages: List of messages that may contain tool_calls + + Returns: + Modified request data with tools and tool_calls converted to the appropriate format for the + target model. If no tools are present, returns the original data unchanged. + + Note: + - Uses middleware classes Tool, ToolChoice, and ToolCall from handler.py + - parallel_tool_calls parameter is currently removed and not implemented + - Tool conversion is performed based on the model family detected from the model name + - OpenAI format tools are passed through unchanged for OpenAI models + - Converts tool_calls in messages between different API formats + """ + from .handler import Tool, ToolCall, ToolChoice + + # Check if there are tool-related fields + tools = data.get("tools") + messages = data.get("messages", []) + + # Determine target model family + model_type = determine_model_family(data.get("model", "gpt-4")) + + # Process tools if present + if tools: + # Get tool call related parameters + tool_choice = data.get("tool_choice", "auto") + + # Remove parallel_tool_calls from data for now + # TODO: Implement parallel tool calls handling later + parallel_tool_calls = data.pop("parallel_tool_calls", False) + + try: + # Convert tools using middleware classes + converted_tools = [] + for tool_dict in tools: + # Validate and convert each tool using Tool middleware + tool_obj = Tool.from_entry( + tool_dict, api_format="openai-chatcompletion" + ) + + if model_type == "openai": + # Keep OpenAI format + converted_tools.append(tool_obj.serialize("openai-chatcompletion")) + elif model_type == "anthropic": + # Convert to Anthropic format + converted_tools.append(tool_obj.serialize("anthropic")) + elif model_type == "google": + # Convert to Google format (when implemented) + converted_tools.append(tool_obj.serialize("google")) + else: + # Default to OpenAI format + converted_tools.append(tool_obj.serialize("openai-chatcompletion")) + + # Convert tool_choice using ToolChoice middleware + if tool_choice is not None: + tool_choice_obj = ToolChoice.from_entry( + tool_choice, api_format="openai-chatcompletion" + ) + + if model_type == "openai": + converted_tool_choice = tool_choice_obj.serialize( + "openai-chatcompletion" + ) + elif model_type == "anthropic": + converted_tool_choice = tool_choice_obj.serialize("anthropic") + elif model_type == "google": + converted_tool_choice = tool_choice_obj.serialize("google") + else: + converted_tool_choice = tool_choice_obj.serialize( + "openai-chatcompletion" + ) + else: + converted_tool_choice = None + + data["tools"] = converted_tools + data["tool_choice"] = converted_tool_choice + + logger.debug(f"{model_type.title()} model detected, converted tools") + logger.debug(f"Converted tools: {converted_tools}") + logger.debug(f"Converted tool_choice: {converted_tool_choice}") + + except (ValueError, ValidationError) as e: + logger.error(f"Tool validation/conversion failed: {e}") + raise ValueError(f"Tool validation/conversion failed: {e}") + + # Process tool_calls and tool messages if present + if messages: + converted_messages = [] + for message in messages: + converted_message = message.copy() + + # Check if message contains tool_calls (assistant messages) + if "tool_calls" in message and message["tool_calls"]: + try: + if model_type == "openai": + # Keep OpenAI format with tool_calls field + converted_tool_calls = [] + for tool_call_dict in message["tool_calls"]: + tool_call_obj = ToolCall.from_entry( + tool_call_dict, api_format="openai-chatcompletion" + ) + converted_tool_calls.append( + tool_call_obj.serialize("openai-chatcompletion") + ) + converted_message["tool_calls"] = converted_tool_calls + logger.debug(f"Converted tool_calls in message: {converted_tool_calls}") + + elif model_type == "anthropic": + # For Anthropic, convert tool_calls to content array format + content_blocks = [] + + # Add text content if present + if message.get("content", ""): + content_blocks.append( + {"type": "text", "text": message["content"]} + ) + + # Convert tool_calls to tool_use blocks in content + for tool_call_dict in message["tool_calls"]: + tool_call_obj = ToolCall.from_entry( + tool_call_dict, api_format="openai-chatcompletion" + ) + anthropic_tool_call = tool_call_obj.serialize("anthropic") + content_blocks.append(anthropic_tool_call) + + # Replace tool_calls with content array + converted_message["content"] = content_blocks + converted_message.pop( + "tool_calls", None + ) # Remove tool_calls field + logger.debug(f"Converted tool_calls to Anthropic content format: {content_blocks}") + + elif model_type == "google": + # TODO: Implement Google format conversion + converted_tool_calls = [] + for tool_call_dict in message["tool_calls"]: + tool_call_obj = ToolCall.from_entry( + tool_call_dict, api_format="openai-chatcompletion" + ) + converted_tool_calls.append( + tool_call_obj.serialize("google") + ) + converted_message["tool_calls"] = converted_tool_calls + logger.debug(f"Converted tool_calls in message: {converted_tool_calls}") + + else: + # Default to OpenAI format + converted_tool_calls = [] + for tool_call_dict in message["tool_calls"]: + tool_call_obj = ToolCall.from_entry( + tool_call_dict, api_format="openai-chatcompletion" + ) + converted_tool_calls.append( + tool_call_obj.serialize("openai-chatcompletion") + ) + converted_message["tool_calls"] = converted_tool_calls + logger.debug(f"Converted tool_calls in message: {converted_tool_calls}") + + except (ValueError, ValidationError) as e: + logger.warning(f"Tool call conversion failed in message: {e}") + # Keep original tool_calls if conversion fails + pass + + # Check if message is a tool result message (role: tool) + elif message.get("role") == "tool": + if model_type == "anthropic": + # For Anthropic, tool results should be in user messages with tool_result content + # Convert OpenAI tool message format to Anthropic format + tool_call_id = message.get("tool_call_id") + content = message.get("content", "") + + # Create Anthropic-style tool result message + converted_message = { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": tool_call_id, + "content": content, + } + ], + } + logger.debug(f"Converted tool message to Anthropic format: {converted_message}") + elif model_type == "google": + # TODO: Implement Google tool result format conversion + logger.debug("Google tool result conversion not implemented yet") + # For OpenAI, keep the original format + + converted_messages.append(converted_message) + + data["messages"] = converted_messages + + return data + + +# ====================================================================== +# MAIN ENTRY POINT +# ====================================================================== + + +def handle_tools(data: Dict[str, Any], *, native_tools: bool = True) -> Dict[str, Any]: + """ + Process input data containing tool calls with fallback strategy. + + This function will: + 1. If native_tools=True: attempt native tool handling (handle_tools_native) + 2. If native handling validation fails or native_tools=False: fallback to prompt-based handling (handle_tools_prompt) + 3. Return processed data + + Parameters + ---------- + data : dict + Dictionary containing request data, may include: + - tools: List of tool definitions + - tool_choice: Tool selection preference + - parallel_tool_calls: Whether to allow parallel tool calls + - messages: Message list + - system: System message + - model: Model identifier + native_tools : bool, optional + Whether to use native tools or prompt-based tools, by default True + + Returns + ------- + dict + Processed data dictionary + """ + # Check if there are tool-related fields + tools = data.get("tools") + if not tools: + return data + + if native_tools: + try: + # First attempt: try native tool handling + return handle_tools_native(data) + except (ValueError, ValidationError, NotImplementedError) as e: + # Fallback: use prompt-based handling if native handling fails + # This handles validation errors, unsupported model types, or unimplemented conversions + logger.warning(f"Native tool handling failed, falling back to prompt-based: {e}") + return handle_tools_prompt(data) + else: + # Directly use prompt-based handling when native_tools=False + return handle_tools_prompt(data) + + +# ====================================================================== +# EXAMPLE USAGE +# ====================================================================== + +if __name__ == "__main__": # pragma: no cover + # --- 1. Define tools exactly as you would for the OpenAI API ------------ + tools_example = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given city.", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string", "description": "City name"}, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + }, + { + "type": "function", + "function": { + "name": "news_headlines", + "description": "Fetch top news headlines.", + "parameters": { + "type": "object", + "properties": { + "category": { + "type": "string", + "enum": ["politics", "technology", "sports"], + }, + "limit": {"type": "integer", "minimum": 1, "maximum": 10}, + }, + "required": ["category"], + }, + } + }, + ] + + # --- 2. (Optional) choose preferred tool or "auto"/"none" -------------- + tool_choice_example = "auto" # could also be {"name": "get_weather"} or "none" + + # --- 3. Build the prompt ------------------------------------------------ + prompt = build_tool_prompt( + tools_example, + tool_choice_example, + parallel_tool_calls=True, + ) + + print("=== Direct Tool Prompt Building ===") + print(prompt) + print("\n" + "=" * 50 + "\n") + + # --- 4. Demonstrate handle_tools function -------------------------------- + print("=== Demonstrate handle_tools Function ===") + + # Example input data (similar to OpenAI API request) + input_data = { + "messages": [ + {"role": "user", "content": "What's the weather like in Beijing today?"} + ], + "tools": tools_example, + "tool_choice": tool_choice_example, + "parallel_tool_calls": True, + } + + print("Original input data:") + print(json.dumps(input_data, indent=2, ensure_ascii=False)) + + # Process tool calls + processed_data = handle_tools(input_data.copy()) + + print("\nProcessed data:") + print(json.dumps(processed_data, indent=2, ensure_ascii=False)) diff --git a/tool_calls/output_handle.py b/tool_calls/output_handle.py new file mode 100644 index 0000000..b9bead0 --- /dev/null +++ b/tool_calls/output_handle.py @@ -0,0 +1,402 @@ +import json +import re +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Tuple, + Union, + overload, +) + +from pydantic import ValidationError + +from .types import ( + ChatCompletionMessageToolCall, + ChoiceDeltaToolCall, + ChoiceDeltaToolCallFunction, + Function, + ResponseFunctionToolCall, +) +from .utils import generate_id +from .handler import ToolCall +from logging_config import get_logger + +# Create module-specific logger +logger = get_logger(__name__) + + +class ToolInterceptor: + """ + Tool interceptor that handles both prompt-based and native tool calling responses. + + This class can process: + 1. Legacy prompt-based responses with tags + 2. Native tool calling responses from different model providers + """ + + def __init__(self): + pass + + def process( + self, + response_content: Union[str, Dict[str, Any]], + model_family: Literal["openai", "anthropic", "google"] = "openai", + ) -> Tuple[Optional[List[ToolCall]], str]: + """ + Process response content and extract tool calls. + + Args: + response_content: Either a string (legacy format) or dict (native format) + model_family: Model family to determine the processing strategy + + Returns: + Tuple of (list of tool calls or None, text content) + """ + if isinstance(response_content, str): + # Legacy prompt-based format + return self._process_prompt_based(response_content) + elif isinstance(response_content, dict): + # Native tool calling format + return self._process_native(response_content, model_family) + else: + logger.warning(f"Unexpected response content type: {type(response_content)}") + return None, str(response_content) + + def _process_prompt_based(self, text: str) -> Tuple[Optional[List[ToolCall]], str]: + """ + Process prompt-based responses with tags. + + Args: + text: Text content containing potential tags + + Returns: + Tuple of (list of ToolCall objects or None, concatenated text from outside tool calls) + """ + tool_calls = [] + text_parts = [] + last_end = 0 + + for match in re.finditer(r"(.*?)", text, re.DOTALL): + # Add text before this tool call + if match.start() > last_end: + text_parts.append(text[last_end : match.start()]) + + # Process the tool call + try: + tool_call_dict = json.loads(match.group(1).strip()) + # Convert dict to ToolCall object + tool_call = ToolCall( + id=generate_id(mode="general"), + name=tool_call_dict.get("name", ""), + arguments=json.dumps(tool_call_dict.get("arguments", {})) + if isinstance(tool_call_dict.get("arguments"), dict) + else str(tool_call_dict.get("arguments", "")), + ) + tool_calls.append(tool_call) + except json.JSONDecodeError: + # On JSON error, include the raw content as text + text_parts.append(f"{match.group(1)}") + + last_end = match.end() + + # Add any remaining text after last tool call + if last_end < len(text): + text_parts.append(text[last_end:]) + + return ( + tool_calls if tool_calls else None, + "".join( + text_parts + ).lstrip(), # Combine all text parts and strip leading whitespace + ) + + def _process_native( + self, + response_data: Dict[str, Any], + model_family: Literal["openai", "anthropic", "google"] = "openai", + ) -> Tuple[Optional[List[ToolCall]], str]: + """ + Process native tool calling responses from different model providers. + + Args: + response_data: Response data containing content and tool_calls + model_family: Model family to determine the processing strategy + + Returns: + Tuple of (list of tool calls or None, text content) + """ + logger.debug(f"Processing native tool calling response with {len(response_data)} keys") + logger.debug(f"Response data: {response_data}") + + if model_family == "openai": + logger.debug("Using OpenAI native tool calling format") + return self._process_openai_native(response_data) + elif model_family == "anthropic": + logger.debug("Using Anthropic native tool calling format") + return self._process_anthropic_native(response_data) + elif model_family == "google": + logger.debug("Using Google native tool calling format") + return self._process_google_native(response_data) + else: + logger.warning(f"Unknown model family: {model_family}, falling back to OpenAI format") + return self._process_openai_native(response_data) + + def _process_openai_native( + self, response_data: Dict[str, Any] + ) -> Tuple[Optional[List[ToolCall]], str]: + """ + Process OpenAI native tool calling response format. + + Expected format: + { + "content": "text response", + "tool_calls": [ + {"name": "function_name", "arguments": {...}} + ] + } + + Args: + response_data: OpenAI format response data + + Returns: + Tuple of (list of ToolCall objects or None, text content) + """ + content = response_data.get("content", "") + tool_calls_data = response_data.get("tool_calls", []) + + # Convert tool calls to ToolCall objects + tool_calls = None + if tool_calls_data: + tool_calls = [] + for tool_call_dict in tool_calls_data: + # Use ToolCall.from_entry to convert from OpenAI format + tool_call = ToolCall.from_entry( + tool_call_dict, api_format="openai-chatcompletion" + ) + tool_calls.append(tool_call) + + return tool_calls, content + + def _process_anthropic_native( + self, response_data: Dict[str, Any] + ) -> Tuple[Optional[List[ToolCall]], str]: + """ + Process Anthropic native tool calling response format. + + Expected in-house gateway format for Anthropic models: + { + "response": { + "content": "I'll get the current stock price...", + "tool_calls": [ + { + "id": "toolu_vrtx_01X1tcW6qR1uUoUkfpZMiXnH", + "input": {"ticker": "MSFT"}, + "name": "get_stock_price", + "type": "tool_use" + } + ] + } + } + + Args: + response_data: Anthropic format response data + + Returns: + Tuple of (list of ToolCall objects or None, text content) + """ + # Extract response object if present + response = response_data.get("response", response_data) + + # Get text content directly + text_content = response.get("content", "") + + # Get tool calls array + claude_tool_calls = response.get("tool_calls", []) + + logger.debug(f"Anthropic tool calls: {claude_tool_calls}") + logger.debug(f"Anthropic text content: {text_content}") + + # Convert Claude tool calls to ToolCall objects + tool_calls = None + if claude_tool_calls: + tool_calls = [] + for claude_tool_call in claude_tool_calls: + # Use ToolCall.from_entry to convert from Anthropic format + tool_call = ToolCall.from_entry( + claude_tool_call, api_format="anthropic" + ) + tool_calls.append(tool_call) + logger.debug(f"Converted {len(tool_calls)} ToolCall objects") + + return tool_calls, text_content + + def _process_google_native( + self, response_data: Dict[str, Any] + ) -> Tuple[Optional[List[ToolCall]], str]: + """ + Process Google native tool calling response format. + + TODO: Implement Google-specific tool calling format processing. + + Args: + response_data: Google format response data + + Returns: + Tuple of (list of ToolCall objects or None, text content) + """ + # Placeholder implementation - to be implemented later + logger.warning("Google native tool calling not implemented yet, falling back to OpenAI format") + raise NotImplementedError("Google native tool calling is not yet implemented. Please implement Google-specific tool calling format processing.") + + +def chat_completion_to_response_tool_call( + chat_tool_call: ChatCompletionMessageToolCall, +) -> ResponseFunctionToolCall: + """Converts a ChatCompletionMessageToolCall to ResponseFunctionToolCall. + + Args: + chat_tool_call: The ChatCompletionMessageToolCall to convert. + + Returns: + ResponseFunctionToolCall with corresponding data. + """ + return ResponseFunctionToolCall( + arguments=chat_tool_call.function.arguments, + call_id=chat_tool_call.id, + name=chat_tool_call.function.name, + id=generate_id(mode="openai-response"), + status="completed", + ) + + +@overload +def tool_calls_to_openai( + tool_calls: List[Union[Dict[str, Any], ChatCompletionMessageToolCall, ToolCall]], + *, + api_format: Literal["chat_completion"] = "chat_completion", +) -> List[ChatCompletionMessageToolCall]: ... + + +@overload +def tool_calls_to_openai( + tool_calls: List[Union[Dict[str, Any], ChatCompletionMessageToolCall, ToolCall]], + *, + api_format: Literal["response"], +) -> List[ResponseFunctionToolCall]: ... + + +def tool_calls_to_openai( + tool_calls: List[Union[Dict[str, Any], ChatCompletionMessageToolCall, ToolCall]], + *, + api_format: Literal["chat_completion", "response"] = "chat_completion", +) -> List[Union[ChatCompletionMessageToolCall, ResponseFunctionToolCall]]: + """Converts parsed tool calls to OpenAI API format. + + Args: + tool_calls: List of parsed tool calls. Can be either dictionaries, + ChatCompletionMessageToolCall objects, or ToolCall objects. + api_format: Output format type, either "chat_completion" or "response". + Defaults to "chat_completion". + + Returns: + List of tool calls in OpenAI function call object type. The specific type + depends on the api_format parameter: + - ChatCompletionMessageToolCall for "chat_completion" + - ResponseFunctionToolCall for "response" + """ + openai_tool_calls = [] + + for call in tool_calls: + # Handle ToolCall, dict and ChatCompletionMessageToolCall inputs + if isinstance(call, ChatCompletionMessageToolCall): + chat_tool_call = call + elif isinstance(call, ToolCall): + # Convert ToolCall to ChatCompletionMessageToolCall + chat_tool_call = call.to_tool_call("openai-chatcompletion") + elif isinstance(call, dict): + # Check if it's already in ChatCompletionMessageToolCall format + try: + # Try to parse as ChatCompletionMessageToolCall using Pydantic + chat_tool_call = ChatCompletionMessageToolCall.model_validate(call) + except (ValidationError, TypeError): + # Legacy format - create from name/arguments + arguments = json.dumps(call.get("arguments", "")) + name = call.get("name", "") + chat_tool_call = ChatCompletionMessageToolCall( + id=generate_id(mode="openai-chatcompletion"), + function=Function(name=name, arguments=arguments), + ) + else: + raise ValueError(f"Unsupported tool call type: {type(call)}") + + if api_format == "chat_completion": + openai_tool_calls.append(chat_tool_call) + else: + # Convert to ResponseFunctionToolCall using helper function + response_tool_call = chat_completion_to_response_tool_call(chat_tool_call) + openai_tool_calls.append(response_tool_call) + + return openai_tool_calls + + +def tool_calls_to_openai_stream( + tool_call: Union[Dict[str, Any], ChatCompletionMessageToolCall, ToolCall], + *, + tc_index: int = 0, + api_format: Literal["chat_completion", "response"] = "chat_completion", +) -> ChoiceDeltaToolCall: + """ + Converts a tool call to OpenAI-compatible tool call objects for streaming. + + Args: + tool_call: Single tool call to convert. Can be either a dictionary, + ChatCompletionMessageToolCall object, or ToolCall object. + tc_index: The index of the tool call. + api_format: The format to convert the tool calls to. Can be "chat_completion" or "response". + + Returns: + An OpenAI-compatible stream tool call object. + """ + + # Handle ToolCall, dict and ChatCompletionMessageToolCall inputs + if isinstance(tool_call, ChatCompletionMessageToolCall): + chat_tool_call = tool_call + elif isinstance(tool_call, ToolCall): + # Convert ToolCall to ChatCompletionMessageToolCall + chat_tool_call = tool_call.to_tool_call("openai-chatcompletion") + elif isinstance(tool_call, dict): + # Check if it's already in ChatCompletionMessageToolCall format + try: + # Try to parse as ChatCompletionMessageToolCall using Pydantic + chat_tool_call = ChatCompletionMessageToolCall.model_validate(tool_call) + except (ValidationError, TypeError): + # Legacy format - create from name/arguments + arguments = json.dumps(tool_call.get("arguments", "")) + name = tool_call.get("name", "") + chat_tool_call = ChatCompletionMessageToolCall( + id=generate_id(mode="openai-chatcompletion"), + function=Function( + name=name, + arguments=arguments, + ), + ) + else: + raise ValueError(f"Unsupported tool call type: {type(tool_call)}") + + if api_format == "chat_completion": + tool_call_obj = ChoiceDeltaToolCall( + id=chat_tool_call.id, + function=ChoiceDeltaToolCallFunction( + name=chat_tool_call.function.name, + arguments=chat_tool_call.function.arguments, + ), + index=tc_index, + ) + else: + # TODO: Implement response format + raise NotImplementedError("response format is not implemented yet.") + + return tool_call_obj diff --git a/tool_calls/tool_prompts.py b/tool_calls/tool_prompts.py new file mode 100644 index 0000000..0da7429 --- /dev/null +++ b/tool_calls/tool_prompts.py @@ -0,0 +1,242 @@ +from typing import Literal + +OPENAI_PROMPT_SKELETON = """You are an AI assistant that can call pre-defined tools when needed. + +### Available Tools +{tools_json} + +### Tool Usage Policy +Tool choice: {tool_choice_json} +- "none": Do not use tools, respond with text only +- "auto": Use tools only when necessary to answer the user's request +- "required": You MUST use at least one tool - cannot respond with text only +- {{"name": "tool_name"}}: Use the specified tool if relevant + +Parallel calls allowed: {parallel_flag} + +### CRITICAL: Response Format Rules + +You have TWO response modes: + +**MODE 1: Tool Call Response** +- Start IMMEDIATELY with (no text before) +- Contains ONLY valid JSON with "name" and "arguments" fields +- End with +- After the tool call, you MUST wait for the tool result before continuing +- Do NOT simulate tool results or continue the conversation + +Format: + +{{"name": "tool_name", "arguments": {{"param": "value"}}}} + + +**MODE 2: Text Response** +- Pure natural language response +- Use when no tools are needed or after receiving tool results +- Never include tags in text responses + +### Important Constraints +- NEVER start a tool call with explanatory text like "I'll help you..." or "Let me search..." +- NEVER simulate or imagine tool results - always wait for actual results +- NEVER use tags like , , or any other XML tags +- If parallel_tool_calls is false, make only ONE tool call per response +- If you start with , you cannot add text before it +- If you don't start with , you cannot use tools in that response + +### Decision Process +Before responding, ask yourself: +1. Is tool choice "required"? → You MUST use a tool +2. Is tool choice "none"? → You MUST NOT use tools +3. Does the user's request require a tool to answer properly? +4. If yes → Start immediately with +5. If no → Respond with natural language only + +Remember: Your first character determines your response mode. Choose wisely.""" + +CLAUDE_PROMPT_SKELETON = """You are an AI assistant that can call pre-defined tools to help answer questions. + +## When to Use Tools vs Your Knowledge + +**Use tools ONLY when:** +- You need real-time, current information (stock prices, weather, news) +- You need to perform calculations beyond simple mental math +- You need to access specific external data or APIs +- The user explicitly requests you to use a particular tool +- You genuinely cannot answer accurately with your existing knowledge + +**Do NOT use tools when:** +- You can answer from your training knowledge (general facts, explanations, advice) +- The question is about concepts, definitions, or well-established information +- You can provide helpful guidance without external data +- The user is asking for your opinion, analysis, or creative input +- Simple calculations you can do mentally (basic arithmetic) + +**Remember:** Your training data is extensive and valuable. Use it first, tools second. + +## CRITICAL: Planning Tool Calls with Dependencies + +**BEFORE making any tool calls, think through:** +1. What information do I need to answer this question? +2. What order must I get this information in? +3. Does tool B need the result from tool A? +4. Can I make these calls in parallel, or must they be sequential? + +**If there are data dependencies:** +- Make ONE tool call at a time +- Wait for the result before planning your next call +- Explain your plan to the user: "First I'll get X, then use that to get Y" + +**Examples of dependencies:** +- ❌ BAD: Call `get_user_id(email)` AND `get_user_profile(user_id)` simultaneously +- ✅ GOOD: Call `get_user_id(email)` first, wait for result, then call `get_user_profile(user_id)` + +- ❌ BAD: Call `search_products(query)` AND `get_product_details(product_id)` together +- ✅ GOOD: Search first, get results, then get details for specific products + +**When parallel calls ARE appropriate:** +- Getting independent information (weather in 3 different cities) +- Performing separate calculations that don't depend on each other +- Only when parallel_tool_calls is true AND there are no dependencies + +## How to Use Tools +When you genuinely need information beyond your knowledge, use this format anywhere in your response: + + +{{"name": "tool_name", "arguments": {{"param": "value"}}}} + + +You can explain what you're doing, ask for clarification, or provide context - just include the tool call when needed. + +## CRITICAL: Do NOT use these formats +``` +// WRONG - Don't use Anthropic's API format: +{{"type": "tool_use", "id": "...", "name": "...", "input": {{...}}}} + +// WRONG - Don't use Anthropic's internal XML format: + + +value1 + + + +// WRONG - Don't use OpenAI's tool calling format: +{{ + "tool_calls": [ + {{ + "id": "call_abc123", + "type": "function", + "function": {{ + "name": "tool_name", + "arguments": "{{\\"param\\": \\"value\\"}}" + }} + }} + ] +}} +``` + +## Available Tools +{tools_json} + +## Tool Settings +- Tool choice: {tool_choice_json} + - "auto": decide carefully when tools are truly needed + - "none": answer without tools unless absolutely necessary + - "required": you must use at least one tool in your response + - {{"name": "tool_name"}}: prefer using the specified tool when relevant +- Parallel calls: {parallel_flag} + - true: you may use multiple tools in one response (only if no dependencies) + - false: use only one tool per response + +## Examples of Good Planning + +**Good - Sequential with dependencies:** +User: "Get me details about user john@example.com's recent orders" +Response: "I'll help you with that. First, I need to find the user ID for that email, then I can get their order details: + + +{{"name": "get_user_id", "arguments": {{"email": "john@example.com"}}}} +" + +**Good - Explaining the plan:** +User: "Compare the weather in New York and London" +Response: "I'll get the current weather for both cities: + + +{{"name": "get_weather", "arguments": {{"city": "New York"}}}} + + +{{"name": "get_weather", "arguments": {{"city": "London"}}}} +" + +**Good - Sequential planning:** +User: "Find the most expensive product in the electronics category" +Response: "I'll search for electronics products first, then analyze the results to find the most expensive one: + + +{{"name": "search_products", "arguments": {{"category": "electronics"}}}} +" + +Remember: Think before you call. Plan your sequence. Respect data dependencies.""" + +GEMINI_PROMPT_SKELETON = """You are an AI assistant with access to tools. Your goal is to assist the user by answering their questions and calling tools when necessary. + +### Available Tools +{tools_json} + +### Tool Policy +- Your current tool policy is: {tool_choice_json} +- "none": You are not allowed to call any tools. +- "auto": You can choose to call one or more tools if they are useful. +- "required": You must call at least one tool. +- {{"name": "X"}}: You must call tool X. + +### How to Respond (VERY IMPORTANT) +You have two options for responding. + +**OPTION 1: Call one or more tools** +If you need to gather information, your ENTIRE response must be one or more `` blocks. + +*Single tool call example:* + +{{"name": "tool_name", "arguments": {{"param": "value"}}}} + + +**OPTION 2: Answer the user directly** +If you have enough information (either from the conversation or from a tool result you just received), write a standard, conversational response in natural language. + +### Using Tool Results +When you call a tool, the system will run it and give you the output in a `` block. You must then use this information to provide a final answer to the user (using Option 2). + +**Example Flow:** +1. **User:** What's the temperature in Shanghai in Fahrenheit? +2. **Your response (Option 1):** + + {{"name": "web_search_google-search", "arguments": {{"query": "temperature in Shanghai celsius"}}}} + +3. **System provides result:** `{{"tool_name": "web_search_google-search", "result": "29°C"}}` +4. **Your next response (Option 1 again):** + + {{"name": "unit_converter-celsius_to_fahrenheit", "arguments": {{"celsius": 29}}}} + +5. **System provides result:** `{{"tool_name": "unit_converter-celsius_to_fahrenheit", "result": 84.2}}` +6. **Your final response (Option 2):** + The temperature in Shanghai is 29°C, which is 84.2°F. + +### Critical Rules to Follow +- **NEVER** use ``. The correct tag is ``. +- When calling tools, your response must ONLY contain `` blocks. No extra text. +- After receiving a ``, use the information to answer the user in plain text. Do not just repeat the call or the raw result. +- You are only REQUESTING tool calls. You do not run them. Wait for the ``. +""" + + +def get_prompt_skeleton(model_family: Literal["openai", "anthropic", "google"]) -> str: + """Get the appropriate prompt skeleton based on model type.""" + + if model_family == "anthropic": + return CLAUDE_PROMPT_SKELETON + elif model_family == "google": + return GEMINI_PROMPT_SKELETON + else: + # Default to OpenAI format for other models + return OPENAI_PROMPT_SKELETON diff --git a/tool_calls/types.py b/tool_calls/types.py new file mode 100644 index 0000000..349be52 --- /dev/null +++ b/tool_calls/types.py @@ -0,0 +1,292 @@ +""" +function_call.py + +Type definitions for function calling APIs used by LLM providers. +This file contains Pydantic models for use with OpenAI's chat-completion +and responses APIs. Types for additional providers (Anthropic, Gemini, etc.) +are also included. + +Sections: + - OpenAI Types (Chat Completions & Responses) + - Anthropic Types + - Google Gemini Types (TODO) +""" + +from typing import Dict, List, Literal, Optional, TypeAlias, Union + +from pydantic import BaseModel + +# ====================================================================== +# 1. OPENAI TYPES (CHAT COMPLETION & RESPONSES API) +# ====================================================================== + +# =========================== +# Chat Completion API SECTION +# =========================== + + +# --------- API INPUT --------- +class FunctionDefinitionCore(BaseModel): + name: str + """The name of the function to be called. + + Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length + of 64. + """ + + +class FunctionDefinition(FunctionDefinitionCore): + description: Optional[str] = None + """ + A description of what the function does, used by the model to choose when and + how to call the function. + """ + parameters: Optional[Dict[str, object]] = None + """The parameters the functions accepts, described as a JSON Schema object. + + See the [guide](https://platform.openai.com/docs/guides/function-calling) for + examples, and the + [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for + documentation about the format. + + Omitting `parameters` defines a function with an empty parameter list. + """ + strict: Optional[bool] = None + """Whether to enable strict schema adherence when generating the function call. + + If set to true, the model will follow the exact schema defined in the + `parameters` field. Only a subset of JSON Schema is supported when `strict` is + `true`. Learn more about Structured Outputs in the + [function calling guide](docs/guides/function-calling). + """ + + +# used in `tools` +class ChatCompletionToolParam(BaseModel): + function: FunctionDefinition + type: Literal["function"] = "function" + """The type of the tool. Currently, only `function` is supported.""" + + +# used in `tool_choice` +class ChatCompletionNamedToolChoiceParam(BaseModel): + function: FunctionDefinitionCore + type: Literal["function"] = "function" + """The type of the tool. Currently, only `function` is supported.""" + + +ChatCompletionToolChoiceOptionParam: TypeAlias = Union[ + Literal["none", "auto", "required"], ChatCompletionNamedToolChoiceParam +] + + +# --------- LLM OUTPUT --------- +class Function(BaseModel): + arguments: str + """ + The arguments to call the function with, as generated by the model in JSON + format. Note that the model does not always generate valid JSON, and may + hallucinate parameters not defined by your function schema. Validate the + arguments in your code before calling your function. + """ + name: str + """The name of the function to call.""" + + +# elements in `tool_calls` +class ChatCompletionMessageToolCall(BaseModel): + id: str + """The ID of the tool call.""" + function: Function + """The function that the model called.""" + type: Literal["function"] = "function" + """The type of the tool. Currently, only `function` is supported.""" + + +# function definition in stream deltas +class ChoiceDeltaToolCallFunction(BaseModel): + arguments: Optional[str] = None + """ + The arguments to call the function with, as generated by the model in JSON + format. Note that the model does not always generate valid JSON, and may + hallucinate parameters not defined by your function schema. Validate the + arguments in your code before calling your function. + """ + name: Optional[str] = None + """The name of the function to call.""" + + +# used in `tool_calls` in stream deltas +class ChoiceDeltaToolCall(BaseModel): + index: int + + id: Optional[str] = None + """The ID of the tool call.""" + + function: Optional[ChoiceDeltaToolCallFunction] = None + + type: Optional[Literal["function"]] = None + """The type of the tool. Currently, only `function` is supported.""" + + +# ===================== +# Responses API SECTION +# ===================== + + +# --------- API INPUT --------- +class FunctionTool(BaseModel): + """API INPUT""" + + name: str + """The name of the function to call.""" + parameters: Optional[Dict[str, object]] = None + """A JSON schema object describing the parameters of the function.""" + strict: Optional[bool] = None + """Whether to enforce strict parameter validation. Default `true`.""" + type: Literal["function"] = "function" + """The type of the function tool. Always `function`.""" + description: Optional[str] = None + """A description of the function. + + Used by the model to determine whether or not to call the function. + """ + + +class ToolChoiceFunctionParam(BaseModel): + """API INPUT""" + + name: str + """The name of the function to call.""" + type: Literal["function"] = "function" + """For function calling, the type is always `function`.""" + + +ToolChoice: TypeAlias = Union[ + Literal["none", "auto", "required"], ToolChoiceFunctionParam +] +# (API INPUT: as tool_choice argument in responses API) + + +# --------- LLM OUTPUT --------- +class ResponseFunctionToolCall(BaseModel): + """LLM OUTPUT""" + + arguments: str + """A JSON string of the arguments to pass to the function.""" + call_id: str + """The unique ID of the function tool call generated by the model.""" + name: str + """The name of the function to run.""" + type: Literal["function_call"] = "function_call" + """The type of the function tool call. Always `function_call`.""" + id: Optional[str] = None + """The unique ID of the function tool call.""" + status: Optional[Literal["in_progress", "completed", "incomplete"]] = None + """The status of the item. + + One of `in_progress`, `completed`, or `incomplete`. Populated when items are + returned via API. + """ + + +# ====================================================================== +# 2. ANTHROPIC TYPES +# ====================================================================== + + +# --------- API INPUT --------- +class InputSchemaTyped(BaseModel): + type: Literal["object"] + properties: Optional[object] = None + required: Optional[List[str]] = None + + +InputSchema: TypeAlias = Union[InputSchemaTyped, Dict[str, object]] + + +class CacheControlEphemeralParam(BaseModel): + type: Literal["ephemeral"] = "ephemeral" + + +class ToolParam(BaseModel): + input_schema: InputSchema + """[JSON schema](https://json-schema.org/draft/2020-12) for this tool's input. + + This defines the shape of the `input` that your tool accepts and that the model + will produce. + """ + + name: str + """Name of the tool. + + This is how the tool will be called by the model and in `tool_use` blocks. + """ + + cache_control: Optional[CacheControlEphemeralParam] = None + """Create a cache control breakpoint at this content block.""" + + description: str + """Description of what this tool does. + + Tool descriptions should be as detailed as possible. The more information that + the model has about what the tool is and how to use it, the better it will + perform. You can use natural language descriptions to reinforce important + aspects of the tool input JSON schema. + """ + + type: Optional[Literal["custom"]] = "custom" + + +# used in `tool_choice` +class ToolChoiceShared(BaseModel): + disable_parallel_tool_use: bool = False + """Whether to disable parallel tool use. + + Defaults to `false`. If set to `true`, the model will output exactly one tool use. + """ + + +class ToolChoiceAnyParam(ToolChoiceShared): + type: Literal["any"] = "any" + + +class ToolChoiceAutoParam(ToolChoiceShared): + type: Literal["auto"] = "auto" + + +class ToolChoiceNoneParam(BaseModel): + type: Literal["none"] = "none" + + +class ToolChoiceToolParam(ToolChoiceShared): + name: str + """The name of the tool to use.""" + + type: Literal["tool"] = "tool" + + +ToolChoiceParam: TypeAlias = Union[ + ToolChoiceAutoParam, ToolChoiceAnyParam, ToolChoiceToolParam, ToolChoiceNoneParam +] + + +# --------- LLM OUTPUT --------- +# elements in `tool_calls` +class ToolUseBlock(BaseModel): + id: str + + input: object + + name: str + + type: Literal["tool_use"] = "tool_use" + + cache_control: Optional[CacheControlEphemeralParam] = None + """Create a cache control breakpoint at this content block.""" + + +# ====================================================================== +# 3. GOOGLE GEMINI TYPES (TODO) +# ====================================================================== +# Add Google Gemini-compatible function call types here... diff --git a/tool_calls/utils.py b/tool_calls/utils.py new file mode 100644 index 0000000..4913695 --- /dev/null +++ b/tool_calls/utils.py @@ -0,0 +1,115 @@ +import secrets +import string +from typing import Any, Dict, Literal, Union + +from pydantic import ValidationError + +from .types import ChatCompletionNamedToolChoiceParam + +API_FORMATS = Literal[ + "openai", # old default, alias to openai-chatcompletion + "openai-chatcompletion", # chat completion + "openai-response", + "anthropic", + "google", +] + + +def determine_model_family( + model: str = "gpt4o", +) -> Literal["openai", "anthropic", "google", "unknown"]: + """ + Determine the model family based on the model name. + """ + model_lower = model.lower() + if "gpt" in model_lower or "o1" in model_lower: + return "openai" + elif "claude" in model_lower: + return "anthropic" + elif "gemini" in model_lower: + return "google" + else: + return "unknown" + + +def generate_id( + *, + mode: Union[API_FORMATS, Literal["general"]] = "general", +) -> str: + """ + Return a random identifier. + + Parameters + ---------- + mode : + 'general' → <22-char base62 string> (default) + 'openai'/'openai-chatcompletion' → call_<22-char base62 string> + 'openai-response' → fc_<48-char hex string> + 'anthropic' → toolu_<24-char base62 string> + + Examples + -------- + >>> generate_id() + 'b9krJaIcuBM4lej3IyI5heVc' + + >>> generate_id(mode='openai') + 'call_b9krJaIcuBM4lej3IyI5heVc' + + >>> generate_id(mode='openai-response') + 'fc_68600a8868248199a436492a47a75e440766032408f75a09' + + >>> generate_id(mode='anthropic') + 'toolu_vrtx_01LiZkD1myhnDz7gcoEe4Y5A' + """ + ALPHANUM = string.ascii_letters + string.digits + if mode == "general": + # Generate 22-char base62 string for general use + return "".join(secrets.choice(ALPHANUM) for _ in range(22)) + + elif mode in ["openai", "openai-chatcompletion"]: + suffix = "".join(secrets.choice(ALPHANUM) for _ in range(22)) + return f"call_{suffix}" + + elif mode == "openai-response": + # 24 bytes → 48 hex chars (matches your example) + return f"fc_{secrets.token_hex(24)}" + + elif mode == "anthropic": + # Generate 24-char base62 string to match the pattern + suffix = "".join(secrets.choice(ALPHANUM) for _ in range(24)) + return f"toolu_{suffix}" + + elif mode == "google": + # TODO: Implement Google-specific ID generation if needed + raise NotImplementedError("Google-specific ID generation not implemented") + + else: + raise ValueError(f"Unknown mode: {mode!r}") + + +def validate_tool_choice(tool_choice: Union[str, Dict[str, Any]]) -> None: + """Helper function to validate tool_choice parameter. + + Args: + tool_choice: The tool choice parameter to validate. + + Raises: + ValueError: If tool_choice is invalid. + """ + if isinstance(tool_choice, str): + valid_strings = ["none", "auto", "required"] + if tool_choice not in valid_strings: + raise ValueError( + f"Invalid tool_choice string '{tool_choice}'. " + f"Must be one of: {', '.join(valid_strings)}" + ) + elif isinstance(tool_choice, dict): + try: + ChatCompletionNamedToolChoiceParam.model_validate(tool_choice, strict=False) + except ValidationError as e: + raise ValueError(f"Invalid tool_choice dict structure: {e}") + else: + raise ValueError( + f"Invalid tool_choice type '{type(tool_choice).__name__}'. " + f"Must be str or dict" + )