PlanOpticon

Source Blame History 105 lines
a94205b… leo 1 """Auto-discover available models across providers."""
a94205b… leo 2
a94205b… leo 3 import logging
a94205b… leo 4 import os
a94205b… leo 5 from typing import Optional
a94205b… leo 6
a94205b… leo 7 from dotenv import load_dotenv
a94205b… leo 8
0981a08… noreply 9 from video_processor.providers.base import ModelInfo, ProviderRegistry
a94205b… leo 10
a94205b… leo 11 load_dotenv()
a94205b… leo 12 logger = logging.getLogger(__name__)
a94205b… leo 13
a94205b… leo 14 _cached_models: Optional[list[ModelInfo]] = None
0981a08… noreply 15
0981a08… noreply 16
0981a08… noreply 17 def _ensure_providers_registered() -> None:
0981a08… noreply 18 """Import all built-in provider modules so they register themselves."""
0981a08… noreply 19 if ProviderRegistry.all_registered():
0981a08… noreply 20 return
0981a08… noreply 21 import video_processor.providers.anthropic_provider # noqa: F401
0981a08… noreply 22 import video_processor.providers.gemini_provider # noqa: F401
0981a08… noreply 23 import video_processor.providers.ollama_provider # noqa: F401
0981a08… noreply 24 import video_processor.providers.openai_provider # noqa: F401
a94205b… leo 25
a94205b… leo 26
a94205b… leo 27 def discover_available_models(
a94205b… leo 28 api_keys: Optional[dict[str, str]] = None,
a94205b… leo 29 force_refresh: bool = False,
a94205b… leo 30 ) -> list[ModelInfo]:
a94205b… leo 31 """
a94205b… leo 32 Discover available models from all configured providers.
a94205b… leo 33
a94205b… leo 34 For each provider with a valid API key, calls list_models() and returns
a94205b… leo 35 a unified list. Results are cached for the session.
a94205b… leo 36 """
a94205b… leo 37 global _cached_models
a94205b… leo 38 if _cached_models is not None and not force_refresh:
a94205b… leo 39 return _cached_models
a94205b… leo 40
0981a08… noreply 41 _ensure_providers_registered()
0981a08… noreply 42
a94205b… leo 43 keys = api_keys or {
a94205b… leo 44 "openai": os.getenv("OPENAI_API_KEY", ""),
a94205b… leo 45 "anthropic": os.getenv("ANTHROPIC_API_KEY", ""),
a94205b… leo 46 "gemini": os.getenv("GEMINI_API_KEY", ""),
a94205b… leo 47 }
a94205b… leo 48
a94205b… leo 49 all_models: list[ModelInfo] = []
a94205b… leo 50
0981a08… noreply 51 for name, info in ProviderRegistry.all_registered().items():
0981a08… noreply 52 env_var = info.get("env_var", "")
0981a08… noreply 53 provider_class = info["class"]
0981a08… noreply 54
0981a08… noreply 55 if name == "ollama":
0981a08… noreply 56 # Ollama: no API key, check server availability
0981a08… noreply 57 try:
0981a08… noreply 58 if provider_class.is_available():
0981a08… noreply 59 provider = provider_class()
0981a08… noreply 60 models = provider.list_models()
0981a08… noreply 61 logger.info(f"Discovered {len(models)} Ollama models")
0981a08… noreply 62 all_models.extend(models)
0981a08… noreply 63 except Exception as e:
0981a08… noreply 64 logger.info(f"Ollama discovery skipped: {e}")
0981a08… noreply 65 continue
0981a08… noreply 66
0981a08… noreply 67 # For key-based providers, check the api_keys dict first, then env var
0981a08… noreply 68 key = keys.get(name, "")
0981a08… noreply 69 if not key and env_var:
0981a08… noreply 70 key = os.getenv(env_var, "")
0981a08… noreply 71
0981a08… noreply 72 # Special case: Gemini also supports service account credentials
0981a08… noreply 73 gemini_creds = ""
0981a08… noreply 74 if name == "gemini":
0981a08… noreply 75 gemini_creds = os.getenv("GOOGLE_APPLICATION_CREDENTIALS", "")
0981a08… noreply 76
0981a08… noreply 77 if not key and not gemini_creds:
0981a08… noreply 78 continue
0981a08… noreply 79
0981a08… noreply 80 try:
0981a08… noreply 81 # Handle provider-specific constructor args
0981a08… noreply 82 if name == "gemini":
0981a08… noreply 83 provider = provider_class(
0981a08… noreply 84 api_key=key or None,
0981a08… noreply 85 credentials_path=gemini_creds or None,
0981a08… noreply 86 )
0981a08… noreply 87 else:
0981a08… noreply 88 provider = provider_class(api_key=key)
0981a08… noreply 89 models = provider.list_models()
0981a08… noreply 90 logger.info(f"Discovered {len(models)} {name.capitalize()} models")
0981a08… noreply 91 all_models.extend(models)
0981a08… noreply 92 except Exception as e:
0981a08… noreply 93 logger.info(f"{name.capitalize()} discovery skipped: {e}")
a94205b… leo 94
a94205b… leo 95 # Sort by provider then id
a94205b… leo 96 all_models.sort(key=lambda m: (m.provider, m.id))
a94205b… leo 97 _cached_models = all_models
a94205b… leo 98 logger.info(f"Total discovered models: {len(all_models)}")
a94205b… leo 99 return all_models
a94205b… leo 100
a94205b… leo 101
a94205b… leo 102 def clear_discovery_cache() -> None:
a94205b… leo 103 """Clear the cached model list."""
a94205b… leo 104 global _cached_models
a94205b… leo 105 _cached_models = None

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button