Skip to content

Commit d9d2e17

Browse files
committed
refactor upstream files and classes
1 parent 3d6bd65 commit d9d2e17

File tree

12 files changed

+194
-166
lines changed

12 files changed

+194
-166
lines changed

routstr/algorithm.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
if TYPE_CHECKING:
88
from .payment.models import Model
9-
from .upstream import UpstreamProvider
9+
from .upstream import BaseUpstreamProvider
1010

1111
logger = get_logger(__name__)
1212

@@ -59,7 +59,7 @@ def calculate_model_cost_score(model: "Model") -> float:
5959
return total_cost
6060

6161

62-
def get_provider_penalty(provider: "UpstreamProvider") -> float:
62+
def get_provider_penalty(provider: "BaseUpstreamProvider") -> float:
6363
"""Calculate a penalty multiplier for certain providers.
6464
6565
This allows applying policy-based adjustments beyond pure cost.
@@ -86,9 +86,9 @@ def get_provider_penalty(provider: "UpstreamProvider") -> float:
8686

8787
def should_prefer_model(
8888
candidate_model: "Model",
89-
candidate_provider: "UpstreamProvider",
89+
candidate_provider: "BaseUpstreamProvider",
9090
current_model: "Model",
91-
current_provider: "UpstreamProvider",
91+
current_provider: "BaseUpstreamProvider",
9292
alias: str,
9393
) -> bool:
9494
"""Determine if candidate model should replace current model for an alias.
@@ -166,10 +166,10 @@ def alias_priority(model: "Model") -> int:
166166

167167

168168
def create_model_mappings(
169-
upstreams: list["UpstreamProvider"],
169+
upstreams: list["BaseUpstreamProvider"],
170170
overrides_by_id: dict[str, tuple],
171171
disabled_model_ids: set[str],
172-
) -> tuple[dict[str, "Model"], dict[str, "UpstreamProvider"], dict[str, "Model"]]:
172+
) -> tuple[dict[str, "Model"], dict[str, "BaseUpstreamProvider"], dict[str, "Model"]]:
173173
"""Create optimal model mappings based on cost and provider preferences.
174174
175175
This is the main entry point for the algorithm. It processes all upstream providers
@@ -196,12 +196,12 @@ def create_model_mappings(
196196
from .upstream import resolve_model_alias
197197

198198
model_instances: dict[str, "Model"] = {}
199-
provider_map: dict[str, "UpstreamProvider"] = {}
199+
provider_map: dict[str, "BaseUpstreamProvider"] = {}
200200
unique_models: dict[str, "Model"] = {}
201201

202202
# Separate OpenRouter from other providers
203-
openrouter: "UpstreamProvider" | None = None
204-
other_upstreams: list["UpstreamProvider"] = []
203+
openrouter: "BaseUpstreamProvider" | None = None
204+
other_upstreams: list["BaseUpstreamProvider"] = []
205205

206206
for upstream in upstreams:
207207
base_url = getattr(upstream, "base_url", "")
@@ -215,7 +215,7 @@ def get_base_model_id(model_id: str) -> str:
215215
return model_id.split("/", 1)[1] if "/" in model_id else model_id
216216

217217
def _maybe_set_alias(
218-
alias: str, model: "Model", provider: "UpstreamProvider"
218+
alias: str, model: "Model", provider: "BaseUpstreamProvider"
219219
) -> None:
220220
"""Set alias to model/provider if not set or if new model is preferred."""
221221
existing_model = model_instances.get(alias)
@@ -233,7 +233,7 @@ def _maybe_set_alias(
233233
provider_map[alias] = provider
234234

235235
def process_provider_models(
236-
upstream: "UpstreamProvider", is_openrouter: bool = False
236+
upstream: "BaseUpstreamProvider", is_openrouter: bool = False
237237
) -> None:
238238
"""Process all models from a given provider."""
239239
upstream_prefix = getattr(upstream, "upstream_name", None)

routstr/proxy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@
2323
get_max_cost_for_model,
2424
)
2525
from .payment.models import Model
26-
from .upstream import UpstreamProvider, init_upstreams
26+
from .upstream import BaseUpstreamProvider, init_upstreams
2727

2828
logger = get_logger(__name__)
2929
proxy_router = APIRouter()
3030

31-
_upstreams: list[UpstreamProvider] = []
31+
_upstreams: list[BaseUpstreamProvider] = []
3232
_model_instances: dict[str, Model] = {} # All aliases -> Model
33-
_provider_map: dict[str, UpstreamProvider] = {} # All aliases -> Provider
33+
_provider_map: dict[str, BaseUpstreamProvider] = {} # All aliases -> Provider
3434
_unique_models: dict[str, Model] = {} # Unique model.id -> Model (no duplicates)
3535

3636

@@ -53,7 +53,7 @@ async def reinitialize_upstreams() -> None:
5353
await refresh_model_maps()
5454

5555

56-
def get_upstreams() -> list[UpstreamProvider]:
56+
def get_upstreams() -> list[BaseUpstreamProvider]:
5757
"""Get the initialized upstream providers.
5858
5959
Returns:
@@ -67,7 +67,7 @@ def get_model_instance(model_id: str) -> Model | None:
6767
return _model_instances.get(model_id)
6868

