PlanOpticon

planopticon / tests / test_providers.py
Blame History Raw 477 lines
1
"""Tests for the provider abstraction layer."""
2
3
import importlib
4
from unittest.mock import MagicMock, patch
5
6
import pytest
7
8
from video_processor.providers.base import (
9
BaseProvider,
10
ModelInfo,
11
OpenAICompatibleProvider,
12
ProviderRegistry,
13
)
14
from video_processor.providers.manager import ProviderManager
15
16
# ---------------------------------------------------------------------------
17
# ModelInfo
18
# ---------------------------------------------------------------------------
19
20
21
class TestModelInfo:
22
def test_basic(self):
23
m = ModelInfo(id="gpt-4o", provider="openai", capabilities=["chat", "vision"])
24
assert m.id == "gpt-4o"
25
assert "vision" in m.capabilities
26
27
def test_round_trip(self):
28
m = ModelInfo(
29
id="claude-sonnet-4-5-20250929",
30
provider="anthropic",
31
display_name="Claude Sonnet",
32
capabilities=["chat", "vision"],
33
)
34
restored = ModelInfo.model_validate_json(m.model_dump_json())
35
assert restored == m
36
37
def test_defaults(self):
38
m = ModelInfo(id="x", provider="y")
39
assert m.display_name == ""
40
assert m.capabilities == []
41
42
43
# ---------------------------------------------------------------------------
44
# ProviderRegistry
45
# ---------------------------------------------------------------------------
46
47
48
class TestProviderRegistry:
49
"""Test ProviderRegistry class methods.
50
51
We save and restore the internal _providers dict around each test so that
52
registrations from one test don't leak into another.
53
"""
54
55
@pytest.fixture(autouse=True)
56
def _save_restore_registry(self):
57
original = dict(ProviderRegistry._providers)
58
yield
59
ProviderRegistry._providers = original
60
61
def test_register_and_get(self):
62
dummy_cls = type("Dummy", (), {})
63
ProviderRegistry.register("test_prov", dummy_cls, env_var="TEST_KEY")
64
assert ProviderRegistry.get("test_prov") is dummy_cls
65
66
def test_get_unknown_raises(self):
67
with pytest.raises(ValueError, match="Unknown provider"):
68
ProviderRegistry.get("nonexistent_provider_xyz")
69
70
def test_get_by_model_prefix(self):
71
dummy_cls = type("Dummy", (), {})
72
ProviderRegistry.register("myprov", dummy_cls, model_prefixes=["mymodel-"])
73
assert ProviderRegistry.get_by_model("mymodel-7b") == "myprov"
74
assert ProviderRegistry.get_by_model("othermodel-7b") is None
75
76
def test_get_by_model_returns_none_for_no_match(self):
77
assert ProviderRegistry.get_by_model("totally_unknown_model_xyz") is None
78
79
def test_available_with_env_var(self):
80
dummy_cls = type("Dummy", (), {})
81
ProviderRegistry.register("envprov", dummy_cls, env_var="ENVPROV_KEY")
82
# Not in env -> should not appear
83
with patch.dict("os.environ", {}, clear=True):
84
avail = ProviderRegistry.available()
85
assert "envprov" not in avail
86
87
# In env -> should appear
88
with patch.dict("os.environ", {"ENVPROV_KEY": "secret"}):
89
avail = ProviderRegistry.available()
90
assert "envprov" in avail
91
92
def test_available_no_env_var_required(self):
93
dummy_cls = type("Dummy", (), {})
94
ProviderRegistry.register("noenvprov", dummy_cls, env_var="")
95
avail = ProviderRegistry.available()
96
assert "noenvprov" in avail
97
98
def test_all_registered(self):
99
dummy_cls = type("Dummy", (), {})
100
ProviderRegistry.register("regprov", dummy_cls, env_var="X", default_models={"chat": "m1"})
101
all_reg = ProviderRegistry.all_registered()
102
assert "regprov" in all_reg
103
assert all_reg["regprov"]["class"] is dummy_cls
104
105
def test_get_default_models(self):
106
dummy_cls = type("Dummy", (), {})
107
ProviderRegistry.register(
108
"defprov", dummy_cls, default_models={"chat": "c1", "vision": "v1"}
109
)
110
defaults = ProviderRegistry.get_default_models("defprov")
111
assert defaults == {"chat": "c1", "vision": "v1"}
112
113
def test_get_default_models_unknown(self):
114
assert ProviderRegistry.get_default_models("unknown_prov_xyz") == {}
115
116
117
# ---------------------------------------------------------------------------
118
# ProviderManager
119
# ---------------------------------------------------------------------------
120
121
122
class TestProviderManager:
123
def _make_mock_provider(self, name="openai"):
124
provider = MagicMock(spec=BaseProvider)
125
provider.provider_name = name
126
provider.chat.return_value = "test response"
127
provider.analyze_image.return_value = "image analysis"
128
provider.transcribe_audio.return_value = {
129
"text": "hello world",
130
"segments": [],
131
"provider": name,
132
"model": "test",
133
}
134
return provider
135
136
def test_init_with_explicit_models(self):
137
mgr = ProviderManager(
138
vision_model="gpt-4o",
139
chat_model="claude-sonnet-4-5-20250929",
140
transcription_model="whisper-1",
141
)
142
assert mgr.vision_model == "gpt-4o"
143
assert mgr.chat_model == "claude-sonnet-4-5-20250929"
144
assert mgr.transcription_model == "whisper-1"
145
146
def test_init_forced_provider(self):
147
mgr = ProviderManager(provider="gemini")
148
assert mgr.vision_model == "gemini-2.5-flash"
149
assert mgr.chat_model == "gemini-2.5-flash"
150
assert mgr.transcription_model == "gemini-2.5-flash"
151
152
def test_init_forced_provider_ollama(self):
153
mgr = ProviderManager(provider="ollama")
154
assert mgr.vision_model == ""
155
assert mgr.chat_model == ""
156
assert mgr.transcription_model == ""
157
158
def test_init_no_overrides(self):
159
mgr = ProviderManager()
160
assert mgr.vision_model is None
161
assert mgr.chat_model is None
162
assert mgr.transcription_model is None
163
assert mgr.auto is True
164
165
def test_default_for_provider_gemini(self):
166
result = ProviderManager._default_for_provider("gemini", "vision")
167
assert result == "gemini-2.5-flash"
168
169
def test_default_for_provider_openai(self):
170
result = ProviderManager._default_for_provider("openai", "chat")
171
assert isinstance(result, str)
172
assert len(result) > 0
173
174
def test_default_for_provider_unknown(self):
175
result = ProviderManager._default_for_provider("nonexistent_xyz", "chat")
176
assert result == ""
177
178
def test_provider_for_model(self):
179
mgr = ProviderManager()
180
assert mgr._provider_for_model("gpt-4o") == "openai"
181
assert mgr._provider_for_model("claude-sonnet-4-5-20250929") == "anthropic"
182
assert mgr._provider_for_model("gemini-2.5-flash") == "gemini"
183
assert mgr._provider_for_model("whisper-1") == "openai"
184
185
def test_provider_for_model_ollama_via_discovery(self):
186
mgr = ProviderManager()
187
mgr._available_models = [
188
ModelInfo(id="llama3.2:latest", provider="ollama", capabilities=["chat"]),
189
]
190
assert mgr._provider_for_model("llama3.2:latest") == "ollama"
191
192
def test_provider_for_model_ollama_fuzzy_tag(self):
193
mgr = ProviderManager()
194
mgr._available_models = [
195
ModelInfo(id="llama3.2:latest", provider="ollama", capabilities=["chat"]),
196
]
197
assert mgr._provider_for_model("llama3.2") == "ollama"
198
199
@patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"})
200
def test_chat_routes_to_provider(self):
201
mgr = ProviderManager(chat_model="gpt-4o")
202
mock_prov = self._make_mock_provider("openai")
203
mgr._providers["openai"] = mock_prov
204
205
result = mgr.chat([{"role": "user", "content": "hello"}])
206
assert result == "test response"
207
mock_prov.chat.assert_called_once()
208
209
@patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"})
210
def test_analyze_image_routes(self):
211
mgr = ProviderManager(vision_model="gpt-4o")
212
mock_prov = self._make_mock_provider("openai")
213
mgr._providers["openai"] = mock_prov
214
215
result = mgr.analyze_image(b"fake-image", "describe this")
216
assert result == "image analysis"
217
mock_prov.analyze_image.assert_called_once()
218
219
@patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"})
220
def test_transcribe_routes(self):
221
mgr = ProviderManager(transcription_model="whisper-1")
222
mock_prov = self._make_mock_provider("openai")
223
mgr._providers["openai"] = mock_prov
224
225
result = mgr.transcribe_audio("/tmp/test.wav")
226
assert result["text"] == "hello world"
227
mock_prov.transcribe_audio.assert_called_once()
228
229
def test_get_models_used(self):
230
mgr = ProviderManager(
231
vision_model="gpt-4o",
232
chat_model="claude-sonnet-4-5-20250929",
233
transcription_model="whisper-1",
234
)
235
for name in ["openai", "anthropic"]:
236
mgr._providers[name] = self._make_mock_provider(name)
237
238
used = mgr.get_models_used()
239
assert "vision" in used
240
assert used["vision"] == "openai/gpt-4o"
241
assert used["chat"] == "anthropic/claude-sonnet-4-5-20250929"
242
243
def test_track_records_usage(self):
244
mgr = ProviderManager(chat_model="gpt-4o")
245
mock_prov = self._make_mock_provider("openai")
246
mock_prov._last_usage = {"input_tokens": 10, "output_tokens": 20}
247
mgr._providers["openai"] = mock_prov
248
249
mgr.chat([{"role": "user", "content": "hi"}])
250
assert mgr.usage.total_input_tokens == 10
251
assert mgr.usage.total_output_tokens == 20
252
253
254
# ---------------------------------------------------------------------------
255
# OpenAICompatibleProvider
256
# ---------------------------------------------------------------------------
257
258
259
class TestOpenAICompatibleProvider:
260
@patch("openai.OpenAI")
261
def test_chat(self, mock_openai_cls):
262
mock_client = MagicMock()
263
mock_openai_cls.return_value = mock_client
264
265
mock_choice = MagicMock()
266
mock_choice.message.content = "hello back"
267
mock_response = MagicMock()
268
mock_response.choices = [mock_choice]
269
mock_response.usage.prompt_tokens = 5
270
mock_response.usage.completion_tokens = 10
271
mock_client.chat.completions.create.return_value = mock_response
272
273
provider = OpenAICompatibleProvider(api_key="test", base_url="http://test")
274
result = provider.chat([{"role": "user", "content": "hi"}], model="test-model")
275
assert result == "hello back"
276
assert provider._last_usage == {"input_tokens": 5, "output_tokens": 10}
277
278
@patch("openai.OpenAI")
279
def test_analyze_image(self, mock_openai_cls):
280
mock_client = MagicMock()
281
mock_openai_cls.return_value = mock_client
282
283
mock_choice = MagicMock()
284
mock_choice.message.content = "a cat"
285
mock_response = MagicMock()
286
mock_response.choices = [mock_choice]
287
mock_response.usage.prompt_tokens = 100
288
mock_response.usage.completion_tokens = 5
289
mock_client.chat.completions.create.return_value = mock_response
290
291
provider = OpenAICompatibleProvider(api_key="test", base_url="http://test")
292
result = provider.analyze_image(b"\x89PNG", "what is this?")
293
assert result == "a cat"
294
assert provider._last_usage["input_tokens"] == 100
295
296
@patch("openai.OpenAI")
297
def test_transcribe_raises(self, mock_openai_cls):
298
provider = OpenAICompatibleProvider(api_key="test", base_url="http://test")
299
with pytest.raises(NotImplementedError):
300
provider.transcribe_audio("/tmp/audio.wav")
301
302
@patch("openai.OpenAI")
303
def test_list_models(self, mock_openai_cls):
304
mock_client = MagicMock()
305
mock_openai_cls.return_value = mock_client
306
307
mock_model = MagicMock()
308
mock_model.id = "test-model-1"
309
mock_client.models.list.return_value = [mock_model]
310
311
provider = OpenAICompatibleProvider(api_key="test", base_url="http://test")
312
provider.provider_name = "testprov"
313
models = provider.list_models()
314
assert len(models) == 1
315
assert models[0].id == "test-model-1"
316
assert models[0].provider == "testprov"
317
318
@patch("openai.OpenAI")
319
def test_list_models_handles_error(self, mock_openai_cls):
320
mock_client = MagicMock()
321
mock_openai_cls.return_value = mock_client
322
mock_client.models.list.side_effect = Exception("connection error")
323
324
provider = OpenAICompatibleProvider(api_key="test", base_url="http://test")
325
models = provider.list_models()
326
assert models == []
327
328
329
# ---------------------------------------------------------------------------
330
# Discovery
331
# ---------------------------------------------------------------------------
332
333
334
class TestDiscovery:
335
@patch("video_processor.providers.discovery._cached_models", None)
336
@patch(
337
"video_processor.providers.ollama_provider.OllamaProvider.is_available",
338
return_value=False,
339
)
340
@patch.dict("os.environ", {}, clear=True)
341
def test_discover_skips_missing_keys(self, mock_ollama):
342
from video_processor.providers.discovery import discover_available_models
343
344
models = discover_available_models(api_keys={"openai": "", "anthropic": "", "gemini": ""})
345
assert models == []
346
347
@patch.dict("os.environ", {}, clear=True)
348
@patch(
349
"video_processor.providers.ollama_provider.OllamaProvider.is_available",
350
return_value=False,
351
)
352
@patch("video_processor.providers.discovery._cached_models", None)
353
def test_discover_caches_results(self, mock_ollama):
354
from video_processor.providers import discovery
355
356
models = discovery.discover_available_models(
357
api_keys={"openai": "", "anthropic": "", "gemini": ""}
358
)
359
assert models == []
360
# Second call should use cache
361
models2 = discovery.discover_available_models(api_keys={"openai": "key"})
362
assert models2 == [] # Still cached empty result
363
364
discovery.clear_discovery_cache()
365
366
@patch("video_processor.providers.discovery._cached_models", None)
367
@patch(
368
"video_processor.providers.ollama_provider.OllamaProvider.is_available",
369
return_value=False,
370
)
371
@patch.dict("os.environ", {}, clear=True)
372
def test_force_refresh_clears_cache(self, mock_ollama):
373
from video_processor.providers import discovery
374
375
# Warm the cache
376
discovery.discover_available_models(api_keys={"openai": "", "anthropic": "", "gemini": ""})
377
# Force refresh should re-run
378
models = discovery.discover_available_models(
379
api_keys={"openai": "", "anthropic": "", "gemini": ""},
380
force_refresh=True,
381
)
382
assert models == []
383
384
def test_clear_discovery_cache(self):
385
from video_processor.providers import discovery
386
387
discovery._cached_models = [ModelInfo(id="x", provider="y")]
388
discovery.clear_discovery_cache()
389
assert discovery._cached_models is None
390
391
392
# ---------------------------------------------------------------------------
393
# OllamaProvider
394
# ---------------------------------------------------------------------------
395
396
397
class TestOllamaProvider:
398
@patch("video_processor.providers.ollama_provider.requests")
399
def test_is_available_when_running(self, mock_requests):
400
mock_resp = MagicMock()
401
mock_resp.status_code = 200
402
mock_requests.get.return_value = mock_resp
403
404
from video_processor.providers.ollama_provider import OllamaProvider
405
406
assert OllamaProvider.is_available()
407
408
@patch("video_processor.providers.ollama_provider.requests")
409
def test_is_available_when_not_running(self, mock_requests):
410
mock_requests.get.side_effect = ConnectionError
411
412
from video_processor.providers.ollama_provider import OllamaProvider
413
414
assert not OllamaProvider.is_available()
415
416
@patch("video_processor.providers.ollama_provider.requests")
417
@patch("video_processor.providers.ollama_provider.OpenAI")
418
def test_transcribe_raises(self, mock_openai, mock_requests):
419
from video_processor.providers.ollama_provider import OllamaProvider
420
421
provider = OllamaProvider()
422
with pytest.raises(NotImplementedError):
423
provider.transcribe_audio("/tmp/test.wav")
424
425
@patch("video_processor.providers.ollama_provider.requests")
426
@patch("video_processor.providers.ollama_provider.OpenAI")
427
def test_list_models(self, mock_openai, mock_requests):
428
mock_resp = MagicMock()
429
mock_resp.status_code = 200
430
mock_resp.json.return_value = {
431
"models": [
432
{"name": "llama3.2:latest", "details": {"family": "llama"}},
433
{"name": "llava:13b", "details": {"family": "llava"}},
434
]
435
}
436
mock_requests.get.return_value = mock_resp
437
438
from video_processor.providers.ollama_provider import OllamaProvider
439
440
provider = OllamaProvider()
441
models = provider.list_models()
442
assert len(models) == 2
443
assert models[0].provider == "ollama"
444
445
llava = [m for m in models if "llava" in m.id][0]
446
assert "vision" in llava.capabilities
447
448
llama = [m for m in models if "llama" in m.id][0]
449
assert "chat" in llama.capabilities
450
assert "vision" not in llama.capabilities
451
452
453
# ---------------------------------------------------------------------------
454
# Provider module imports
455
# ---------------------------------------------------------------------------
456
457
458
class TestProviderImports:
459
"""Verify that all provider modules import without errors."""
460
461
PROVIDER_MODULES = [
462
"video_processor.providers.openai_provider",
463
"video_processor.providers.anthropic_provider",
464
"video_processor.providers.gemini_provider",
465
"video_processor.providers.ollama_provider",
466
"video_processor.providers.azure_provider",
467
"video_processor.providers.together_provider",
468
"video_processor.providers.fireworks_provider",
469
"video_processor.providers.cerebras_provider",
470
"video_processor.providers.xai_provider",
471
]
472
473
@pytest.mark.parametrize("module_name", PROVIDER_MODULES)
474
def test_import(self, module_name):
475
mod = importlib.import_module(module_name)
476
assert mod is not None
477

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button