PlanOpticon

Blame History Raw 106 lines
1
"""Auto-discover available models across providers."""
2
3
import logging
4
import os
5
from typing import Optional
6
7
from dotenv import load_dotenv
8
9
from video_processor.providers.base import ModelInfo, ProviderRegistry
10
11
load_dotenv()
12
logger = logging.getLogger(__name__)
13
14
_cached_models: Optional[list[ModelInfo]] = None
15
16
17
def _ensure_providers_registered() -> None:
18
"""Import all built-in provider modules so they register themselves."""
19
if ProviderRegistry.all_registered():
20
return
21
import video_processor.providers.anthropic_provider # noqa: F401
22
import video_processor.providers.gemini_provider # noqa: F401
23
import video_processor.providers.ollama_provider # noqa: F401
24
import video_processor.providers.openai_provider # noqa: F401
25
26
27
def discover_available_models(
28
api_keys: Optional[dict[str, str]] = None,
29
force_refresh: bool = False,
30
) -> list[ModelInfo]:
31
"""
32
Discover available models from all configured providers.
33
34
For each provider with a valid API key, calls list_models() and returns
35
a unified list. Results are cached for the session.
36
"""
37
global _cached_models
38
if _cached_models is not None and not force_refresh:
39
return _cached_models
40
41
_ensure_providers_registered()
42
43
keys = api_keys or {
44
"openai": os.getenv("OPENAI_API_KEY", ""),
45
"anthropic": os.getenv("ANTHROPIC_API_KEY", ""),
46
"gemini": os.getenv("GEMINI_API_KEY", ""),
47
}
48
49
all_models: list[ModelInfo] = []
50
51
for name, info in ProviderRegistry.all_registered().items():
52
env_var = info.get("env_var", "")
53
provider_class = info["class"]
54
55
if name == "ollama":
56
# Ollama: no API key, check server availability
57
try:
58
if provider_class.is_available():
59
provider = provider_class()
60
models = provider.list_models()
61
logger.info(f"Discovered {len(models)} Ollama models")
62
all_models.extend(models)
63
except Exception as e:
64
logger.info(f"Ollama discovery skipped: {e}")
65
continue
66
67
# For key-based providers, check the api_keys dict first, then env var
68
key = keys.get(name, "")
69
if not key and env_var:
70
key = os.getenv(env_var, "")
71
72
# Special case: Gemini also supports service account credentials
73
gemini_creds = ""
74
if name == "gemini":
75
gemini_creds = os.getenv("GOOGLE_APPLICATION_CREDENTIALS", "")
76
77
if not key and not gemini_creds:
78
continue
79
80
try:
81
# Handle provider-specific constructor args
82
if name == "gemini":
83
provider = provider_class(
84
api_key=key or None,
85
credentials_path=gemini_creds or None,
86
)
87
else:
88
provider = provider_class(api_key=key)
89
models = provider.list_models()
90
logger.info(f"Discovered {len(models)} {name.capitalize()} models")
91
all_models.extend(models)
92
except Exception as e:
93
logger.info(f"{name.capitalize()} discovery skipped: {e}")
94
95
# Sort by provider then id
96
all_models.sort(key=lambda m: (m.provider, m.id))
97
_cached_models = all_models
98
logger.info(f"Total discovered models: {len(all_models)}")
99
return all_models
100
101
102
def clear_discovery_cache() -> None:
103
"""Clear the cached model list."""
104
global _cached_models
105
_cached_models = None
106

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button