6969

70-
def get_provider_for_model(model_id: str) -> UpstreamProvider | None:
70+
def get_provider_for_model(model_id: str) -> BaseUpstreamProvider | None:
7171
"""Get UpstreamProvider for model ID from global cache."""
7272
return _provider_map.get(model_id)
7373

routstr/upstream/__init__.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from .anthropic import AnthropicUpstreamProvider
2+
from .azure import AzureUpstreamProvider
3+
from .base import BaseUpstreamProvider
4+
from .generic import GenericUpstreamProvider
5+
from .helpers import (
6+
_instantiate_provider,
7+
_seed_providers_from_settings,
8+
get_all_models_with_overrides,
9+
get_model_with_override,
10+
init_upstreams,
11+
refresh_upstreams_models_periodically,
12+
resolve_model_alias,
13+
)
14+
from .ollama import OllamaUpstreamProvider
15+
from .openai import OpenAIUpstreamProvider
16+
from .openrouter import OpenRouterUpstreamProvider
17+
18+
__all__ = [
19+
# upstreams
20+
"AnthropicUpstreamProvider",
21+
"AzureUpstreamProvider",
22+
"BaseUpstreamProvider",
23+
"GenericUpstreamProvider",
24+
"OllamaUpstreamProvider",
25+
"OpenAIUpstreamProvider",
26+
"OpenRouterUpstreamProvider",
27+
# helpers
28+
"resolve_model_alias",
29+
"get_all_models_with_overrides",
30+
"get_model_with_override",
31+
"refresh_upstreams_models_periodically",
32+
"init_upstreams",
33+
"_seed_providers_from_settings",
34+
"_instantiate_provider",
35+
]

routstr/upstream/anthropic.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from ..payment.models import Model, async_fetch_openrouter_models
2+
from .base import BaseUpstreamProvider
3+
4+
5+
class AnthropicUpstreamProvider(BaseUpstreamProvider):
6+
"""Upstream provider specifically configured for Anthropic API."""
7+
8+
def __init__(self, api_key: str, provider_fee: float = 1.01):
9+
self.upstream_name = "anthropic"
10+
super().__init__(
11+
base_url="https://api.anthropic.com/v1",
12+
api_key=api_key,
13+
provider_fee=provider_fee,
14+
)
15+
16+
def transform_model_name(self, model_id: str) -> str:
17+
"""Strip 'anthropic/' prefix for Anthropic API compatibility."""
18+
return model_id.removeprefix("anthropic/")
19+
20+
async def fetch_models(self) -> list[Model]:
21+
"""Fetch Anthropic models from OpenRouter API filtered by anthropic source."""
22+
models_data = await async_fetch_openrouter_models(source_filter="anthropic")
23+
return [Model(**model) for model in models_data] # type: ignore

