|
1
|
"""Auto-discover available models across providers.""" |
|
2
|
|
|
3
|
import logging |
|
4
|
import os |
|
5
|
from typing import Optional |
|
6
|
|
|
7
|
from dotenv import load_dotenv |
|
8
|
|
|
9
|
from video_processor.providers.base import ModelInfo, ProviderRegistry |
|
10
|
|
|
11
|
load_dotenv() |
|
12
|
logger = logging.getLogger(__name__) |
|
13
|
|
|
14
|
_cached_models: Optional[list[ModelInfo]] = None |
|
15
|
|
|
16
|
|
|
17
|
def _ensure_providers_registered() -> None: |
|
18
|
"""Import all built-in provider modules so they register themselves.""" |
|
19
|
if ProviderRegistry.all_registered(): |
|
20
|
return |
|
21
|
import video_processor.providers.anthropic_provider # noqa: F401 |
|
22
|
import video_processor.providers.gemini_provider # noqa: F401 |
|
23
|
import video_processor.providers.ollama_provider # noqa: F401 |
|
24
|
import video_processor.providers.openai_provider # noqa: F401 |
|
25
|
|
|
26
|
|
|
27
|
def discover_available_models( |
|
28
|
api_keys: Optional[dict[str, str]] = None, |
|
29
|
force_refresh: bool = False, |
|
30
|
) -> list[ModelInfo]: |
|
31
|
""" |
|
32
|
Discover available models from all configured providers. |
|
33
|
|
|
34
|
For each provider with a valid API key, calls list_models() and returns |
|
35
|
a unified list. Results are cached for the session. |
|
36
|
""" |
|
37
|
global _cached_models |
|
38
|
if _cached_models is not None and not force_refresh: |
|
39
|
return _cached_models |
|
40
|
|
|
41
|
_ensure_providers_registered() |
|
42
|
|
|
43
|
keys = api_keys or { |
|
44
|
"openai": os.getenv("OPENAI_API_KEY", ""), |
|
45
|
"anthropic": os.getenv("ANTHROPIC_API_KEY", ""), |
|
46
|
"gemini": os.getenv("GEMINI_API_KEY", ""), |
|
47
|
} |
|
48
|
|
|
49
|
all_models: list[ModelInfo] = [] |
|
50
|
|
|
51
|
for name, info in ProviderRegistry.all_registered().items(): |
|
52
|
env_var = info.get("env_var", "") |
|
53
|
provider_class = info["class"] |
|
54
|
|
|
55
|
if name == "ollama": |
|
56
|
# Ollama: no API key, check server availability |
|
57
|
try: |
|
58
|
if provider_class.is_available(): |
|
59
|
provider = provider_class() |
|
60
|
models = provider.list_models() |
|
61
|
logger.info(f"Discovered {len(models)} Ollama models") |
|
62
|
all_models.extend(models) |
|
63
|
except Exception as e: |
|
64
|
logger.info(f"Ollama discovery skipped: {e}") |
|
65
|
continue |
|
66
|
|
|
67
|
# For key-based providers, check the api_keys dict first, then env var |
|
68
|
key = keys.get(name, "") |
|
69
|
if not key and env_var: |
|
70
|
key = os.getenv(env_var, "") |
|
71
|
|
|
72
|
# Special case: Gemini also supports service account credentials |
|
73
|
gemini_creds = "" |
|
74
|
if name == "gemini": |
|
75
|
gemini_creds = os.getenv("GOOGLE_APPLICATION_CREDENTIALS", "") |
|
76
|
|
|
77
|
if not key and not gemini_creds: |
|
78
|
continue |
|
79
|
|
|
80
|
try: |
|
81
|
# Handle provider-specific constructor args |
|
82
|
if name == "gemini": |
|
83
|
provider = provider_class( |
|
84
|
api_key=key or None, |
|
85
|
credentials_path=gemini_creds or None, |
|
86
|
) |
|
87
|
else: |
|
88
|
provider = provider_class(api_key=key) |
|
89
|
models = provider.list_models() |
|
90
|
logger.info(f"Discovered {len(models)} {name.capitalize()} models") |
|
91
|
all_models.extend(models) |
|
92
|
except Exception as e: |
|
93
|
logger.info(f"{name.capitalize()} discovery skipped: {e}") |
|
94
|
|
|
95
|
# Sort by provider then id |
|
96
|
all_models.sort(key=lambda m: (m.provider, m.id)) |
|
97
|
_cached_models = all_models |
|
98
|
logger.info(f"Total discovered models: {len(all_models)}") |
|
99
|
return all_models |
|
100
|
|
|
101
|
|
|
102
|
def clear_discovery_cache() -> None: |
|
103
|
"""Clear the cached model list.""" |
|
104
|
global _cached_models |
|
105
|
_cached_models = None |
|
106
|
|