|
a94205b…
|
leo
|
1 |
"""Auto-discover available models across providers.""" |
|
a94205b…
|
leo
|
2 |
|
|
a94205b…
|
leo
|
3 |
import logging |
|
a94205b…
|
leo
|
4 |
import os |
|
a94205b…
|
leo
|
5 |
from typing import Optional |
|
a94205b…
|
leo
|
6 |
|
|
a94205b…
|
leo
|
7 |
from dotenv import load_dotenv |
|
a94205b…
|
leo
|
8 |
|
|
0981a08…
|
noreply
|
9 |
from video_processor.providers.base import ModelInfo, ProviderRegistry |
|
a94205b…
|
leo
|
10 |
|
|
a94205b…
|
leo
|
11 |
load_dotenv() |
|
a94205b…
|
leo
|
12 |
logger = logging.getLogger(__name__) |
|
a94205b…
|
leo
|
13 |
|
|
a94205b…
|
leo
|
14 |
_cached_models: Optional[list[ModelInfo]] = None |
|
0981a08…
|
noreply
|
15 |
|
|
0981a08…
|
noreply
|
16 |
|
|
0981a08…
|
noreply
|
17 |
def _ensure_providers_registered() -> None: |
|
0981a08…
|
noreply
|
18 |
"""Import all built-in provider modules so they register themselves.""" |
|
0981a08…
|
noreply
|
19 |
if ProviderRegistry.all_registered(): |
|
0981a08…
|
noreply
|
20 |
return |
|
0981a08…
|
noreply
|
21 |
import video_processor.providers.anthropic_provider # noqa: F401 |
|
0981a08…
|
noreply
|
22 |
import video_processor.providers.gemini_provider # noqa: F401 |
|
0981a08…
|
noreply
|
23 |
import video_processor.providers.ollama_provider # noqa: F401 |
|
0981a08…
|
noreply
|
24 |
import video_processor.providers.openai_provider # noqa: F401 |
|
a94205b…
|
leo
|
25 |
|
|
a94205b…
|
leo
|
26 |
|
|
a94205b…
|
leo
|
27 |
def discover_available_models( |
|
a94205b…
|
leo
|
28 |
api_keys: Optional[dict[str, str]] = None, |
|
a94205b…
|
leo
|
29 |
force_refresh: bool = False, |
|
a94205b…
|
leo
|
30 |
) -> list[ModelInfo]: |
|
a94205b…
|
leo
|
31 |
""" |
|
a94205b…
|
leo
|
32 |
Discover available models from all configured providers. |
|
a94205b…
|
leo
|
33 |
|
|
a94205b…
|
leo
|
34 |
For each provider with a valid API key, calls list_models() and returns |
|
a94205b…
|
leo
|
35 |
a unified list. Results are cached for the session. |
|
a94205b…
|
leo
|
36 |
""" |
|
a94205b…
|
leo
|
37 |
global _cached_models |
|
a94205b…
|
leo
|
38 |
if _cached_models is not None and not force_refresh: |
|
a94205b…
|
leo
|
39 |
return _cached_models |
|
a94205b…
|
leo
|
40 |
|
|
0981a08…
|
noreply
|
41 |
_ensure_providers_registered() |
|
0981a08…
|
noreply
|
42 |
|
|
a94205b…
|
leo
|
43 |
keys = api_keys or { |
|
a94205b…
|
leo
|
44 |
"openai": os.getenv("OPENAI_API_KEY", ""), |
|
a94205b…
|
leo
|
45 |
"anthropic": os.getenv("ANTHROPIC_API_KEY", ""), |
|
a94205b…
|
leo
|
46 |
"gemini": os.getenv("GEMINI_API_KEY", ""), |
|
a94205b…
|
leo
|
47 |
} |
|
a94205b…
|
leo
|
48 |
|
|
a94205b…
|
leo
|
49 |
all_models: list[ModelInfo] = [] |
|
a94205b…
|
leo
|
50 |
|
|
0981a08…
|
noreply
|
51 |
for name, info in ProviderRegistry.all_registered().items(): |
|
0981a08…
|
noreply
|
52 |
env_var = info.get("env_var", "") |
|
0981a08…
|
noreply
|
53 |
provider_class = info["class"] |
|
0981a08…
|
noreply
|
54 |
|
|
0981a08…
|
noreply
|
55 |
if name == "ollama": |
|
0981a08…
|
noreply
|
56 |
# Ollama: no API key, check server availability |
|
0981a08…
|
noreply
|
57 |
try: |
|
0981a08…
|
noreply
|
58 |
if provider_class.is_available(): |
|
0981a08…
|
noreply
|
59 |
provider = provider_class() |
|
0981a08…
|
noreply
|
60 |
models = provider.list_models() |
|
0981a08…
|
noreply
|
61 |
logger.info(f"Discovered {len(models)} Ollama models") |
|
0981a08…
|
noreply
|
62 |
all_models.extend(models) |
|
0981a08…
|
noreply
|
63 |
except Exception as e: |
|
0981a08…
|
noreply
|
64 |
logger.info(f"Ollama discovery skipped: {e}") |
|
0981a08…
|
noreply
|
65 |
continue |
|
0981a08…
|
noreply
|
66 |
|
|
0981a08…
|
noreply
|
67 |
# For key-based providers, check the api_keys dict first, then env var |
|
0981a08…
|
noreply
|
68 |
key = keys.get(name, "") |
|
0981a08…
|
noreply
|
69 |
if not key and env_var: |
|
0981a08…
|
noreply
|
70 |
key = os.getenv(env_var, "") |
|
0981a08…
|
noreply
|
71 |
|
|
0981a08…
|
noreply
|
72 |
# Special case: Gemini also supports service account credentials |
|
0981a08…
|
noreply
|
73 |
gemini_creds = "" |
|
0981a08…
|
noreply
|
74 |
if name == "gemini": |
|
0981a08…
|
noreply
|
75 |
gemini_creds = os.getenv("GOOGLE_APPLICATION_CREDENTIALS", "") |
|
0981a08…
|
noreply
|
76 |
|
|
0981a08…
|
noreply
|
77 |
if not key and not gemini_creds: |
|
0981a08…
|
noreply
|
78 |
continue |
|
0981a08…
|
noreply
|
79 |
|
|
0981a08…
|
noreply
|
80 |
try: |
|
0981a08…
|
noreply
|
81 |
# Handle provider-specific constructor args |
|
0981a08…
|
noreply
|
82 |
if name == "gemini": |
|
0981a08…
|
noreply
|
83 |
provider = provider_class( |
|
0981a08…
|
noreply
|
84 |
api_key=key or None, |
|
0981a08…
|
noreply
|
85 |
credentials_path=gemini_creds or None, |
|
0981a08…
|
noreply
|
86 |
) |
|
0981a08…
|
noreply
|
87 |
else: |
|
0981a08…
|
noreply
|
88 |
provider = provider_class(api_key=key) |
|
0981a08…
|
noreply
|
89 |
models = provider.list_models() |
|
0981a08…
|
noreply
|
90 |
logger.info(f"Discovered {len(models)} {name.capitalize()} models") |
|
0981a08…
|
noreply
|
91 |
all_models.extend(models) |
|
0981a08…
|
noreply
|
92 |
except Exception as e: |
|
0981a08…
|
noreply
|
93 |
logger.info(f"{name.capitalize()} discovery skipped: {e}") |
|
a94205b…
|
leo
|
94 |
|
|
a94205b…
|
leo
|
95 |
# Sort by provider then id |
|
a94205b…
|
leo
|
96 |
all_models.sort(key=lambda m: (m.provider, m.id)) |
|
a94205b…
|
leo
|
97 |
_cached_models = all_models |
|
a94205b…
|
leo
|
98 |
logger.info(f"Total discovered models: {len(all_models)}") |
|
a94205b…
|
leo
|
99 |
return all_models |
|
a94205b…
|
leo
|
100 |
|
|
a94205b…
|
leo
|
101 |
|
|
a94205b…
|
leo
|
102 |
def clear_discovery_cache() -> None: |
|
a94205b…
|
leo
|
103 |
"""Clear the cached model list.""" |
|
a94205b…
|
leo
|
104 |
global _cached_models |
|
a94205b…
|
leo
|
105 |
_cached_models = None |