routstr/upstream/azure.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from typing import Mapping
2+
3+
from .base import BaseUpstreamProvider
4+
5+
6+
class AzureUpstreamProvider(BaseUpstreamProvider):
7+
"""Upstream provider specifically configured for Azure OpenAI Service."""
8+
9+
def __init__(
10+
self,
11+
base_url: str,
12+
api_key: str,
13+
api_version: str,
14+
provider_fee: float = 1.01,
15+
):
16+
"""Initialize Azure provider with API key and version.
17+
18+
Args:
19+
base_url: Azure OpenAI endpoint base URL
20+
api_key: Azure OpenAI API key for authentication
21+
api_version: Azure OpenAI API version (e.g., "2024-02-15-preview")
22+
provider_fee: Provider fee multiplier (default 1.01 for 1% fee)
23+
"""
24+
super().__init__(
25+
base_url=base_url,
26+
api_key=api_key,
27+
provider_fee=provider_fee,
28+
)
29+
self.api_version = api_version
30+
31+
def prepare_params(
32+
self, path: str, query_params: Mapping[str, str] | None
33+
) -> Mapping[str, str]:
34+
"""Prepare query parameters for Azure OpenAI, adding API version.
35+
36+
Args:
37+
path: Request path
38+
query_params: Original query parameters from the client
39+
40+
Returns:
41+
Query parameters dict with Azure API version added for chat completions
42+
"""
43+
params = dict(query_params or {})
44+
if path.endswith("chat/completions"):
45+
params["api-version"] = self.api_version
46+
return params

routstr/upstreams/upstream.py renamed to routstr/upstream/base.py

Lines changed: 1 addition & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,14 @@
2525
Pricing,
2626
_calculate_usd_max_costs,
2727
_update_model_sats_pricing,
28-
async_fetch_openrouter_models,
2928
)
3029
from ..payment.price import sats_usd_price
3130
from ..wallet import recieve_token, send_token
3231

3332
logger = get_logger(__name__)
3433

3534

36-
class UpstreamProvider:
35+
class BaseUpstreamProvider:
3736
"""Provider for forwarding requests to an upstream AI service API."""
3837

