Skip to content
This repository was archived by the owner on Feb 10, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ metrics
.env
prometheus.yml
myserver.crt
myserver.key
myserver.key
sandbox.py
195 changes: 187 additions & 8 deletions argo_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,20 @@ def after_request(response):

'gpto1': 'gpto1',
'o1': 'gpto1',

'gpto3': 'gpto3',
'gpto4mini': 'gpto4mini',
'gpt41': 'gpt41',
'gpt41mini' : 'gpt41mini',
'gpt41nano' : 'gpt41nano',


'gemini25pro': 'gemini25pro',
'gemini25flash': 'gemini25flash',
'claudeopus4': 'claudeopus4',
'claudesonnet4': 'claudesonnet4',
'claudesonnet37': 'claudesonnet37',
'claudesonnet35v2': 'claudesonnet35v2',
}


Expand Down Expand Up @@ -108,9 +122,25 @@ def after_request(response):
# Models using development environment
'gpto3mini': 'dev',
'gpto1mini': 'dev',
'gpto1': 'dev'
'gpto1': 'dev',
'gemini25pro': 'dev',
'gemini25flash': 'dev',
'claudeopus4': 'dev',
'claudesonnet4': 'dev',
'claudesonnet37': 'dev',
'claudesonnet35v2': 'dev',
'gpto3': 'dev',
'gpto4mini': 'dev',
'gpt41': 'dev',
'gpt41mini' : 'dev',
'gpt41nano' : 'dev',
}


NON_STREAMING_MODELS = ['gemini25pro', 'gemini25flash',
'claudeopus4', 'claudesonnet4', 'claudesonnet37', 'claudesonnet35v2',
'gpto3', 'gpto4mini', 'gpt41', 'gpt41mini', 'gpt41nano',]

