PlanOpticon

Source Blame History 294 lines
a94205b… leo 1 """ProviderManager - unified interface for routing API calls to the best available provider."""
a94205b… leo 2
a94205b… leo 3 import logging
a94205b… leo 4 from pathlib import Path
a94205b… leo 5 from typing import Optional
a94205b… leo 6
a94205b… leo 7 from dotenv import load_dotenv
a94205b… leo 8
0981a08… noreply 9 from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry
a94205b… leo 10 from video_processor.providers.discovery import discover_available_models
287a3bb… leo 11 from video_processor.utils.usage_tracker import UsageTracker
a94205b… leo 12
a94205b… leo 13 load_dotenv()
a94205b… leo 14 logger = logging.getLogger(__name__)
a94205b… leo 15
0981a08… noreply 16
0981a08… noreply 17 def _ensure_providers_registered() -> None:
0981a08… noreply 18 """Import all built-in provider modules so they register themselves."""
0981a08… noreply 19 if ProviderRegistry.all_registered():
0981a08… noreply 20 return
0981a08… noreply 21 # Each module registers itself on import via ProviderRegistry.register()
0981a08… noreply 22 import video_processor.providers.anthropic_provider # noqa: F401
0981a08… noreply 23 import video_processor.providers.azure_provider # noqa: F401
0981a08… noreply 24 import video_processor.providers.cerebras_provider # noqa: F401
0981a08… noreply 25 import video_processor.providers.fireworks_provider # noqa: F401
0981a08… noreply 26 import video_processor.providers.gemini_provider # noqa: F401
0981a08… noreply 27 import video_processor.providers.ollama_provider # noqa: F401
0981a08… noreply 28 import video_processor.providers.openai_provider # noqa: F401
0981a08… noreply 29 import video_processor.providers.together_provider # noqa: F401
0981a08… noreply 30 import video_processor.providers.xai_provider # noqa: F401
0981a08… noreply 31
0981a08… noreply 32
a94205b… leo 33 # Default model preference rankings (tried in order)
a94205b… leo 34 _VISION_PREFERENCES = [
a94205b… leo 35 ("gemini", "gemini-2.5-flash"),
0981a08… noreply 36 ("openai", "gpt-4o-mini"),
0981a08… noreply 37 ("anthropic", "claude-haiku-4-5-20251001"),
a94205b… leo 38 ]
a94205b… leo 39
a94205b… leo 40 _CHAT_PREFERENCES = [
0981a08… noreply 41 ("anthropic", "claude-haiku-4-5-20251001"),
0981a08… noreply 42 ("openai", "gpt-4o-mini"),
a94205b… leo 43 ("gemini", "gemini-2.5-flash"),
a94205b… leo 44 ]
a94205b… leo 45
a94205b… leo 46 _TRANSCRIPTION_PREFERENCES = [
a94205b… leo 47 ("openai", "whisper-1"),
a94205b… leo 48 ("gemini", "gemini-2.5-flash"),
a94205b… leo 49 ]
a94205b… leo 50
a94205b… leo 51
a94205b… leo 52 class ProviderManager:
a94205b… leo 53 """
a94205b… leo 54 Routes API calls to the best available provider/model.
a94205b… leo 55
a94205b… leo 56 Supports explicit model selection or auto-routing based on
a94205b… leo 57 discovered available models.
a94205b… leo 58 """
a94205b… leo 59
a94205b… leo 60 def __init__(
a94205b… leo 61 self,
a94205b… leo 62 vision_model: Optional[str] = None,
a94205b… leo 63 chat_model: Optional[str] = None,
a94205b… leo 64 transcription_model: Optional[str] = None,
a94205b… leo 65 provider: Optional[str] = None,
a94205b… leo 66 auto: bool = True,
a94205b… leo 67 ):
a94205b… leo 68 """
a94205b… leo 69 Initialize the ProviderManager.
a94205b… leo 70
a94205b… leo 71 Parameters
a94205b… leo 72 ----------
a94205b… leo 73 vision_model : override model for vision tasks (e.g. 'gpt-4o')
a94205b… leo 74 chat_model : override model for chat/LLM tasks
a94205b… leo 75 transcription_model : override model for transcription
a94205b… leo 76 provider : force all tasks to a single provider ('openai', 'anthropic', 'gemini')
a94205b… leo 77 auto : if True and no model specified, pick the best available
a94205b… leo 78 """
0981a08… noreply 79 _ensure_providers_registered()
a94205b… leo 80 self.auto = auto
a94205b… leo 81 self._providers: dict[str, BaseProvider] = {}
a94205b… leo 82 self._available_models: Optional[list[ModelInfo]] = None
287a3bb… leo 83 self.usage = UsageTracker()
a94205b… leo 84
a94205b… leo 85 # If a single provider is forced, apply it
a94205b… leo 86 if provider:
a94205b… leo 87 self.vision_model = vision_model or self._default_for_provider(provider, "vision")
a94205b… leo 88 self.chat_model = chat_model or self._default_for_provider(provider, "chat")
829e24a… leo 89 self.transcription_model = transcription_model or self._default_for_provider(
829e24a… leo 90 provider, "audio"
829e24a… leo 91 )
a94205b… leo 92 else:
a94205b… leo 93 self.vision_model = vision_model
a94205b… leo 94 self.chat_model = chat_model
a94205b… leo 95 self.transcription_model = transcription_model
a94205b… leo 96
a94205b… leo 97 self._forced_provider = provider
a94205b… leo 98
a94205b… leo 99 @staticmethod
a94205b… leo 100 def _default_for_provider(provider: str, capability: str) -> str:
a94205b… leo 101 """Return the default model for a provider/capability combo."""
0981a08… noreply 102 defaults = ProviderRegistry.get_default_models(provider)
0981a08… noreply 103 if defaults:
0981a08… noreply 104 return defaults.get(capability, "")
0981a08… noreply 105 # Fallback for unregistered providers
0981a08… noreply 106 return ""
a94205b… leo 107
a94205b… leo 108 def _get_provider(self, provider_name: str) -> BaseProvider:
a94205b… leo 109 """Lazily initialize and cache a provider instance."""
a94205b… leo 110 if provider_name not in self._providers:
0981a08… noreply 111 _ensure_providers_registered()
0981a08… noreply 112 provider_class = ProviderRegistry.get(provider_name)
0981a08… noreply 113 self._providers[provider_name] = provider_class()
a94205b… leo 114 return self._providers[provider_name]
a94205b… leo 115
a94205b… leo 116 def _provider_for_model(self, model_id: str) -> str:
a94205b… leo 117 """Infer the provider from a model id."""
0981a08… noreply 118 _ensure_providers_registered()
0981a08… noreply 119 # Check registry prefix matching first
0981a08… noreply 120 provider_name = ProviderRegistry.get_by_model(model_id)
0981a08… noreply 121 if provider_name:
0981a08… noreply 122 return provider_name
a0146a5… noreply 123 # Try discovery (exact match, then prefix match for ollama name:tag format)
a94205b… leo 124 models = self._get_available_models()
a94205b… leo 125 for m in models:
a94205b… leo 126 if m.id == model_id:
a0146a5… noreply 127 return m.provider
a0146a5… noreply 128 for m in models:
a0146a5… noreply 129 if m.id.startswith(model_id + ":"):
a94205b… leo 130 return m.provider
a94205b… leo 131 raise ValueError(f"Cannot determine provider for model: {model_id}")
a94205b… leo 132
a94205b… leo 133 def _get_available_models(self) -> list[ModelInfo]:
a94205b… leo 134 if self._available_models is None:
a94205b… leo 135 self._available_models = discover_available_models()
a94205b… leo 136 return self._available_models
a94205b… leo 137
829e24a… leo 138 def _resolve_model(
829e24a… leo 139 self, explicit: Optional[str], capability: str, preferences: list[tuple[str, str]]
829e24a… leo 140 ) -> tuple[str, str]:
a94205b… leo 141 """
a94205b… leo 142 Resolve which (provider, model) to use for a capability.
a94205b… leo 143
a94205b… leo 144 Returns (provider_name, model_id).
a94205b… leo 145 """
a94205b… leo 146 if explicit:
a94205b… leo 147 prov = self._provider_for_model(explicit)
a94205b… leo 148 return prov, explicit
a94205b… leo 149
a94205b… leo 150 if self.auto:
a94205b… leo 151 # Try preference order, picking the first provider that has an API key
a94205b… leo 152 for prov, model in preferences:
a94205b… leo 153 try:
a94205b… leo 154 self._get_provider(prov)
a94205b… leo 155 return prov, model
a94205b… leo 156 except (ValueError, ImportError):
a94205b… leo 157 continue
a94205b… leo 158
a0146a5… noreply 159 # Fallback: try Ollama if available (no API key needed)
a0146a5… noreply 160 try:
a0146a5… noreply 161 from video_processor.providers.ollama_provider import OllamaProvider
a0146a5… noreply 162
a0146a5… noreply 163 if OllamaProvider.is_available():
a0146a5… noreply 164 provider = self._get_provider("ollama")
a0146a5… noreply 165 models = provider.list_models()
a0146a5… noreply 166 for m in models:
a0146a5… noreply 167 if capability in m.capabilities:
a0146a5… noreply 168 return "ollama", m.id
a0146a5… noreply 169 except Exception:
a0146a5… noreply 170 pass
a0146a5… noreply 171
a94205b… leo 172 raise RuntimeError(
a94205b… leo 173 f"No provider available for capability '{capability}'. "
a0146a5… noreply 174 "Set an API key for at least one provider, or start Ollama."
a94205b… leo 175 )
287a3bb… leo 176
287a3bb… leo 177 def _track(self, provider: BaseProvider, prov_name: str, model: str) -> None:
287a3bb… leo 178 """Record usage from the last API call on a provider."""
287a3bb… leo 179 last = getattr(provider, "_last_usage", None)
287a3bb… leo 180 if last:
287a3bb… leo 181 self.usage.record(
287a3bb… leo 182 provider=prov_name,
287a3bb… leo 183 model=model,
287a3bb… leo 184 input_tokens=last.get("input_tokens", 0),
287a3bb… leo 185 output_tokens=last.get("output_tokens", 0),
287a3bb… leo 186 )
287a3bb… leo 187 provider._last_usage = None
a94205b… leo 188
a94205b… leo 189 # --- Public API ---
a94205b… leo 190
a94205b… leo 191 def chat(
a94205b… leo 192 self,
a94205b… leo 193 messages: list[dict],
a94205b… leo 194 max_tokens: int = 4096,
a94205b… leo 195 temperature: float = 0.7,
a94205b… leo 196 ) -> str:
a94205b… leo 197 """Send a chat completion to the best available provider."""
a94205b… leo 198 prov_name, model = self._resolve_model(self.chat_model, "chat", _CHAT_PREFERENCES)
a94205b… leo 199 logger.info(f"Chat: using {prov_name}/{model}")
a94205b… leo 200 provider = self._get_provider(prov_name)
829e24a… leo 201 result = provider.chat(
829e24a… leo 202 messages, max_tokens=max_tokens, temperature=temperature, model=model
829e24a… leo 203 )
287a3bb… leo 204 self._track(provider, prov_name, model)
287a3bb… leo 205 return result
a94205b… leo 206
a94205b… leo 207 def analyze_image(
a94205b… leo 208 self,
a94205b… leo 209 image_bytes: bytes,
a94205b… leo 210 prompt: str,
a94205b… leo 211 max_tokens: int = 4096,
a94205b… leo 212 ) -> str:
a94205b… leo 213 """Analyze an image using the best available vision provider."""
a94205b… leo 214 prov_name, model = self._resolve_model(self.vision_model, "vision", _VISION_PREFERENCES)
a94205b… leo 215 logger.info(f"Vision: using {prov_name}/{model}")
a94205b… leo 216 provider = self._get_provider(prov_name)
287a3bb… leo 217 result = provider.analyze_image(image_bytes, prompt, max_tokens=max_tokens, model=model)
287a3bb… leo 218 self._track(provider, prov_name, model)
287a3bb… leo 219 return result
a94205b… leo 220
a94205b… leo 221 def transcribe_audio(
a94205b… leo 222 self,
a94205b… leo 223 audio_path: str | Path,
a94205b… leo 224 language: Optional[str] = None,
0981a08… noreply 225 speaker_hints: Optional[list[str]] = None,
a94205b… leo 226 ) -> dict:
287a3bb… leo 227 """Transcribe audio using local Whisper if available, otherwise API."""
287a3bb… leo 228 # Prefer local Whisper — no file size limits, no API costs
287a3bb… leo 229 if not self.transcription_model or self.transcription_model.startswith("whisper-local"):
287a3bb… leo 230 try:
287a3bb… leo 231 from video_processor.providers.whisper_local import WhisperLocal
287a3bb… leo 232
287a3bb… leo 233 if WhisperLocal.is_available():
287a3bb… leo 234 # Parse model size from "whisper-local:large" or default to "large"
287a3bb… leo 235 size = "large"
287a3bb… leo 236 if self.transcription_model and ":" in self.transcription_model:
287a3bb… leo 237 size = self.transcription_model.split(":", 1)[1]
287a3bb… leo 238 if not hasattr(self, "_whisper_local"):
287a3bb… leo 239 self._whisper_local = WhisperLocal(model_size=size)
287a3bb… leo 240 logger.info(f"Transcription: using local whisper-{size}")
0981a08… noreply 241 # Pass speaker names as initial prompt hint for Whisper
0981a08… noreply 242 whisper_kwargs = {"language": language}
0981a08… noreply 243 if speaker_hints:
0981a08… noreply 244 whisper_kwargs["initial_prompt"] = (
0981a08… noreply 245 "Speakers: " + ", ".join(speaker_hints) + "."
0981a08… noreply 246 )
0981a08… noreply 247 result = self._whisper_local.transcribe(audio_path, **whisper_kwargs)
287a3bb… leo 248 duration = result.get("duration") or 0
287a3bb… leo 249 self.usage.record(
287a3bb… leo 250 provider="local",
287a3bb… leo 251 model=f"whisper-{size}",
287a3bb… leo 252 audio_minutes=duration / 60 if duration else 0,
287a3bb… leo 253 )
287a3bb… leo 254 return result
287a3bb… leo 255 except ImportError:
287a3bb… leo 256 pass
287a3bb… leo 257
287a3bb… leo 258 # Fall back to API-based transcription
a94205b… leo 259 prov_name, model = self._resolve_model(
a94205b… leo 260 self.transcription_model, "audio", _TRANSCRIPTION_PREFERENCES
a94205b… leo 261 )
a94205b… leo 262 logger.info(f"Transcription: using {prov_name}/{model}")
a94205b… leo 263 provider = self._get_provider(prov_name)
0981a08… noreply 264 # Build transcription kwargs, passing speaker hints where supported
0981a08… noreply 265 transcribe_kwargs: dict = {"language": language, "model": model}
0981a08… noreply 266 if speaker_hints:
0981a08… noreply 267 if prov_name == "openai":
0981a08… noreply 268 # OpenAI Whisper supports a 'prompt' parameter for hints
0981a08… noreply 269 transcribe_kwargs["prompt"] = "Speakers: " + ", ".join(speaker_hints) + "."
0981a08… noreply 270 else:
0981a08… noreply 271 transcribe_kwargs["speaker_hints"] = speaker_hints
0981a08… noreply 272 result = provider.transcribe_audio(audio_path, **transcribe_kwargs)
287a3bb… leo 273 duration = result.get("duration") or 0
287a3bb… leo 274 self.usage.record(
287a3bb… leo 275 provider=prov_name,
287a3bb… leo 276 model=model,
287a3bb… leo 277 audio_minutes=duration / 60 if duration else 0,
287a3bb… leo 278 )
287a3bb… leo 279 return result
a94205b… leo 280
a94205b… leo 281 def get_models_used(self) -> dict[str, str]:
a94205b… leo 282 """Return a dict mapping capability to 'provider/model' for tracking."""
a94205b… leo 283 result = {}
a94205b… leo 284 for cap, explicit, prefs in [
a94205b… leo 285 ("vision", self.vision_model, _VISION_PREFERENCES),
a94205b… leo 286 ("chat", self.chat_model, _CHAT_PREFERENCES),
a94205b… leo 287 ("transcription", self.transcription_model, _TRANSCRIPTION_PREFERENCES),
a94205b… leo 288 ]:
a94205b… leo 289 try:
a94205b… leo 290 prov, model = self._resolve_model(explicit, cap, prefs)
a94205b… leo 291 result[cap] = f"{prov}/{model}"
a94205b… leo 292 except RuntimeError:
a94205b… leo 293 pass
a94205b… leo 294 return result

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button