PlanOpticon

Blame History Raw 295 lines
1
"""ProviderManager - unified interface for routing API calls to the best available provider."""
2
3
import logging
4
from pathlib import Path
5
from typing import Optional
6
7
from dotenv import load_dotenv
8
9
from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry
10
from video_processor.providers.discovery import discover_available_models
11
from video_processor.utils.usage_tracker import UsageTracker
12
13
load_dotenv()
14
logger = logging.getLogger(__name__)
15
16
17
def _ensure_providers_registered() -> None:
18
"""Import all built-in provider modules so they register themselves."""
19
if ProviderRegistry.all_registered():
20
return
21
# Each module registers itself on import via ProviderRegistry.register()
22
import video_processor.providers.anthropic_provider # noqa: F401
23
import video_processor.providers.azure_provider # noqa: F401
24
import video_processor.providers.cerebras_provider # noqa: F401
25
import video_processor.providers.fireworks_provider # noqa: F401
26
import video_processor.providers.gemini_provider # noqa: F401
27
import video_processor.providers.ollama_provider # noqa: F401
28
import video_processor.providers.openai_provider # noqa: F401
29
import video_processor.providers.together_provider # noqa: F401
30
import video_processor.providers.xai_provider # noqa: F401
31
32
33
# Default model preference rankings (tried in order)
34
_VISION_PREFERENCES = [
35
("gemini", "gemini-2.5-flash"),
36
("openai", "gpt-4o-mini"),
37
("anthropic", "claude-haiku-4-5-20251001"),
38
]
39
40
_CHAT_PREFERENCES = [
41
("anthropic", "claude-haiku-4-5-20251001"),
42
("openai", "gpt-4o-mini"),
43
("gemini", "gemini-2.5-flash"),
44
]
45
46
_TRANSCRIPTION_PREFERENCES = [
47
("openai", "whisper-1"),
48
("gemini", "gemini-2.5-flash"),
49
]
50
51
52
class ProviderManager:
53
"""
54
Routes API calls to the best available provider/model.
55
56
Supports explicit model selection or auto-routing based on
57
discovered available models.
58
"""
59
60
def __init__(
61
self,
62
vision_model: Optional[str] = None,
63
chat_model: Optional[str] = None,
64
transcription_model: Optional[str] = None,
65
provider: Optional[str] = None,
66
auto: bool = True,
67
):
68
"""
69
Initialize the ProviderManager.
70
71
Parameters
72
----------
73
vision_model : override model for vision tasks (e.g. 'gpt-4o')
74
chat_model : override model for chat/LLM tasks
75
transcription_model : override model for transcription
76
provider : force all tasks to a single provider ('openai', 'anthropic', 'gemini')
77
auto : if True and no model specified, pick the best available
78
"""
79
_ensure_providers_registered()
80
self.auto = auto
81
self._providers: dict[str, BaseProvider] = {}
82
self._available_models: Optional[list[ModelInfo]] = None
83
self.usage = UsageTracker()
84
85
# If a single provider is forced, apply it
86
if provider:
87
self.vision_model = vision_model or self._default_for_provider(provider, "vision")
88
self.chat_model = chat_model or self._default_for_provider(provider, "chat")
89
self.transcription_model = transcription_model or self._default_for_provider(
90
provider, "audio"
91
)
92
else:
93
self.vision_model = vision_model
94
self.chat_model = chat_model
95
self.transcription_model = transcription_model
96
97
self._forced_provider = provider
98
99
@staticmethod
100
def _default_for_provider(provider: str, capability: str) -> str:
101
"""Return the default model for a provider/capability combo."""
102
defaults = ProviderRegistry.get_default_models(provider)
103
if defaults:
104
return defaults.get(capability, "")
105
# Fallback for unregistered providers
106
return ""
107
108
def _get_provider(self, provider_name: str) -> BaseProvider:
109
"""Lazily initialize and cache a provider instance."""
110
if provider_name not in self._providers:
111
_ensure_providers_registered()
112
provider_class = ProviderRegistry.get(provider_name)
113
self._providers[provider_name] = provider_class()
114
return self._providers[provider_name]
115
116
def _provider_for_model(self, model_id: str) -> str:
117
"""Infer the provider from a model id."""
118
_ensure_providers_registered()
119
# Check registry prefix matching first
120
provider_name = ProviderRegistry.get_by_model(model_id)
121
if provider_name:
122
return provider_name
123
# Try discovery (exact match, then prefix match for ollama name:tag format)
124
models = self._get_available_models()
125
for m in models:
126
if m.id == model_id:
127
return m.provider
128
for m in models:
129
if m.id.startswith(model_id + ":"):
130
return m.provider
131
raise ValueError(f"Cannot determine provider for model: {model_id}")
132
133
def _get_available_models(self) -> list[ModelInfo]:
134
if self._available_models is None:
135
self._available_models = discover_available_models()
136
return self._available_models
137
138
def _resolve_model(
139
self, explicit: Optional[str], capability: str, preferences: list[tuple[str, str]]
140
) -> tuple[str, str]:
141
"""
142
Resolve which (provider, model) to use for a capability.
143
144
Returns (provider_name, model_id).
145
"""
146
if explicit:
147
prov = self._provider_for_model(explicit)
148
return prov, explicit
149
150
if self.auto:
151
# Try preference order, picking the first provider that has an API key
152
for prov, model in preferences:
153
try:
154
self._get_provider(prov)
155
return prov, model
156
except (ValueError, ImportError):
157
continue
158
159
# Fallback: try Ollama if available (no API key needed)
160
try:
161
from video_processor.providers.ollama_provider import OllamaProvider
162
163
if OllamaProvider.is_available():
164
provider = self._get_provider("ollama")
165
models = provider.list_models()
166
for m in models:
167
if capability in m.capabilities:
168
return "ollama", m.id
169
except Exception:
170
pass
171
172
raise RuntimeError(
173
f"No provider available for capability '{capability}'. "
174
"Set an API key for at least one provider, or start Ollama."
175
)
176
177
def _track(self, provider: BaseProvider, prov_name: str, model: str) -> None:
178
"""Record usage from the last API call on a provider."""
179
last = getattr(provider, "_last_usage", None)
180
if last:
181
self.usage.record(
182
provider=prov_name,
183
model=model,
184
input_tokens=last.get("input_tokens", 0),
185
output_tokens=last.get("output_tokens", 0),
186
)
187
provider._last_usage = None
188
189
# --- Public API ---
190
191
def chat(
192
self,
193
messages: list[dict],
194
max_tokens: int = 4096,
195
temperature: float = 0.7,
196
) -> str:
197
"""Send a chat completion to the best available provider."""
198
prov_name, model = self._resolve_model(self.chat_model, "chat", _CHAT_PREFERENCES)
199
logger.info(f"Chat: using {prov_name}/{model}")
200
provider = self._get_provider(prov_name)
201
result = provider.chat(
202
messages, max_tokens=max_tokens, temperature=temperature, model=model
203
)
204
self._track(provider, prov_name, model)
205
return result
206
207
def analyze_image(
208
self,
209
image_bytes: bytes,
210
prompt: str,
211
max_tokens: int = 4096,
212
) -> str:
213
"""Analyze an image using the best available vision provider."""
214
prov_name, model = self._resolve_model(self.vision_model, "vision", _VISION_PREFERENCES)
215
logger.info(f"Vision: using {prov_name}/{model}")
216
provider = self._get_provider(prov_name)
217
result = provider.analyze_image(image_bytes, prompt, max_tokens=max_tokens, model=model)
218
self._track(provider, prov_name, model)
219
return result
220
221
def transcribe_audio(
222
self,
223
audio_path: str | Path,
224
language: Optional[str] = None,
225
speaker_hints: Optional[list[str]] = None,
226
) -> dict:
227
"""Transcribe audio using local Whisper if available, otherwise API."""
228
# Prefer local Whisper — no file size limits, no API costs
229
if not self.transcription_model or self.transcription_model.startswith("whisper-local"):
230
try:
231
from video_processor.providers.whisper_local import WhisperLocal
232
233
if WhisperLocal.is_available():
234
# Parse model size from "whisper-local:large" or default to "large"
235
size = "large"
236
if self.transcription_model and ":" in self.transcription_model:
237
size = self.transcription_model.split(":", 1)[1]
238
if not hasattr(self, "_whisper_local"):
239
self._whisper_local = WhisperLocal(model_size=size)
240
logger.info(f"Transcription: using local whisper-{size}")
241
# Pass speaker names as initial prompt hint for Whisper
242
whisper_kwargs = {"language": language}
243
if speaker_hints:
244
whisper_kwargs["initial_prompt"] = (
245
"Speakers: " + ", ".join(speaker_hints) + "."
246
)
247
result = self._whisper_local.transcribe(audio_path, **whisper_kwargs)
248
duration = result.get("duration") or 0
249
self.usage.record(
250
provider="local",
251
model=f"whisper-{size}",
252
audio_minutes=duration / 60 if duration else 0,
253
)
254
return result
255
except ImportError:
256
pass
257
258
# Fall back to API-based transcription
259
prov_name, model = self._resolve_model(
260
self.transcription_model, "audio", _TRANSCRIPTION_PREFERENCES
261
)
262
logger.info(f"Transcription: using {prov_name}/{model}")
263
provider = self._get_provider(prov_name)
264
# Build transcription kwargs, passing speaker hints where supported
265
transcribe_kwargs: dict = {"language": language, "model": model}
266
if speaker_hints:
267
if prov_name == "openai":
268
# OpenAI Whisper supports a 'prompt' parameter for hints
269
transcribe_kwargs["prompt"] = "Speakers: " + ", ".join(speaker_hints) + "."
270
else:
271
transcribe_kwargs["speaker_hints"] = speaker_hints
272
result = provider.transcribe_audio(audio_path, **transcribe_kwargs)
273
duration = result.get("duration") or 0
274
self.usage.record(
275
provider=prov_name,
276
model=model,
277
audio_minutes=duration / 60 if duration else 0,
278
)
279
return result
280
281
def get_models_used(self) -> dict[str, str]:
282
"""Return a dict mapping capability to 'provider/model' for tracking."""
283
result = {}
284
for cap, explicit, prefs in [
285
("vision", self.vision_model, _VISION_PREFERENCES),
286
("chat", self.chat_model, _CHAT_PREFERENCES),
287
("transcription", self.transcription_model, _TRANSCRIPTION_PREFERENCES),
288
]:
289
try:
290
prov, model = self._resolve_model(explicit, cap, prefs)
291
result[cap] = f"{prov}/{model}"
292
except RuntimeError:
293
pass
294
return result
295

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button