|
a94205b…
|
leo
|
1 |
"""ProviderManager - unified interface for routing API calls to the best available provider.""" |
|
a94205b…
|
leo
|
2 |
|
|
a94205b…
|
leo
|
3 |
import logging |
|
a94205b…
|
leo
|
4 |
from pathlib import Path |
|
a94205b…
|
leo
|
5 |
from typing import Optional |
|
a94205b…
|
leo
|
6 |
|
|
a94205b…
|
leo
|
7 |
from dotenv import load_dotenv |
|
a94205b…
|
leo
|
8 |
|
|
0981a08…
|
noreply
|
9 |
from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry |
|
a94205b…
|
leo
|
10 |
from video_processor.providers.discovery import discover_available_models |
|
287a3bb…
|
leo
|
11 |
from video_processor.utils.usage_tracker import UsageTracker |
|
a94205b…
|
leo
|
12 |
|
|
a94205b…
|
leo
|
13 |
load_dotenv() |
|
a94205b…
|
leo
|
14 |
logger = logging.getLogger(__name__) |
|
a94205b…
|
leo
|
15 |
|
|
0981a08…
|
noreply
|
16 |
|
|
0981a08…
|
noreply
|
17 |
def _ensure_providers_registered() -> None: |
|
0981a08…
|
noreply
|
18 |
"""Import all built-in provider modules so they register themselves.""" |
|
0981a08…
|
noreply
|
19 |
if ProviderRegistry.all_registered(): |
|
0981a08…
|
noreply
|
20 |
return |
|
0981a08…
|
noreply
|
21 |
# Each module registers itself on import via ProviderRegistry.register() |
|
0981a08…
|
noreply
|
22 |
import video_processor.providers.anthropic_provider # noqa: F401 |
|
0981a08…
|
noreply
|
23 |
import video_processor.providers.azure_provider # noqa: F401 |
|
0981a08…
|
noreply
|
24 |
import video_processor.providers.cerebras_provider # noqa: F401 |
|
0981a08…
|
noreply
|
25 |
import video_processor.providers.fireworks_provider # noqa: F401 |
|
0981a08…
|
noreply
|
26 |
import video_processor.providers.gemini_provider # noqa: F401 |
|
0981a08…
|
noreply
|
27 |
import video_processor.providers.ollama_provider # noqa: F401 |
|
0981a08…
|
noreply
|
28 |
import video_processor.providers.openai_provider # noqa: F401 |
|
0981a08…
|
noreply
|
29 |
import video_processor.providers.together_provider # noqa: F401 |
|
0981a08…
|
noreply
|
30 |
import video_processor.providers.xai_provider # noqa: F401 |
|
0981a08…
|
noreply
|
31 |
|
|
0981a08…
|
noreply
|
32 |
|
|
a94205b…
|
leo
|
33 |
# Default model preference rankings (tried in order) |
|
a94205b…
|
leo
|
34 |
_VISION_PREFERENCES = [ |
|
a94205b…
|
leo
|
35 |
("gemini", "gemini-2.5-flash"), |
|
0981a08…
|
noreply
|
36 |
("openai", "gpt-4o-mini"), |
|
0981a08…
|
noreply
|
37 |
("anthropic", "claude-haiku-4-5-20251001"), |
|
a94205b…
|
leo
|
38 |
] |
|
a94205b…
|
leo
|
39 |
|
|
a94205b…
|
leo
|
40 |
_CHAT_PREFERENCES = [ |
|
0981a08…
|
noreply
|
41 |
("anthropic", "claude-haiku-4-5-20251001"), |
|
0981a08…
|
noreply
|
42 |
("openai", "gpt-4o-mini"), |
|
a94205b…
|
leo
|
43 |
("gemini", "gemini-2.5-flash"), |
|
a94205b…
|
leo
|
44 |
] |
|
a94205b…
|
leo
|
45 |
|
|
a94205b…
|
leo
|
46 |
_TRANSCRIPTION_PREFERENCES = [ |
|
a94205b…
|
leo
|
47 |
("openai", "whisper-1"), |
|
a94205b…
|
leo
|
48 |
("gemini", "gemini-2.5-flash"), |
|
a94205b…
|
leo
|
49 |
] |
|
a94205b…
|
leo
|
50 |
|
|
a94205b…
|
leo
|
51 |
|
|
a94205b…
|
leo
|
52 |
class ProviderManager: |
|
a94205b…
|
leo
|
53 |
""" |
|
a94205b…
|
leo
|
54 |
Routes API calls to the best available provider/model. |
|
a94205b…
|
leo
|
55 |
|
|
a94205b…
|
leo
|
56 |
Supports explicit model selection or auto-routing based on |
|
a94205b…
|
leo
|
57 |
discovered available models. |
|
a94205b…
|
leo
|
58 |
""" |
|
a94205b…
|
leo
|
59 |
|
|
a94205b…
|
leo
|
60 |
def __init__( |
|
a94205b…
|
leo
|
61 |
self, |
|
a94205b…
|
leo
|
62 |
vision_model: Optional[str] = None, |
|
a94205b…
|
leo
|
63 |
chat_model: Optional[str] = None, |
|
a94205b…
|
leo
|
64 |
transcription_model: Optional[str] = None, |
|
a94205b…
|
leo
|
65 |
provider: Optional[str] = None, |
|
a94205b…
|
leo
|
66 |
auto: bool = True, |
|
a94205b…
|
leo
|
67 |
): |
|
a94205b…
|
leo
|
68 |
""" |
|
a94205b…
|
leo
|
69 |
Initialize the ProviderManager. |
|
a94205b…
|
leo
|
70 |
|
|
a94205b…
|
leo
|
71 |
Parameters |
|
a94205b…
|
leo
|
72 |
---------- |
|
a94205b…
|
leo
|
73 |
vision_model : override model for vision tasks (e.g. 'gpt-4o') |
|
a94205b…
|
leo
|
74 |
chat_model : override model for chat/LLM tasks |
|
a94205b…
|
leo
|
75 |
transcription_model : override model for transcription |
|
a94205b…
|
leo
|
76 |
provider : force all tasks to a single provider ('openai', 'anthropic', 'gemini') |
|
a94205b…
|
leo
|
77 |
auto : if True and no model specified, pick the best available |
|
a94205b…
|
leo
|
78 |
""" |
|
0981a08…
|
noreply
|
79 |
_ensure_providers_registered() |
|
a94205b…
|
leo
|
80 |
self.auto = auto |
|
a94205b…
|
leo
|
81 |
self._providers: dict[str, BaseProvider] = {} |
|
a94205b…
|
leo
|
82 |
self._available_models: Optional[list[ModelInfo]] = None |
|
287a3bb…
|
leo
|
83 |
self.usage = UsageTracker() |
|
a94205b…
|
leo
|
84 |
|
|
a94205b…
|
leo
|
85 |
# If a single provider is forced, apply it |
|
a94205b…
|
leo
|
86 |
if provider: |
|
a94205b…
|
leo
|
87 |
self.vision_model = vision_model or self._default_for_provider(provider, "vision") |
|
a94205b…
|
leo
|
88 |
self.chat_model = chat_model or self._default_for_provider(provider, "chat") |
|
829e24a…
|
leo
|
89 |
self.transcription_model = transcription_model or self._default_for_provider( |
|
829e24a…
|
leo
|
90 |
provider, "audio" |
|
829e24a…
|
leo
|
91 |
) |
|
a94205b…
|
leo
|
92 |
else: |
|
a94205b…
|
leo
|
93 |
self.vision_model = vision_model |
|
a94205b…
|
leo
|
94 |
self.chat_model = chat_model |
|
a94205b…
|
leo
|
95 |
self.transcription_model = transcription_model |
|
a94205b…
|
leo
|
96 |
|
|
a94205b…
|
leo
|
97 |
self._forced_provider = provider |
|
a94205b…
|
leo
|
98 |
|
|
a94205b…
|
leo
|
99 |
@staticmethod |
|
a94205b…
|
leo
|
100 |
def _default_for_provider(provider: str, capability: str) -> str: |
|
a94205b…
|
leo
|
101 |
"""Return the default model for a provider/capability combo.""" |
|
0981a08…
|
noreply
|
102 |
defaults = ProviderRegistry.get_default_models(provider) |
|
0981a08…
|
noreply
|
103 |
if defaults: |
|
0981a08…
|
noreply
|
104 |
return defaults.get(capability, "") |
|
0981a08…
|
noreply
|
105 |
# Fallback for unregistered providers |
|
0981a08…
|
noreply
|
106 |
return "" |
|
a94205b…
|
leo
|
107 |
|
|
a94205b…
|
leo
|
108 |
def _get_provider(self, provider_name: str) -> BaseProvider: |
|
a94205b…
|
leo
|
109 |
"""Lazily initialize and cache a provider instance.""" |
|
a94205b…
|
leo
|
110 |
if provider_name not in self._providers: |
|
0981a08…
|
noreply
|
111 |
_ensure_providers_registered() |
|
0981a08…
|
noreply
|
112 |
provider_class = ProviderRegistry.get(provider_name) |
|
0981a08…
|
noreply
|
113 |
self._providers[provider_name] = provider_class() |
|
a94205b…
|
leo
|
114 |
return self._providers[provider_name] |
|
a94205b…
|
leo
|
115 |
|
|
a94205b…
|
leo
|
116 |
def _provider_for_model(self, model_id: str) -> str: |
|
a94205b…
|
leo
|
117 |
"""Infer the provider from a model id.""" |
|
0981a08…
|
noreply
|
118 |
_ensure_providers_registered() |
|
0981a08…
|
noreply
|
119 |
# Check registry prefix matching first |
|
0981a08…
|
noreply
|
120 |
provider_name = ProviderRegistry.get_by_model(model_id) |
|
0981a08…
|
noreply
|
121 |
if provider_name: |
|
0981a08…
|
noreply
|
122 |
return provider_name |
|
a0146a5…
|
noreply
|
123 |
# Try discovery (exact match, then prefix match for ollama name:tag format) |
|
a94205b…
|
leo
|
124 |
models = self._get_available_models() |
|
a94205b…
|
leo
|
125 |
for m in models: |
|
a94205b…
|
leo
|
126 |
if m.id == model_id: |
|
a0146a5…
|
noreply
|
127 |
return m.provider |
|
a0146a5…
|
noreply
|
128 |
for m in models: |
|
a0146a5…
|
noreply
|
129 |
if m.id.startswith(model_id + ":"): |
|
a94205b…
|
leo
|
130 |
return m.provider |
|
a94205b…
|
leo
|
131 |
raise ValueError(f"Cannot determine provider for model: {model_id}") |
|
a94205b…
|
leo
|
132 |
|
|
a94205b…
|
leo
|
133 |
def _get_available_models(self) -> list[ModelInfo]: |
|
a94205b…
|
leo
|
134 |
if self._available_models is None: |
|
a94205b…
|
leo
|
135 |
self._available_models = discover_available_models() |
|
a94205b…
|
leo
|
136 |
return self._available_models |
|
a94205b…
|
leo
|
137 |
|
|
829e24a…
|
leo
|
138 |
def _resolve_model( |
|
829e24a…
|
leo
|
139 |
self, explicit: Optional[str], capability: str, preferences: list[tuple[str, str]] |
|
829e24a…
|
leo
|
140 |
) -> tuple[str, str]: |
|
a94205b…
|
leo
|
141 |
""" |
|
a94205b…
|
leo
|
142 |
Resolve which (provider, model) to use for a capability. |
|
a94205b…
|
leo
|
143 |
|
|
a94205b…
|
leo
|
144 |
Returns (provider_name, model_id). |
|
a94205b…
|
leo
|
145 |
""" |
|
a94205b…
|
leo
|
146 |
if explicit: |
|
a94205b…
|
leo
|
147 |
prov = self._provider_for_model(explicit) |
|
a94205b…
|
leo
|
148 |
return prov, explicit |
|
a94205b…
|
leo
|
149 |
|
|
a94205b…
|
leo
|
150 |
if self.auto: |
|
a94205b…
|
leo
|
151 |
# Try preference order, picking the first provider that has an API key |
|
a94205b…
|
leo
|
152 |
for prov, model in preferences: |
|
a94205b…
|
leo
|
153 |
try: |
|
a94205b…
|
leo
|
154 |
self._get_provider(prov) |
|
a94205b…
|
leo
|
155 |
return prov, model |
|
a94205b…
|
leo
|
156 |
except (ValueError, ImportError): |
|
a94205b…
|
leo
|
157 |
continue |
|
a94205b…
|
leo
|
158 |
|
|
a0146a5…
|
noreply
|
159 |
# Fallback: try Ollama if available (no API key needed) |
|
a0146a5…
|
noreply
|
160 |
try: |
|
a0146a5…
|
noreply
|
161 |
from video_processor.providers.ollama_provider import OllamaProvider |
|
a0146a5…
|
noreply
|
162 |
|
|
a0146a5…
|
noreply
|
163 |
if OllamaProvider.is_available(): |
|
a0146a5…
|
noreply
|
164 |
provider = self._get_provider("ollama") |
|
a0146a5…
|
noreply
|
165 |
models = provider.list_models() |
|
a0146a5…
|
noreply
|
166 |
for m in models: |
|
a0146a5…
|
noreply
|
167 |
if capability in m.capabilities: |
|
a0146a5…
|
noreply
|
168 |
return "ollama", m.id |
|
a0146a5…
|
noreply
|
169 |
except Exception: |
|
a0146a5…
|
noreply
|
170 |
pass |
|
a0146a5…
|
noreply
|
171 |
|
|
a94205b…
|
leo
|
172 |
raise RuntimeError( |
|
a94205b…
|
leo
|
173 |
f"No provider available for capability '{capability}'. " |
|
a0146a5…
|
noreply
|
174 |
"Set an API key for at least one provider, or start Ollama." |
|
a94205b…
|
leo
|
175 |
) |
|
287a3bb…
|
leo
|
176 |
|
|
287a3bb…
|
leo
|
177 |
def _track(self, provider: BaseProvider, prov_name: str, model: str) -> None: |
|
287a3bb…
|
leo
|
178 |
"""Record usage from the last API call on a provider.""" |
|
287a3bb…
|
leo
|
179 |
last = getattr(provider, "_last_usage", None) |
|
287a3bb…
|
leo
|
180 |
if last: |
|
287a3bb…
|
leo
|
181 |
self.usage.record( |
|
287a3bb…
|
leo
|
182 |
provider=prov_name, |
|
287a3bb…
|
leo
|
183 |
model=model, |
|
287a3bb…
|
leo
|
184 |
input_tokens=last.get("input_tokens", 0), |
|
287a3bb…
|
leo
|
185 |
output_tokens=last.get("output_tokens", 0), |
|
287a3bb…
|
leo
|
186 |
) |
|
287a3bb…
|
leo
|
187 |
provider._last_usage = None |
|
a94205b…
|
leo
|
188 |
|
|
a94205b…
|
leo
|
189 |
# --- Public API --- |
|
a94205b…
|
leo
|
190 |
|
|
a94205b…
|
leo
|
191 |
def chat( |
|
a94205b…
|
leo
|
192 |
self, |
|
a94205b…
|
leo
|
193 |
messages: list[dict], |
|
a94205b…
|
leo
|
194 |
max_tokens: int = 4096, |
|
a94205b…
|
leo
|
195 |
temperature: float = 0.7, |
|
a94205b…
|
leo
|
196 |
) -> str: |
|
a94205b…
|
leo
|
197 |
"""Send a chat completion to the best available provider.""" |
|
a94205b…
|
leo
|
198 |
prov_name, model = self._resolve_model(self.chat_model, "chat", _CHAT_PREFERENCES) |
|
a94205b…
|
leo
|
199 |
logger.info(f"Chat: using {prov_name}/{model}") |
|
a94205b…
|
leo
|
200 |
provider = self._get_provider(prov_name) |
|
829e24a…
|
leo
|
201 |
result = provider.chat( |
|
829e24a…
|
leo
|
202 |
messages, max_tokens=max_tokens, temperature=temperature, model=model |
|
829e24a…
|
leo
|
203 |
) |
|
287a3bb…
|
leo
|
204 |
self._track(provider, prov_name, model) |
|
287a3bb…
|
leo
|
205 |
return result |
|
a94205b…
|
leo
|
206 |
|
|
a94205b…
|
leo
|
207 |
def analyze_image( |
|
a94205b…
|
leo
|
208 |
self, |
|
a94205b…
|
leo
|
209 |
image_bytes: bytes, |
|
a94205b…
|
leo
|
210 |
prompt: str, |
|
a94205b…
|
leo
|
211 |
max_tokens: int = 4096, |
|
a94205b…
|
leo
|
212 |
) -> str: |
|
a94205b…
|
leo
|
213 |
"""Analyze an image using the best available vision provider.""" |
|
a94205b…
|
leo
|
214 |
prov_name, model = self._resolve_model(self.vision_model, "vision", _VISION_PREFERENCES) |
|
a94205b…
|
leo
|
215 |
logger.info(f"Vision: using {prov_name}/{model}") |
|
a94205b…
|
leo
|
216 |
provider = self._get_provider(prov_name) |
|
287a3bb…
|
leo
|
217 |
result = provider.analyze_image(image_bytes, prompt, max_tokens=max_tokens, model=model) |
|
287a3bb…
|
leo
|
218 |
self._track(provider, prov_name, model) |
|
287a3bb…
|
leo
|
219 |
return result |
|
a94205b…
|
leo
|
220 |
|
|
a94205b…
|
leo
|
221 |
def transcribe_audio( |
|
a94205b…
|
leo
|
222 |
self, |
|
a94205b…
|
leo
|
223 |
audio_path: str | Path, |
|
a94205b…
|
leo
|
224 |
language: Optional[str] = None, |
|
0981a08…
|
noreply
|
225 |
speaker_hints: Optional[list[str]] = None, |
|
a94205b…
|
leo
|
226 |
) -> dict: |
|
287a3bb…
|
leo
|
227 |
"""Transcribe audio using local Whisper if available, otherwise API.""" |
|
287a3bb…
|
leo
|
228 |
# Prefer local Whisper — no file size limits, no API costs |
|
287a3bb…
|
leo
|
229 |
if not self.transcription_model or self.transcription_model.startswith("whisper-local"): |
|
287a3bb…
|
leo
|
230 |
try: |
|
287a3bb…
|
leo
|
231 |
from video_processor.providers.whisper_local import WhisperLocal |
|
287a3bb…
|
leo
|
232 |
|
|
287a3bb…
|
leo
|
233 |
if WhisperLocal.is_available(): |
|
287a3bb…
|
leo
|
234 |
# Parse model size from "whisper-local:large" or default to "large" |
|
287a3bb…
|
leo
|
235 |
size = "large" |
|
287a3bb…
|
leo
|
236 |
if self.transcription_model and ":" in self.transcription_model: |
|
287a3bb…
|
leo
|
237 |
size = self.transcription_model.split(":", 1)[1] |
|
287a3bb…
|
leo
|
238 |
if not hasattr(self, "_whisper_local"): |
|
287a3bb…
|
leo
|
239 |
self._whisper_local = WhisperLocal(model_size=size) |
|
287a3bb…
|
leo
|
240 |
logger.info(f"Transcription: using local whisper-{size}") |
|
0981a08…
|
noreply
|
241 |
# Pass speaker names as initial prompt hint for Whisper |
|
0981a08…
|
noreply
|
242 |
whisper_kwargs = {"language": language} |
|
0981a08…
|
noreply
|
243 |
if speaker_hints: |
|
0981a08…
|
noreply
|
244 |
whisper_kwargs["initial_prompt"] = ( |
|
0981a08…
|
noreply
|
245 |
"Speakers: " + ", ".join(speaker_hints) + "." |
|
0981a08…
|
noreply
|
246 |
) |
|
0981a08…
|
noreply
|
247 |
result = self._whisper_local.transcribe(audio_path, **whisper_kwargs) |
|
287a3bb…
|
leo
|
248 |
duration = result.get("duration") or 0 |
|
287a3bb…
|
leo
|
249 |
self.usage.record( |
|
287a3bb…
|
leo
|
250 |
provider="local", |
|
287a3bb…
|
leo
|
251 |
model=f"whisper-{size}", |
|
287a3bb…
|
leo
|
252 |
audio_minutes=duration / 60 if duration else 0, |
|
287a3bb…
|
leo
|
253 |
) |
|
287a3bb…
|
leo
|
254 |
return result |
|
287a3bb…
|
leo
|
255 |
except ImportError: |
|
287a3bb…
|
leo
|
256 |
pass |
|
287a3bb…
|
leo
|
257 |
|
|
287a3bb…
|
leo
|
258 |
# Fall back to API-based transcription |
|
a94205b…
|
leo
|
259 |
prov_name, model = self._resolve_model( |
|
a94205b…
|
leo
|
260 |
self.transcription_model, "audio", _TRANSCRIPTION_PREFERENCES |
|
a94205b…
|
leo
|
261 |
) |
|
a94205b…
|
leo
|
262 |
logger.info(f"Transcription: using {prov_name}/{model}") |
|
a94205b…
|
leo
|
263 |
provider = self._get_provider(prov_name) |
|
0981a08…
|
noreply
|
264 |
# Build transcription kwargs, passing speaker hints where supported |
|
0981a08…
|
noreply
|
265 |
transcribe_kwargs: dict = {"language": language, "model": model} |
|
0981a08…
|
noreply
|
266 |
if speaker_hints: |
|
0981a08…
|
noreply
|
267 |
if prov_name == "openai": |
|
0981a08…
|
noreply
|
268 |
# OpenAI Whisper supports a 'prompt' parameter for hints |
|
0981a08…
|
noreply
|
269 |
transcribe_kwargs["prompt"] = "Speakers: " + ", ".join(speaker_hints) + "." |
|
0981a08…
|
noreply
|
270 |
else: |
|
0981a08…
|
noreply
|
271 |
transcribe_kwargs["speaker_hints"] = speaker_hints |
|
0981a08…
|
noreply
|
272 |
result = provider.transcribe_audio(audio_path, **transcribe_kwargs) |
|
287a3bb…
|
leo
|
273 |
duration = result.get("duration") or 0 |
|
287a3bb…
|
leo
|
274 |
self.usage.record( |
|
287a3bb…
|
leo
|
275 |
provider=prov_name, |
|
287a3bb…
|
leo
|
276 |
model=model, |
|
287a3bb…
|
leo
|
277 |
audio_minutes=duration / 60 if duration else 0, |
|
287a3bb…
|
leo
|
278 |
) |
|
287a3bb…
|
leo
|
279 |
return result |
|
a94205b…
|
leo
|
280 |
|
|
a94205b…
|
leo
|
281 |
def get_models_used(self) -> dict[str, str]: |
|
a94205b…
|
leo
|
282 |
"""Return a dict mapping capability to 'provider/model' for tracking.""" |
|
a94205b…
|
leo
|
283 |
result = {} |
|
a94205b…
|
leo
|
284 |
for cap, explicit, prefs in [ |
|
a94205b…
|
leo
|
285 |
("vision", self.vision_model, _VISION_PREFERENCES), |
|
a94205b…
|
leo
|
286 |
("chat", self.chat_model, _CHAT_PREFERENCES), |
|
a94205b…
|
leo
|
287 |
("transcription", self.transcription_model, _TRANSCRIPTION_PREFERENCES), |
|
a94205b…
|
leo
|
288 |
]: |
|
a94205b…
|
leo
|
289 |
try: |
|
a94205b…
|
leo
|
290 |
prov, model = self._resolve_model(explicit, cap, prefs) |
|
a94205b…
|
leo
|
291 |
result[cap] = f"{prov}/{model}" |
|
a94205b…
|
leo
|
292 |
except RuntimeError: |
|
a94205b…
|
leo
|
293 |
pass |
|
a94205b…
|
leo
|
294 |
return result |