3938
base_url: str
@@ -1702,111 +1701,3 @@ def get_cached_model_by_id(self, model_id: str) -> Model | None:
17021701
Model object or None if not found
17031702
"""
17041703
return self._models_by_id.get(model_id)
1705-
1706-
1707-
class OpenAIUpstreamProvider(UpstreamProvider):
1708-
"""Upstream provider specifically configured for OpenAI API."""
1709-
1710-
def __init__(self, api_key: str, provider_fee: float = 1.01):
1711-
self.upstream_name = "openai"
1712-
super().__init__(
1713-
base_url="https://api.openai.com/v1",
1714-
api_key=api_key,
1715-
provider_fee=provider_fee,
1716-
)
1717-
1718-
def transform_model_name(self, model_id: str) -> str:
1719-
"""Strip 'openai/' prefix for OpenAI API compatibility."""
1720-
return model_id.removeprefix("openai/")
1721-
1722-
async def fetch_models(self) -> list[Model]:
1723-
"""Fetch OpenAI models from OpenRouter API filtered by openai source."""
1724-
models_data = await async_fetch_openrouter_models(source_filter="openai")
1725-
return [Model(**model) for model in models_data] # type: ignore
1726-
1727-
1728-
class AnthropicUpstreamProvider(UpstreamProvider):
1729-
"""Upstream provider specifically configured for Anthropic API."""
1730-
1731-
def __init__(self, api_key: str, provider_fee: float = 1.01):
1732-
self.upstream_name = "anthropic"
1733-
super().__init__(
1734-
base_url="https://api.anthropic.com/v1",
1735-
api_key=api_key,
1736-
provider_fee=provider_fee,
1737-
)
1738-
1739-
def transform_model_name(self, model_id: str) -> str:
1740-
"""Strip 'anthropic/' prefix for Anthropic API compatibility."""
1741-
return model_id.removeprefix("anthropic/")
1742-
1743-
async def fetch_models(self) -> list[Model]:
1744-
"""Fetch Anthropic models from OpenRouter API filtered by anthropic source."""
1745-
models_data = await async_fetch_openrouter_models(source_filter="anthropic")
1746-
return [Model(**model) for model in models_data] # type: ignore
1747-
1748-
1749-
class AzureUpstreamProvider(UpstreamProvider):
1750-
"""Upstream provider specifically configured for Azure OpenAI Service."""
1751-
1752-
def __init__(
1753-
self,
1754-
base_url: str,
1755-
api_key: str,
1756-
api_version: str,
1757-
provider_fee: float = 1.01,
1758-
):
1759-
"""Initialize Azure provider with API key and version.
1760-
1761-
Args:
1762-
base_url: Azure OpenAI endpoint base URL
1763-
api_key: Azure OpenAI API key for authentication
1764-
api_version: Azure OpenAI API version (e.g., "2024-02-15-preview")
1765-
provider_fee: Provider fee multiplier (default 1.01 for 1% fee)
1766-
"""
1767-
super().__init__(
1768-
base_url=base_url,
1769-
api_key=api_key,
1770-
provider_fee=provider_fee,
1771-
)
1772-
self.api_version = api_version
1773-
1774-
def prepare_params(
1775-
self, path: str, query_params: Mapping[str, str] | None
1776-
) -> Mapping[str, str]:
1777-
"""Prepare query parameters for Azure OpenAI, adding API version.
1778-
1779-
Args:
1780-
path: Request path
1781-
query_params: Original query parameters from the client
1782-
1783-
Returns:
1784-
Query parameters dict with Azure API version added for chat completions
1785-
"""
1786-
params = dict(query_params or {})
1787-
if path.endswith("chat/completions"):
1788-
params["api-version"] = self.api_version
1789-
return params
1790-
1791-
1792-
class OpenRouterUpstreamProvider(UpstreamProvider):
1793-
"""Upstream provider specifically configured for OpenRouter API."""
1794-
1795-
def __init__(self, api_key: str, provider_fee: float = 1.06):
1796-
"""Initialize OpenRouter provider with API key.
1797-
1798-
Args:
1799-
api_key: OpenRouter API key for authentication
1800-
provider_fee: Provider fee multiplier (default 1.06 for 6% fee)
1801-
"""
1802-
self.upstream_name = "openrouter"
1803-
super().__init__(
1804-
base_url="https://openrouter.ai/api/v1",
1805-
api_key=api_key,
1806-
provider_fee=provider_fee,
1807-
)
1808-
1809-
async def fetch_models(self) -> list[Model]:
1810-
"""Fetch all OpenRouter models."""
1811-
models_data = await async_fetch_openrouter_models()
1812-
return [Model(**model) for model in models_data] # type: ignore

routstr/upstreams/generic.py renamed to routstr/upstream/generic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import httpx
66

7-
from .upstream import UpstreamProvider
7+
from .base import BaseUpstreamProvider
88

99
if TYPE_CHECKING:
1010
from ..payment.models import Model
@@ -14,7 +14,7 @@
1414
logger = get_logger(__name__)
1515

1616

17-
class GenericUpstreamProvider(UpstreamProvider):
17+
class GenericUpstreamProvider(BaseUpstreamProvider):
1818
"""Generic upstream provider that can fetch models from any OpenAI-compatible API."""
1919

2020
def __init__(

0 commit comments

Comments
 (0)