PlanOpticon

1
"""Abstract base class, registry, and shared types for provider implementations."""
2
3
import base64
4
import logging
5
import os
6
from abc import ABC, abstractmethod
7
from pathlib import Path
8
from typing import Dict, List, Optional
9
10
from pydantic import BaseModel, Field
11
12
logger = logging.getLogger(__name__)
13
14
15
class ModelInfo(BaseModel):
16
"""Information about an available model."""
17
18
id: str = Field(description="Model identifier (e.g. gpt-4o)")
19
provider: str = Field(description="Provider name (openai, anthropic, gemini)")
20
display_name: str = Field(default="", description="Human-readable name")
21
capabilities: List[str] = Field(
22
default_factory=list, description="Model capabilities: chat, vision, audio, embedding"
23
)
24
25
26
class BaseProvider(ABC):
27
"""Abstract base for all provider implementations."""
28
29
provider_name: str = ""
30
31
@abstractmethod
32
def chat(
33
self,
34
messages: list[dict],
35
max_tokens: int = 4096,
36
temperature: float = 0.7,
37
model: Optional[str] = None,
38
) -> str:
39
"""Send a chat completion request. Returns the assistant text."""
40
41
@abstractmethod
42
def analyze_image(
43
self,
44
image_bytes: bytes,
45
prompt: str,
46
max_tokens: int = 4096,
47
model: Optional[str] = None,
48
) -> str:
49
"""Analyze an image with a prompt. Returns the assistant text."""
50
51
@abstractmethod
52
def transcribe_audio(
53
self,
54
audio_path: str | Path,
55
language: Optional[str] = None,
56
model: Optional[str] = None,
57
) -> dict:
58
"""Transcribe an audio file. Returns dict with 'text', 'segments', etc."""
59
60
@abstractmethod
61
def list_models(self) -> list[ModelInfo]:
62
"""Discover available models from this provider's API."""
63
64
65
class ProviderRegistry:
66
"""Registry for provider classes. Providers register themselves with metadata."""
67
68
_providers: Dict[str, Dict] = {}
69
70
@classmethod
71
def register(
72
cls,
73
name: str,
74
provider_class: type,
75
env_var: str = "",
76
model_prefixes: Optional[List[str]] = None,
77
default_models: Optional[Dict[str, str]] = None,
78
) -> None:
79
"""Register a provider class with its metadata."""
80
cls._providers[name] = {
81
"class": provider_class,
82
"env_var": env_var,
83
"model_prefixes": model_prefixes or [],
84
"default_models": default_models or {},
85
}
86
87
@classmethod
88
def get(cls, name: str) -> type:
89
"""Return the provider class for a given name."""
90
if name not in cls._providers:
91
raise ValueError(f"Unknown provider: {name}")
92
return cls._providers[name]["class"]
93
94
@classmethod
95
def get_by_model(cls, model_id: str) -> Optional[str]:
96
"""Return provider name for a model ID based on prefix matching."""
97
for name, info in cls._providers.items():
98
for prefix in info["model_prefixes"]:
99
if model_id.startswith(prefix):
100
return name
101
return None
102
103
@classmethod
104
def get_default_models(cls, name: str) -> Dict[str, str]:
105
"""Return the default models dict for a provider."""
106
if name not in cls._providers:
107
return {}
108
return cls._providers[name].get("default_models", {})
109
110
@classmethod
111
def available(cls) -> List[str]:
112
"""Return names of providers whose env var is set (or have no env var requirement)."""
113
result = []
114
for name, info in cls._providers.items():
115
env_var = info.get("env_var", "")
116
if not env_var:
117
# Providers without an env var (e.g. ollama) need special availability checks
118
result.append(name)
119
elif os.getenv(env_var, ""):
120
result.append(name)
121
return result
122
123
@classmethod
124
def all_registered(cls) -> Dict[str, Dict]:
125
"""Return all registered providers and their metadata."""
126
return dict(cls._providers)
127
128
129
class OpenAICompatibleProvider(BaseProvider):
130
"""Base for providers using OpenAI-compatible APIs.
131
132
Suitable for Together, Fireworks, Cerebras, xAI, Azure, and similar services.
133
"""
134
135
provider_name: str = ""
136
base_url: str = ""
137
env_var: str = ""
138
139
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
140
from openai import OpenAI
141
142
self._api_key = api_key or os.getenv(self.env_var, "")
143
self._base_url = base_url or self.base_url
144
self._client = OpenAI(api_key=self._api_key, base_url=self._base_url)
145
self._last_usage = None
146
147
def chat(
148
self,
149
messages: list[dict],
150
max_tokens: int = 4096,
151
temperature: float = 0.7,
152
model: Optional[str] = None,
153
) -> str:
154
model = model or "gpt-4o"
155
response = self._client.chat.completions.create(
156
model=model,
157
messages=messages,
158
max_tokens=max_tokens,
159
temperature=temperature,
160
)
161
self._last_usage = {
162
"input_tokens": getattr(response.usage, "prompt_tokens", 0) if response.usage else 0,
163
"output_tokens": getattr(response.usage, "completion_tokens", 0)
164
if response.usage
165
else 0,
166
}
167
return response.choices[0].message.content or ""
168
169
def analyze_image(
170
self,
171
image_bytes: bytes,
172
prompt: str,
173
max_tokens: int = 4096,
174
model: Optional[str] = None,
175
) -> str:
176
model = model or "gpt-4o"
177
b64 = base64.b64encode(image_bytes).decode()
178
response = self._client.chat.completions.create(
179
model=model,
180
messages=[
181
{
182
"role": "user",
183
"content": [
184
{"type": "text", "text": prompt},
185
{
186
"type": "image_url",
187
"image_url": {"url": f"data:image/jpeg;base64,{b64}"},
188
},
189
],
190
}
191
],
192
max_tokens=max_tokens,
193
)
194
self._last_usage = {
195
"input_tokens": getattr(response.usage, "prompt_tokens", 0) if response.usage else 0,
196
"output_tokens": getattr(response.usage, "completion_tokens", 0)
197
if response.usage
198
else 0,
199
}
200
return response.choices[0].message.content or ""
201
202
def transcribe_audio(
203
self,
204
audio_path: str | Path,
205
language: Optional[str] = None,
206
model: Optional[str] = None,
207
) -> dict:
208
raise NotImplementedError(f"{self.provider_name} does not support audio transcription")
209
210
def list_models(self) -> list[ModelInfo]:
211
models = []
212
try:
213
for m in self._client.models.list():
214
mid = m.id
215
caps = ["chat"]
216
models.append(
217
ModelInfo(
218
id=mid,
219
provider=self.provider_name,
220
display_name=mid,
221
capabilities=caps,
222
)
223
)
224
except Exception as e:
225
logger.warning(f"Failed to list {self.provider_name} models: {e}")
226
return sorted(models, key=lambda m: m.id)
227

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button