|
0981a08…
|
noreply
|
1 |
"""Abstract base class, registry, and shared types for provider implementations.""" |
|
a94205b…
|
leo
|
2 |
|
|
0981a08…
|
noreply
|
3 |
import base64 |
|
0981a08…
|
noreply
|
4 |
import logging |
|
0981a08…
|
noreply
|
5 |
import os |
|
a94205b…
|
leo
|
6 |
from abc import ABC, abstractmethod |
|
a94205b…
|
leo
|
7 |
from pathlib import Path |
|
0981a08…
|
noreply
|
8 |
from typing import Dict, List, Optional |
|
a94205b…
|
leo
|
9 |
|
|
a94205b…
|
leo
|
10 |
from pydantic import BaseModel, Field |
|
0981a08…
|
noreply
|
11 |
|
|
0981a08…
|
noreply
|
12 |
logger = logging.getLogger(__name__) |
|
a94205b…
|
leo
|
13 |
|
|
a94205b…
|
leo
|
14 |
|
|
a94205b…
|
leo
|
15 |
class ModelInfo(BaseModel): |
|
a94205b…
|
leo
|
16 |
"""Information about an available model.""" |
|
829e24a…
|
leo
|
17 |
|
|
a94205b…
|
leo
|
18 |
id: str = Field(description="Model identifier (e.g. gpt-4o)") |
|
a94205b…
|
leo
|
19 |
provider: str = Field(description="Provider name (openai, anthropic, gemini)") |
|
a94205b…
|
leo
|
20 |
display_name: str = Field(default="", description="Human-readable name") |
|
a94205b…
|
leo
|
21 |
capabilities: List[str] = Field( |
|
829e24a…
|
leo
|
22 |
default_factory=list, description="Model capabilities: chat, vision, audio, embedding" |
|
a94205b…
|
leo
|
23 |
) |
|
a94205b…
|
leo
|
24 |
|
|
a94205b…
|
leo
|
25 |
|
|
a94205b…
|
leo
|
26 |
class BaseProvider(ABC): |
|
a94205b…
|
leo
|
27 |
"""Abstract base for all provider implementations.""" |
|
a94205b…
|
leo
|
28 |
|
|
a94205b…
|
leo
|
29 |
provider_name: str = "" |
|
a94205b…
|
leo
|
30 |
|
|
a94205b…
|
leo
|
31 |
@abstractmethod |
|
a94205b…
|
leo
|
32 |
def chat( |
|
a94205b…
|
leo
|
33 |
self, |
|
a94205b…
|
leo
|
34 |
messages: list[dict], |
|
a94205b…
|
leo
|
35 |
max_tokens: int = 4096, |
|
a94205b…
|
leo
|
36 |
temperature: float = 0.7, |
|
a94205b…
|
leo
|
37 |
model: Optional[str] = None, |
|
a94205b…
|
leo
|
38 |
) -> str: |
|
a94205b…
|
leo
|
39 |
"""Send a chat completion request. Returns the assistant text.""" |
|
a94205b…
|
leo
|
40 |
|
|
a94205b…
|
leo
|
41 |
@abstractmethod |
|
a94205b…
|
leo
|
42 |
def analyze_image( |
|
a94205b…
|
leo
|
43 |
self, |
|
a94205b…
|
leo
|
44 |
image_bytes: bytes, |
|
a94205b…
|
leo
|
45 |
prompt: str, |
|
a94205b…
|
leo
|
46 |
max_tokens: int = 4096, |
|
a94205b…
|
leo
|
47 |
model: Optional[str] = None, |
|
a94205b…
|
leo
|
48 |
) -> str: |
|
a94205b…
|
leo
|
49 |
"""Analyze an image with a prompt. Returns the assistant text.""" |
|
a94205b…
|
leo
|
50 |
|
|
a94205b…
|
leo
|
51 |
@abstractmethod |
|
a94205b…
|
leo
|
52 |
def transcribe_audio( |
|
a94205b…
|
leo
|
53 |
self, |
|
a94205b…
|
leo
|
54 |
audio_path: str | Path, |
|
a94205b…
|
leo
|
55 |
language: Optional[str] = None, |
|
a94205b…
|
leo
|
56 |
model: Optional[str] = None, |
|
a94205b…
|
leo
|
57 |
) -> dict: |
|
a94205b…
|
leo
|
58 |
"""Transcribe an audio file. Returns dict with 'text', 'segments', etc.""" |
|
a94205b…
|
leo
|
59 |
|
|
a94205b…
|
leo
|
60 |
@abstractmethod |
|
a94205b…
|
leo
|
61 |
def list_models(self) -> list[ModelInfo]: |
|
a94205b…
|
leo
|
62 |
"""Discover available models from this provider's API.""" |
|
0981a08…
|
noreply
|
63 |
|
|
0981a08…
|
noreply
|
64 |
|
|
0981a08…
|
noreply
|
65 |
class ProviderRegistry: |
|
0981a08…
|
noreply
|
66 |
"""Registry for provider classes. Providers register themselves with metadata.""" |
|
0981a08…
|
noreply
|
67 |
|
|
0981a08…
|
noreply
|
68 |
_providers: Dict[str, Dict] = {} |
|
0981a08…
|
noreply
|
69 |
|
|
0981a08…
|
noreply
|
70 |
@classmethod |
|
0981a08…
|
noreply
|
71 |
def register( |
|
0981a08…
|
noreply
|
72 |
cls, |
|
0981a08…
|
noreply
|
73 |
name: str, |
|
0981a08…
|
noreply
|
74 |
provider_class: type, |
|
0981a08…
|
noreply
|
75 |
env_var: str = "", |
|
0981a08…
|
noreply
|
76 |
model_prefixes: Optional[List[str]] = None, |
|
0981a08…
|
noreply
|
77 |
default_models: Optional[Dict[str, str]] = None, |
|
0981a08…
|
noreply
|
78 |
) -> None: |
|
0981a08…
|
noreply
|
79 |
"""Register a provider class with its metadata.""" |
|
0981a08…
|
noreply
|
80 |
cls._providers[name] = { |
|
0981a08…
|
noreply
|
81 |
"class": provider_class, |
|
0981a08…
|
noreply
|
82 |
"env_var": env_var, |
|
0981a08…
|
noreply
|
83 |
"model_prefixes": model_prefixes or [], |
|
0981a08…
|
noreply
|
84 |
"default_models": default_models or {}, |
|
0981a08…
|
noreply
|
85 |
} |
|
0981a08…
|
noreply
|
86 |
|
|
0981a08…
|
noreply
|
87 |
@classmethod |
|
0981a08…
|
noreply
|
88 |
def get(cls, name: str) -> type: |
|
0981a08…
|
noreply
|
89 |
"""Return the provider class for a given name.""" |
|
0981a08…
|
noreply
|
90 |
if name not in cls._providers: |
|
0981a08…
|
noreply
|
91 |
raise ValueError(f"Unknown provider: {name}") |
|
0981a08…
|
noreply
|
92 |
return cls._providers[name]["class"] |
|
0981a08…
|
noreply
|
93 |
|
|
0981a08…
|
noreply
|
94 |
@classmethod |
|
0981a08…
|
noreply
|
95 |
def get_by_model(cls, model_id: str) -> Optional[str]: |
|
0981a08…
|
noreply
|
96 |
"""Return provider name for a model ID based on prefix matching.""" |
|
0981a08…
|
noreply
|
97 |
for name, info in cls._providers.items(): |
|
0981a08…
|
noreply
|
98 |
for prefix in info["model_prefixes"]: |
|
0981a08…
|
noreply
|
99 |
if model_id.startswith(prefix): |
|
0981a08…
|
noreply
|
100 |
return name |
|
0981a08…
|
noreply
|
101 |
return None |
|
0981a08…
|
noreply
|
102 |
|
|
0981a08…
|
noreply
|
103 |
@classmethod |
|
0981a08…
|
noreply
|
104 |
def get_default_models(cls, name: str) -> Dict[str, str]: |
|
0981a08…
|
noreply
|
105 |
"""Return the default models dict for a provider.""" |
|
0981a08…
|
noreply
|
106 |
if name not in cls._providers: |
|
0981a08…
|
noreply
|
107 |
return {} |
|
0981a08…
|
noreply
|
108 |
return cls._providers[name].get("default_models", {}) |
|
0981a08…
|
noreply
|
109 |
|
|
0981a08…
|
noreply
|
110 |
@classmethod |
|
0981a08…
|
noreply
|
111 |
def available(cls) -> List[str]: |
|
0981a08…
|
noreply
|
112 |
"""Return names of providers whose env var is set (or have no env var requirement).""" |
|
0981a08…
|
noreply
|
113 |
result = [] |
|
0981a08…
|
noreply
|
114 |
for name, info in cls._providers.items(): |
|
0981a08…
|
noreply
|
115 |
env_var = info.get("env_var", "") |
|
0981a08…
|
noreply
|
116 |
if not env_var: |
|
0981a08…
|
noreply
|
117 |
# Providers without an env var (e.g. ollama) need special availability checks |
|
0981a08…
|
noreply
|
118 |
result.append(name) |
|
0981a08…
|
noreply
|
119 |
elif os.getenv(env_var, ""): |
|
0981a08…
|
noreply
|
120 |
result.append(name) |
|
0981a08…
|
noreply
|
121 |
return result |
|
0981a08…
|
noreply
|
122 |
|
|
0981a08…
|
noreply
|
123 |
@classmethod |
|
0981a08…
|
noreply
|
124 |
def all_registered(cls) -> Dict[str, Dict]: |
|
0981a08…
|
noreply
|
125 |
"""Return all registered providers and their metadata.""" |
|
0981a08…
|
noreply
|
126 |
return dict(cls._providers) |
|
0981a08…
|
noreply
|
127 |
|
|
0981a08…
|
noreply
|
128 |
|
|
0981a08…
|
noreply
|
129 |
class OpenAICompatibleProvider(BaseProvider): |
|
0981a08…
|
noreply
|
130 |
"""Base for providers using OpenAI-compatible APIs. |
|
0981a08…
|
noreply
|
131 |
|
|
0981a08…
|
noreply
|
132 |
Suitable for Together, Fireworks, Cerebras, xAI, Azure, and similar services. |
|
0981a08…
|
noreply
|
133 |
""" |
|
0981a08…
|
noreply
|
134 |
|
|
0981a08…
|
noreply
|
135 |
provider_name: str = "" |
|
0981a08…
|
noreply
|
136 |
base_url: str = "" |
|
0981a08…
|
noreply
|
137 |
env_var: str = "" |
|
0981a08…
|
noreply
|
138 |
|
|
0981a08…
|
noreply
|
139 |
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None): |
|
0981a08…
|
noreply
|
140 |
from openai import OpenAI |
|
0981a08…
|
noreply
|
141 |
|
|
0981a08…
|
noreply
|
142 |
self._api_key = api_key or os.getenv(self.env_var, "") |
|
0981a08…
|
noreply
|
143 |
self._base_url = base_url or self.base_url |
|
0981a08…
|
noreply
|
144 |
self._client = OpenAI(api_key=self._api_key, base_url=self._base_url) |
|
0981a08…
|
noreply
|
145 |
self._last_usage = None |
|
0981a08…
|
noreply
|
146 |
|
|
0981a08…
|
noreply
|
147 |
def chat( |
|
0981a08…
|
noreply
|
148 |
self, |
|
0981a08…
|
noreply
|
149 |
messages: list[dict], |
|
0981a08…
|
noreply
|
150 |
max_tokens: int = 4096, |
|
0981a08…
|
noreply
|
151 |
temperature: float = 0.7, |
|
0981a08…
|
noreply
|
152 |
model: Optional[str] = None, |
|
0981a08…
|
noreply
|
153 |
) -> str: |
|
0981a08…
|
noreply
|
154 |
model = model or "gpt-4o" |
|
0981a08…
|
noreply
|
155 |
response = self._client.chat.completions.create( |
|
0981a08…
|
noreply
|
156 |
model=model, |
|
0981a08…
|
noreply
|
157 |
messages=messages, |
|
0981a08…
|
noreply
|
158 |
max_tokens=max_tokens, |
|
0981a08…
|
noreply
|
159 |
temperature=temperature, |
|
0981a08…
|
noreply
|
160 |
) |
|
0981a08…
|
noreply
|
161 |
self._last_usage = { |
|
0981a08…
|
noreply
|
162 |
"input_tokens": getattr(response.usage, "prompt_tokens", 0) if response.usage else 0, |
|
0981a08…
|
noreply
|
163 |
"output_tokens": getattr(response.usage, "completion_tokens", 0) |
|
0981a08…
|
noreply
|
164 |
if response.usage |
|
0981a08…
|
noreply
|
165 |
else 0, |
|
0981a08…
|
noreply
|
166 |
} |
|
0981a08…
|
noreply
|
167 |
return response.choices[0].message.content or "" |
|
0981a08…
|
noreply
|
168 |
|
|
0981a08…
|
noreply
|
169 |
def analyze_image( |
|
0981a08…
|
noreply
|
170 |
self, |
|
0981a08…
|
noreply
|
171 |
image_bytes: bytes, |
|
0981a08…
|
noreply
|
172 |
prompt: str, |
|
0981a08…
|
noreply
|
173 |
max_tokens: int = 4096, |
|
0981a08…
|
noreply
|
174 |
model: Optional[str] = None, |
|
0981a08…
|
noreply
|
175 |
) -> str: |
|
0981a08…
|
noreply
|
176 |
model = model or "gpt-4o" |
|
0981a08…
|
noreply
|
177 |
b64 = base64.b64encode(image_bytes).decode() |
|
0981a08…
|
noreply
|
178 |
response = self._client.chat.completions.create( |
|
0981a08…
|
noreply
|
179 |
model=model, |
|
0981a08…
|
noreply
|
180 |
messages=[ |
|
0981a08…
|
noreply
|
181 |
{ |
|
0981a08…
|
noreply
|
182 |
"role": "user", |
|
0981a08…
|
noreply
|
183 |
"content": [ |
|
0981a08…
|
noreply
|
184 |
{"type": "text", "text": prompt}, |
|
0981a08…
|
noreply
|
185 |
{ |
|
0981a08…
|
noreply
|
186 |
"type": "image_url", |
|
0981a08…
|
noreply
|
187 |
"image_url": {"url": f"data:image/jpeg;base64,{b64}"}, |
|
0981a08…
|
noreply
|
188 |
}, |
|
0981a08…
|
noreply
|
189 |
], |
|
0981a08…
|
noreply
|
190 |
} |
|
0981a08…
|
noreply
|
191 |
], |
|
0981a08…
|
noreply
|
192 |
max_tokens=max_tokens, |
|
0981a08…
|
noreply
|
193 |
) |
|
0981a08…
|
noreply
|
194 |
self._last_usage = { |
|
0981a08…
|
noreply
|
195 |
"input_tokens": getattr(response.usage, "prompt_tokens", 0) if response.usage else 0, |
|
0981a08…
|
noreply
|
196 |
"output_tokens": getattr(response.usage, "completion_tokens", 0) |
|
0981a08…
|
noreply
|
197 |
if response.usage |
|
0981a08…
|
noreply
|
198 |
else 0, |
|
0981a08…
|
noreply
|
199 |
} |
|
0981a08…
|
noreply
|
200 |
return response.choices[0].message.content or "" |
|
0981a08…
|
noreply
|
201 |
|
|
0981a08…
|
noreply
|
202 |
def transcribe_audio( |
|
0981a08…
|
noreply
|
203 |
self, |
|
0981a08…
|
noreply
|
204 |
audio_path: str | Path, |
|
0981a08…
|
noreply
|
205 |
language: Optional[str] = None, |
|
0981a08…
|
noreply
|
206 |
model: Optional[str] = None, |
|
0981a08…
|
noreply
|
207 |
) -> dict: |
|
0981a08…
|
noreply
|
208 |
raise NotImplementedError(f"{self.provider_name} does not support audio transcription") |
|
0981a08…
|
noreply
|
209 |
|
|
0981a08…
|
noreply
|
210 |
def list_models(self) -> list[ModelInfo]: |
|
0981a08…
|
noreply
|
211 |
models = [] |
|
0981a08…
|
noreply
|
212 |
try: |
|
0981a08…
|
noreply
|
213 |
for m in self._client.models.list(): |
|
0981a08…
|
noreply
|
214 |
mid = m.id |
|
0981a08…
|
noreply
|
215 |
caps = ["chat"] |
|
0981a08…
|
noreply
|
216 |
models.append( |
|
0981a08…
|
noreply
|
217 |
ModelInfo( |
|
0981a08…
|
noreply
|
218 |
id=mid, |
|
0981a08…
|
noreply
|
219 |
provider=self.provider_name, |
|
0981a08…
|
noreply
|
220 |
display_name=mid, |
|
0981a08…
|
noreply
|
221 |
capabilities=caps, |
|
0981a08…
|
noreply
|
222 |
) |
|
0981a08…
|
noreply
|
223 |
) |
|
0981a08…
|
noreply
|
224 |
except Exception as e: |
|
0981a08…
|
noreply
|
225 |
logger.warning(f"Failed to list {self.provider_name} models: {e}") |
|
0981a08…
|
noreply
|
226 |
return sorted(models, key=lambda m: m.id) |