|
1
|
"""Tests for the provider abstraction layer.""" |
|
2
|
|
|
3
|
import importlib |
|
4
|
from unittest.mock import MagicMock, patch |
|
5
|
|
|
6
|
import pytest |
|
7
|
|
|
8
|
from video_processor.providers.base import ( |
|
9
|
BaseProvider, |
|
10
|
ModelInfo, |
|
11
|
OpenAICompatibleProvider, |
|
12
|
ProviderRegistry, |
|
13
|
) |
|
14
|
from video_processor.providers.manager import ProviderManager |
|
15
|
|
|
16
|
# --------------------------------------------------------------------------- |
|
17
|
# ModelInfo |
|
18
|
# --------------------------------------------------------------------------- |
|
19
|
|
|
20
|
|
|
21
|
class TestModelInfo: |
|
22
|
def test_basic(self): |
|
23
|
m = ModelInfo(id="gpt-4o", provider="openai", capabilities=["chat", "vision"]) |
|
24
|
assert m.id == "gpt-4o" |
|
25
|
assert "vision" in m.capabilities |
|
26
|
|
|
27
|
def test_round_trip(self): |
|
28
|
m = ModelInfo( |
|
29
|
id="claude-sonnet-4-5-20250929", |
|
30
|
provider="anthropic", |
|
31
|
display_name="Claude Sonnet", |
|
32
|
capabilities=["chat", "vision"], |
|
33
|
) |
|
34
|
restored = ModelInfo.model_validate_json(m.model_dump_json()) |
|
35
|
assert restored == m |
|
36
|
|
|
37
|
def test_defaults(self): |
|
38
|
m = ModelInfo(id="x", provider="y") |
|
39
|
assert m.display_name == "" |
|
40
|
assert m.capabilities == [] |
|
41
|
|
|
42
|
|
|
43
|
# --------------------------------------------------------------------------- |
|
44
|
# ProviderRegistry |
|
45
|
# --------------------------------------------------------------------------- |
|
46
|
|
|
47
|
|
|
48
|
class TestProviderRegistry: |
|
49
|
"""Test ProviderRegistry class methods. |
|
50
|
|
|
51
|
We save and restore the internal _providers dict around each test so that |
|
52
|
registrations from one test don't leak into another. |
|
53
|
""" |
|
54
|
|
|
55
|
@pytest.fixture(autouse=True) |
|
56
|
def _save_restore_registry(self): |
|
57
|
original = dict(ProviderRegistry._providers) |
|
58
|
yield |
|
59
|
ProviderRegistry._providers = original |
|
60
|
|
|
61
|
def test_register_and_get(self): |
|
62
|
dummy_cls = type("Dummy", (), {}) |
|
63
|
ProviderRegistry.register("test_prov", dummy_cls, env_var="TEST_KEY") |
|
64
|
assert ProviderRegistry.get("test_prov") is dummy_cls |
|
65
|
|
|
66
|
def test_get_unknown_raises(self): |
|
67
|
with pytest.raises(ValueError, match="Unknown provider"): |
|
68
|
ProviderRegistry.get("nonexistent_provider_xyz") |
|
69
|
|
|
70
|
def test_get_by_model_prefix(self): |
|
71
|
dummy_cls = type("Dummy", (), {}) |
|
72
|
ProviderRegistry.register("myprov", dummy_cls, model_prefixes=["mymodel-"]) |
|
73
|
assert ProviderRegistry.get_by_model("mymodel-7b") == "myprov" |
|
74
|
assert ProviderRegistry.get_by_model("othermodel-7b") is None |
|
75
|
|
|
76
|
def test_get_by_model_returns_none_for_no_match(self): |
|
77
|
assert ProviderRegistry.get_by_model("totally_unknown_model_xyz") is None |
|
78
|
|
|
79
|
def test_available_with_env_var(self): |
|
80
|
dummy_cls = type("Dummy", (), {}) |
|
81
|
ProviderRegistry.register("envprov", dummy_cls, env_var="ENVPROV_KEY") |
|
82
|
# Not in env -> should not appear |
|
83
|
with patch.dict("os.environ", {}, clear=True): |
|
84
|
avail = ProviderRegistry.available() |
|
85
|
assert "envprov" not in avail |
|
86
|
|
|
87
|
# In env -> should appear |
|
88
|
with patch.dict("os.environ", {"ENVPROV_KEY": "secret"}): |
|
89
|
avail = ProviderRegistry.available() |
|
90
|
assert "envprov" in avail |
|
91
|
|
|
92
|
def test_available_no_env_var_required(self): |
|
93
|
dummy_cls = type("Dummy", (), {}) |
|
94
|
ProviderRegistry.register("noenvprov", dummy_cls, env_var="") |
|
95
|
avail = ProviderRegistry.available() |
|
96
|
assert "noenvprov" in avail |
|
97
|
|
|
98
|
def test_all_registered(self): |
|
99
|
dummy_cls = type("Dummy", (), {}) |
|
100
|
ProviderRegistry.register("regprov", dummy_cls, env_var="X", default_models={"chat": "m1"}) |
|
101
|
all_reg = ProviderRegistry.all_registered() |
|
102
|
assert "regprov" in all_reg |
|
103
|
assert all_reg["regprov"]["class"] is dummy_cls |
|
104
|
|
|
105
|
def test_get_default_models(self): |
|
106
|
dummy_cls = type("Dummy", (), {}) |
|
107
|
ProviderRegistry.register( |
|
108
|
"defprov", dummy_cls, default_models={"chat": "c1", "vision": "v1"} |
|
109
|
) |
|
110
|
defaults = ProviderRegistry.get_default_models("defprov") |
|
111
|
assert defaults == {"chat": "c1", "vision": "v1"} |
|
112
|
|
|
113
|
def test_get_default_models_unknown(self): |
|
114
|
assert ProviderRegistry.get_default_models("unknown_prov_xyz") == {} |
|
115
|
|
|
116
|
|
|
117
|
# --------------------------------------------------------------------------- |
|
118
|
# ProviderManager |
|
119
|
# --------------------------------------------------------------------------- |
|
120
|
|
|
121
|
|
|
122
|
class TestProviderManager: |
|
123
|
def _make_mock_provider(self, name="openai"): |
|
124
|
provider = MagicMock(spec=BaseProvider) |
|
125
|
provider.provider_name = name |
|
126
|
provider.chat.return_value = "test response" |
|
127
|
provider.analyze_image.return_value = "image analysis" |
|
128
|
provider.transcribe_audio.return_value = { |
|
129
|
"text": "hello world", |
|
130
|
"segments": [], |
|
131
|
"provider": name, |
|
132
|
"model": "test", |
|
133
|
} |
|
134
|
return provider |
|
135
|
|
|
136
|
def test_init_with_explicit_models(self): |
|
137
|
mgr = ProviderManager( |
|
138
|
vision_model="gpt-4o", |
|
139
|
chat_model="claude-sonnet-4-5-20250929", |
|
140
|
transcription_model="whisper-1", |
|
141
|
) |
|
142
|
assert mgr.vision_model == "gpt-4o" |
|
143
|
assert mgr.chat_model == "claude-sonnet-4-5-20250929" |
|
144
|
assert mgr.transcription_model == "whisper-1" |
|
145
|
|
|
146
|
def test_init_forced_provider(self): |
|
147
|
mgr = ProviderManager(provider="gemini") |
|
148
|
assert mgr.vision_model == "gemini-2.5-flash" |
|
149
|
assert mgr.chat_model == "gemini-2.5-flash" |
|
150
|
assert mgr.transcription_model == "gemini-2.5-flash" |
|
151
|
|
|
152
|
def test_init_forced_provider_ollama(self): |
|
153
|
mgr = ProviderManager(provider="ollama") |
|
154
|
assert mgr.vision_model == "" |
|
155
|
assert mgr.chat_model == "" |
|
156
|
assert mgr.transcription_model == "" |
|
157
|
|
|
158
|
def test_init_no_overrides(self): |
|
159
|
mgr = ProviderManager() |
|
160
|
assert mgr.vision_model is None |
|
161
|
assert mgr.chat_model is None |
|
162
|
assert mgr.transcription_model is None |
|
163
|
assert mgr.auto is True |
|
164
|
|
|
165
|
def test_default_for_provider_gemini(self): |
|
166
|
result = ProviderManager._default_for_provider("gemini", "vision") |
|
167
|
assert result == "gemini-2.5-flash" |
|
168
|
|
|
169
|
def test_default_for_provider_openai(self): |
|
170
|
result = ProviderManager._default_for_provider("openai", "chat") |
|
171
|
assert isinstance(result, str) |
|
172
|
assert len(result) > 0 |
|
173
|
|
|
174
|
def test_default_for_provider_unknown(self): |
|
175
|
result = ProviderManager._default_for_provider("nonexistent_xyz", "chat") |
|
176
|
assert result == "" |
|
177
|
|
|
178
|
def test_provider_for_model(self): |
|
179
|
mgr = ProviderManager() |
|
180
|
assert mgr._provider_for_model("gpt-4o") == "openai" |
|
181
|
assert mgr._provider_for_model("claude-sonnet-4-5-20250929") == "anthropic" |
|
182
|
assert mgr._provider_for_model("gemini-2.5-flash") == "gemini" |
|
183
|
assert mgr._provider_for_model("whisper-1") == "openai" |
|
184
|
|
|
185
|
def test_provider_for_model_ollama_via_discovery(self): |
|
186
|
mgr = ProviderManager() |
|
187
|
mgr._available_models = [ |
|
188
|
ModelInfo(id="llama3.2:latest", provider="ollama", capabilities=["chat"]), |
|
189
|
] |
|
190
|
assert mgr._provider_for_model("llama3.2:latest") == "ollama" |
|
191
|
|
|
192
|
def test_provider_for_model_ollama_fuzzy_tag(self): |
|
193
|
mgr = ProviderManager() |
|
194
|
mgr._available_models = [ |
|
195
|
ModelInfo(id="llama3.2:latest", provider="ollama", capabilities=["chat"]), |
|
196
|
] |
|
197
|
assert mgr._provider_for_model("llama3.2") == "ollama" |
|
198
|
|
|
199
|
@patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}) |
|
200
|
def test_chat_routes_to_provider(self): |
|
201
|
mgr = ProviderManager(chat_model="gpt-4o") |
|
202
|
mock_prov = self._make_mock_provider("openai") |
|
203
|
mgr._providers["openai"] = mock_prov |
|
204
|
|
|
205
|
result = mgr.chat([{"role": "user", "content": "hello"}]) |
|
206
|
assert result == "test response" |
|
207
|
mock_prov.chat.assert_called_once() |
|
208
|
|
|
209
|
@patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}) |
|
210
|
def test_analyze_image_routes(self): |
|
211
|
mgr = ProviderManager(vision_model="gpt-4o") |
|
212
|
mock_prov = self._make_mock_provider("openai") |
|
213
|
mgr._providers["openai"] = mock_prov |
|
214
|
|
|
215
|
result = mgr.analyze_image(b"fake-image", "describe this") |
|
216
|
assert result == "image analysis" |
|
217
|
mock_prov.analyze_image.assert_called_once() |
|
218
|
|
|
219
|
@patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}) |
|
220
|
def test_transcribe_routes(self): |
|
221
|
mgr = ProviderManager(transcription_model="whisper-1") |
|
222
|
mock_prov = self._make_mock_provider("openai") |
|
223
|
mgr._providers["openai"] = mock_prov |
|
224
|
|
|
225
|
result = mgr.transcribe_audio("/tmp/test.wav") |
|
226
|
assert result["text"] == "hello world" |
|
227
|
mock_prov.transcribe_audio.assert_called_once() |
|
228
|
|
|
229
|
def test_get_models_used(self): |
|
230
|
mgr = ProviderManager( |
|
231
|
vision_model="gpt-4o", |
|
232
|
chat_model="claude-sonnet-4-5-20250929", |
|
233
|
transcription_model="whisper-1", |
|
234
|
) |
|
235
|
for name in ["openai", "anthropic"]: |
|
236
|
mgr._providers[name] = self._make_mock_provider(name) |
|
237
|
|
|
238
|
used = mgr.get_models_used() |
|
239
|
assert "vision" in used |
|
240
|
assert used["vision"] == "openai/gpt-4o" |
|
241
|
assert used["chat"] == "anthropic/claude-sonnet-4-5-20250929" |
|
242
|
|
|
243
|
def test_track_records_usage(self): |
|
244
|
mgr = ProviderManager(chat_model="gpt-4o") |
|
245
|
mock_prov = self._make_mock_provider("openai") |
|
246
|
mock_prov._last_usage = {"input_tokens": 10, "output_tokens": 20} |
|
247
|
mgr._providers["openai"] = mock_prov |
|
248
|
|
|
249
|
mgr.chat([{"role": "user", "content": "hi"}]) |
|
250
|
assert mgr.usage.total_input_tokens == 10 |
|
251
|
assert mgr.usage.total_output_tokens == 20 |
|
252
|
|
|
253
|
|
|
254
|
# --------------------------------------------------------------------------- |
|
255
|
# OpenAICompatibleProvider |
|
256
|
# --------------------------------------------------------------------------- |
|
257
|
|
|
258
|
|
|
259
|
class TestOpenAICompatibleProvider: |
|
260
|
@patch("openai.OpenAI") |
|
261
|
def test_chat(self, mock_openai_cls): |
|
262
|
mock_client = MagicMock() |
|
263
|
mock_openai_cls.return_value = mock_client |
|
264
|
|
|
265
|
mock_choice = MagicMock() |
|
266
|
mock_choice.message.content = "hello back" |
|
267
|
mock_response = MagicMock() |
|
268
|
mock_response.choices = [mock_choice] |
|
269
|
mock_response.usage.prompt_tokens = 5 |
|
270
|
mock_response.usage.completion_tokens = 10 |
|
271
|
mock_client.chat.completions.create.return_value = mock_response |
|
272
|
|
|
273
|
provider = OpenAICompatibleProvider(api_key="test", base_url="http://test") |
|
274
|
result = provider.chat([{"role": "user", "content": "hi"}], model="test-model") |
|
275
|
assert result == "hello back" |
|
276
|
assert provider._last_usage == {"input_tokens": 5, "output_tokens": 10} |
|
277
|
|
|
278
|
@patch("openai.OpenAI") |
|
279
|
def test_analyze_image(self, mock_openai_cls): |
|
280
|
mock_client = MagicMock() |
|
281
|
mock_openai_cls.return_value = mock_client |
|
282
|
|
|
283
|
mock_choice = MagicMock() |
|
284
|
mock_choice.message.content = "a cat" |
|
285
|
mock_response = MagicMock() |
|
286
|
mock_response.choices = [mock_choice] |
|
287
|
mock_response.usage.prompt_tokens = 100 |
|
288
|
mock_response.usage.completion_tokens = 5 |
|
289
|
mock_client.chat.completions.create.return_value = mock_response |
|
290
|
|
|
291
|
provider = OpenAICompatibleProvider(api_key="test", base_url="http://test") |
|
292
|
result = provider.analyze_image(b"\x89PNG", "what is this?") |
|
293
|
assert result == "a cat" |
|
294
|
assert provider._last_usage["input_tokens"] == 100 |
|
295
|
|
|
296
|
@patch("openai.OpenAI") |
|
297
|
def test_transcribe_raises(self, mock_openai_cls): |
|
298
|
provider = OpenAICompatibleProvider(api_key="test", base_url="http://test") |
|
299
|
with pytest.raises(NotImplementedError): |
|
300
|
provider.transcribe_audio("/tmp/audio.wav") |
|
301
|
|
|
302
|
@patch("openai.OpenAI") |
|
303
|
def test_list_models(self, mock_openai_cls): |
|
304
|
mock_client = MagicMock() |
|
305
|
mock_openai_cls.return_value = mock_client |
|
306
|
|
|
307
|
mock_model = MagicMock() |
|
308
|
mock_model.id = "test-model-1" |
|
309
|
mock_client.models.list.return_value = [mock_model] |
|
310
|
|
|
311
|
provider = OpenAICompatibleProvider(api_key="test", base_url="http://test") |
|
312
|
provider.provider_name = "testprov" |
|
313
|
models = provider.list_models() |
|
314
|
assert len(models) == 1 |
|
315
|
assert models[0].id == "test-model-1" |
|
316
|
assert models[0].provider == "testprov" |
|
317
|
|
|
318
|
@patch("openai.OpenAI") |
|
319
|
def test_list_models_handles_error(self, mock_openai_cls): |
|
320
|
mock_client = MagicMock() |
|
321
|
mock_openai_cls.return_value = mock_client |
|
322
|
mock_client.models.list.side_effect = Exception("connection error") |
|
323
|
|
|
324
|
provider = OpenAICompatibleProvider(api_key="test", base_url="http://test") |
|
325
|
models = provider.list_models() |
|
326
|
assert models == [] |
|
327
|
|
|
328
|
|
|
329
|
# --------------------------------------------------------------------------- |
|
330
|
# Discovery |
|
331
|
# --------------------------------------------------------------------------- |
|
332
|
|
|
333
|
|
|
334
|
class TestDiscovery: |
|
335
|
@patch("video_processor.providers.discovery._cached_models", None) |
|
336
|
@patch( |
|
337
|
"video_processor.providers.ollama_provider.OllamaProvider.is_available", |
|
338
|
return_value=False, |
|
339
|
) |
|
340
|
@patch.dict("os.environ", {}, clear=True) |
|
341
|
def test_discover_skips_missing_keys(self, mock_ollama): |
|
342
|
from video_processor.providers.discovery import discover_available_models |
|
343
|
|
|
344
|
models = discover_available_models(api_keys={"openai": "", "anthropic": "", "gemini": ""}) |
|
345
|
assert models == [] |
|
346
|
|
|
347
|
@patch.dict("os.environ", {}, clear=True) |
|
348
|
@patch( |
|
349
|
"video_processor.providers.ollama_provider.OllamaProvider.is_available", |
|
350
|
return_value=False, |
|
351
|
) |
|
352
|
@patch("video_processor.providers.discovery._cached_models", None) |
|
353
|
def test_discover_caches_results(self, mock_ollama): |
|
354
|
from video_processor.providers import discovery |
|
355
|
|
|
356
|
models = discovery.discover_available_models( |
|
357
|
api_keys={"openai": "", "anthropic": "", "gemini": ""} |
|
358
|
) |
|
359
|
assert models == [] |
|
360
|
# Second call should use cache |
|
361
|
models2 = discovery.discover_available_models(api_keys={"openai": "key"}) |
|
362
|
assert models2 == [] # Still cached empty result |
|
363
|
|
|
364
|
discovery.clear_discovery_cache() |
|
365
|
|
|
366
|
@patch("video_processor.providers.discovery._cached_models", None) |
|
367
|
@patch( |
|
368
|
"video_processor.providers.ollama_provider.OllamaProvider.is_available", |
|
369
|
return_value=False, |
|
370
|
) |
|
371
|
@patch.dict("os.environ", {}, clear=True) |
|
372
|
def test_force_refresh_clears_cache(self, mock_ollama): |
|
373
|
from video_processor.providers import discovery |
|
374
|
|
|
375
|
# Warm the cache |
|
376
|
discovery.discover_available_models(api_keys={"openai": "", "anthropic": "", "gemini": ""}) |
|
377
|
# Force refresh should re-run |
|
378
|
models = discovery.discover_available_models( |
|
379
|
api_keys={"openai": "", "anthropic": "", "gemini": ""}, |
|
380
|
force_refresh=True, |
|
381
|
) |
|
382
|
assert models == [] |
|
383
|
|
|
384
|
def test_clear_discovery_cache(self): |
|
385
|
from video_processor.providers import discovery |
|
386
|
|
|
387
|
discovery._cached_models = [ModelInfo(id="x", provider="y")] |
|
388
|
discovery.clear_discovery_cache() |
|
389
|
assert discovery._cached_models is None |
|
390
|
|
|
391
|
|
|
392
|
# --------------------------------------------------------------------------- |
|
393
|
# OllamaProvider |
|
394
|
# --------------------------------------------------------------------------- |
|
395
|
|
|
396
|
|
|
397
|
class TestOllamaProvider: |
|
398
|
@patch("video_processor.providers.ollama_provider.requests") |
|
399
|
def test_is_available_when_running(self, mock_requests): |
|
400
|
mock_resp = MagicMock() |
|
401
|
mock_resp.status_code = 200 |
|
402
|
mock_requests.get.return_value = mock_resp |
|
403
|
|
|
404
|
from video_processor.providers.ollama_provider import OllamaProvider |
|
405
|
|
|
406
|
assert OllamaProvider.is_available() |
|
407
|
|
|
408
|
@patch("video_processor.providers.ollama_provider.requests") |
|
409
|
def test_is_available_when_not_running(self, mock_requests): |
|
410
|
mock_requests.get.side_effect = ConnectionError |
|
411
|
|
|
412
|
from video_processor.providers.ollama_provider import OllamaProvider |
|
413
|
|
|
414
|
assert not OllamaProvider.is_available() |
|
415
|
|
|
416
|
@patch("video_processor.providers.ollama_provider.requests") |
|
417
|
@patch("video_processor.providers.ollama_provider.OpenAI") |
|
418
|
def test_transcribe_raises(self, mock_openai, mock_requests): |
|
419
|
from video_processor.providers.ollama_provider import OllamaProvider |
|
420
|
|
|
421
|
provider = OllamaProvider() |
|
422
|
with pytest.raises(NotImplementedError): |
|
423
|
provider.transcribe_audio("/tmp/test.wav") |
|
424
|
|
|
425
|
@patch("video_processor.providers.ollama_provider.requests") |
|
426
|
@patch("video_processor.providers.ollama_provider.OpenAI") |
|
427
|
def test_list_models(self, mock_openai, mock_requests): |
|
428
|
mock_resp = MagicMock() |
|
429
|
mock_resp.status_code = 200 |
|
430
|
mock_resp.json.return_value = { |
|
431
|
"models": [ |
|
432
|
{"name": "llama3.2:latest", "details": {"family": "llama"}}, |
|
433
|
{"name": "llava:13b", "details": {"family": "llava"}}, |
|
434
|
] |
|
435
|
} |
|
436
|
mock_requests.get.return_value = mock_resp |
|
437
|
|
|
438
|
from video_processor.providers.ollama_provider import OllamaProvider |
|
439
|
|
|
440
|
provider = OllamaProvider() |
|
441
|
models = provider.list_models() |
|
442
|
assert len(models) == 2 |
|
443
|
assert models[0].provider == "ollama" |
|
444
|
|
|
445
|
llava = [m for m in models if "llava" in m.id][0] |
|
446
|
assert "vision" in llava.capabilities |
|
447
|
|
|
448
|
llama = [m for m in models if "llama" in m.id][0] |
|
449
|
assert "chat" in llama.capabilities |
|
450
|
assert "vision" not in llama.capabilities |
|
451
|
|
|
452
|
|
|
453
|
# --------------------------------------------------------------------------- |
|
454
|
# Provider module imports |
|
455
|
# --------------------------------------------------------------------------- |
|
456
|
|
|
457
|
|
|
458
|
class TestProviderImports: |
|
459
|
"""Verify that all provider modules import without errors.""" |
|
460
|
|
|
461
|
PROVIDER_MODULES = [ |
|
462
|
"video_processor.providers.openai_provider", |
|
463
|
"video_processor.providers.anthropic_provider", |
|
464
|
"video_processor.providers.gemini_provider", |
|
465
|
"video_processor.providers.ollama_provider", |
|
466
|
"video_processor.providers.azure_provider", |
|
467
|
"video_processor.providers.together_provider", |
|
468
|
"video_processor.providers.fireworks_provider", |
|
469
|
"video_processor.providers.cerebras_provider", |
|
470
|
"video_processor.providers.xai_provider", |
|
471
|
] |
|
472
|
|
|
473
|
@pytest.mark.parametrize("module_name", PROVIDER_MODULES) |
|
474
|
def test_import(self, module_name): |
|
475
|
mod = importlib.import_module(module_name) |
|
476
|
assert mod is not None |
|
477
|
|