PlanOpticon
feat(providers): add pluggable provider registry and OpenAICompatibleProvider base class Refactor the provider system from hardcoded if/elif chains to a self-registering plugin architecture. Each provider module registers itself with ProviderRegistry on import, declaring its env var, model prefixes, and default models. ProviderManager and discovery now use the registry for instantiation and model routing instead of hardcoded logic. Adds OpenAICompatibleProvider base class for future providers that use the OpenAI-compatible API format (Together, Fireworks, Cerebras, etc). Closes #77
Commit
ae85592da49c4ccc5f62070f82f3000bf0c49bd1ef3434ecb2d2100189e14fda
Parent
613bc9ca307adee…
8 files changed
+13
-2
+14
-1
+171
-2
+56
-53
+14
-1
+27
-50
+10
-1
+10
-1
~
video_processor/providers/__init__.py
~
video_processor/providers/anthropic_provider.py
~
video_processor/providers/base.py
~
video_processor/providers/discovery.py
~
video_processor/providers/gemini_provider.py
~
video_processor/providers/manager.py
~
video_processor/providers/ollama_provider.py
~
video_processor/providers/openai_provider.py
| --- video_processor/providers/__init__.py | ||
| +++ video_processor/providers/__init__.py | ||
| @@ -1,6 +1,17 @@ | ||
| 1 | 1 | """Provider abstraction layer for LLM, vision, and transcription APIs.""" |
| 2 | 2 | |
| 3 | -from video_processor.providers.base import BaseProvider, ModelInfo | |
| 3 | +from video_processor.providers.base import ( | |
| 4 | + BaseProvider, | |
| 5 | + ModelInfo, | |
| 6 | + OpenAICompatibleProvider, | |
| 7 | + ProviderRegistry, | |
| 8 | +) | |
| 4 | 9 | from video_processor.providers.manager import ProviderManager |
| 5 | 10 | |
| 6 | -__all__ = ["BaseProvider", "ModelInfo", "ProviderManager"] | |
| 11 | +__all__ = [ | |
| 12 | + "BaseProvider", | |
| 13 | + "ModelInfo", | |
| 14 | + "OpenAICompatibleProvider", | |
| 15 | + "ProviderManager", | |
| 16 | + "ProviderRegistry", | |
| 17 | +] | |
| 7 | 18 |
| --- video_processor/providers/__init__.py | |
| +++ video_processor/providers/__init__.py | |
| @@ -1,6 +1,17 @@ | |
| 1 | """Provider abstraction layer for LLM, vision, and transcription APIs.""" |
| 2 | |
| 3 | from video_processor.providers.base import BaseProvider, ModelInfo |
| 4 | from video_processor.providers.manager import ProviderManager |
| 5 | |
| 6 | __all__ = ["BaseProvider", "ModelInfo", "ProviderManager"] |
| 7 |
| --- video_processor/providers/__init__.py | |
| +++ video_processor/providers/__init__.py | |
| @@ -1,6 +1,17 @@ | |
| 1 | """Provider abstraction layer for LLM, vision, and transcription APIs.""" |
| 2 | |
| 3 | from video_processor.providers.base import ( |
| 4 | BaseProvider, |
| 5 | ModelInfo, |
| 6 | OpenAICompatibleProvider, |
| 7 | ProviderRegistry, |
| 8 | ) |
| 9 | from video_processor.providers.manager import ProviderManager |
| 10 | |
| 11 | __all__ = [ |
| 12 | "BaseProvider", |
| 13 | "ModelInfo", |
| 14 | "OpenAICompatibleProvider", |
| 15 | "ProviderManager", |
| 16 | "ProviderRegistry", |
| 17 | ] |
| 18 |
| --- video_processor/providers/anthropic_provider.py | ||
| +++ video_processor/providers/anthropic_provider.py | ||
| @@ -7,11 +7,11 @@ | ||
| 7 | 7 | from typing import Optional |
| 8 | 8 | |
| 9 | 9 | import anthropic |
| 10 | 10 | from dotenv import load_dotenv |
| 11 | 11 | |
| 12 | -from video_processor.providers.base import BaseProvider, ModelInfo | |
| 12 | +from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry | |
| 13 | 13 | |
| 14 | 14 | load_dotenv() |
| 15 | 15 | logger = logging.getLogger(__name__) |
| 16 | 16 | |
| 17 | 17 | |
| @@ -108,5 +108,18 @@ | ||
| 108 | 108 | ) |
| 109 | 109 | ) |
| 110 | 110 | except Exception as e: |
| 111 | 111 | logger.warning(f"Failed to list Anthropic models: {e}") |
| 112 | 112 | return sorted(models, key=lambda m: m.id) |
| 113 | + | |
| 114 | + | |
| 115 | +ProviderRegistry.register( | |
| 116 | + name="anthropic", | |
| 117 | + provider_class=AnthropicProvider, | |
| 118 | + env_var="ANTHROPIC_API_KEY", | |
| 119 | + model_prefixes=["claude-"], | |
| 120 | + default_models={ | |
| 121 | + "chat": "claude-sonnet-4-5-20250929", | |
| 122 | + "vision": "claude-sonnet-4-5-20250929", | |
| 123 | + "audio": "", | |
| 124 | + }, | |
| 125 | +) | |
| 113 | 126 |
| --- video_processor/providers/anthropic_provider.py | |
| +++ video_processor/providers/anthropic_provider.py | |
| @@ -7,11 +7,11 @@ | |
| 7 | from typing import Optional |
| 8 | |
| 9 | import anthropic |
| 10 | from dotenv import load_dotenv |
| 11 | |
| 12 | from video_processor.providers.base import BaseProvider, ModelInfo |
| 13 | |
| 14 | load_dotenv() |
| 15 | logger = logging.getLogger(__name__) |
| 16 | |
| 17 | |
| @@ -108,5 +108,18 @@ | |
| 108 | ) |
| 109 | ) |
| 110 | except Exception as e: |
| 111 | logger.warning(f"Failed to list Anthropic models: {e}") |
| 112 | return sorted(models, key=lambda m: m.id) |
| 113 |
| --- video_processor/providers/anthropic_provider.py | |
| +++ video_processor/providers/anthropic_provider.py | |
| @@ -7,11 +7,11 @@ | |
| 7 | from typing import Optional |
| 8 | |
| 9 | import anthropic |
| 10 | from dotenv import load_dotenv |
| 11 | |
| 12 | from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry |
| 13 | |
| 14 | load_dotenv() |
| 15 | logger = logging.getLogger(__name__) |
| 16 | |
| 17 | |
| @@ -108,5 +108,18 @@ | |
| 108 | ) |
| 109 | ) |
| 110 | except Exception as e: |
| 111 | logger.warning(f"Failed to list Anthropic models: {e}") |
| 112 | return sorted(models, key=lambda m: m.id) |
| 113 | |
| 114 | |
| 115 | ProviderRegistry.register( |
| 116 | name="anthropic", |
| 117 | provider_class=AnthropicProvider, |
| 118 | env_var="ANTHROPIC_API_KEY", |
| 119 | model_prefixes=["claude-"], |
| 120 | default_models={ |
| 121 | "chat": "claude-sonnet-4-5-20250929", |
| 122 | "vision": "claude-sonnet-4-5-20250929", |
| 123 | "audio": "", |
| 124 | }, |
| 125 | ) |
| 126 |
+171
-2
| --- video_processor/providers/base.py | ||
| +++ video_processor/providers/base.py | ||
| @@ -1,12 +1,17 @@ | ||
| 1 | -"""Abstract base class and shared types for provider implementations.""" | |
| 1 | +"""Abstract base class, registry, and shared types for provider implementations.""" | |
| 2 | 2 | |
| 3 | +import base64 | |
| 4 | +import logging | |
| 5 | +import os | |
| 3 | 6 | from abc import ABC, abstractmethod |
| 4 | 7 | from pathlib import Path |
| 5 | -from typing import List, Optional | |
| 8 | +from typing import Dict, List, Optional | |
| 6 | 9 | |
| 7 | 10 | from pydantic import BaseModel, Field |
| 11 | + | |
| 12 | +logger = logging.getLogger(__name__) | |
| 8 | 13 | |
| 9 | 14 | |
| 10 | 15 | class ModelInfo(BaseModel): |
| 11 | 16 | """Information about an available model.""" |
| 12 | 17 | |
| @@ -53,5 +58,169 @@ | ||
| 53 | 58 | """Transcribe an audio file. Returns dict with 'text', 'segments', etc.""" |
| 54 | 59 | |
| 55 | 60 | @abstractmethod |
| 56 | 61 | def list_models(self) -> list[ModelInfo]: |
| 57 | 62 | """Discover available models from this provider's API.""" |
| 63 | + | |
| 64 | + | |
| 65 | +class ProviderRegistry: | |
| 66 | + """Registry for provider classes. Providers register themselves with metadata.""" | |
| 67 | + | |
| 68 | + _providers: Dict[str, Dict] = {} | |
| 69 | + | |
| 70 | + @classmethod | |
| 71 | + def register( | |
| 72 | + cls, | |
| 73 | + name: str, | |
| 74 | + provider_class: type, | |
| 75 | + env_var: str = "", | |
| 76 | + model_prefixes: Optional[List[str]] = None, | |
| 77 | + default_models: Optional[Dict[str, str]] = None, | |
| 78 | + ) -> None: | |
| 79 | + """Register a provider class with its metadata.""" | |
| 80 | + cls._providers[name] = { | |
| 81 | + "class": provider_class, | |
| 82 | + "env_var": env_var, | |
| 83 | + "model_prefixes": model_prefixes or [], | |
| 84 | + "default_models": default_models or {}, | |
| 85 | + } | |
| 86 | + | |
| 87 | + @classmethod | |
| 88 | + def get(cls, name: str) -> type: | |
| 89 | + """Return the provider class for a given name.""" | |
| 90 | + if name not in cls._providers: | |
| 91 | + raise ValueError(f"Unknown provider: {name}") | |
| 92 | + return cls._providers[name]["class"] | |
| 93 | + | |
| 94 | + @classmethod | |
| 95 | + def get_by_model(cls, model_id: str) -> Optional[str]: | |
| 96 | + """Return provider name for a model ID based on prefix matching.""" | |
| 97 | + for name, info in cls._providers.items(): | |
| 98 | + for prefix in info["model_prefixes"]: | |
| 99 | + if model_id.startswith(prefix): | |
| 100 | + return name | |
| 101 | + return None | |
| 102 | + | |
| 103 | + @classmethod | |
| 104 | + def get_default_models(cls, name: str) -> Dict[str, str]: | |
| 105 | + """Return the default models dict for a provider.""" | |
| 106 | + if name not in cls._providers: | |
| 107 | + return {} | |
| 108 | + return cls._providers[name].get("default_models", {}) | |
| 109 | + | |
| 110 | + @classmethod | |
| 111 | + def available(cls) -> List[str]: | |
| 112 | + """Return names of providers whose env var is set (or have no env var requirement).""" | |
| 113 | + result = [] | |
| 114 | + for name, info in cls._providers.items(): | |
| 115 | + env_var = info.get("env_var", "") | |
| 116 | + if not env_var: | |
| 117 | + # Providers without an env var (e.g. ollama) need special availability checks | |
| 118 | + result.append(name) | |
| 119 | + elif os.getenv(env_var, ""): | |
| 120 | + result.append(name) | |
| 121 | + return result | |
| 122 | + | |
| 123 | + @classmethod | |
| 124 | + def all_registered(cls) -> Dict[str, Dict]: | |
| 125 | + """Return all registered providers and their metadata.""" | |
| 126 | + return dict(cls._providers) | |
| 127 | + | |
| 128 | + | |
| 129 | +class OpenAICompatibleProvider(BaseProvider): | |
| 130 | + """Base for providers using OpenAI-compatible APIs. | |
| 131 | + | |
| 132 | + Suitable for Together, Fireworks, Cerebras, xAI, Azure, and similar services. | |
| 133 | + """ | |
| 134 | + | |
| 135 | + provider_name: str = "" | |
| 136 | + base_url: str = "" | |
| 137 | + env_var: str = "" | |
| 138 | + | |
| 139 | + def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None): | |
| 140 | + from openai import OpenAI | |
| 141 | + | |
| 142 | + self._api_key = api_key or os.getenv(self.env_var, "") | |
| 143 | + self._base_url = base_url or self.base_url | |
| 144 | + self._client = OpenAI(api_key=self._api_key, base_url=self._base_url) | |
| 145 | + self._last_usage = None | |
| 146 | + | |
| 147 | + def chat( | |
| 148 | + self, | |
| 149 | + messages: list[dict], | |
| 150 | + max_tokens: int = 4096, | |
| 151 | + temperature: float = 0.7, | |
| 152 | + model: Optional[str] = None, | |
| 153 | + ) -> str: | |
| 154 | + model = model or "gpt-4o" | |
| 155 | + response = self._client.chat.completions.create( | |
| 156 | + model=model, | |
| 157 | + messages=messages, | |
| 158 | + max_tokens=max_tokens, | |
| 159 | + temperature=temperature, | |
| 160 | + ) | |
| 161 | + self._last_usage = { | |
| 162 | + "input_tokens": getattr(response.usage, "prompt_tokens", 0) if response.usage else 0, | |
| 163 | + "output_tokens": getattr(response.usage, "completion_tokens", 0) | |
| 164 | + if response.usage | |
| 165 | + else 0, | |
| 166 | + } | |
| 167 | + return response.choices[0].message.content or "" | |
| 168 | + | |
| 169 | + def analyze_image( | |
| 170 | + self, | |
| 171 | + image_bytes: bytes, | |
| 172 | + prompt: str, | |
| 173 | + max_tokens: int = 4096, | |
| 174 | + model: Optional[str] = None, | |
| 175 | + ) -> str: | |
| 176 | + model = model or "gpt-4o" | |
| 177 | + b64 = base64.b64encode(image_bytes).decode() | |
| 178 | + response = self._client.chat.completions.create( | |
| 179 | + model=model, | |
| 180 | + messages=[ | |
| 181 | + { | |
| 182 | + "role": "user", | |
| 183 | + "content": [ | |
| 184 | + {"type": "text", "text": prompt}, | |
| 185 | + { | |
| 186 | + "type": "image_url", | |
| 187 | + "image_url": {"url": f"data:image/jpeg;base64,{b64}"}, | |
| 188 | + }, | |
| 189 | + ], | |
| 190 | + } | |
| 191 | + ], | |
| 192 | + max_tokens=max_tokens, | |
| 193 | + ) | |
| 194 | + self._last_usage = { | |
| 195 | + "input_tokens": getattr(response.usage, "prompt_tokens", 0) if response.usage else 0, | |
| 196 | + "output_tokens": getattr(response.usage, "completion_tokens", 0) | |
| 197 | + if response.usage | |
| 198 | + else 0, | |
| 199 | + } | |
| 200 | + return response.choices[0].message.content or "" | |
| 201 | + | |
| 202 | + def transcribe_audio( | |
| 203 | + self, | |
| 204 | + audio_path: str | Path, | |
| 205 | + language: Optional[str] = None, | |
| 206 | + model: Optional[str] = None, | |
| 207 | + ) -> dict: | |
| 208 | + raise NotImplementedError(f"{self.provider_name} does not support audio transcription") | |
| 209 | + | |
| 210 | + def list_models(self) -> list[ModelInfo]: | |
| 211 | + models = [] | |
| 212 | + try: | |
| 213 | + for m in self._client.models.list(): | |
| 214 | + mid = m.id | |
| 215 | + caps = ["chat"] | |
| 216 | + models.append( | |
| 217 | + ModelInfo( | |
| 218 | + id=mid, | |
| 219 | + provider=self.provider_name, | |
| 220 | + display_name=mid, | |
| 221 | + capabilities=caps, | |
| 222 | + ) | |
| 223 | + ) | |
| 224 | + except Exception as e: | |
| 225 | + logger.warning(f"Failed to list {self.provider_name} models: {e}") | |
| 226 | + return sorted(models, key=lambda m: m.id) | |
| 58 | 227 |
| --- video_processor/providers/base.py | |
| +++ video_processor/providers/base.py | |
| @@ -1,12 +1,17 @@ | |
| 1 | """Abstract base class and shared types for provider implementations.""" |
| 2 | |
| 3 | from abc import ABC, abstractmethod |
| 4 | from pathlib import Path |
| 5 | from typing import List, Optional |
| 6 | |
| 7 | from pydantic import BaseModel, Field |
| 8 | |
| 9 | |
| 10 | class ModelInfo(BaseModel): |
| 11 | """Information about an available model.""" |
| 12 | |
| @@ -53,5 +58,169 @@ | |
| 53 | """Transcribe an audio file. Returns dict with 'text', 'segments', etc.""" |
| 54 | |
| 55 | @abstractmethod |
| 56 | def list_models(self) -> list[ModelInfo]: |
| 57 | """Discover available models from this provider's API.""" |
| 58 |
| --- video_processor/providers/base.py | |
| +++ video_processor/providers/base.py | |
| @@ -1,12 +1,17 @@ | |
| 1 | """Abstract base class, registry, and shared types for provider implementations.""" |
| 2 | |
| 3 | import base64 |
| 4 | import logging |
| 5 | import os |
| 6 | from abc import ABC, abstractmethod |
| 7 | from pathlib import Path |
| 8 | from typing import Dict, List, Optional |
| 9 | |
| 10 | from pydantic import BaseModel, Field |
| 11 | |
| 12 | logger = logging.getLogger(__name__) |
| 13 | |
| 14 | |
| 15 | class ModelInfo(BaseModel): |
| 16 | """Information about an available model.""" |
| 17 | |
| @@ -53,5 +58,169 @@ | |
| 58 | """Transcribe an audio file. Returns dict with 'text', 'segments', etc.""" |
| 59 | |
| 60 | @abstractmethod |
| 61 | def list_models(self) -> list[ModelInfo]: |
| 62 | """Discover available models from this provider's API.""" |
| 63 | |
| 64 | |
| 65 | class ProviderRegistry: |
| 66 | """Registry for provider classes. Providers register themselves with metadata.""" |
| 67 | |
| 68 | _providers: Dict[str, Dict] = {} |
| 69 | |
| 70 | @classmethod |
| 71 | def register( |
| 72 | cls, |
| 73 | name: str, |
| 74 | provider_class: type, |
| 75 | env_var: str = "", |
| 76 | model_prefixes: Optional[List[str]] = None, |
| 77 | default_models: Optional[Dict[str, str]] = None, |
| 78 | ) -> None: |
| 79 | """Register a provider class with its metadata.""" |
| 80 | cls._providers[name] = { |
| 81 | "class": provider_class, |
| 82 | "env_var": env_var, |
| 83 | "model_prefixes": model_prefixes or [], |
| 84 | "default_models": default_models or {}, |
| 85 | } |
| 86 | |
| 87 | @classmethod |
| 88 | def get(cls, name: str) -> type: |
| 89 | """Return the provider class for a given name.""" |
| 90 | if name not in cls._providers: |
| 91 | raise ValueError(f"Unknown provider: {name}") |
| 92 | return cls._providers[name]["class"] |
| 93 | |
| 94 | @classmethod |
| 95 | def get_by_model(cls, model_id: str) -> Optional[str]: |
| 96 | """Return provider name for a model ID based on prefix matching.""" |
| 97 | for name, info in cls._providers.items(): |
| 98 | for prefix in info["model_prefixes"]: |
| 99 | if model_id.startswith(prefix): |
| 100 | return name |
| 101 | return None |
| 102 | |
| 103 | @classmethod |
| 104 | def get_default_models(cls, name: str) -> Dict[str, str]: |
| 105 | """Return the default models dict for a provider.""" |
| 106 | if name not in cls._providers: |
| 107 | return {} |
| 108 | return cls._providers[name].get("default_models", {}) |
| 109 | |
| 110 | @classmethod |
| 111 | def available(cls) -> List[str]: |
| 112 | """Return names of providers whose env var is set (or have no env var requirement).""" |
| 113 | result = [] |
| 114 | for name, info in cls._providers.items(): |
| 115 | env_var = info.get("env_var", "") |
| 116 | if not env_var: |
| 117 | # Providers without an env var (e.g. ollama) need special availability checks |
| 118 | result.append(name) |
| 119 | elif os.getenv(env_var, ""): |
| 120 | result.append(name) |
| 121 | return result |
| 122 | |
| 123 | @classmethod |
| 124 | def all_registered(cls) -> Dict[str, Dict]: |
| 125 | """Return all registered providers and their metadata.""" |
| 126 | return dict(cls._providers) |
| 127 | |
| 128 | |
| 129 | class OpenAICompatibleProvider(BaseProvider): |
| 130 | """Base for providers using OpenAI-compatible APIs. |
| 131 | |
| 132 | Suitable for Together, Fireworks, Cerebras, xAI, Azure, and similar services. |
| 133 | """ |
| 134 | |
| 135 | provider_name: str = "" |
| 136 | base_url: str = "" |
| 137 | env_var: str = "" |
| 138 | |
| 139 | def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None): |
| 140 | from openai import OpenAI |
| 141 | |
| 142 | self._api_key = api_key or os.getenv(self.env_var, "") |
| 143 | self._base_url = base_url or self.base_url |
| 144 | self._client = OpenAI(api_key=self._api_key, base_url=self._base_url) |
| 145 | self._last_usage = None |
| 146 | |
| 147 | def chat( |
| 148 | self, |
| 149 | messages: list[dict], |
| 150 | max_tokens: int = 4096, |
| 151 | temperature: float = 0.7, |
| 152 | model: Optional[str] = None, |
| 153 | ) -> str: |
| 154 | model = model or "gpt-4o" |
| 155 | response = self._client.chat.completions.create( |
| 156 | model=model, |
| 157 | messages=messages, |
| 158 | max_tokens=max_tokens, |
| 159 | temperature=temperature, |
| 160 | ) |
| 161 | self._last_usage = { |
| 162 | "input_tokens": getattr(response.usage, "prompt_tokens", 0) if response.usage else 0, |
| 163 | "output_tokens": getattr(response.usage, "completion_tokens", 0) |
| 164 | if response.usage |
| 165 | else 0, |
| 166 | } |
| 167 | return response.choices[0].message.content or "" |
| 168 | |
| 169 | def analyze_image( |
| 170 | self, |
| 171 | image_bytes: bytes, |
| 172 | prompt: str, |
| 173 | max_tokens: int = 4096, |
| 174 | model: Optional[str] = None, |
| 175 | ) -> str: |
| 176 | model = model or "gpt-4o" |
| 177 | b64 = base64.b64encode(image_bytes).decode() |
| 178 | response = self._client.chat.completions.create( |
| 179 | model=model, |
| 180 | messages=[ |
| 181 | { |
| 182 | "role": "user", |
| 183 | "content": [ |
| 184 | {"type": "text", "text": prompt}, |
| 185 | { |
| 186 | "type": "image_url", |
| 187 | "image_url": {"url": f"data:image/jpeg;base64,{b64}"}, |
| 188 | }, |
| 189 | ], |
| 190 | } |
| 191 | ], |
| 192 | max_tokens=max_tokens, |
| 193 | ) |
| 194 | self._last_usage = { |
| 195 | "input_tokens": getattr(response.usage, "prompt_tokens", 0) if response.usage else 0, |
| 196 | "output_tokens": getattr(response.usage, "completion_tokens", 0) |
| 197 | if response.usage |
| 198 | else 0, |
| 199 | } |
| 200 | return response.choices[0].message.content or "" |
| 201 | |
| 202 | def transcribe_audio( |
| 203 | self, |
| 204 | audio_path: str | Path, |
| 205 | language: Optional[str] = None, |
| 206 | model: Optional[str] = None, |
| 207 | ) -> dict: |
| 208 | raise NotImplementedError(f"{self.provider_name} does not support audio transcription") |
| 209 | |
| 210 | def list_models(self) -> list[ModelInfo]: |
| 211 | models = [] |
| 212 | try: |
| 213 | for m in self._client.models.list(): |
| 214 | mid = m.id |
| 215 | caps = ["chat"] |
| 216 | models.append( |
| 217 | ModelInfo( |
| 218 | id=mid, |
| 219 | provider=self.provider_name, |
| 220 | display_name=mid, |
| 221 | capabilities=caps, |
| 222 | ) |
| 223 | ) |
| 224 | except Exception as e: |
| 225 | logger.warning(f"Failed to list {self.provider_name} models: {e}") |
| 226 | return sorted(models, key=lambda m: m.id) |
| 227 |
+56
-53
| --- video_processor/providers/discovery.py | ||
| +++ video_processor/providers/discovery.py | ||
| @@ -4,17 +4,27 @@ | ||
| 4 | 4 | import os |
| 5 | 5 | from typing import Optional |
| 6 | 6 | |
| 7 | 7 | from dotenv import load_dotenv |
| 8 | 8 | |
| 9 | -from video_processor.providers.base import ModelInfo | |
| 9 | +from video_processor.providers.base import ModelInfo, ProviderRegistry | |
| 10 | 10 | |
| 11 | 11 | load_dotenv() |
| 12 | 12 | logger = logging.getLogger(__name__) |
| 13 | 13 | |
| 14 | 14 | _cached_models: Optional[list[ModelInfo]] = None |
| 15 | 15 | |
| 16 | + | |
| 17 | +def _ensure_providers_registered() -> None: | |
| 18 | + """Import all built-in provider modules so they register themselves.""" | |
| 19 | + if ProviderRegistry.all_registered(): | |
| 20 | + return | |
| 21 | + import video_processor.providers.anthropic_provider # noqa: F401 | |
| 22 | + import video_processor.providers.gemini_provider # noqa: F401 | |
| 23 | + import video_processor.providers.ollama_provider # noqa: F401 | |
| 24 | + import video_processor.providers.openai_provider # noqa: F401 | |
| 25 | + | |
| 16 | 26 | |
| 17 | 27 | def discover_available_models( |
| 18 | 28 | api_keys: Optional[dict[str, str]] = None, |
| 19 | 29 | force_refresh: bool = False, |
| 20 | 30 | ) -> list[ModelInfo]: |
| @@ -26,70 +36,63 @@ | ||
| 26 | 36 | """ |
| 27 | 37 | global _cached_models |
| 28 | 38 | if _cached_models is not None and not force_refresh: |
| 29 | 39 | return _cached_models |
| 30 | 40 | |
| 41 | + _ensure_providers_registered() | |
| 42 | + | |
| 31 | 43 | keys = api_keys or { |
| 32 | 44 | "openai": os.getenv("OPENAI_API_KEY", ""), |
| 33 | 45 | "anthropic": os.getenv("ANTHROPIC_API_KEY", ""), |
| 34 | 46 | "gemini": os.getenv("GEMINI_API_KEY", ""), |
| 35 | 47 | } |
| 36 | 48 | |
| 37 | 49 | all_models: list[ModelInfo] = [] |
| 38 | 50 | |
| 39 | - # OpenAI | |
| 40 | - if keys.get("openai"): | |
| 41 | - try: | |
| 42 | - from video_processor.providers.openai_provider import OpenAIProvider | |
| 43 | - | |
| 44 | - provider = OpenAIProvider(api_key=keys["openai"]) | |
| 45 | - models = provider.list_models() | |
| 46 | - logger.info(f"Discovered {len(models)} OpenAI models") | |
| 47 | - all_models.extend(models) | |
| 48 | - except Exception as e: | |
| 49 | - logger.info(f"OpenAI discovery skipped: {e}") | |
| 50 | - | |
| 51 | - # Anthropic | |
| 52 | - if keys.get("anthropic"): | |
| 53 | - try: | |
| 54 | - from video_processor.providers.anthropic_provider import AnthropicProvider | |
| 55 | - | |
| 56 | - provider = AnthropicProvider(api_key=keys["anthropic"]) | |
| 57 | - models = provider.list_models() | |
| 58 | - logger.info(f"Discovered {len(models)} Anthropic models") | |
| 59 | - all_models.extend(models) | |
| 60 | - except Exception as e: | |
| 61 | - logger.info(f"Anthropic discovery skipped: {e}") | |
| 62 | - | |
| 63 | - # Gemini (API key or service account) | |
| 64 | - gemini_key = keys.get("gemini") | |
| 65 | - gemini_creds = os.getenv("GOOGLE_APPLICATION_CREDENTIALS", "") | |
| 66 | - if gemini_key or gemini_creds: | |
| 67 | - try: | |
| 68 | - from video_processor.providers.gemini_provider import GeminiProvider | |
| 69 | - | |
| 70 | - provider = GeminiProvider( | |
| 71 | - api_key=gemini_key or None, | |
| 72 | - credentials_path=gemini_creds or None, | |
| 73 | - ) | |
| 74 | - models = provider.list_models() | |
| 75 | - logger.info(f"Discovered {len(models)} Gemini models") | |
| 76 | - all_models.extend(models) | |
| 77 | - except Exception as e: | |
| 78 | - logger.warning(f"Gemini discovery failed: {e}") | |
| 79 | - | |
| 80 | - # Ollama (local, no API key needed) | |
| 81 | - try: | |
| 82 | - from video_processor.providers.ollama_provider import OllamaProvider | |
| 83 | - | |
| 84 | - if OllamaProvider.is_available(): | |
| 85 | - provider = OllamaProvider() | |
| 86 | - models = provider.list_models() | |
| 87 | - logger.info(f"Discovered {len(models)} Ollama models") | |
| 88 | - all_models.extend(models) | |
| 89 | - except Exception as e: | |
| 90 | - logger.info(f"Ollama discovery skipped: {e}") | |
| 51 | + for name, info in ProviderRegistry.all_registered().items(): | |
| 52 | + env_var = info.get("env_var", "") | |
| 53 | + provider_class = info["class"] | |
| 54 | + | |
| 55 | + if name == "ollama": | |
| 56 | + # Ollama: no API key, check server availability | |
| 57 | + try: | |
| 58 | + if provider_class.is_available(): | |
| 59 | + provider = provider_class() | |
| 60 | + models = provider.list_models() | |
| 61 | + logger.info(f"Discovered {len(models)} Ollama models") | |
| 62 | + all_models.extend(models) | |
| 63 | + except Exception as e: | |
| 64 | + logger.info(f"Ollama discovery skipped: {e}") | |
| 65 | + continue | |
| 66 | + | |
| 67 | + # For key-based providers, check the api_keys dict first, then env var | |
| 68 | + key = keys.get(name, "") | |
| 69 | + if not key and env_var: | |
| 70 | + key = os.getenv(env_var, "") | |
| 71 | + | |
| 72 | + # Special case: Gemini also supports service account credentials | |
| 73 | + gemini_creds = "" | |
| 74 | + if name == "gemini": | |
| 75 | + gemini_creds = os.getenv("GOOGLE_APPLICATION_CREDENTIALS", "") | |
| 76 | + | |
| 77 | + if not key and not gemini_creds: | |
| 78 | + continue | |
| 79 | + | |
| 80 | + try: | |
| 81 | + # Handle provider-specific constructor args | |
| 82 | + if name == "gemini": | |
| 83 | + provider = provider_class( | |
| 84 | + api_key=key or None, | |
| 85 | + credentials_path=gemini_creds or None, | |
| 86 | + ) | |
| 87 | + else: | |
| 88 | + provider = provider_class(api_key=key) | |
| 89 | + models = provider.list_models() | |
| 90 | + logger.info(f"Discovered {len(models)} {name.capitalize()} models") | |
| 91 | + all_models.extend(models) | |
| 92 | + except Exception as e: | |
| 93 | + logger.info(f"{name.capitalize()} discovery skipped: {e}") | |
| 91 | 94 | |
| 92 | 95 | # Sort by provider then id |
| 93 | 96 | all_models.sort(key=lambda m: (m.provider, m.id)) |
| 94 | 97 | _cached_models = all_models |
| 95 | 98 | logger.info(f"Total discovered models: {len(all_models)}") |
| 96 | 99 |
| --- video_processor/providers/discovery.py | |
| +++ video_processor/providers/discovery.py | |
| @@ -4,17 +4,27 @@ | |
| 4 | import os |
| 5 | from typing import Optional |
| 6 | |
| 7 | from dotenv import load_dotenv |
| 8 | |
| 9 | from video_processor.providers.base import ModelInfo |
| 10 | |
| 11 | load_dotenv() |
| 12 | logger = logging.getLogger(__name__) |
| 13 | |
| 14 | _cached_models: Optional[list[ModelInfo]] = None |
| 15 | |
| 16 | |
| 17 | def discover_available_models( |
| 18 | api_keys: Optional[dict[str, str]] = None, |
| 19 | force_refresh: bool = False, |
| 20 | ) -> list[ModelInfo]: |
| @@ -26,70 +36,63 @@ | |
| 26 | """ |
| 27 | global _cached_models |
| 28 | if _cached_models is not None and not force_refresh: |
| 29 | return _cached_models |
| 30 | |
| 31 | keys = api_keys or { |
| 32 | "openai": os.getenv("OPENAI_API_KEY", ""), |
| 33 | "anthropic": os.getenv("ANTHROPIC_API_KEY", ""), |
| 34 | "gemini": os.getenv("GEMINI_API_KEY", ""), |
| 35 | } |
| 36 | |
| 37 | all_models: list[ModelInfo] = [] |
| 38 | |
| 39 | # OpenAI |
| 40 | if keys.get("openai"): |
| 41 | try: |
| 42 | from video_processor.providers.openai_provider import OpenAIProvider |
| 43 | |
| 44 | provider = OpenAIProvider(api_key=keys["openai"]) |
| 45 | models = provider.list_models() |
| 46 | logger.info(f"Discovered {len(models)} OpenAI models") |
| 47 | all_models.extend(models) |
| 48 | except Exception as e: |
| 49 | logger.info(f"OpenAI discovery skipped: {e}") |
| 50 | |
| 51 | # Anthropic |
| 52 | if keys.get("anthropic"): |
| 53 | try: |
| 54 | from video_processor.providers.anthropic_provider import AnthropicProvider |
| 55 | |
| 56 | provider = AnthropicProvider(api_key=keys["anthropic"]) |
| 57 | models = provider.list_models() |
| 58 | logger.info(f"Discovered {len(models)} Anthropic models") |
| 59 | all_models.extend(models) |
| 60 | except Exception as e: |
| 61 | logger.info(f"Anthropic discovery skipped: {e}") |
| 62 | |
| 63 | # Gemini (API key or service account) |
| 64 | gemini_key = keys.get("gemini") |
| 65 | gemini_creds = os.getenv("GOOGLE_APPLICATION_CREDENTIALS", "") |
| 66 | if gemini_key or gemini_creds: |
| 67 | try: |
| 68 | from video_processor.providers.gemini_provider import GeminiProvider |
| 69 | |
| 70 | provider = GeminiProvider( |
| 71 | api_key=gemini_key or None, |
| 72 | credentials_path=gemini_creds or None, |
| 73 | ) |
| 74 | models = provider.list_models() |
| 75 | logger.info(f"Discovered {len(models)} Gemini models") |
| 76 | all_models.extend(models) |
| 77 | except Exception as e: |
| 78 | logger.warning(f"Gemini discovery failed: {e}") |
| 79 | |
| 80 | # Ollama (local, no API key needed) |
| 81 | try: |
| 82 | from video_processor.providers.ollama_provider import OllamaProvider |
| 83 | |
| 84 | if OllamaProvider.is_available(): |
| 85 | provider = OllamaProvider() |
| 86 | models = provider.list_models() |
| 87 | logger.info(f"Discovered {len(models)} Ollama models") |
| 88 | all_models.extend(models) |
| 89 | except Exception as e: |
| 90 | logger.info(f"Ollama discovery skipped: {e}") |
| 91 | |
| 92 | # Sort by provider then id |
| 93 | all_models.sort(key=lambda m: (m.provider, m.id)) |
| 94 | _cached_models = all_models |
| 95 | logger.info(f"Total discovered models: {len(all_models)}") |
| 96 |
| --- video_processor/providers/discovery.py | |
| +++ video_processor/providers/discovery.py | |
| @@ -4,17 +4,27 @@ | |
| 4 | import os |
| 5 | from typing import Optional |
| 6 | |
| 7 | from dotenv import load_dotenv |
| 8 | |
| 9 | from video_processor.providers.base import ModelInfo, ProviderRegistry |
| 10 | |
| 11 | load_dotenv() |
| 12 | logger = logging.getLogger(__name__) |
| 13 | |
| 14 | _cached_models: Optional[list[ModelInfo]] = None |
| 15 | |
| 16 | |
| 17 | def _ensure_providers_registered() -> None: |
| 18 | """Import all built-in provider modules so they register themselves.""" |
| 19 | if ProviderRegistry.all_registered(): |
| 20 | return |
| 21 | import video_processor.providers.anthropic_provider # noqa: F401 |
| 22 | import video_processor.providers.gemini_provider # noqa: F401 |
| 23 | import video_processor.providers.ollama_provider # noqa: F401 |
| 24 | import video_processor.providers.openai_provider # noqa: F401 |
| 25 | |
| 26 | |
| 27 | def discover_available_models( |
| 28 | api_keys: Optional[dict[str, str]] = None, |
| 29 | force_refresh: bool = False, |
| 30 | ) -> list[ModelInfo]: |
| @@ -26,70 +36,63 @@ | |
| 36 | """ |
| 37 | global _cached_models |
| 38 | if _cached_models is not None and not force_refresh: |
| 39 | return _cached_models |
| 40 | |
| 41 | _ensure_providers_registered() |
| 42 | |
| 43 | keys = api_keys or { |
| 44 | "openai": os.getenv("OPENAI_API_KEY", ""), |
| 45 | "anthropic": os.getenv("ANTHROPIC_API_KEY", ""), |
| 46 | "gemini": os.getenv("GEMINI_API_KEY", ""), |
| 47 | } |
| 48 | |
| 49 | all_models: list[ModelInfo] = [] |
| 50 | |
| 51 | for name, info in ProviderRegistry.all_registered().items(): |
| 52 | env_var = info.get("env_var", "") |
| 53 | provider_class = info["class"] |
| 54 | |
| 55 | if name == "ollama": |
| 56 | # Ollama: no API key, check server availability |
| 57 | try: |
| 58 | if provider_class.is_available(): |
| 59 | provider = provider_class() |
| 60 | models = provider.list_models() |
| 61 | logger.info(f"Discovered {len(models)} Ollama models") |
| 62 | all_models.extend(models) |
| 63 | except Exception as e: |
| 64 | logger.info(f"Ollama discovery skipped: {e}") |
| 65 | continue |
| 66 | |
| 67 | # For key-based providers, check the api_keys dict first, then env var |
| 68 | key = keys.get(name, "") |
| 69 | if not key and env_var: |
| 70 | key = os.getenv(env_var, "") |
| 71 | |
| 72 | # Special case: Gemini also supports service account credentials |
| 73 | gemini_creds = "" |
| 74 | if name == "gemini": |
| 75 | gemini_creds = os.getenv("GOOGLE_APPLICATION_CREDENTIALS", "") |
| 76 | |
| 77 | if not key and not gemini_creds: |
| 78 | continue |
| 79 | |
| 80 | try: |
| 81 | # Handle provider-specific constructor args |
| 82 | if name == "gemini": |
| 83 | provider = provider_class( |
| 84 | api_key=key or None, |
| 85 | credentials_path=gemini_creds or None, |
| 86 | ) |
| 87 | else: |
| 88 | provider = provider_class(api_key=key) |
| 89 | models = provider.list_models() |
| 90 | logger.info(f"Discovered {len(models)} {name.capitalize()} models") |
| 91 | all_models.extend(models) |
| 92 | except Exception as e: |
| 93 | logger.info(f"{name.capitalize()} discovery skipped: {e}") |
| 94 | |
| 95 | # Sort by provider then id |
| 96 | all_models.sort(key=lambda m: (m.provider, m.id)) |
| 97 | _cached_models = all_models |
| 98 | logger.info(f"Total discovered models: {len(all_models)}") |
| 99 |
| --- video_processor/providers/gemini_provider.py | ||
| +++ video_processor/providers/gemini_provider.py | ||
| @@ -5,11 +5,11 @@ | ||
| 5 | 5 | from pathlib import Path |
| 6 | 6 | from typing import Optional |
| 7 | 7 | |
| 8 | 8 | from dotenv import load_dotenv |
| 9 | 9 | |
| 10 | -from video_processor.providers.base import BaseProvider, ModelInfo | |
| 10 | +from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry | |
| 11 | 11 | |
| 12 | 12 | load_dotenv() |
| 13 | 13 | logger = logging.getLogger(__name__) |
| 14 | 14 | |
| 15 | 15 | # Capabilities inferred from model id patterns |
| @@ -218,5 +218,18 @@ | ||
| 218 | 218 | ) |
| 219 | 219 | ) |
| 220 | 220 | except Exception as e: |
| 221 | 221 | logger.warning(f"Failed to list Gemini models: {e}") |
| 222 | 222 | return sorted(models, key=lambda m: m.id) |
| 223 | + | |
| 224 | + | |
| 225 | +ProviderRegistry.register( | |
| 226 | + name="gemini", | |
| 227 | + provider_class=GeminiProvider, | |
| 228 | + env_var="GEMINI_API_KEY", | |
| 229 | + model_prefixes=["gemini-"], | |
| 230 | + default_models={ | |
| 231 | + "chat": "gemini-2.5-flash", | |
| 232 | + "vision": "gemini-2.5-flash", | |
| 233 | + "audio": "gemini-2.5-flash", | |
| 234 | + }, | |
| 235 | +) | |
| 223 | 236 |
| --- video_processor/providers/gemini_provider.py | |
| +++ video_processor/providers/gemini_provider.py | |
| @@ -5,11 +5,11 @@ | |
| 5 | from pathlib import Path |
| 6 | from typing import Optional |
| 7 | |
| 8 | from dotenv import load_dotenv |
| 9 | |
| 10 | from video_processor.providers.base import BaseProvider, ModelInfo |
| 11 | |
| 12 | load_dotenv() |
| 13 | logger = logging.getLogger(__name__) |
| 14 | |
| 15 | # Capabilities inferred from model id patterns |
| @@ -218,5 +218,18 @@ | |
| 218 | ) |
| 219 | ) |
| 220 | except Exception as e: |
| 221 | logger.warning(f"Failed to list Gemini models: {e}") |
| 222 | return sorted(models, key=lambda m: m.id) |
| 223 |
| --- video_processor/providers/gemini_provider.py | |
| +++ video_processor/providers/gemini_provider.py | |
| @@ -5,11 +5,11 @@ | |
| 5 | from pathlib import Path |
| 6 | from typing import Optional |
| 7 | |
| 8 | from dotenv import load_dotenv |
| 9 | |
| 10 | from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry |
| 11 | |
| 12 | load_dotenv() |
| 13 | logger = logging.getLogger(__name__) |
| 14 | |
| 15 | # Capabilities inferred from model id patterns |
| @@ -218,5 +218,18 @@ | |
| 218 | ) |
| 219 | ) |
| 220 | except Exception as e: |
| 221 | logger.warning(f"Failed to list Gemini models: {e}") |
| 222 | return sorted(models, key=lambda m: m.id) |
| 223 | |
| 224 | |
| 225 | ProviderRegistry.register( |
| 226 | name="gemini", |
| 227 | provider_class=GeminiProvider, |
| 228 | env_var="GEMINI_API_KEY", |
| 229 | model_prefixes=["gemini-"], |
| 230 | default_models={ |
| 231 | "chat": "gemini-2.5-flash", |
| 232 | "vision": "gemini-2.5-flash", |
| 233 | "audio": "gemini-2.5-flash", |
| 234 | }, |
| 235 | ) |
| 236 |
+27
-50
| --- video_processor/providers/manager.py | ||
| +++ video_processor/providers/manager.py | ||
| @@ -4,16 +4,28 @@ | ||
| 4 | 4 | from pathlib import Path |
| 5 | 5 | from typing import Optional |
| 6 | 6 | |
| 7 | 7 | from dotenv import load_dotenv |
| 8 | 8 | |
| 9 | -from video_processor.providers.base import BaseProvider, ModelInfo | |
| 9 | +from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry | |
| 10 | 10 | from video_processor.providers.discovery import discover_available_models |
| 11 | 11 | from video_processor.utils.usage_tracker import UsageTracker |
| 12 | 12 | |
| 13 | 13 | load_dotenv() |
| 14 | 14 | logger = logging.getLogger(__name__) |
| 15 | + | |
| 16 | + | |
| 17 | +def _ensure_providers_registered() -> None: | |
| 18 | + """Import all built-in provider modules so they register themselves.""" | |
| 19 | + if ProviderRegistry.all_registered(): | |
| 20 | + return | |
| 21 | + # Each module registers itself on import via ProviderRegistry.register() | |
| 22 | + import video_processor.providers.anthropic_provider # noqa: F401 | |
| 23 | + import video_processor.providers.gemini_provider # noqa: F401 | |
| 24 | + import video_processor.providers.ollama_provider # noqa: F401 | |
| 25 | + import video_processor.providers.openai_provider # noqa: F401 | |
| 26 | + | |
| 15 | 27 | |
| 16 | 28 | # Default model preference rankings (tried in order) |
| 17 | 29 | _VISION_PREFERENCES = [ |
| 18 | 30 | ("gemini", "gemini-2.5-flash"), |
| 19 | 31 | ("openai", "gpt-4o"), |
| @@ -57,10 +69,11 @@ | ||
| 57 | 69 | chat_model : override model for chat/LLM tasks |
| 58 | 70 | transcription_model : override model for transcription |
| 59 | 71 | provider : force all tasks to a single provider ('openai', 'anthropic', 'gemini') |
| 60 | 72 | auto : if True and no model specified, pick the best available |
| 61 | 73 | """ |
| 74 | + _ensure_providers_registered() | |
| 62 | 75 | self.auto = auto |
| 63 | 76 | self._providers: dict[str, BaseProvider] = {} |
| 64 | 77 | self._available_models: Optional[list[ModelInfo]] = None |
| 65 | 78 | self.usage = UsageTracker() |
| 66 | 79 | |
| @@ -79,67 +92,31 @@ | ||
| 79 | 92 | self._forced_provider = provider |
| 80 | 93 | |
| 81 | 94 | @staticmethod |
| 82 | 95 | def _default_for_provider(provider: str, capability: str) -> str: |
| 83 | 96 | """Return the default model for a provider/capability combo.""" |
| 84 | - defaults = { | |
| 85 | - "openai": {"chat": "gpt-4o", "vision": "gpt-4o", "audio": "whisper-1"}, | |
| 86 | - "anthropic": { | |
| 87 | - "chat": "claude-sonnet-4-5-20250929", | |
| 88 | - "vision": "claude-sonnet-4-5-20250929", | |
| 89 | - "audio": "", | |
| 90 | - }, | |
| 91 | - "gemini": { | |
| 92 | - "chat": "gemini-2.5-flash", | |
| 93 | - "vision": "gemini-2.5-flash", | |
| 94 | - "audio": "gemini-2.5-flash", | |
| 95 | - }, | |
| 96 | - "ollama": { | |
| 97 | - "chat": "", | |
| 98 | - "vision": "", | |
| 99 | - "audio": "", | |
| 100 | - }, | |
| 101 | - } | |
| 102 | - return defaults.get(provider, {}).get(capability, "") | |
| 97 | + defaults = ProviderRegistry.get_default_models(provider) | |
| 98 | + if defaults: | |
| 99 | + return defaults.get(capability, "") | |
| 100 | + # Fallback for unregistered providers | |
| 101 | + return "" | |
| 103 | 102 | |
| 104 | 103 | def _get_provider(self, provider_name: str) -> BaseProvider: |
| 105 | 104 | """Lazily initialize and cache a provider instance.""" |
| 106 | 105 | if provider_name not in self._providers: |
| 107 | - if provider_name == "openai": | |
| 108 | - from video_processor.providers.openai_provider import OpenAIProvider | |
| 109 | - | |
| 110 | - self._providers[provider_name] = OpenAIProvider() | |
| 111 | - elif provider_name == "anthropic": | |
| 112 | - from video_processor.providers.anthropic_provider import AnthropicProvider | |
| 113 | - | |
| 114 | - self._providers[provider_name] = AnthropicProvider() | |
| 115 | - elif provider_name == "gemini": | |
| 116 | - from video_processor.providers.gemini_provider import GeminiProvider | |
| 117 | - | |
| 118 | - self._providers[provider_name] = GeminiProvider() | |
| 119 | - elif provider_name == "ollama": | |
| 120 | - from video_processor.providers.ollama_provider import OllamaProvider | |
| 121 | - | |
| 122 | - self._providers[provider_name] = OllamaProvider() | |
| 123 | - else: | |
| 124 | - raise ValueError(f"Unknown provider: {provider_name}") | |
| 106 | + _ensure_providers_registered() | |
| 107 | + provider_class = ProviderRegistry.get(provider_name) | |
| 108 | + self._providers[provider_name] = provider_class() | |
| 125 | 109 | return self._providers[provider_name] |
| 126 | 110 | |
| 127 | 111 | def _provider_for_model(self, model_id: str) -> str: |
| 128 | 112 | """Infer the provider from a model id.""" |
| 129 | - if ( | |
| 130 | - model_id.startswith("gpt-") | |
| 131 | - or model_id.startswith("o1") | |
| 132 | - or model_id.startswith("o3") | |
| 133 | - or model_id.startswith("o4") | |
| 134 | - or model_id.startswith("whisper") | |
| 135 | - ): | |
| 136 | - return "openai" | |
| 137 | - if model_id.startswith("claude-"): | |
| 138 | - return "anthropic" | |
| 139 | - if model_id.startswith("gemini-"): | |
| 140 | - return "gemini" | |
| 113 | + _ensure_providers_registered() | |
| 114 | + # Check registry prefix matching first | |
| 115 | + provider_name = ProviderRegistry.get_by_model(model_id) | |
| 116 | + if provider_name: | |
| 117 | + return provider_name | |
| 141 | 118 | # Try discovery (exact match, then prefix match for ollama name:tag format) |
| 142 | 119 | models = self._get_available_models() |
| 143 | 120 | for m in models: |
| 144 | 121 | if m.id == model_id: |
| 145 | 122 | return m.provider |
| 146 | 123 |
| --- video_processor/providers/manager.py | |
| +++ video_processor/providers/manager.py | |
| @@ -4,16 +4,28 @@ | |
| 4 | from pathlib import Path |
| 5 | from typing import Optional |
| 6 | |
| 7 | from dotenv import load_dotenv |
| 8 | |
| 9 | from video_processor.providers.base import BaseProvider, ModelInfo |
| 10 | from video_processor.providers.discovery import discover_available_models |
| 11 | from video_processor.utils.usage_tracker import UsageTracker |
| 12 | |
| 13 | load_dotenv() |
| 14 | logger = logging.getLogger(__name__) |
| 15 | |
| 16 | # Default model preference rankings (tried in order) |
| 17 | _VISION_PREFERENCES = [ |
| 18 | ("gemini", "gemini-2.5-flash"), |
| 19 | ("openai", "gpt-4o"), |
| @@ -57,10 +69,11 @@ | |
| 57 | chat_model : override model for chat/LLM tasks |
| 58 | transcription_model : override model for transcription |
| 59 | provider : force all tasks to a single provider ('openai', 'anthropic', 'gemini') |
| 60 | auto : if True and no model specified, pick the best available |
| 61 | """ |
| 62 | self.auto = auto |
| 63 | self._providers: dict[str, BaseProvider] = {} |
| 64 | self._available_models: Optional[list[ModelInfo]] = None |
| 65 | self.usage = UsageTracker() |
| 66 | |
| @@ -79,67 +92,31 @@ | |
| 79 | self._forced_provider = provider |
| 80 | |
| 81 | @staticmethod |
| 82 | def _default_for_provider(provider: str, capability: str) -> str: |
| 83 | """Return the default model for a provider/capability combo.""" |
| 84 | defaults = { |
| 85 | "openai": {"chat": "gpt-4o", "vision": "gpt-4o", "audio": "whisper-1"}, |
| 86 | "anthropic": { |
| 87 | "chat": "claude-sonnet-4-5-20250929", |
| 88 | "vision": "claude-sonnet-4-5-20250929", |
| 89 | "audio": "", |
| 90 | }, |
| 91 | "gemini": { |
| 92 | "chat": "gemini-2.5-flash", |
| 93 | "vision": "gemini-2.5-flash", |
| 94 | "audio": "gemini-2.5-flash", |
| 95 | }, |
| 96 | "ollama": { |
| 97 | "chat": "", |
| 98 | "vision": "", |
| 99 | "audio": "", |
| 100 | }, |
| 101 | } |
| 102 | return defaults.get(provider, {}).get(capability, "") |
| 103 | |
| 104 | def _get_provider(self, provider_name: str) -> BaseProvider: |
| 105 | """Lazily initialize and cache a provider instance.""" |
| 106 | if provider_name not in self._providers: |
| 107 | if provider_name == "openai": |
| 108 | from video_processor.providers.openai_provider import OpenAIProvider |
| 109 | |
| 110 | self._providers[provider_name] = OpenAIProvider() |
| 111 | elif provider_name == "anthropic": |
| 112 | from video_processor.providers.anthropic_provider import AnthropicProvider |
| 113 | |
| 114 | self._providers[provider_name] = AnthropicProvider() |
| 115 | elif provider_name == "gemini": |
| 116 | from video_processor.providers.gemini_provider import GeminiProvider |
| 117 | |
| 118 | self._providers[provider_name] = GeminiProvider() |
| 119 | elif provider_name == "ollama": |
| 120 | from video_processor.providers.ollama_provider import OllamaProvider |
| 121 | |
| 122 | self._providers[provider_name] = OllamaProvider() |
| 123 | else: |
| 124 | raise ValueError(f"Unknown provider: {provider_name}") |
| 125 | return self._providers[provider_name] |
| 126 | |
| 127 | def _provider_for_model(self, model_id: str) -> str: |
| 128 | """Infer the provider from a model id.""" |
| 129 | if ( |
| 130 | model_id.startswith("gpt-") |
| 131 | or model_id.startswith("o1") |
| 132 | or model_id.startswith("o3") |
| 133 | or model_id.startswith("o4") |
| 134 | or model_id.startswith("whisper") |
| 135 | ): |
| 136 | return "openai" |
| 137 | if model_id.startswith("claude-"): |
| 138 | return "anthropic" |
| 139 | if model_id.startswith("gemini-"): |
| 140 | return "gemini" |
| 141 | # Try discovery (exact match, then prefix match for ollama name:tag format) |
| 142 | models = self._get_available_models() |
| 143 | for m in models: |
| 144 | if m.id == model_id: |
| 145 | return m.provider |
| 146 |
| --- video_processor/providers/manager.py | |
| +++ video_processor/providers/manager.py | |
| @@ -4,16 +4,28 @@ | |
| 4 | from pathlib import Path |
| 5 | from typing import Optional |
| 6 | |
| 7 | from dotenv import load_dotenv |
| 8 | |
| 9 | from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry |
| 10 | from video_processor.providers.discovery import discover_available_models |
| 11 | from video_processor.utils.usage_tracker import UsageTracker |
| 12 | |
| 13 | load_dotenv() |
| 14 | logger = logging.getLogger(__name__) |
| 15 | |
| 16 | |
| 17 | def _ensure_providers_registered() -> None: |
| 18 | """Import all built-in provider modules so they register themselves.""" |
| 19 | if ProviderRegistry.all_registered(): |
| 20 | return |
| 21 | # Each module registers itself on import via ProviderRegistry.register() |
| 22 | import video_processor.providers.anthropic_provider # noqa: F401 |
| 23 | import video_processor.providers.gemini_provider # noqa: F401 |
| 24 | import video_processor.providers.ollama_provider # noqa: F401 |
| 25 | import video_processor.providers.openai_provider # noqa: F401 |
| 26 | |
| 27 | |
| 28 | # Default model preference rankings (tried in order) |
| 29 | _VISION_PREFERENCES = [ |
| 30 | ("gemini", "gemini-2.5-flash"), |
| 31 | ("openai", "gpt-4o"), |
| @@ -57,10 +69,11 @@ | |
| 69 | chat_model : override model for chat/LLM tasks |
| 70 | transcription_model : override model for transcription |
| 71 | provider : force all tasks to a single provider ('openai', 'anthropic', 'gemini') |
| 72 | auto : if True and no model specified, pick the best available |
| 73 | """ |
| 74 | _ensure_providers_registered() |
| 75 | self.auto = auto |
| 76 | self._providers: dict[str, BaseProvider] = {} |
| 77 | self._available_models: Optional[list[ModelInfo]] = None |
| 78 | self.usage = UsageTracker() |
| 79 | |
| @@ -79,67 +92,31 @@ | |
| 92 | self._forced_provider = provider |
| 93 | |
| 94 | @staticmethod |
| 95 | def _default_for_provider(provider: str, capability: str) -> str: |
| 96 | """Return the default model for a provider/capability combo.""" |
| 97 | defaults = ProviderRegistry.get_default_models(provider) |
| 98 | if defaults: |
| 99 | return defaults.get(capability, "") |
| 100 | # Fallback for unregistered providers |
| 101 | return "" |
| 102 | |
| 103 | def _get_provider(self, provider_name: str) -> BaseProvider: |
| 104 | """Lazily initialize and cache a provider instance.""" |
| 105 | if provider_name not in self._providers: |
| 106 | _ensure_providers_registered() |
| 107 | provider_class = ProviderRegistry.get(provider_name) |
| 108 | self._providers[provider_name] = provider_class() |
| 109 | return self._providers[provider_name] |
| 110 | |
| 111 | def _provider_for_model(self, model_id: str) -> str: |
| 112 | """Infer the provider from a model id.""" |
| 113 | _ensure_providers_registered() |
| 114 | # Check registry prefix matching first |
| 115 | provider_name = ProviderRegistry.get_by_model(model_id) |
| 116 | if provider_name: |
| 117 | return provider_name |
| 118 | # Try discovery (exact match, then prefix match for ollama name:tag format) |
| 119 | models = self._get_available_models() |
| 120 | for m in models: |
| 121 | if m.id == model_id: |
| 122 | return m.provider |
| 123 |
| --- video_processor/providers/ollama_provider.py | ||
| +++ video_processor/providers/ollama_provider.py | ||
| @@ -7,11 +7,11 @@ | ||
| 7 | 7 | from typing import Optional |
| 8 | 8 | |
| 9 | 9 | import requests |
| 10 | 10 | from openai import OpenAI |
| 11 | 11 | |
| 12 | -from video_processor.providers.base import BaseProvider, ModelInfo | |
| 12 | +from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry | |
| 13 | 13 | |
| 14 | 14 | logger = logging.getLogger(__name__) |
| 15 | 15 | |
| 16 | 16 | # Known vision-capable model families (base name before the colon/tag) |
| 17 | 17 | _VISION_FAMILIES = { |
| @@ -168,5 +168,14 @@ | ||
| 168 | 168 | ) |
| 169 | 169 | ) |
| 170 | 170 | except Exception as e: |
| 171 | 171 | logger.warning(f"Failed to list Ollama models: {e}") |
| 172 | 172 | return sorted(models, key=lambda m: m.id) |
| 173 | + | |
| 174 | + | |
| 175 | +ProviderRegistry.register( | |
| 176 | + name="ollama", | |
| 177 | + provider_class=OllamaProvider, | |
| 178 | + env_var="", | |
| 179 | + model_prefixes=[], | |
| 180 | + default_models={"chat": "", "vision": "", "audio": ""}, | |
| 181 | +) | |
| 173 | 182 |
| --- video_processor/providers/ollama_provider.py | |
| +++ video_processor/providers/ollama_provider.py | |
| @@ -7,11 +7,11 @@ | |
| 7 | from typing import Optional |
| 8 | |
| 9 | import requests |
| 10 | from openai import OpenAI |
| 11 | |
| 12 | from video_processor.providers.base import BaseProvider, ModelInfo |
| 13 | |
| 14 | logger = logging.getLogger(__name__) |
| 15 | |
| 16 | # Known vision-capable model families (base name before the colon/tag) |
| 17 | _VISION_FAMILIES = { |
| @@ -168,5 +168,14 @@ | |
| 168 | ) |
| 169 | ) |
| 170 | except Exception as e: |
| 171 | logger.warning(f"Failed to list Ollama models: {e}") |
| 172 | return sorted(models, key=lambda m: m.id) |
| 173 |
| --- video_processor/providers/ollama_provider.py | |
| +++ video_processor/providers/ollama_provider.py | |
| @@ -7,11 +7,11 @@ | |
| 7 | from typing import Optional |
| 8 | |
| 9 | import requests |
| 10 | from openai import OpenAI |
| 11 | |
| 12 | from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry |
| 13 | |
| 14 | logger = logging.getLogger(__name__) |
| 15 | |
| 16 | # Known vision-capable model families (base name before the colon/tag) |
| 17 | _VISION_FAMILIES = { |
| @@ -168,5 +168,14 @@ | |
| 168 | ) |
| 169 | ) |
| 170 | except Exception as e: |
| 171 | logger.warning(f"Failed to list Ollama models: {e}") |
| 172 | return sorted(models, key=lambda m: m.id) |
| 173 | |
| 174 | |
| 175 | ProviderRegistry.register( |
| 176 | name="ollama", |
| 177 | provider_class=OllamaProvider, |
| 178 | env_var="", |
| 179 | model_prefixes=[], |
| 180 | default_models={"chat": "", "vision": "", "audio": ""}, |
| 181 | ) |
| 182 |
| --- video_processor/providers/openai_provider.py | ||
| +++ video_processor/providers/openai_provider.py | ||
| @@ -7,11 +7,11 @@ | ||
| 7 | 7 | from typing import Optional |
| 8 | 8 | |
| 9 | 9 | from dotenv import load_dotenv |
| 10 | 10 | from openai import OpenAI |
| 11 | 11 | |
| 12 | -from video_processor.providers.base import BaseProvider, ModelInfo | |
| 12 | +from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry | |
| 13 | 13 | |
| 14 | 14 | load_dotenv() |
| 15 | 15 | logger = logging.getLogger(__name__) |
| 16 | 16 | |
| 17 | 17 | # Models known to have vision capability |
| @@ -225,5 +225,14 @@ | ||
| 225 | 225 | ) |
| 226 | 226 | ) |
| 227 | 227 | except Exception as e: |
| 228 | 228 | logger.warning(f"Failed to list OpenAI models: {e}") |
| 229 | 229 | return sorted(models, key=lambda m: m.id) |
| 230 | + | |
| 231 | + | |
| 232 | +ProviderRegistry.register( | |
| 233 | + name="openai", | |
| 234 | + provider_class=OpenAIProvider, | |
| 235 | + env_var="OPENAI_API_KEY", | |
| 236 | + model_prefixes=["gpt-", "o1", "o3", "o4", "whisper"], | |
| 237 | + default_models={"chat": "gpt-4o", "vision": "gpt-4o", "audio": "whisper-1"}, | |
| 238 | +) | |
| 230 | 239 |
| --- video_processor/providers/openai_provider.py | |
| +++ video_processor/providers/openai_provider.py | |
| @@ -7,11 +7,11 @@ | |
| 7 | from typing import Optional |
| 8 | |
| 9 | from dotenv import load_dotenv |
| 10 | from openai import OpenAI |
| 11 | |
| 12 | from video_processor.providers.base import BaseProvider, ModelInfo |
| 13 | |
| 14 | load_dotenv() |
| 15 | logger = logging.getLogger(__name__) |
| 16 | |
| 17 | # Models known to have vision capability |
| @@ -225,5 +225,14 @@ | |
| 225 | ) |
| 226 | ) |
| 227 | except Exception as e: |
| 228 | logger.warning(f"Failed to list OpenAI models: {e}") |
| 229 | return sorted(models, key=lambda m: m.id) |
| 230 |
| --- video_processor/providers/openai_provider.py | |
| +++ video_processor/providers/openai_provider.py | |
| @@ -7,11 +7,11 @@ | |
| 7 | from typing import Optional |
| 8 | |
| 9 | from dotenv import load_dotenv |
| 10 | from openai import OpenAI |
| 11 | |
| 12 | from video_processor.providers.base import BaseProvider, ModelInfo, ProviderRegistry |
| 13 | |
| 14 | load_dotenv() |
| 15 | logger = logging.getLogger(__name__) |
| 16 | |
| 17 | # Models known to have vision capability |
| @@ -225,5 +225,14 @@ | |
| 225 | ) |
| 226 | ) |
| 227 | except Exception as e: |
| 228 | logger.warning(f"Failed to list OpenAI models: {e}") |
| 229 | return sorted(models, key=lambda m: m.id) |
| 230 | |
| 231 | |
| 232 | ProviderRegistry.register( |
| 233 | name="openai", |
| 234 | provider_class=OpenAIProvider, |
| 235 | env_var="OPENAI_API_KEY", |
| 236 | model_prefixes=["gpt-", "o1", "o3", "o4", "whisper"], |
| 237 | default_models={"chat": "gpt-4o", "vision": "gpt-4o", "audio": "whisper-1"}, |
| 238 | ) |
| 239 |