diff --git a/.gitignore b/.gitignore index 9507849..df9adf7 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ metrics .env prometheus.yml myserver.crt -myserver.key \ No newline at end of file +myserver.key +sandbox.py \ No newline at end of file diff --git a/argo_bridge.py b/argo_bridge.py index 9c2af66..c1ee91a 100644 --- a/argo_bridge.py +++ b/argo_bridge.py @@ -65,6 +65,20 @@ def after_request(response): 'gpto1': 'gpto1', 'o1': 'gpto1', + + 'gpto3': 'gpto3', + 'gpto4mini': 'gpto4mini', + 'gpt41': 'gpt41', + 'gpt41mini' : 'gpt41mini', + 'gpt41nano' : 'gpt41nano', + + + 'gemini25pro': 'gemini25pro', + 'gemini25flash': 'gemini25flash', + 'claudeopus4': 'claudeopus4', + 'claudesonnet4': 'claudesonnet4', + 'claudesonnet37': 'claudesonnet37', + 'claudesonnet35v2': 'claudesonnet35v2', } @@ -108,9 +122,25 @@ def after_request(response): # Models using development environment 'gpto3mini': 'dev', 'gpto1mini': 'dev', - 'gpto1': 'dev' + 'gpto1': 'dev', + 'gemini25pro': 'dev', + 'gemini25flash': 'dev', + 'claudeopus4': 'dev', + 'claudesonnet4': 'dev', + 'claudesonnet37': 'dev', + 'claudesonnet35v2': 'dev', + 'gpto3': 'dev', + 'gpto4mini': 'dev', + 'gpt41': 'dev', + 'gpt41mini' : 'dev', + 'gpt41nano' : 'dev', } + +NON_STREAMING_MODELS = ['gemini25pro', 'gemini25flash', + 'claudeopus4', 'claudesonnet4', 'claudesonnet37', 'claudesonnet35v2', + 'gpto3', 'gpto4mini', 'gpt41', 'gpt41mini', 'gpt41nano',] + # For models endpoint MODELS = { "object": "list", @@ -131,10 +161,30 @@ def after_request(response): EMBED_ENV = 'prod' DEFAULT_MODEL = "gpt4o" -ANL_USER = "APS" +BRIDGE_USER = "ARGO_BRIDGE" ANL_STREAM_URL = "https://apps-dev.inside.anl.gov/argoapi/api/v1/resource/streamchat/" ANL_DEBUG_FP = 'log_bridge.log' + +def get_user_from_auth_header(): + """ + Extracts the user from the Authorization header. + If the header is present and valid, the bearer token is returned. + Otherwise, the default user is returned. + """ + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + # Return the token part of the header + token = auth_header.split(" ")[1] + logging.debug(f"Authorization header found: {auth_header}") + if token == 'noop': + return BRIDGE_USER + + return auth_header.split(" ")[1] + # Return the default user if no valid header is found + return BRIDGE_USER + + def get_api_url(model, endpoint_type): """ Determine the correct API URL based on model and endpoint type @@ -169,11 +219,18 @@ def chat_completions(): logging.info("Received chat completions request") data = request.get_json() + logging.info(f"Request Data: {data}") model_base = data.get("model", DEFAULT_MODEL) is_streaming = data.get("stream", False) temperature = data.get("temperature", 0.1) stop = data.get("stop", []) + # Force non-streaming for specific models. Remove once Argo supports streaming for all models. + # TODO: TEMP Fake streaming for the new models until Argo supports it + is_fake_stream = False + if model_base in NON_STREAMING_MODELS and is_streaming: + is_fake_stream = True + if model_base not in MODEL_MAPPING: return jsonify({"error": { "message": f"Model '{model_base}' not supported." @@ -183,8 +240,19 @@ def chat_completions(): logging.debug(f"Received request: {data}") + # Process multimodal content for Gemini models + if model_base.startswith('gemini'): + try: + data['messages'] = convert_multimodal_to_text(data['messages'], model_base) + except ValueError as e: + return jsonify({"error": { + "message": str(e) + }}), 400 + + user = get_user_from_auth_header() + req_obj = { - "user": ANL_USER, + "user": user, "model": model, "messages": data['messages'], "system": "", @@ -194,7 +262,22 @@ def chat_completions(): logging.debug(f"Argo Request {req_obj}") - if is_streaming: + if is_fake_stream: + logging.info(req_obj) + response = requests.post(get_api_url(model, 'chat'), json=req_obj) + + if not response.ok: + logging.error(f"Internal API error: {response.status_code} {response.reason}") + return jsonify({"error": { + "message": f"Internal API error: {response.status_code} {response.reason}" + }}), 500 + + json_response = response.json() + text = json_response.get("response", "") + logging.debug(f"Response Text {text}") + return Response(_fake_stream_response(text, model), mimetype='text/event-stream') + + elif is_streaming: return Response(_stream_chat_response(model, req_obj), mimetype='text/event-stream') else: response = requests.post(get_api_url(model, 'chat'), json=req_obj) @@ -280,6 +363,98 @@ def _static_chat_response(text, model): }] } +def _fake_stream_response(text, model): + begin_chunk = { + "id": 'abc', + "object": "chat.completion.chunk", + "created": int(datetime.datetime.now().timestamp()), + "model": model, + "choices": [{ + "index": 0, + "delta": {'role': 'assistant', 'content':''}, + "logprobs": None, + "finish_reason": None + }] + } + yield f"data: {json.dumps(begin_chunk)}\n\n" + chunk = { + "id": 'abc', + "object": "chat.completion.chunk", + "created": int(datetime.datetime.now().timestamp()), + "model": model, + "choices": [{ + "index": 0, + "delta": {'content': text}, + "logprobs": None, + "finish_reason": None + }] + } + yield f"data: {json.dumps(chunk)}\n\n" + end_chunk = { + "id": 'argo', + "object": "chat.completion.chunk", + "created": int(datetime.datetime.now().timestamp()), + "model": model, + "system_fingerprint": "fp_44709d6fcb", + "choices": [{ + "index": 0, + "delta": {}, + "logprobs": None, + "finish_reason": "stop" + }] + } + yield f"data: {json.dumps(end_chunk)}\n\n" + yield "data: [DONE]\n\n" + +def convert_multimodal_to_text(messages, model_base): + """ + Convert multimodal content format to plain text for Gemini models. + + Args: + messages (list): List of message objects + model_base (str): The model being used + + Returns: + list: Processed messages with text-only content + + Raises: + ValueError: If non-text content is found in multimodal format + """ + # Only process for Gemini models + gemini_models = ['gemini25pro', 'gemini25flash'] + if model_base not in gemini_models: + return messages + + processed_messages = [] + + for message in messages: + processed_message = message.copy() + content = message.get("content") + + # Check if content is in multimodal format (list of content objects) + if isinstance(content, list): + text_parts = [] + + for content_item in content: + if isinstance(content_item, dict): + content_type = content_item.get("type") + + if content_type == "text": + text_parts.append(content_item.get("text", "")) + else: + # Error if non-text content is found + raise ValueError(f"Gemini models only support text content. Found unsupported content type: '{content_type}'") + else: + # If content item is not a dict, treat as plain text + text_parts.append(str(content_item)) + + # Join all text parts and set as the content + processed_message["content"] = " ".join(text_parts) + + processed_messages.append(processed_message) + + return processed_messages + """ ================================= @@ -308,8 +483,10 @@ def completions(): logging.debug(f"Received request: {data}") + user = get_user_from_auth_header() + req_obj = { - "user": ANL_USER, + "user": user, "model": model, "prompt": [data['prompt']], "system": "", @@ -389,7 +566,8 @@ def embeddings(): if isinstance(input_data, str): input_data = [input_data] - embedding_vectors = _get_embeddings_from_argo(input_data, model) + user = get_user_from_auth_header() + embedding_vectors = _get_embeddings_from_argo(input_data, model, user) response_data = { "object": "list", @@ -411,7 +589,7 @@ def embeddings(): return jsonify(response_data) -def _get_embeddings_from_argo(texts, model): +def _get_embeddings_from_argo(texts, model, user=BRIDGE_USER): BATCH_SIZE = 16 all_embeddings = [] @@ -419,7 +597,7 @@ def _get_embeddings_from_argo(texts, model): batch_texts = texts[i:i + BATCH_SIZE] payload = { - "user": ANL_USER, + "user": user, "model": model, "prompt": batch_texts } @@ -509,6 +687,7 @@ def parse_args(): level=logging.DEBUG if debug_enabled else logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) + logging.getLogger('watchdog').setLevel(logging.CRITICAL+10) logging.info(f'Starting server with debug mode: {debug_enabled}') print(f'Starting server... | Port {args.port} | User {args.username} | Debug: {debug_enabled}')