# For models endpoint
MODELS = {
"object": "list",
Expand All @@ -131,10 +161,30 @@ def after_request(response):
EMBED_ENV = 'prod'

DEFAULT_MODEL = "gpt4o"
ANL_USER = "APS"
BRIDGE_USER = "ARGO_BRIDGE"
ANL_STREAM_URL = "https://apps-dev.inside.anl.gov/argoapi/api/v1/resource/streamchat/"
ANL_DEBUG_FP = 'log_bridge.log'


def get_user_from_auth_header():
"""
Extracts the user from the Authorization header.
If the header is present and valid, the bearer token is returned.
Otherwise, the default user is returned.
"""
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
# Return the token part of the header
token = auth_header.split(" ")[1]
logging.debug(f"Authorization header found: {auth_header}")
if token == 'noop':
return BRIDGE_USER

return auth_header.split(" ")[1]
# Return the default user if no valid header is found
return BRIDGE_USER


def get_api_url(model, endpoint_type):
"""
Determine the correct API URL based on model and endpoint type
Expand Down Expand Up @@ -169,11 +219,18 @@ def chat_completions():
logging.info("Received chat completions request")

data = request.get_json()
logging.info(f"Request Data: {data}")
model_base = data.get("model", DEFAULT_MODEL)
is_streaming = data.get("stream", False)
temperature = data.get("temperature", 0.1)
stop = data.get("stop", [])

# Force non-streaming for specific models. Remove once Argo supports streaming for all models.
# TODO: TEMP Fake streaming for the new models until Argo supports it
is_fake_stream = False
if model_base in NON_STREAMING_MODELS and is_streaming:
is_fake_stream = True

if model_base not in MODEL_MAPPING:
return jsonify({"error": {
"message": f"Model '{model_base}' not supported."
Expand All @@ -183,8 +240,19 @@ def chat_completions():

logging.debug(f"Received request: {data}")

# Process multimodal content for Gemini models
if model_base.startswith('gemini'):
try:
data['messages'] = convert_multimodal_to_text(data['messages'], model_base)
except ValueError as e:
return jsonify({"error": {
"message": str(e)
}}), 400

user = get_user_from_auth_header()

req_obj = {
"user": ANL_USER,
"user": user,
"model": model,
"messages": data['messages'],
"system": "",
Expand All @@ -194,7 +262,22 @@ def chat_completions():

logging.debug(f"Argo Request {req_obj}")

if is_streaming:
if is_fake_stream:
logging.info(req_obj)
response = requests.post(get_api_url(model, 'chat'), json=req_obj)

if not response.ok:
logging.error(f"Internal API error: {response.status_code} {response.reason}")
return jsonify({"error": {
"message": f"Internal API error: {response.status_code} {response.reason}"
}}), 500

json_response = response.json()
text = json_response.get("response", "")
logging.debug(f"Response Text {text}")
return Response(_fake_stream_response(text, model), mimetype='text/event-stream')

elif is_streaming:
return Response(_stream_chat_response(model, req_obj), mimetype='text/event-stream')
else:
response = requests.post(get_api_url(model, 'chat'), json=req_obj)
Expand Down Expand Up @@ -280,6 +363,98 @@ def _static_chat_response(text, model):
}]
}

def _fake_stream_response(text, model):
begin_chunk = {
"id": 'abc',
"object": "chat.completion.chunk",
"created": int(datetime.datetime.now().timestamp()),
"model": model,
"choices": [{
"index": 0,
"delta": {'role': 'assistant', 'content':''},
"logprobs": None,
"finish_reason": None
}]
}
yield f"data: {json.dumps(begin_chunk)}\n\n"
chunk = {
"id": 'abc',
"object": "chat.completion.chunk",
"created": int(datetime.datetime.now().timestamp()),
"model": model,
"choices": [{
"index": 0,
"delta": {'content': text},
"logprobs": None,
"finish_reason": None
}]
}
yield f"data: {json.dumps(chunk)}\n\n"
end_chunk = {
"id": 'argo',
"object": "chat.completion.chunk",
"created": int(datetime.datetime.now().timestamp()),
"model": model,
"system_fingerprint": "fp_44709d6fcb",
"choices": [{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop"
}]
}
yield f"data: {json.dumps(end_chunk)}\n\n"
yield "data: [DONE]\n\n"

def convert_multimodal_to_text(messages, model_base):
"""
Convert multimodal content format to plain text for Gemini models.

Args:
messages (list): List of message objects
model_base (str): The model being used

Returns:
list: Processed messages with text-only content

Raises:
ValueError: If non-text content is found in multimodal format
"""
# Only process for Gemini models
gemini_models = ['gemini25pro', 'gemini25flash']
if model_base not in gemini_models:
return messages

processed_messages = []

for message in messages:
processed_message = message.copy()
content = message.get("content")

# Check if content is in multimodal format (list of content objects)
if isinstance(content, list):
text_parts = []

for content_item in content:
if isinstance(content_item, dict):
content_type = content_item.get("type")

if content_type == "text":
text_parts.append(content_item.get("text", ""))
else:
# Error if non-text content is found
raise ValueError(f"Gemini models only support text content. Found unsupported content type: '{content_type}'")
else:
# If content item is not a dict, treat as plain text
text_parts.append(str(content_item))

# Join all text parts and set as the content
processed_message["content"] = " ".join(text_parts)

processed_messages.append(processed_message)

return processed_messages


"""
=================================
Expand Down Expand Up @@ -308,8 +483,10 @@ def completions():

logging.debug(f"Received request: {data}")

user = get_user_from_auth_header()

req_obj = {
"user": ANL_USER,
"user": user,
"model": model,
"prompt": [data['prompt']],
"system": "",
Expand Down Expand Up @@ -389,7 +566,8 @@ def embeddings():
if isinstance(input_data, str):
input_data = [input_data]

embedding_vectors = _get_embeddings_from_argo(input_data, model)
user = get_user_from_auth_header()
embedding_vectors = _get_embeddings_from_argo(input_data, model, user)

response_data = {
"object": "list",
Expand All @@ -411,15 +589,15 @@ def embeddings():
return jsonify(response_data)


def _get_embeddings_from_argo(texts, model):
def _get_embeddings_from_argo(texts, model, user=BRIDGE_USER):
BATCH_SIZE = 16
all_embeddings = []

for i in range(0, len(texts), BATCH_SIZE):
batch_texts = texts[i:i + BATCH_SIZE]

payload = {
"user": ANL_USER,
"user": user,
"model": model,
"prompt": batch_texts
}
Expand Down Expand Up @@ -509,6 +687,7 @@ def parse_args():
level=logging.DEBUG if debug_enabled else logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logging.getLogger('watchdog').setLevel(logging.CRITICAL+10)

logging.info(f'Starting server with debug mode: {debug_enabled}')
print(f'Starting server... | Port {args.port} | User {args.username} | Debug: {debug_enabled}')
Expand Down