|
1
|
"""ProviderManager - unified interface for routing API calls to the best available provider.""" |
|
2
|
|
|
3
|
import logging |
|
4
|
from pathlib import Path |
|
5
|
from typing import Optional |
|
6
|
|
|
7
|
from dotenv import load_dotenv |
|
8
|
|
|
9
|
from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry |
|
10
|
from video_processor.providers.discovery import discover_available_models |
|
11
|
from video_processor.utils.usage_tracker import UsageTracker |
|
12
|
|
|
13
|
load_dotenv() |
|
14
|
logger = logging.getLogger(__name__) |
|
15
|
|
|
16
|
|
|
17
|
def _ensure_providers_registered() -> None: |
|
18
|
"""Import all built-in provider modules so they register themselves.""" |
|
19
|
if ProviderRegistry.all_registered(): |
|
20
|
return |
|
21
|
# Each module registers itself on import via ProviderRegistry.register() |
|
22
|
import video_processor.providers.anthropic_provider # noqa: F401 |
|
23
|
import video_processor.providers.azure_provider # noqa: F401 |
|
24
|
import video_processor.providers.cerebras_provider # noqa: F401 |
|
25
|
import video_processor.providers.fireworks_provider # noqa: F401 |
|
26
|
import video_processor.providers.gemini_provider # noqa: F401 |
|
27
|
import video_processor.providers.ollama_provider # noqa: F401 |
|
28
|
import video_processor.providers.openai_provider # noqa: F401 |
|
29
|
import video_processor.providers.together_provider # noqa: F401 |
|
30
|
import video_processor.providers.xai_provider # noqa: F401 |
|
31
|
|
|
32
|
|
|
33
|
# Default model preference rankings (tried in order) |
|
34
|
_VISION_PREFERENCES = [ |
|
35
|
("gemini", "gemini-2.5-flash"), |
|
36
|
("openai", "gpt-4o-mini"), |
|
37
|
("anthropic", "claude-haiku-4-5-20251001"), |
|
38
|
] |
|
39
|
|
|
40
|
_CHAT_PREFERENCES = [ |
|
41
|
("anthropic", "claude-haiku-4-5-20251001"), |
|
42
|
("openai", "gpt-4o-mini"), |
|
43
|
("gemini", "gemini-2.5-flash"), |
|
44
|
] |
|
45
|
|
|
46
|
_TRANSCRIPTION_PREFERENCES = [ |
|
47
|
("openai", "whisper-1"), |
|
48
|
("gemini", "gemini-2.5-flash"), |
|
49
|
] |
|
50
|
|
|
51
|
|
|
52
|
class ProviderManager: |
|
53
|
""" |
|
54
|
Routes API calls to the best available provider/model. |
|
55
|
|
|
56
|
Supports explicit model selection or auto-routing based on |
|
57
|
discovered available models. |
|
58
|
""" |
|
59
|
|
|
60
|
def __init__( |
|
61
|
self, |
|
62
|
vision_model: Optional[str] = None, |
|
63
|
chat_model: Optional[str] = None, |
|
64
|
transcription_model: Optional[str] = None, |
|
65
|
provider: Optional[str] = None, |
|
66
|
auto: bool = True, |
|
67
|
): |
|
68
|
""" |
|
69
|
Initialize the ProviderManager. |
|
70
|
|
|
71
|
Parameters |
|
72
|
---------- |
|
73
|
vision_model : override model for vision tasks (e.g. 'gpt-4o') |
|
74
|
chat_model : override model for chat/LLM tasks |
|
75
|
transcription_model : override model for transcription |
|
76
|
provider : force all tasks to a single provider ('openai', 'anthropic', 'gemini') |
|
77
|
auto : if True and no model specified, pick the best available |
|
78
|
""" |
|
79
|
_ensure_providers_registered() |
|
80
|
self.auto = auto |
|
81
|
self._providers: dict[str, BaseProvider] = {} |
|
82
|
self._available_models: Optional[list[ModelInfo]] = None |
|
83
|
self.usage = UsageTracker() |
|
84
|
|
|
85
|
# If a single provider is forced, apply it |
|
86
|
if provider: |
|
87
|
self.vision_model = vision_model or self._default_for_provider(provider, "vision") |
|
88
|
self.chat_model = chat_model or self._default_for_provider(provider, "chat") |
|
89
|
self.transcription_model = transcription_model or self._default_for_provider( |
|
90
|
provider, "audio" |
|
91
|
) |
|
92
|
else: |
|
93
|
self.vision_model = vision_model |
|
94
|
self.chat_model = chat_model |
|
95
|
self.transcription_model = transcription_model |
|
96
|
|
|
97
|
self._forced_provider = provider |
|
98
|
|
|
99
|
@staticmethod |
|
100
|
def _default_for_provider(provider: str, capability: str) -> str: |
|
101
|
"""Return the default model for a provider/capability combo.""" |
|
102
|
defaults = ProviderRegistry.get_default_models(provider) |
|
103
|
if defaults: |
|
104
|
return defaults.get(capability, "") |
|
105
|
# Fallback for unregistered providers |
|
106
|
return "" |
|
107
|
|
|
108
|
def _get_provider(self, provider_name: str) -> BaseProvider: |
|
109
|
"""Lazily initialize and cache a provider instance.""" |
|
110
|
if provider_name not in self._providers: |
|
111
|
_ensure_providers_registered() |
|
112
|
provider_class = ProviderRegistry.get(provider_name) |
|
113
|
self._providers[provider_name] = provider_class() |
|
114
|
return self._providers[provider_name] |
|
115
|
|
|
116
|
def _provider_for_model(self, model_id: str) -> str: |
|
117
|
"""Infer the provider from a model id.""" |
|
118
|
_ensure_providers_registered() |
|
119
|
# Check registry prefix matching first |
|
120
|
provider_name = ProviderRegistry.get_by_model(model_id) |
|
121
|
if provider_name: |
|
122
|
return provider_name |
|
123
|
# Try discovery (exact match, then prefix match for ollama name:tag format) |
|
124
|
models = self._get_available_models() |
|
125
|
for m in models: |
|
126
|
if m.id == model_id: |
|
127
|
return m.provider |
|
128
|
for m in models: |
|
129
|
if m.id.startswith(model_id + ":"): |
|
130
|
return m.provider |
|
131
|
raise ValueError(f"Cannot determine provider for model: {model_id}") |
|
132
|
|
|
133
|
def _get_available_models(self) -> list[ModelInfo]: |
|
134
|
if self._available_models is None: |
|
135
|
self._available_models = discover_available_models() |
|
136
|
return self._available_models |
|
137
|
|
|
138
|
def _resolve_model( |
|
139
|
self, explicit: Optional[str], capability: str, preferences: list[tuple[str, str]] |
|
140
|
) -> tuple[str, str]: |
|
141
|
""" |
|
142
|
Resolve which (provider, model) to use for a capability. |
|
143
|
|
|
144
|
Returns (provider_name, model_id). |
|
145
|
""" |
|
146
|
if explicit: |
|
147
|
prov = self._provider_for_model(explicit) |
|
148
|
return prov, explicit |
|
149
|
|
|
150
|
if self.auto: |
|
151
|
# Try preference order, picking the first provider that has an API key |
|
152
|
for prov, model in preferences: |
|
153
|
try: |
|
154
|
self._get_provider(prov) |
|
155
|
return prov, model |
|
156
|
except (ValueError, ImportError): |
|
157
|
continue |
|
158
|
|
|
159
|
# Fallback: try Ollama if available (no API key needed) |
|
160
|
try: |
|
161
|
from video_processor.providers.ollama_provider import OllamaProvider |
|
162
|
|
|
163
|
if OllamaProvider.is_available(): |
|
164
|
provider = self._get_provider("ollama") |
|
165
|
models = provider.list_models() |
|
166
|
for m in models: |
|
167
|
if capability in m.capabilities: |
|
168
|
return "ollama", m.id |
|
169
|
except Exception: |
|
170
|
pass |
|
171
|
|
|
172
|
raise RuntimeError( |
|
173
|
f"No provider available for capability '{capability}'. " |
|
174
|
"Set an API key for at least one provider, or start Ollama." |
|
175
|
) |
|
176
|
|
|
177
|
def _track(self, provider: BaseProvider, prov_name: str, model: str) -> None: |
|
178
|
"""Record usage from the last API call on a provider.""" |
|
179
|
last = getattr(provider, "_last_usage", None) |
|
180
|
if last: |
|
181
|
self.usage.record( |
|
182
|
provider=prov_name, |
|
183
|
model=model, |
|
184
|
input_tokens=last.get("input_tokens", 0), |
|
185
|
output_tokens=last.get("output_tokens", 0), |
|
186
|
) |
|
187
|
provider._last_usage = None |
|
188
|
|
|
189
|
# --- Public API --- |
|
190
|
|
|
191
|
def chat( |
|
192
|
self, |
|
193
|
messages: list[dict], |
|
194
|
max_tokens: int = 4096, |
|
195
|
temperature: float = 0.7, |
|
196
|
) -> str: |
|
197
|
"""Send a chat completion to the best available provider.""" |
|
198
|
prov_name, model = self._resolve_model(self.chat_model, "chat", _CHAT_PREFERENCES) |
|
199
|
logger.info(f"Chat: using {prov_name}/{model}") |
|
200
|
provider = self._get_provider(prov_name) |
|
201
|
result = provider.chat( |
|
202
|
messages, max_tokens=max_tokens, temperature=temperature, model=model |
|
203
|
) |
|
204
|
self._track(provider, prov_name, model) |
|
205
|
return result |
|
206
|
|
|
207
|
def analyze_image( |
|
208
|
self, |
|
209
|
image_bytes: bytes, |
|
210
|
prompt: str, |
|
211
|
max_tokens: int = 4096, |
|
212
|
) -> str: |
|
213
|
"""Analyze an image using the best available vision provider.""" |
|
214
|
prov_name, model = self._resolve_model(self.vision_model, "vision", _VISION_PREFERENCES) |
|
215
|
logger.info(f"Vision: using {prov_name}/{model}") |
|
216
|
provider = self._get_provider(prov_name) |
|
217
|
result = provider.analyze_image(image_bytes, prompt, max_tokens=max_tokens, model=model) |
|
218
|
self._track(provider, prov_name, model) |
|
219
|
return result |
|
220
|
|
|
221
|
def transcribe_audio( |
|
222
|
self, |
|
223
|
audio_path: str | Path, |
|
224
|
language: Optional[str] = None, |
|
225
|
speaker_hints: Optional[list[str]] = None, |
|
226
|
) -> dict: |
|
227
|
"""Transcribe audio using local Whisper if available, otherwise API.""" |
|
228
|
# Prefer local Whisper — no file size limits, no API costs |
|
229
|
if not self.transcription_model or self.transcription_model.startswith("whisper-local"): |
|
230
|
try: |
|
231
|
from video_processor.providers.whisper_local import WhisperLocal |
|
232
|
|
|
233
|
if WhisperLocal.is_available(): |
|
234
|
# Parse model size from "whisper-local:large" or default to "large" |
|
235
|
size = "large" |
|
236
|
if self.transcription_model and ":" in self.transcription_model: |
|
237
|
size = self.transcription_model.split(":", 1)[1] |
|
238
|
if not hasattr(self, "_whisper_local"): |
|
239
|
self._whisper_local = WhisperLocal(model_size=size) |
|
240
|
logger.info(f"Transcription: using local whisper-{size}") |
|
241
|
# Pass speaker names as initial prompt hint for Whisper |
|
242
|
whisper_kwargs = {"language": language} |
|
243
|
if speaker_hints: |
|
244
|
whisper_kwargs["initial_prompt"] = ( |
|
245
|
"Speakers: " + ", ".join(speaker_hints) + "." |
|
246
|
) |
|
247
|
result = self._whisper_local.transcribe(audio_path, **whisper_kwargs) |
|
248
|
duration = result.get("duration") or 0 |
|
249
|
self.usage.record( |
|
250
|
provider="local", |
|
251
|
model=f"whisper-{size}", |
|
252
|
audio_minutes=duration / 60 if duration else 0, |
|
253
|
) |
|
254
|
return result |
|
255
|
except ImportError: |
|
256
|
pass |
|
257
|
|
|
258
|
# Fall back to API-based transcription |
|
259
|
prov_name, model = self._resolve_model( |
|
260
|
self.transcription_model, "audio", _TRANSCRIPTION_PREFERENCES |
|
261
|
) |
|
262
|
logger.info(f"Transcription: using {prov_name}/{model}") |
|
263
|
provider = self._get_provider(prov_name) |
|
264
|
# Build transcription kwargs, passing speaker hints where supported |
|
265
|
transcribe_kwargs: dict = {"language": language, "model": model} |
|
266
|
if speaker_hints: |
|
267
|
if prov_name == "openai": |
|
268
|
# OpenAI Whisper supports a 'prompt' parameter for hints |
|
269
|
transcribe_kwargs["prompt"] = "Speakers: " + ", ".join(speaker_hints) + "." |
|
270
|
else: |
|
271
|
transcribe_kwargs["speaker_hints"] = speaker_hints |
|
272
|
result = provider.transcribe_audio(audio_path, **transcribe_kwargs) |
|
273
|
duration = result.get("duration") or 0 |
|
274
|
self.usage.record( |
|
275
|
provider=prov_name, |
|
276
|
model=model, |
|
277
|
audio_minutes=duration / 60 if duration else 0, |
|
278
|
) |
|
279
|
return result |
|
280
|
|
|
281
|
def get_models_used(self) -> dict[str, str]: |
|
282
|
"""Return a dict mapping capability to 'provider/model' for tracking.""" |
|
283
|
result = {} |
|
284
|
for cap, explicit, prefs in [ |
|
285
|
("vision", self.vision_model, _VISION_PREFERENCES), |
|
286
|
("chat", self.chat_model, _CHAT_PREFERENCES), |
|
287
|
("transcription", self.transcription_model, _TRANSCRIPTION_PREFERENCES), |
|
288
|
]: |
|
289
|
try: |
|
290
|
prov, model = self._resolve_model(explicit, cap, prefs) |
|
291
|
result[cap] = f"{prov}/{model}" |
|
292
|
except RuntimeError: |
|
293
|
pass |
|
294
|
return result |
|
295
|
|