PlanOpticon
test: add comprehensive test suite for KG, pipeline, providers, CLI, agent - test_knowledge_graph.py: KG creation, processing, merge, save/load - test_pipeline.py: pipeline with mocked extractors and providers - test_usage_tracker.py: usage tracking and cost calculation - test_providers.py: registry, manager, OpenAI compatible, discovery - test_cli.py: CLI help/version, command structure - test_agent.py: skill registry, agent context, planning agent - Fix graph_store merge_entity to update type on merge Coverage: 41% -> 52% (532 tests)
Commit
7f4ad5f492f39934d66fad6ef355d669633bcbcfc930b3c61acbf916613e8ac8
Parent
8abc5d431ec08f9…
7 files changed
+413
-11
+1
+422
+433
-2
+286
-33
+198
+4
-2
+413
-11
| --- tests/test_agent.py | ||
| +++ tests/test_agent.py | ||
| @@ -1,15 +1,392 @@ | ||
| 1 | -"""Tests for the agentic processing orchestrator.""" | |
| 1 | +"""Tests for the planning agent, skill registry, KB context, and agent loop.""" | |
| 2 | 2 | |
| 3 | 3 | import json |
| 4 | -from unittest.mock import MagicMock | |
| 4 | +from pathlib import Path | |
| 5 | +from unittest.mock import MagicMock, patch | |
| 6 | + | |
| 7 | +import pytest | |
| 8 | + | |
| 9 | +from video_processor.agent.skills.base import ( | |
| 10 | + AgentContext, | |
| 11 | + Artifact, | |
| 12 | + Skill, | |
| 13 | + _skills, | |
| 14 | + get_skill, | |
| 15 | + list_skills, | |
| 16 | + register_skill, | |
| 17 | +) | |
| 18 | + | |
| 19 | +# --------------------------------------------------------------------------- | |
| 20 | +# Fixtures | |
| 21 | +# --------------------------------------------------------------------------- | |
| 22 | + | |
| 23 | + | |
| 24 | +@pytest.fixture(autouse=True) | |
| 25 | +def _clean_skill_registry(): | |
| 26 | + """Save and restore the global skill registry between tests.""" | |
| 27 | + original = dict(_skills) | |
| 28 | + yield | |
| 29 | + _skills.clear() | |
| 30 | + _skills.update(original) | |
| 31 | + | |
| 32 | + | |
| 33 | +class _DummySkill(Skill): | |
| 34 | + name = "dummy_test_skill" | |
| 35 | + description = "A dummy skill for testing" | |
| 36 | + | |
| 37 | + def execute(self, context: AgentContext, **kwargs) -> Artifact: | |
| 38 | + return Artifact( | |
| 39 | + name="dummy artifact", | |
| 40 | + content="dummy content", | |
| 41 | + artifact_type="document", | |
| 42 | + ) | |
| 43 | + | |
| 44 | + | |
| 45 | +class _NoLLMSkill(Skill): | |
| 46 | + """Skill that doesn't require provider_manager.""" | |
| 47 | + | |
| 48 | + name = "nollm_skill" | |
| 49 | + description = "Works without LLM" | |
| 50 | + | |
| 51 | + def execute(self, context: AgentContext, **kwargs) -> Artifact: | |
| 52 | + return Artifact( | |
| 53 | + name="nollm artifact", | |
| 54 | + content="generated", | |
| 55 | + artifact_type="document", | |
| 56 | + ) | |
| 57 | + | |
| 58 | + def can_execute(self, context: AgentContext) -> bool: | |
| 59 | + return context.knowledge_graph is not None | |
| 60 | + | |
| 61 | + | |
| 62 | +# --------------------------------------------------------------------------- | |
| 63 | +# Skill registry | |
| 64 | +# --------------------------------------------------------------------------- | |
| 65 | + | |
| 66 | + | |
| 67 | +class TestSkillRegistry: | |
| 68 | + def test_register_and_get(self): | |
| 69 | + skill = _DummySkill() | |
| 70 | + register_skill(skill) | |
| 71 | + assert get_skill("dummy_test_skill") is skill | |
| 72 | + | |
| 73 | + def test_get_unknown_returns_none(self): | |
| 74 | + assert get_skill("no_such_skill_xyz") is None | |
| 75 | + | |
| 76 | + def test_list_skills(self): | |
| 77 | + s1 = _DummySkill() | |
| 78 | + register_skill(s1) | |
| 79 | + skills = list_skills() | |
| 80 | + assert any(s.name == "dummy_test_skill" for s in skills) | |
| 81 | + | |
| 82 | + def test_list_skills_empty(self): | |
| 83 | + _skills.clear() | |
| 84 | + assert list_skills() == [] | |
| 85 | + | |
| 86 | + | |
| 87 | +# --------------------------------------------------------------------------- | |
| 88 | +# AgentContext dataclass | |
| 89 | +# --------------------------------------------------------------------------- | |
| 90 | + | |
| 91 | + | |
| 92 | +class TestAgentContext: | |
| 93 | + def test_defaults(self): | |
| 94 | + ctx = AgentContext() | |
| 95 | + assert ctx.knowledge_graph is None | |
| 96 | + assert ctx.query_engine is None | |
| 97 | + assert ctx.provider_manager is None | |
| 98 | + assert ctx.planning_entities == [] | |
| 99 | + assert ctx.user_requirements == {} | |
| 100 | + assert ctx.conversation_history == [] | |
| 101 | + assert ctx.artifacts == [] | |
| 102 | + assert ctx.config == {} | |
| 103 | + | |
| 104 | + def test_with_values(self): | |
| 105 | + mock_kg = MagicMock() | |
| 106 | + mock_qe = MagicMock() | |
| 107 | + mock_pm = MagicMock() | |
| 108 | + ctx = AgentContext( | |
| 109 | + knowledge_graph=mock_kg, | |
| 110 | + query_engine=mock_qe, | |
| 111 | + provider_manager=mock_pm, | |
| 112 | + config={"key": "value"}, | |
| 113 | + ) | |
| 114 | + assert ctx.knowledge_graph is mock_kg | |
| 115 | + assert ctx.config == {"key": "value"} | |
| 116 | + | |
| 117 | + def test_conversation_history_is_mutable(self): | |
| 118 | + ctx = AgentContext() | |
| 119 | + ctx.conversation_history.append({"role": "user", "content": "hello"}) | |
| 120 | + assert len(ctx.conversation_history) == 1 | |
| 121 | + | |
| 122 | + | |
| 123 | +# --------------------------------------------------------------------------- | |
| 124 | +# Artifact dataclass | |
| 125 | +# --------------------------------------------------------------------------- | |
| 126 | + | |
| 127 | + | |
| 128 | +class TestArtifact: | |
| 129 | + def test_basic(self): | |
| 130 | + a = Artifact(name="Plan", content="# Plan\n...", artifact_type="project_plan") | |
| 131 | + assert a.name == "Plan" | |
| 132 | + assert a.format == "markdown" # default | |
| 133 | + assert a.metadata == {} | |
| 134 | + | |
| 135 | + def test_with_metadata(self): | |
| 136 | + a = Artifact( | |
| 137 | + name="Tasks", | |
| 138 | + content="[]", | |
| 139 | + artifact_type="task_list", | |
| 140 | + format="json", | |
| 141 | + metadata={"source": "kg"}, | |
| 142 | + ) | |
| 143 | + assert a.format == "json" | |
| 144 | + assert a.metadata["source"] == "kg" | |
| 145 | + | |
| 146 | + | |
| 147 | +# --------------------------------------------------------------------------- | |
| 148 | +# Skill.can_execute | |
| 149 | +# --------------------------------------------------------------------------- | |
| 150 | + | |
| 151 | + | |
| 152 | +class TestSkillCanExecute: | |
| 153 | + def test_default_requires_kg_and_pm(self): | |
| 154 | + skill = _DummySkill() | |
| 155 | + ctx_no_kg = AgentContext(provider_manager=MagicMock()) | |
| 156 | + assert not skill.can_execute(ctx_no_kg) | |
| 157 | + | |
| 158 | + ctx_no_pm = AgentContext(knowledge_graph=MagicMock()) | |
| 159 | + assert not skill.can_execute(ctx_no_pm) | |
| 160 | + | |
| 161 | + ctx_both = AgentContext(knowledge_graph=MagicMock(), provider_manager=MagicMock()) | |
| 162 | + assert skill.can_execute(ctx_both) | |
| 163 | + | |
| 164 | + | |
| 165 | +# --------------------------------------------------------------------------- | |
| 166 | +# KBContext | |
| 167 | +# --------------------------------------------------------------------------- | |
| 168 | + | |
| 169 | + | |
| 170 | +class TestKBContext: | |
| 171 | + def test_add_source_nonexistent_raises(self, tmp_path): | |
| 172 | + from video_processor.agent.kb_context import KBContext | |
| 173 | + | |
| 174 | + ctx = KBContext() | |
| 175 | + with pytest.raises(FileNotFoundError, match="Not found"): | |
| 176 | + ctx.add_source(tmp_path / "nonexistent.json") | |
| 177 | + | |
| 178 | + def test_add_source_file(self, tmp_path): | |
| 179 | + from video_processor.agent.kb_context import KBContext | |
| 180 | + | |
| 181 | + f = tmp_path / "kg.json" | |
| 182 | + f.write_text("{}") | |
| 183 | + ctx = KBContext() | |
| 184 | + ctx.add_source(f) | |
| 185 | + assert len(ctx.sources) == 1 | |
| 186 | + assert ctx.sources[0] == f.resolve() | |
| 187 | + | |
| 188 | + def test_add_source_directory(self, tmp_path): | |
| 189 | + from video_processor.agent.kb_context import KBContext | |
| 190 | + | |
| 191 | + with patch( | |
| 192 | + "video_processor.integrators.graph_discovery.find_knowledge_graphs", | |
| 193 | + return_value=[tmp_path / "a.db"], | |
| 194 | + ): | |
| 195 | + ctx = KBContext() | |
| 196 | + ctx.add_source(tmp_path) | |
| 197 | + assert len(ctx.sources) == 1 | |
| 198 | + | |
| 199 | + def test_knowledge_graph_before_load_raises(self): | |
| 200 | + from video_processor.agent.kb_context import KBContext | |
| 201 | + | |
| 202 | + ctx = KBContext() | |
| 203 | + with pytest.raises(RuntimeError, match="Call load"): | |
| 204 | + _ = ctx.knowledge_graph | |
| 205 | + | |
| 206 | + def test_query_engine_before_load_raises(self): | |
| 207 | + from video_processor.agent.kb_context import KBContext | |
| 208 | + | |
| 209 | + ctx = KBContext() | |
| 210 | + with pytest.raises(RuntimeError, match="Call load"): | |
| 211 | + _ = ctx.query_engine | |
| 212 | + | |
| 213 | + def test_summary_no_data(self): | |
| 214 | + from video_processor.agent.kb_context import KBContext | |
| 215 | + | |
| 216 | + ctx = KBContext() | |
| 217 | + assert ctx.summary() == "No knowledge base loaded." | |
| 218 | + | |
| 219 | + def test_load_json_and_summary(self, tmp_path): | |
| 220 | + from video_processor.agent.kb_context import KBContext | |
| 221 | + | |
| 222 | + kg_data = {"nodes": [], "relationships": []} | |
| 223 | + f = tmp_path / "kg.json" | |
| 224 | + f.write_text(json.dumps(kg_data)) | |
| 225 | + | |
| 226 | + ctx = KBContext() | |
| 227 | + ctx.add_source(f) | |
| 228 | + ctx.load() | |
| 229 | + | |
| 230 | + summary = ctx.summary() | |
| 231 | + assert "Knowledge base" in summary | |
| 232 | + assert "Entities" in summary | |
| 233 | + assert "Relationships" in summary | |
| 234 | + | |
| 235 | + | |
| 236 | +# --------------------------------------------------------------------------- | |
| 237 | +# PlanningAgent | |
| 238 | +# --------------------------------------------------------------------------- | |
| 239 | + | |
| 240 | + | |
| 241 | +class TestPlanningAgent: | |
| 242 | + def test_from_kb_paths(self, tmp_path): | |
| 243 | + from video_processor.agent.agent_loop import PlanningAgent | |
| 244 | + | |
| 245 | + kg_data = {"nodes": [], "relationships": []} | |
| 246 | + f = tmp_path / "kg.json" | |
| 247 | + f.write_text(json.dumps(kg_data)) | |
| 248 | + | |
| 249 | + agent = PlanningAgent.from_kb_paths([f], provider_manager=None) | |
| 250 | + assert agent.context.knowledge_graph is not None | |
| 251 | + assert agent.context.provider_manager is None | |
| 252 | + | |
| 253 | + def test_execute_with_mock_provider(self, tmp_path): | |
| 254 | + from video_processor.agent.agent_loop import PlanningAgent | |
| 255 | + | |
| 256 | + # Register a dummy skill | |
| 257 | + skill = _DummySkill() | |
| 258 | + register_skill(skill) | |
| 259 | + | |
| 260 | + mock_pm = MagicMock() | |
| 261 | + mock_pm.chat.return_value = json.dumps([{"skill": "dummy_test_skill", "params": {}}]) | |
| 262 | + | |
| 263 | + ctx = AgentContext( | |
| 264 | + knowledge_graph=MagicMock(), | |
| 265 | + query_engine=MagicMock(), | |
| 266 | + provider_manager=mock_pm, | |
| 267 | + ) | |
| 268 | + # Mock stats().to_text() | |
| 269 | + ctx.query_engine.stats.return_value.to_text.return_value = "3 entities" | |
| 270 | + | |
| 271 | + agent = PlanningAgent(context=ctx) | |
| 272 | + artifacts = agent.execute("generate a plan") | |
| 273 | + | |
| 274 | + assert len(artifacts) == 1 | |
| 275 | + assert artifacts[0].name == "dummy artifact" | |
| 276 | + mock_pm.chat.assert_called_once() | |
| 277 | + | |
| 278 | + def test_execute_no_provider_keyword_match(self): | |
| 279 | + from video_processor.agent.agent_loop import PlanningAgent | |
| 280 | + | |
| 281 | + skill = _DummySkill() | |
| 282 | + register_skill(skill) | |
| 283 | + | |
| 284 | + ctx = AgentContext( | |
| 285 | + knowledge_graph=MagicMock(), | |
| 286 | + provider_manager=None, | |
| 287 | + ) | |
| 288 | + | |
| 289 | + agent = PlanningAgent(context=ctx) | |
| 290 | + # "dummy" is a keyword in the skill name, but can_execute needs provider_manager | |
| 291 | + # so it should return empty | |
| 292 | + artifacts = agent.execute("dummy request") | |
| 293 | + assert artifacts == [] | |
| 294 | + | |
| 295 | + def test_execute_keyword_match_nollm_skill(self): | |
| 296 | + from video_processor.agent.agent_loop import PlanningAgent | |
| 297 | + | |
| 298 | + skill = _NoLLMSkill() | |
| 299 | + register_skill(skill) | |
| 300 | + | |
| 301 | + ctx = AgentContext( | |
| 302 | + knowledge_graph=MagicMock(), | |
| 303 | + provider_manager=None, | |
| 304 | + ) | |
| 305 | + | |
| 306 | + agent = PlanningAgent(context=ctx) | |
| 307 | + # "nollm" is in the skill name | |
| 308 | + artifacts = agent.execute("nollm stuff") | |
| 309 | + assert len(artifacts) == 1 | |
| 310 | + assert artifacts[0].name == "nollm artifact" | |
| 311 | + | |
| 312 | + def test_execute_skips_unknown_skills(self): | |
| 313 | + from video_processor.agent.agent_loop import PlanningAgent | |
| 314 | + | |
| 315 | + mock_pm = MagicMock() | |
| 316 | + mock_pm.chat.return_value = json.dumps([{"skill": "nonexistent_skill_xyz", "params": {}}]) | |
| 317 | + | |
| 318 | + ctx = AgentContext( | |
| 319 | + knowledge_graph=MagicMock(), | |
| 320 | + query_engine=MagicMock(), | |
| 321 | + provider_manager=mock_pm, | |
| 322 | + ) | |
| 323 | + ctx.query_engine.stats.return_value.to_text.return_value = "" | |
| 324 | + | |
| 325 | + agent = PlanningAgent(context=ctx) | |
| 326 | + artifacts = agent.execute("do something") | |
| 327 | + assert artifacts == [] | |
| 328 | + | |
| 329 | + def test_chat_no_provider(self): | |
| 330 | + from video_processor.agent.agent_loop import PlanningAgent | |
| 331 | + | |
| 332 | + ctx = AgentContext(provider_manager=None) | |
| 333 | + agent = PlanningAgent(context=ctx) | |
| 334 | + | |
| 335 | + reply = agent.chat("hello") | |
| 336 | + assert "requires" in reply.lower() or "provider" in reply.lower() | |
| 337 | + | |
| 338 | + def test_chat_with_provider(self): | |
| 339 | + from video_processor.agent.agent_loop import PlanningAgent | |
| 340 | + | |
| 341 | + mock_pm = MagicMock() | |
| 342 | + mock_pm.chat.return_value = "I can help you plan." | |
| 343 | + | |
| 344 | + ctx = AgentContext( | |
| 345 | + knowledge_graph=MagicMock(), | |
| 346 | + query_engine=MagicMock(), | |
| 347 | + provider_manager=mock_pm, | |
| 348 | + ) | |
| 349 | + ctx.query_engine.stats.return_value.to_text.return_value = "5 entities" | |
| 350 | + | |
| 351 | + agent = PlanningAgent(context=ctx) | |
| 352 | + reply = agent.chat("help me plan") | |
| 353 | + | |
| 354 | + assert reply == "I can help you plan." | |
| 355 | + assert len(ctx.conversation_history) == 2 # user + assistant | |
| 356 | + assert ctx.conversation_history[0]["role"] == "user" | |
| 357 | + assert ctx.conversation_history[1]["role"] == "assistant" | |
| 358 | + | |
| 359 | + def test_chat_accumulates_history(self): | |
| 360 | + from video_processor.agent.agent_loop import PlanningAgent | |
| 361 | + | |
| 362 | + mock_pm = MagicMock() | |
| 363 | + mock_pm.chat.side_effect = ["reply1", "reply2"] | |
| 364 | + | |
| 365 | + ctx = AgentContext(provider_manager=mock_pm) | |
| 366 | + agent = PlanningAgent(context=ctx) | |
| 367 | + | |
| 368 | + agent.chat("msg1") | |
| 369 | + agent.chat("msg2") | |
| 370 | + | |
| 371 | + assert len(ctx.conversation_history) == 4 # 2 user + 2 assistant | |
| 372 | + # The system message is constructed each time but not stored in history | |
| 373 | + # Provider should receive progressively longer message lists | |
| 374 | + second_call_messages = mock_pm.chat.call_args_list[1][0][0] | |
| 375 | + # Should include system + 3 prior messages (user, assistant, user) | |
| 376 | + assert len(second_call_messages) == 4 # system + user + assistant + user | |
| 377 | + | |
| 5 | 378 | |
| 6 | -from video_processor.agent.orchestrator import AgentOrchestrator | |
| 379 | +# --------------------------------------------------------------------------- | |
| 380 | +# Orchestrator tests (from existing test_agent.py — kept for coverage) | |
| 381 | +# --------------------------------------------------------------------------- | |
| 7 | 382 | |
| 8 | 383 | |
| 9 | 384 | class TestPlanCreation: |
| 10 | 385 | def test_basic_plan(self): |
| 386 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 387 | + | |
| 11 | 388 | agent = AgentOrchestrator() |
| 12 | 389 | plan = agent._create_plan("test.mp4", "basic") |
| 13 | 390 | steps = [s["step"] for s in plan] |
| 14 | 391 | assert "extract_frames" in steps |
| 15 | 392 | assert "extract_audio" in steps |
| @@ -18,18 +395,22 @@ | ||
| 18 | 395 | assert "extract_action_items" in steps |
| 19 | 396 | assert "generate_reports" in steps |
| 20 | 397 | assert "detect_diagrams" not in steps |
| 21 | 398 | |
| 22 | 399 | def test_standard_plan(self): |
| 400 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 401 | + | |
| 23 | 402 | agent = AgentOrchestrator() |
| 24 | 403 | plan = agent._create_plan("test.mp4", "standard") |
| 25 | 404 | steps = [s["step"] for s in plan] |
| 26 | 405 | assert "detect_diagrams" in steps |
| 27 | 406 | assert "build_knowledge_graph" in steps |
| 28 | 407 | assert "deep_analysis" not in steps |
| 29 | 408 | |
| 30 | 409 | def test_comprehensive_plan(self): |
| 410 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 411 | + | |
| 31 | 412 | agent = AgentOrchestrator() |
| 32 | 413 | plan = agent._create_plan("test.mp4", "comprehensive") |
| 33 | 414 | steps = [s["step"] for s in plan] |
| 34 | 415 | assert "detect_diagrams" in steps |
| 35 | 416 | assert "deep_analysis" in steps |
| @@ -36,42 +417,52 @@ | ||
| 36 | 417 | assert "cross_reference" in steps |
| 37 | 418 | |
| 38 | 419 | |
| 39 | 420 | class TestAdaptPlan: |
| 40 | 421 | def test_adapts_for_long_transcript(self): |
| 422 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 423 | + | |
| 41 | 424 | agent = AgentOrchestrator() |
| 42 | 425 | agent._plan = [{"step": "generate_reports", "priority": "required"}] |
| 43 | - long_text = "word " * 3000 # > 10000 chars | |
| 426 | + long_text = "word " * 3000 | |
| 44 | 427 | agent._adapt_plan("transcribe", {"text": long_text}) |
| 45 | 428 | steps = [s["step"] for s in agent._plan] |
| 46 | 429 | assert "deep_analysis" in steps |
| 47 | 430 | |
| 48 | 431 | def test_no_adapt_for_short_transcript(self): |
| 432 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 433 | + | |
| 49 | 434 | agent = AgentOrchestrator() |
| 50 | 435 | agent._plan = [{"step": "generate_reports", "priority": "required"}] |
| 51 | 436 | agent._adapt_plan("transcribe", {"text": "Short text"}) |
| 52 | 437 | steps = [s["step"] for s in agent._plan] |
| 53 | 438 | assert "deep_analysis" not in steps |
| 54 | 439 | |
| 55 | 440 | def test_adapts_for_many_diagrams(self): |
| 441 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 442 | + | |
| 56 | 443 | agent = AgentOrchestrator() |
| 57 | 444 | agent._plan = [{"step": "generate_reports", "priority": "required"}] |
| 58 | 445 | diagrams = [MagicMock() for _ in range(5)] |
| 59 | 446 | agent._adapt_plan("detect_diagrams", {"diagrams": diagrams, "captures": []}) |
| 60 | 447 | steps = [s["step"] for s in agent._plan] |
| 61 | 448 | assert "cross_reference" in steps |
| 62 | 449 | |
| 63 | 450 | def test_insight_for_many_captures(self): |
| 451 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 452 | + | |
| 64 | 453 | agent = AgentOrchestrator() |
| 65 | 454 | agent._plan = [] |
| 66 | 455 | captures = [MagicMock() for _ in range(5)] |
| 67 | 456 | diagrams = [MagicMock() for _ in range(2)] |
| 68 | 457 | agent._adapt_plan("detect_diagrams", {"diagrams": diagrams, "captures": captures}) |
| 69 | 458 | assert len(agent._insights) == 1 |
| 70 | 459 | assert "uncertain frames" in agent._insights[0] |
| 71 | 460 | |
| 72 | 461 | def test_no_duplicate_steps(self): |
| 462 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 463 | + | |
| 73 | 464 | agent = AgentOrchestrator() |
| 74 | 465 | agent._plan = [{"step": "deep_analysis", "priority": "comprehensive"}] |
| 75 | 466 | long_text = "word " * 3000 |
| 76 | 467 | agent._adapt_plan("transcribe", {"text": long_text}) |
| 77 | 468 | deep_steps = [s for s in agent._plan if s["step"] == "deep_analysis"] |
| @@ -78,28 +469,35 @@ | ||
| 78 | 469 | assert len(deep_steps) == 1 |
| 79 | 470 | |
| 80 | 471 | |
| 81 | 472 | class TestFallbacks: |
| 82 | 473 | def test_diagram_fallback(self): |
| 474 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 475 | + | |
| 83 | 476 | agent = AgentOrchestrator() |
| 84 | 477 | assert agent._get_fallback("detect_diagrams") == "screengrab_fallback" |
| 85 | 478 | |
| 86 | 479 | def test_no_fallback_for_unknown(self): |
| 480 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 481 | + | |
| 87 | 482 | agent = AgentOrchestrator() |
| 88 | 483 | assert agent._get_fallback("transcribe") is None |
| 89 | 484 | |
| 90 | 485 | |
| 91 | 486 | class TestInsights: |
| 92 | 487 | def test_insights_property(self): |
| 488 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 489 | + | |
| 93 | 490 | agent = AgentOrchestrator() |
| 94 | 491 | agent._insights = ["Insight 1", "Insight 2"] |
| 95 | 492 | assert agent.insights == ["Insight 1", "Insight 2"] |
| 96 | - # Should return a copy | |
| 97 | 493 | agent.insights.append("should not modify internal") |
| 98 | 494 | assert len(agent._insights) == 2 |
| 99 | 495 | |
| 100 | 496 | def test_deep_analysis_populates_insights(self): |
| 497 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 498 | + | |
| 101 | 499 | pm = MagicMock() |
| 102 | 500 | pm.chat.return_value = json.dumps( |
| 103 | 501 | { |
| 104 | 502 | "decisions": ["Decided to use microservices"], |
| 105 | 503 | "risks": ["Timeline is tight"], |
| @@ -113,57 +511,61 @@ | ||
| 113 | 511 | assert "decisions" in result |
| 114 | 512 | assert any("microservices" in i for i in agent._insights) |
| 115 | 513 | assert any("Timeline" in i for i in agent._insights) |
| 116 | 514 | |
| 117 | 515 | def test_deep_analysis_handles_error(self): |
| 516 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 517 | + | |
| 118 | 518 | pm = MagicMock() |
| 119 | 519 | pm.chat.side_effect = Exception("API error") |
| 120 | 520 | agent = AgentOrchestrator(provider_manager=pm) |
| 121 | 521 | agent._results["transcribe"] = {"text": "some text"} |
| 122 | 522 | result = agent._deep_analysis("/tmp") |
| 123 | 523 | assert result == {} |
| 124 | 524 | |
| 125 | 525 | def test_deep_analysis_no_transcript(self): |
| 526 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 527 | + | |
| 126 | 528 | agent = AgentOrchestrator() |
| 127 | 529 | agent._results["transcribe"] = {"text": ""} |
| 128 | 530 | result = agent._deep_analysis("/tmp") |
| 129 | 531 | assert result == {} |
| 130 | 532 | |
| 131 | 533 | |
| 132 | 534 | class TestBuildManifest: |
| 133 | 535 | def test_builds_from_results(self): |
| 536 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 537 | + | |
| 134 | 538 | agent = AgentOrchestrator() |
| 135 | 539 | agent._results = { |
| 136 | 540 | "extract_frames": {"frames": [1, 2, 3], "paths": ["/a.jpg", "/b.jpg"]}, |
| 137 | 541 | "extract_audio": {"audio_path": "/audio.wav", "properties": {"duration": 60.0}}, |
| 138 | 542 | "detect_diagrams": {"diagrams": [], "captures": []}, |
| 139 | 543 | "extract_key_points": {"key_points": []}, |
| 140 | 544 | "extract_action_items": {"action_items": []}, |
| 141 | 545 | } |
| 142 | - from pathlib import Path | |
| 143 | - | |
| 144 | 546 | manifest = agent._build_manifest(Path("test.mp4"), Path("/out"), "Test", 5.0) |
| 145 | 547 | assert manifest.video.title == "Test" |
| 146 | 548 | assert manifest.stats.frames_extracted == 3 |
| 147 | 549 | assert manifest.stats.duration_seconds == 5.0 |
| 148 | 550 | assert manifest.video.duration_seconds == 60.0 |
| 149 | 551 | |
| 150 | 552 | def test_handles_missing_results(self): |
| 553 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 554 | + | |
| 151 | 555 | agent = AgentOrchestrator() |
| 152 | 556 | agent._results = {} |
| 153 | - from pathlib import Path | |
| 154 | - | |
| 155 | 557 | manifest = agent._build_manifest(Path("test.mp4"), Path("/out"), None, 1.0) |
| 156 | 558 | assert manifest.video.title == "Analysis of test" |
| 157 | 559 | assert manifest.stats.frames_extracted == 0 |
| 158 | 560 | |
| 159 | 561 | def test_handles_error_results(self): |
| 562 | + from video_processor.agent.orchestrator import AgentOrchestrator | |
| 563 | + | |
| 160 | 564 | agent = AgentOrchestrator() |
| 161 | 565 | agent._results = { |
| 162 | 566 | "extract_frames": {"error": "failed"}, |
| 163 | 567 | "detect_diagrams": {"error": "also failed"}, |
| 164 | 568 | } |
| 165 | - from pathlib import Path | |
| 166 | - | |
| 167 | 569 | manifest = agent._build_manifest(Path("vid.mp4"), Path("/out"), None, 2.0) |
| 168 | 570 | assert manifest.stats.frames_extracted == 0 |
| 169 | 571 | assert len(manifest.diagrams) == 0 |
| 170 | 572 | |
| 171 | 573 | ADDED tests/test_cli.py |
| 172 | 574 | ADDED tests/test_knowledge_graph.py |
| --- tests/test_agent.py | |
| +++ tests/test_agent.py | |
| @@ -1,15 +1,392 @@ | |
| 1 | """Tests for the agentic processing orchestrator.""" |
| 2 | |
| 3 | import json |
| 4 | from unittest.mock import MagicMock |
| 5 | |
| 6 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 7 | |
| 8 | |
| 9 | class TestPlanCreation: |
| 10 | def test_basic_plan(self): |
| 11 | agent = AgentOrchestrator() |
| 12 | plan = agent._create_plan("test.mp4", "basic") |
| 13 | steps = [s["step"] for s in plan] |
| 14 | assert "extract_frames" in steps |
| 15 | assert "extract_audio" in steps |
| @@ -18,18 +395,22 @@ | |
| 18 | assert "extract_action_items" in steps |
| 19 | assert "generate_reports" in steps |
| 20 | assert "detect_diagrams" not in steps |
| 21 | |
| 22 | def test_standard_plan(self): |
| 23 | agent = AgentOrchestrator() |
| 24 | plan = agent._create_plan("test.mp4", "standard") |
| 25 | steps = [s["step"] for s in plan] |
| 26 | assert "detect_diagrams" in steps |
| 27 | assert "build_knowledge_graph" in steps |
| 28 | assert "deep_analysis" not in steps |
| 29 | |
| 30 | def test_comprehensive_plan(self): |
| 31 | agent = AgentOrchestrator() |
| 32 | plan = agent._create_plan("test.mp4", "comprehensive") |
| 33 | steps = [s["step"] for s in plan] |
| 34 | assert "detect_diagrams" in steps |
| 35 | assert "deep_analysis" in steps |
| @@ -36,42 +417,52 @@ | |
| 36 | assert "cross_reference" in steps |
| 37 | |
| 38 | |
| 39 | class TestAdaptPlan: |
| 40 | def test_adapts_for_long_transcript(self): |
| 41 | agent = AgentOrchestrator() |
| 42 | agent._plan = [{"step": "generate_reports", "priority": "required"}] |
| 43 | long_text = "word " * 3000 # > 10000 chars |
| 44 | agent._adapt_plan("transcribe", {"text": long_text}) |
| 45 | steps = [s["step"] for s in agent._plan] |
| 46 | assert "deep_analysis" in steps |
| 47 | |
| 48 | def test_no_adapt_for_short_transcript(self): |
| 49 | agent = AgentOrchestrator() |
| 50 | agent._plan = [{"step": "generate_reports", "priority": "required"}] |
| 51 | agent._adapt_plan("transcribe", {"text": "Short text"}) |
| 52 | steps = [s["step"] for s in agent._plan] |
| 53 | assert "deep_analysis" not in steps |
| 54 | |
| 55 | def test_adapts_for_many_diagrams(self): |
| 56 | agent = AgentOrchestrator() |
| 57 | agent._plan = [{"step": "generate_reports", "priority": "required"}] |
| 58 | diagrams = [MagicMock() for _ in range(5)] |
| 59 | agent._adapt_plan("detect_diagrams", {"diagrams": diagrams, "captures": []}) |
| 60 | steps = [s["step"] for s in agent._plan] |
| 61 | assert "cross_reference" in steps |
| 62 | |
| 63 | def test_insight_for_many_captures(self): |
| 64 | agent = AgentOrchestrator() |
| 65 | agent._plan = [] |
| 66 | captures = [MagicMock() for _ in range(5)] |
| 67 | diagrams = [MagicMock() for _ in range(2)] |
| 68 | agent._adapt_plan("detect_diagrams", {"diagrams": diagrams, "captures": captures}) |
| 69 | assert len(agent._insights) == 1 |
| 70 | assert "uncertain frames" in agent._insights[0] |
| 71 | |
| 72 | def test_no_duplicate_steps(self): |
| 73 | agent = AgentOrchestrator() |
| 74 | agent._plan = [{"step": "deep_analysis", "priority": "comprehensive"}] |
| 75 | long_text = "word " * 3000 |
| 76 | agent._adapt_plan("transcribe", {"text": long_text}) |
| 77 | deep_steps = [s for s in agent._plan if s["step"] == "deep_analysis"] |
| @@ -78,28 +469,35 @@ | |
| 78 | assert len(deep_steps) == 1 |
| 79 | |
| 80 | |
| 81 | class TestFallbacks: |
| 82 | def test_diagram_fallback(self): |
| 83 | agent = AgentOrchestrator() |
| 84 | assert agent._get_fallback("detect_diagrams") == "screengrab_fallback" |
| 85 | |
| 86 | def test_no_fallback_for_unknown(self): |
| 87 | agent = AgentOrchestrator() |
| 88 | assert agent._get_fallback("transcribe") is None |
| 89 | |
| 90 | |
| 91 | class TestInsights: |
| 92 | def test_insights_property(self): |
| 93 | agent = AgentOrchestrator() |
| 94 | agent._insights = ["Insight 1", "Insight 2"] |
| 95 | assert agent.insights == ["Insight 1", "Insight 2"] |
| 96 | # Should return a copy |
| 97 | agent.insights.append("should not modify internal") |
| 98 | assert len(agent._insights) == 2 |
| 99 | |
| 100 | def test_deep_analysis_populates_insights(self): |
| 101 | pm = MagicMock() |
| 102 | pm.chat.return_value = json.dumps( |
| 103 | { |
| 104 | "decisions": ["Decided to use microservices"], |
| 105 | "risks": ["Timeline is tight"], |
| @@ -113,57 +511,61 @@ | |
| 113 | assert "decisions" in result |
| 114 | assert any("microservices" in i for i in agent._insights) |
| 115 | assert any("Timeline" in i for i in agent._insights) |
| 116 | |
| 117 | def test_deep_analysis_handles_error(self): |
| 118 | pm = MagicMock() |
| 119 | pm.chat.side_effect = Exception("API error") |
| 120 | agent = AgentOrchestrator(provider_manager=pm) |
| 121 | agent._results["transcribe"] = {"text": "some text"} |
| 122 | result = agent._deep_analysis("/tmp") |
| 123 | assert result == {} |
| 124 | |
| 125 | def test_deep_analysis_no_transcript(self): |
| 126 | agent = AgentOrchestrator() |
| 127 | agent._results["transcribe"] = {"text": ""} |
| 128 | result = agent._deep_analysis("/tmp") |
| 129 | assert result == {} |
| 130 | |
| 131 | |
| 132 | class TestBuildManifest: |
| 133 | def test_builds_from_results(self): |
| 134 | agent = AgentOrchestrator() |
| 135 | agent._results = { |
| 136 | "extract_frames": {"frames": [1, 2, 3], "paths": ["/a.jpg", "/b.jpg"]}, |
| 137 | "extract_audio": {"audio_path": "/audio.wav", "properties": {"duration": 60.0}}, |
| 138 | "detect_diagrams": {"diagrams": [], "captures": []}, |
| 139 | "extract_key_points": {"key_points": []}, |
| 140 | "extract_action_items": {"action_items": []}, |
| 141 | } |
| 142 | from pathlib import Path |
| 143 | |
| 144 | manifest = agent._build_manifest(Path("test.mp4"), Path("/out"), "Test", 5.0) |
| 145 | assert manifest.video.title == "Test" |
| 146 | assert manifest.stats.frames_extracted == 3 |
| 147 | assert manifest.stats.duration_seconds == 5.0 |
| 148 | assert manifest.video.duration_seconds == 60.0 |
| 149 | |
| 150 | def test_handles_missing_results(self): |
| 151 | agent = AgentOrchestrator() |
| 152 | agent._results = {} |
| 153 | from pathlib import Path |
| 154 | |
| 155 | manifest = agent._build_manifest(Path("test.mp4"), Path("/out"), None, 1.0) |
| 156 | assert manifest.video.title == "Analysis of test" |
| 157 | assert manifest.stats.frames_extracted == 0 |
| 158 | |
| 159 | def test_handles_error_results(self): |
| 160 | agent = AgentOrchestrator() |
| 161 | agent._results = { |
| 162 | "extract_frames": {"error": "failed"}, |
| 163 | "detect_diagrams": {"error": "also failed"}, |
| 164 | } |
| 165 | from pathlib import Path |
| 166 | |
| 167 | manifest = agent._build_manifest(Path("vid.mp4"), Path("/out"), None, 2.0) |
| 168 | assert manifest.stats.frames_extracted == 0 |
| 169 | assert len(manifest.diagrams) == 0 |
| 170 | |
| 171 | DDED tests/test_cli.py |
| 172 | DDED tests/test_knowledge_graph.py |
| --- tests/test_agent.py | |
| +++ tests/test_agent.py | |
| @@ -1,15 +1,392 @@ | |
| 1 | """Tests for the planning agent, skill registry, KB context, and agent loop.""" |
| 2 | |
| 3 | import json |
| 4 | from pathlib import Path |
| 5 | from unittest.mock import MagicMock, patch |
| 6 | |
| 7 | import pytest |
| 8 | |
| 9 | from video_processor.agent.skills.base import ( |
| 10 | AgentContext, |
| 11 | Artifact, |
| 12 | Skill, |
| 13 | _skills, |
| 14 | get_skill, |
| 15 | list_skills, |
| 16 | register_skill, |
| 17 | ) |
| 18 | |
| 19 | # --------------------------------------------------------------------------- |
| 20 | # Fixtures |
| 21 | # --------------------------------------------------------------------------- |
| 22 | |
| 23 | |
| 24 | @pytest.fixture(autouse=True) |
| 25 | def _clean_skill_registry(): |
| 26 | """Save and restore the global skill registry between tests.""" |
| 27 | original = dict(_skills) |
| 28 | yield |
| 29 | _skills.clear() |
| 30 | _skills.update(original) |
| 31 | |
| 32 | |
| 33 | class _DummySkill(Skill): |
| 34 | name = "dummy_test_skill" |
| 35 | description = "A dummy skill for testing" |
| 36 | |
| 37 | def execute(self, context: AgentContext, **kwargs) -> Artifact: |
| 38 | return Artifact( |
| 39 | name="dummy artifact", |
| 40 | content="dummy content", |
| 41 | artifact_type="document", |
| 42 | ) |
| 43 | |
| 44 | |
| 45 | class _NoLLMSkill(Skill): |
| 46 | """Skill that doesn't require provider_manager.""" |
| 47 | |
| 48 | name = "nollm_skill" |
| 49 | description = "Works without LLM" |
| 50 | |
| 51 | def execute(self, context: AgentContext, **kwargs) -> Artifact: |
| 52 | return Artifact( |
| 53 | name="nollm artifact", |
| 54 | content="generated", |
| 55 | artifact_type="document", |
| 56 | ) |
| 57 | |
| 58 | def can_execute(self, context: AgentContext) -> bool: |
| 59 | return context.knowledge_graph is not None |
| 60 | |
| 61 | |
| 62 | # --------------------------------------------------------------------------- |
| 63 | # Skill registry |
| 64 | # --------------------------------------------------------------------------- |
| 65 | |
| 66 | |
| 67 | class TestSkillRegistry: |
| 68 | def test_register_and_get(self): |
| 69 | skill = _DummySkill() |
| 70 | register_skill(skill) |
| 71 | assert get_skill("dummy_test_skill") is skill |
| 72 | |
| 73 | def test_get_unknown_returns_none(self): |
| 74 | assert get_skill("no_such_skill_xyz") is None |
| 75 | |
| 76 | def test_list_skills(self): |
| 77 | s1 = _DummySkill() |
| 78 | register_skill(s1) |
| 79 | skills = list_skills() |
| 80 | assert any(s.name == "dummy_test_skill" for s in skills) |
| 81 | |
| 82 | def test_list_skills_empty(self): |
| 83 | _skills.clear() |
| 84 | assert list_skills() == [] |
| 85 | |
| 86 | |
| 87 | # --------------------------------------------------------------------------- |
| 88 | # AgentContext dataclass |
| 89 | # --------------------------------------------------------------------------- |
| 90 | |
| 91 | |
| 92 | class TestAgentContext: |
| 93 | def test_defaults(self): |
| 94 | ctx = AgentContext() |
| 95 | assert ctx.knowledge_graph is None |
| 96 | assert ctx.query_engine is None |
| 97 | assert ctx.provider_manager is None |
| 98 | assert ctx.planning_entities == [] |
| 99 | assert ctx.user_requirements == {} |
| 100 | assert ctx.conversation_history == [] |
| 101 | assert ctx.artifacts == [] |
| 102 | assert ctx.config == {} |
| 103 | |
| 104 | def test_with_values(self): |
| 105 | mock_kg = MagicMock() |
| 106 | mock_qe = MagicMock() |
| 107 | mock_pm = MagicMock() |
| 108 | ctx = AgentContext( |
| 109 | knowledge_graph=mock_kg, |
| 110 | query_engine=mock_qe, |
| 111 | provider_manager=mock_pm, |
| 112 | config={"key": "value"}, |
| 113 | ) |
| 114 | assert ctx.knowledge_graph is mock_kg |
| 115 | assert ctx.config == {"key": "value"} |
| 116 | |
| 117 | def test_conversation_history_is_mutable(self): |
| 118 | ctx = AgentContext() |
| 119 | ctx.conversation_history.append({"role": "user", "content": "hello"}) |
| 120 | assert len(ctx.conversation_history) == 1 |
| 121 | |
| 122 | |
| 123 | # --------------------------------------------------------------------------- |
| 124 | # Artifact dataclass |
| 125 | # --------------------------------------------------------------------------- |
| 126 | |
| 127 | |
| 128 | class TestArtifact: |
| 129 | def test_basic(self): |
| 130 | a = Artifact(name="Plan", content="# Plan\n...", artifact_type="project_plan") |
| 131 | assert a.name == "Plan" |
| 132 | assert a.format == "markdown" # default |
| 133 | assert a.metadata == {} |
| 134 | |
| 135 | def test_with_metadata(self): |
| 136 | a = Artifact( |
| 137 | name="Tasks", |
| 138 | content="[]", |
| 139 | artifact_type="task_list", |
| 140 | format="json", |
| 141 | metadata={"source": "kg"}, |
| 142 | ) |
| 143 | assert a.format == "json" |
| 144 | assert a.metadata["source"] == "kg" |
| 145 | |
| 146 | |
| 147 | # --------------------------------------------------------------------------- |
| 148 | # Skill.can_execute |
| 149 | # --------------------------------------------------------------------------- |
| 150 | |
| 151 | |
| 152 | class TestSkillCanExecute: |
| 153 | def test_default_requires_kg_and_pm(self): |
| 154 | skill = _DummySkill() |
| 155 | ctx_no_kg = AgentContext(provider_manager=MagicMock()) |
| 156 | assert not skill.can_execute(ctx_no_kg) |
| 157 | |
| 158 | ctx_no_pm = AgentContext(knowledge_graph=MagicMock()) |
| 159 | assert not skill.can_execute(ctx_no_pm) |
| 160 | |
| 161 | ctx_both = AgentContext(knowledge_graph=MagicMock(), provider_manager=MagicMock()) |
| 162 | assert skill.can_execute(ctx_both) |
| 163 | |
| 164 | |
| 165 | # --------------------------------------------------------------------------- |
| 166 | # KBContext |
| 167 | # --------------------------------------------------------------------------- |
| 168 | |
| 169 | |
| 170 | class TestKBContext: |
| 171 | def test_add_source_nonexistent_raises(self, tmp_path): |
| 172 | from video_processor.agent.kb_context import KBContext |
| 173 | |
| 174 | ctx = KBContext() |
| 175 | with pytest.raises(FileNotFoundError, match="Not found"): |
| 176 | ctx.add_source(tmp_path / "nonexistent.json") |
| 177 | |
| 178 | def test_add_source_file(self, tmp_path): |
| 179 | from video_processor.agent.kb_context import KBContext |
| 180 | |
| 181 | f = tmp_path / "kg.json" |
| 182 | f.write_text("{}") |
| 183 | ctx = KBContext() |
| 184 | ctx.add_source(f) |
| 185 | assert len(ctx.sources) == 1 |
| 186 | assert ctx.sources[0] == f.resolve() |
| 187 | |
| 188 | def test_add_source_directory(self, tmp_path): |
| 189 | from video_processor.agent.kb_context import KBContext |
| 190 | |
| 191 | with patch( |
| 192 | "video_processor.integrators.graph_discovery.find_knowledge_graphs", |
| 193 | return_value=[tmp_path / "a.db"], |
| 194 | ): |
| 195 | ctx = KBContext() |
| 196 | ctx.add_source(tmp_path) |
| 197 | assert len(ctx.sources) == 1 |
| 198 | |
| 199 | def test_knowledge_graph_before_load_raises(self): |
| 200 | from video_processor.agent.kb_context import KBContext |
| 201 | |
| 202 | ctx = KBContext() |
| 203 | with pytest.raises(RuntimeError, match="Call load"): |
| 204 | _ = ctx.knowledge_graph |
| 205 | |
| 206 | def test_query_engine_before_load_raises(self): |
| 207 | from video_processor.agent.kb_context import KBContext |
| 208 | |
| 209 | ctx = KBContext() |
| 210 | with pytest.raises(RuntimeError, match="Call load"): |
| 211 | _ = ctx.query_engine |
| 212 | |
| 213 | def test_summary_no_data(self): |
| 214 | from video_processor.agent.kb_context import KBContext |
| 215 | |
| 216 | ctx = KBContext() |
| 217 | assert ctx.summary() == "No knowledge base loaded." |
| 218 | |
| 219 | def test_load_json_and_summary(self, tmp_path): |
| 220 | from video_processor.agent.kb_context import KBContext |
| 221 | |
| 222 | kg_data = {"nodes": [], "relationships": []} |
| 223 | f = tmp_path / "kg.json" |
| 224 | f.write_text(json.dumps(kg_data)) |
| 225 | |
| 226 | ctx = KBContext() |
| 227 | ctx.add_source(f) |
| 228 | ctx.load() |
| 229 | |
| 230 | summary = ctx.summary() |
| 231 | assert "Knowledge base" in summary |
| 232 | assert "Entities" in summary |
| 233 | assert "Relationships" in summary |
| 234 | |
| 235 | |
| 236 | # --------------------------------------------------------------------------- |
| 237 | # PlanningAgent |
| 238 | # --------------------------------------------------------------------------- |
| 239 | |
| 240 | |
| 241 | class TestPlanningAgent: |
| 242 | def test_from_kb_paths(self, tmp_path): |
| 243 | from video_processor.agent.agent_loop import PlanningAgent |
| 244 | |
| 245 | kg_data = {"nodes": [], "relationships": []} |
| 246 | f = tmp_path / "kg.json" |
| 247 | f.write_text(json.dumps(kg_data)) |
| 248 | |
| 249 | agent = PlanningAgent.from_kb_paths([f], provider_manager=None) |
| 250 | assert agent.context.knowledge_graph is not None |
| 251 | assert agent.context.provider_manager is None |
| 252 | |
| 253 | def test_execute_with_mock_provider(self, tmp_path): |
| 254 | from video_processor.agent.agent_loop import PlanningAgent |
| 255 | |
| 256 | # Register a dummy skill |
| 257 | skill = _DummySkill() |
| 258 | register_skill(skill) |
| 259 | |
| 260 | mock_pm = MagicMock() |
| 261 | mock_pm.chat.return_value = json.dumps([{"skill": "dummy_test_skill", "params": {}}]) |
| 262 | |
| 263 | ctx = AgentContext( |
| 264 | knowledge_graph=MagicMock(), |
| 265 | query_engine=MagicMock(), |
| 266 | provider_manager=mock_pm, |
| 267 | ) |
| 268 | # Mock stats().to_text() |
| 269 | ctx.query_engine.stats.return_value.to_text.return_value = "3 entities" |
| 270 | |
| 271 | agent = PlanningAgent(context=ctx) |
| 272 | artifacts = agent.execute("generate a plan") |
| 273 | |
| 274 | assert len(artifacts) == 1 |
| 275 | assert artifacts[0].name == "dummy artifact" |
| 276 | mock_pm.chat.assert_called_once() |
| 277 | |
| 278 | def test_execute_no_provider_keyword_match(self): |
| 279 | from video_processor.agent.agent_loop import PlanningAgent |
| 280 | |
| 281 | skill = _DummySkill() |
| 282 | register_skill(skill) |
| 283 | |
| 284 | ctx = AgentContext( |
| 285 | knowledge_graph=MagicMock(), |
| 286 | provider_manager=None, |
| 287 | ) |
| 288 | |
| 289 | agent = PlanningAgent(context=ctx) |
| 290 | # "dummy" is a keyword in the skill name, but can_execute needs provider_manager |
| 291 | # so it should return empty |
| 292 | artifacts = agent.execute("dummy request") |
| 293 | assert artifacts == [] |
| 294 | |
| 295 | def test_execute_keyword_match_nollm_skill(self): |
| 296 | from video_processor.agent.agent_loop import PlanningAgent |
| 297 | |
| 298 | skill = _NoLLMSkill() |
| 299 | register_skill(skill) |
| 300 | |
| 301 | ctx = AgentContext( |
| 302 | knowledge_graph=MagicMock(), |
| 303 | provider_manager=None, |
| 304 | ) |
| 305 | |
| 306 | agent = PlanningAgent(context=ctx) |
| 307 | # "nollm" is in the skill name |
| 308 | artifacts = agent.execute("nollm stuff") |
| 309 | assert len(artifacts) == 1 |
| 310 | assert artifacts[0].name == "nollm artifact" |
| 311 | |
| 312 | def test_execute_skips_unknown_skills(self): |
| 313 | from video_processor.agent.agent_loop import PlanningAgent |
| 314 | |
| 315 | mock_pm = MagicMock() |
| 316 | mock_pm.chat.return_value = json.dumps([{"skill": "nonexistent_skill_xyz", "params": {}}]) |
| 317 | |
| 318 | ctx = AgentContext( |
| 319 | knowledge_graph=MagicMock(), |
| 320 | query_engine=MagicMock(), |
| 321 | provider_manager=mock_pm, |
| 322 | ) |
| 323 | ctx.query_engine.stats.return_value.to_text.return_value = "" |
| 324 | |
| 325 | agent = PlanningAgent(context=ctx) |
| 326 | artifacts = agent.execute("do something") |
| 327 | assert artifacts == [] |
| 328 | |
| 329 | def test_chat_no_provider(self): |
| 330 | from video_processor.agent.agent_loop import PlanningAgent |
| 331 | |
| 332 | ctx = AgentContext(provider_manager=None) |
| 333 | agent = PlanningAgent(context=ctx) |
| 334 | |
| 335 | reply = agent.chat("hello") |
| 336 | assert "requires" in reply.lower() or "provider" in reply.lower() |
| 337 | |
| 338 | def test_chat_with_provider(self): |
| 339 | from video_processor.agent.agent_loop import PlanningAgent |
| 340 | |
| 341 | mock_pm = MagicMock() |
| 342 | mock_pm.chat.return_value = "I can help you plan." |
| 343 | |
| 344 | ctx = AgentContext( |
| 345 | knowledge_graph=MagicMock(), |
| 346 | query_engine=MagicMock(), |
| 347 | provider_manager=mock_pm, |
| 348 | ) |
| 349 | ctx.query_engine.stats.return_value.to_text.return_value = "5 entities" |
| 350 | |
| 351 | agent = PlanningAgent(context=ctx) |
| 352 | reply = agent.chat("help me plan") |
| 353 | |
| 354 | assert reply == "I can help you plan." |
| 355 | assert len(ctx.conversation_history) == 2 # user + assistant |
| 356 | assert ctx.conversation_history[0]["role"] == "user" |
| 357 | assert ctx.conversation_history[1]["role"] == "assistant" |
| 358 | |
| 359 | def test_chat_accumulates_history(self): |
| 360 | from video_processor.agent.agent_loop import PlanningAgent |
| 361 | |
| 362 | mock_pm = MagicMock() |
| 363 | mock_pm.chat.side_effect = ["reply1", "reply2"] |
| 364 | |
| 365 | ctx = AgentContext(provider_manager=mock_pm) |
| 366 | agent = PlanningAgent(context=ctx) |
| 367 | |
| 368 | agent.chat("msg1") |
| 369 | agent.chat("msg2") |
| 370 | |
| 371 | assert len(ctx.conversation_history) == 4 # 2 user + 2 assistant |
| 372 | # The system message is constructed each time but not stored in history |
| 373 | # Provider should receive progressively longer message lists |
| 374 | second_call_messages = mock_pm.chat.call_args_list[1][0][0] |
| 375 | # Should include system + 3 prior messages (user, assistant, user) |
| 376 | assert len(second_call_messages) == 4 # system + user + assistant + user |
| 377 | |
| 378 | |
| 379 | # --------------------------------------------------------------------------- |
| 380 | # Orchestrator tests (from existing test_agent.py — kept for coverage) |
| 381 | # --------------------------------------------------------------------------- |
| 382 | |
| 383 | |
| 384 | class TestPlanCreation: |
| 385 | def test_basic_plan(self): |
| 386 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 387 | |
| 388 | agent = AgentOrchestrator() |
| 389 | plan = agent._create_plan("test.mp4", "basic") |
| 390 | steps = [s["step"] for s in plan] |
| 391 | assert "extract_frames" in steps |
| 392 | assert "extract_audio" in steps |
| @@ -18,18 +395,22 @@ | |
| 395 | assert "extract_action_items" in steps |
| 396 | assert "generate_reports" in steps |
| 397 | assert "detect_diagrams" not in steps |
| 398 | |
| 399 | def test_standard_plan(self): |
| 400 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 401 | |
| 402 | agent = AgentOrchestrator() |
| 403 | plan = agent._create_plan("test.mp4", "standard") |
| 404 | steps = [s["step"] for s in plan] |
| 405 | assert "detect_diagrams" in steps |
| 406 | assert "build_knowledge_graph" in steps |
| 407 | assert "deep_analysis" not in steps |
| 408 | |
| 409 | def test_comprehensive_plan(self): |
| 410 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 411 | |
| 412 | agent = AgentOrchestrator() |
| 413 | plan = agent._create_plan("test.mp4", "comprehensive") |
| 414 | steps = [s["step"] for s in plan] |
| 415 | assert "detect_diagrams" in steps |
| 416 | assert "deep_analysis" in steps |
| @@ -36,42 +417,52 @@ | |
| 417 | assert "cross_reference" in steps |
| 418 | |
| 419 | |
| 420 | class TestAdaptPlan: |
| 421 | def test_adapts_for_long_transcript(self): |
| 422 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 423 | |
| 424 | agent = AgentOrchestrator() |
| 425 | agent._plan = [{"step": "generate_reports", "priority": "required"}] |
| 426 | long_text = "word " * 3000 |
| 427 | agent._adapt_plan("transcribe", {"text": long_text}) |
| 428 | steps = [s["step"] for s in agent._plan] |
| 429 | assert "deep_analysis" in steps |
| 430 | |
| 431 | def test_no_adapt_for_short_transcript(self): |
| 432 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 433 | |
| 434 | agent = AgentOrchestrator() |
| 435 | agent._plan = [{"step": "generate_reports", "priority": "required"}] |
| 436 | agent._adapt_plan("transcribe", {"text": "Short text"}) |
| 437 | steps = [s["step"] for s in agent._plan] |
| 438 | assert "deep_analysis" not in steps |
| 439 | |
| 440 | def test_adapts_for_many_diagrams(self): |
| 441 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 442 | |
| 443 | agent = AgentOrchestrator() |
| 444 | agent._plan = [{"step": "generate_reports", "priority": "required"}] |
| 445 | diagrams = [MagicMock() for _ in range(5)] |
| 446 | agent._adapt_plan("detect_diagrams", {"diagrams": diagrams, "captures": []}) |
| 447 | steps = [s["step"] for s in agent._plan] |
| 448 | assert "cross_reference" in steps |
| 449 | |
| 450 | def test_insight_for_many_captures(self): |
| 451 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 452 | |
| 453 | agent = AgentOrchestrator() |
| 454 | agent._plan = [] |
| 455 | captures = [MagicMock() for _ in range(5)] |
| 456 | diagrams = [MagicMock() for _ in range(2)] |
| 457 | agent._adapt_plan("detect_diagrams", {"diagrams": diagrams, "captures": captures}) |
| 458 | assert len(agent._insights) == 1 |
| 459 | assert "uncertain frames" in agent._insights[0] |
| 460 | |
| 461 | def test_no_duplicate_steps(self): |
| 462 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 463 | |
| 464 | agent = AgentOrchestrator() |
| 465 | agent._plan = [{"step": "deep_analysis", "priority": "comprehensive"}] |
| 466 | long_text = "word " * 3000 |
| 467 | agent._adapt_plan("transcribe", {"text": long_text}) |
| 468 | deep_steps = [s for s in agent._plan if s["step"] == "deep_analysis"] |
| @@ -78,28 +469,35 @@ | |
| 469 | assert len(deep_steps) == 1 |
| 470 | |
| 471 | |
| 472 | class TestFallbacks: |
| 473 | def test_diagram_fallback(self): |
| 474 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 475 | |
| 476 | agent = AgentOrchestrator() |
| 477 | assert agent._get_fallback("detect_diagrams") == "screengrab_fallback" |
| 478 | |
| 479 | def test_no_fallback_for_unknown(self): |
| 480 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 481 | |
| 482 | agent = AgentOrchestrator() |
| 483 | assert agent._get_fallback("transcribe") is None |
| 484 | |
| 485 | |
| 486 | class TestInsights: |
| 487 | def test_insights_property(self): |
| 488 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 489 | |
| 490 | agent = AgentOrchestrator() |
| 491 | agent._insights = ["Insight 1", "Insight 2"] |
| 492 | assert agent.insights == ["Insight 1", "Insight 2"] |
| 493 | agent.insights.append("should not modify internal") |
| 494 | assert len(agent._insights) == 2 |
| 495 | |
| 496 | def test_deep_analysis_populates_insights(self): |
| 497 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 498 | |
| 499 | pm = MagicMock() |
| 500 | pm.chat.return_value = json.dumps( |
| 501 | { |
| 502 | "decisions": ["Decided to use microservices"], |
| 503 | "risks": ["Timeline is tight"], |
| @@ -113,57 +511,61 @@ | |
| 511 | assert "decisions" in result |
| 512 | assert any("microservices" in i for i in agent._insights) |
| 513 | assert any("Timeline" in i for i in agent._insights) |
| 514 | |
| 515 | def test_deep_analysis_handles_error(self): |
| 516 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 517 | |
| 518 | pm = MagicMock() |
| 519 | pm.chat.side_effect = Exception("API error") |
| 520 | agent = AgentOrchestrator(provider_manager=pm) |
| 521 | agent._results["transcribe"] = {"text": "some text"} |
| 522 | result = agent._deep_analysis("/tmp") |
| 523 | assert result == {} |
| 524 | |
| 525 | def test_deep_analysis_no_transcript(self): |
| 526 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 527 | |
| 528 | agent = AgentOrchestrator() |
| 529 | agent._results["transcribe"] = {"text": ""} |
| 530 | result = agent._deep_analysis("/tmp") |
| 531 | assert result == {} |
| 532 | |
| 533 | |
| 534 | class TestBuildManifest: |
| 535 | def test_builds_from_results(self): |
| 536 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 537 | |
| 538 | agent = AgentOrchestrator() |
| 539 | agent._results = { |
| 540 | "extract_frames": {"frames": [1, 2, 3], "paths": ["/a.jpg", "/b.jpg"]}, |
| 541 | "extract_audio": {"audio_path": "/audio.wav", "properties": {"duration": 60.0}}, |
| 542 | "detect_diagrams": {"diagrams": [], "captures": []}, |
| 543 | "extract_key_points": {"key_points": []}, |
| 544 | "extract_action_items": {"action_items": []}, |
| 545 | } |
| 546 | manifest = agent._build_manifest(Path("test.mp4"), Path("/out"), "Test", 5.0) |
| 547 | assert manifest.video.title == "Test" |
| 548 | assert manifest.stats.frames_extracted == 3 |
| 549 | assert manifest.stats.duration_seconds == 5.0 |
| 550 | assert manifest.video.duration_seconds == 60.0 |
| 551 | |
| 552 | def test_handles_missing_results(self): |
| 553 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 554 | |
| 555 | agent = AgentOrchestrator() |
| 556 | agent._results = {} |
| 557 | manifest = agent._build_manifest(Path("test.mp4"), Path("/out"), None, 1.0) |
| 558 | assert manifest.video.title == "Analysis of test" |
| 559 | assert manifest.stats.frames_extracted == 0 |
| 560 | |
| 561 | def test_handles_error_results(self): |
| 562 | from video_processor.agent.orchestrator import AgentOrchestrator |
| 563 | |
| 564 | agent = AgentOrchestrator() |
| 565 | agent._results = { |
| 566 | "extract_frames": {"error": "failed"}, |
| 567 | "detect_diagrams": {"error": "also failed"}, |
| 568 | } |
| 569 | manifest = agent._build_manifest(Path("vid.mp4"), Path("/out"), None, 2.0) |
| 570 | assert manifest.stats.frames_extracted == 0 |
| 571 | assert len(manifest.diagrams) == 0 |
| 572 | |
| 573 | DDED tests/test_cli.py |
| 574 | DDED tests/test_knowledge_graph.py |
+1
| --- a/tests/test_cli.py | ||
| +++ b/tests/test_cli.py | ||
| @@ -0,0 +1 @@ | ||
| 1 | +"batchAuthbatc2.0 |
| --- a/tests/test_cli.py | |
| +++ b/tests/test_cli.py | |
| @@ -0,0 +1 @@ | |
| --- a/tests/test_cli.py | |
| +++ b/tests/test_cli.py | |
| @@ -0,0 +1 @@ | |
| 1 | "batchAuthbatc2.0 |
| --- a/tests/test_knowledge_graph.py | ||
| +++ b/tests/test_knowledge_graph.py | ||
| @@ -0,0 +1,422 @@ | ||
| 1 | +"""Tests for the KnowledgeGraph class.""" | |
| 2 | + | |
| 3 | +import json | |
| 4 | +from unittest.mock import MagicMock, patch | |
| 5 | + | |
| 6 | +import pytest | |
| 7 | + | |
| 8 | +from video_processor.integrators.knowledge_graph import KnowledgeGraph | |
| 9 | + | |
| 10 | + | |
| 11 | +@pytest.fixture | |
| 12 | +def mock_pm(): | |
| 13 | + """A mock ProviderManager that returns predictable JSON from chat().""" | |
| 14 | + pm = MagicMock() | |
| 15 | + pm.chat.return_value = json.dumps( | |
| 16 | + { | |
| 17 | + "entities": [ | |
| 18 | + {"name": "Python", "type": "technology", "description": "A programming language"}, | |
| 19 | + {"name": "Alice", "type": "person", "description": "Lead developer"}, | |
| 20 | + ], | |
| 21 | + "relationships": [ | |
| 22 | + {"source": "Alice", "target": "Python", "type": "uses"}, | |
| 23 | + ], | |
| 24 | + } | |
| 25 | + ) | |
| 26 | + return pm | |
| 27 | + | |
| 28 | + | |
| 29 | +@pytest.fixture | |
| 30 | +def kg_no_provider(): | |
| 31 | + """KnowledgeGraph with no provider (in-memory store).""" | |
| 32 | + return KnowledgeGraph() | |
| 33 | + | |
| 34 | + | |
| 35 | +@pytest.fixture | |
| 36 | +def kg_with_provider(mock_pm): | |
| 37 | + """KnowledgeGraph with a mock provider (in-memory store).""" | |
| 38 | + return KnowledgeGraph(provider_manager=mock_pm) | |
| 39 | + | |
| 40 | + | |
| 41 | +class TestCreation: | |
| 42 | + def test_create_without_db_path(self): | |
| 43 | + kg = KnowledgeGraph() | |
| 44 | + assert kg.pm is None | |
| 45 | + assert kg._store.get_entity_count() == 0 | |
| 46 | + assert kg._store.get_relationship_count() == 0 | |
| 47 | + | |
| 48 | + def test_create_with_db_path(self, tmp_path): | |
| 49 | + db_path = tmp_path / "test.db" | |
| 50 | + kg = KnowledgeGraph(db_path=db_path) | |
| 51 | + assert kg._store.get_entity_count() == 0 | |
| 52 | + assert db_path.exists() | |
| 53 | + | |
| 54 | + def test_create_with_provider(self, mock_pm): | |
| 55 | + kg = KnowledgeGraph(provider_manager=mock_pm) | |
| 56 | + assert kg.pm is mock_pm | |
| 57 | + | |
| 58 | + | |
| 59 | +class TestProcessTranscript: | |
| 60 | + def test_process_transcript_extracts_entities(self, kg_with_provider, mock_pm): | |
| 61 | + transcript = { | |
| 62 | + "segments": [ | |
| 63 | + {"text": "Alice is using Python for the project", "start": 0.0, "speaker": "Alice"}, | |
| 64 | + {"text": "It works great for data processing", "start": 5.0}, | |
| 65 | + ] | |
| 66 | + } | |
| 67 | + kg_with_provider.process_transcript(transcript) | |
| 68 | + | |
| 69 | + # The mock returns Python and Alice as entities | |
| 70 | + nodes = kg_with_provider.nodes | |
| 71 | + assert "Python" in nodes | |
| 72 | + assert "Alice" in nodes | |
| 73 | + assert nodes["Python"]["type"] == "technology" | |
| 74 | + | |
| 75 | + def test_process_transcript_registers_speakers(self, kg_with_provider): | |
| 76 | + transcript = { | |
| 77 | + "segments": [ | |
| 78 | + {"text": "Hello everyone", "start": 0.0, "speaker": "Bob"}, | |
| 79 | + ] | |
| 80 | + } | |
| 81 | + kg_with_provider.process_transcript(transcript) | |
| 82 | + assert kg_with_provider._store.has_entity("Bob") | |
| 83 | + | |
| 84 | + def test_process_transcript_missing_segments(self, kg_with_provider): | |
| 85 | + """Should log warning and return without error.""" | |
| 86 | + kg_with_provider.process_transcript({}) | |
| 87 | + assert kg_with_provider._store.get_entity_count() == 0 | |
| 88 | + | |
| 89 | + def test_process_transcript_empty_text_skipped(self, kg_with_provider, mock_pm): | |
| 90 | + transcript = { | |
| 91 | + "segments": [ | |
| 92 | + {"text": " ", "start": 0.0}, | |
| 93 | + ] | |
| 94 | + } | |
| 95 | + kg_with_provider.process_transcript(transcript) | |
| 96 | + # chat should not be called for empty batches (speaker registration may still happen) | |
| 97 | + mock_pm.chat.assert_not_called() | |
| 98 | + | |
| 99 | + def test_process_transcript_batching(self, kg_with_provider, mock_pm): | |
| 100 | + """With batch_size=2, 5 segments should produce 3 batches.""" | |
| 101 | + segments = [{"text": f"Segment {i}", "start": float(i)} for i in range(5)] | |
| 102 | + transcript = {"segments": segments} | |
| 103 | + kg_with_provider.process_transcript(transcript, batch_size=2) | |
| 104 | + assert mock_pm.chat.call_count == 3 | |
| 105 | + | |
| 106 | + | |
| 107 | +class TestProcessDiagrams: | |
| 108 | + def test_process_diagrams_with_text(self, kg_with_provider, mock_pm): | |
| 109 | + diagrams = [ | |
| 110 | + {"text_content": "Architecture shows Python microservices", "frame_index": 0}, | |
| 111 | + ] | |
| 112 | + kg_with_provider.process_diagrams(diagrams) | |
| 113 | + | |
| 114 | + # Should have called chat once for the text content | |
| 115 | + assert mock_pm.chat.call_count == 1 | |
| 116 | + # diagram_0 entity should exist | |
| 117 | + assert kg_with_provider._store.has_entity("diagram_0") | |
| 118 | + | |
| 119 | + def test_process_diagrams_without_text(self, kg_with_provider, mock_pm): | |
| 120 | + diagrams = [ | |
| 121 | + {"text_content": "", "frame_index": 5}, | |
| 122 | + ] | |
| 123 | + kg_with_provider.process_diagrams(diagrams) | |
| 124 | + # No chat call for empty text | |
| 125 | + mock_pm.chat.assert_not_called() | |
| 126 | + # But diagram entity still created | |
| 127 | + assert kg_with_provider._store.has_entity("diagram_0") | |
| 128 | + | |
| 129 | + def test_process_multiple_diagrams(self, kg_with_provider, mock_pm): | |
| 130 | + diagrams = [ | |
| 131 | + {"text_content": "Diagram A content", "frame_index": 0}, | |
| 132 | + {"text_content": "Diagram B content", "frame_index": 10}, | |
| 133 | + ] | |
| 134 | + kg_with_provider.process_diagrams(diagrams) | |
| 135 | + assert kg_with_provider._store.has_entity("diagram_0") | |
| 136 | + assert kg_with_provider._store.has_entity("diagram_1") | |
| 137 | + | |
| 138 | + | |
| 139 | +class Testcess_screenshots(screenshots) | |
| 140 | + # LLM extraction from text_content | |
| 141 | + mock_pm.chat.assert_called() | |
| 142 | + # Explicitly listed entities should be added | |
| 143 | + assert kg_with_provider._store.has_entity("Flask") | |
| 144 | + assert kg_with_provider._store.has_entity("Python") | |
| 145 | + | |
| 146 | + def test_process_screenshots_without_text(self, kg_with_provider, mock_pm): | |
| 147 | + screenshots = [ | |
| 148 | + { | |
| 149 | + "text_content": "", | |
| 150 | + "content_type": "other", | |
| 151 | + "entities": ["Docker"], | |
| 152 | + "frame_index": 5, | |
| 153 | + }, | |
| 154 | + ] | |
| 155 | + kg_with_provider.process_screenshots(screenshots) | |
| 156 | + # No chat call for empty text | |
| 157 | + mock_pm.chat.assert_not_called() | |
| 158 | + # But explicit entities still added | |
| 159 | + assert kg_with_provider._store.has_entity("Docker") | |
| 160 | + | |
| 161 | + def test_process_screenshots_empty_entities(self, kg_with_provider): | |
| 162 | + screenshots = [ | |
| 163 | + { | |
| 164 | + "text_content": "", | |
| 165 | + "content_type": "slide", | |
| 166 | + "entities": [], | |
| 167 | + "frame_index": 0, | |
| 168 | + }, | |
| 169 | + ] | |
| 170 | + kg_with_provider.process_screenshots(screenshots) | |
| 171 | + # No crash, no entities added | |
| 172 | + | |
| 173 | + def test_process_screenshots_filters_short_names(self, kg_with_provider): | |
| 174 | + screenshots = [ | |
| 175 | + { | |
| 176 | + "text_content": "", | |
| 177 | + "entities": ["A", "Go", "Python"], | |
| 178 | + "frame_index": 0, | |
| 179 | + }, | |
| 180 | + ] | |
| 181 | + kg_with_provider.process_screenshots(screenshots) | |
| 182 | + # "A" is too short (< 2 chars), filtered out | |
| 183 | + assert not kg_with_provider._store.has_entity("A") | |
| 184 | + assert kg_with_provider._store.has_entity("Go") | |
| 185 | + assert kg_with_provider._store.has_entity("Python") | |
| 186 | + | |
| 187 | + | |
| 188 | +class TestToDictFromDict: | |
| 189 | + def test_round_trip_empty(self): | |
| 190 | + kg = KnowledgeGraph() | |
| 191 | + data = kg.to_dict() | |
| 192 | + kg2 = KnowledgeGraph.from_dict(data) | |
| 193 | + assert kg2._store.get_entity_count() == 0 | |
| 194 | + assert kg2._store.get_relationship_count() == 0 | |
| 195 | + | |
| 196 | + def test_round_trip_with_entities(self, kg_with_provider, mock_pm): | |
| 197 | + # Add some content to populate the graph | |
| 198 | + kg_with_provider.add_content("Alice uses Python", "test_source") | |
| 199 | + original = kg_with_provider.to_dict() | |
| 200 | + | |
| 201 | + restored = KnowledgeGraph.from_dict(original) | |
| 202 | + restored_dict = restored.to_dict() | |
| 203 | + | |
| 204 | + assert len(restored_dict["nodes"]) == len(original["nodes"]) | |
| 205 | + assert len(restored_dict["relationships"]) == len(original["relationships"]) | |
| 206 | + | |
| 207 | + original_names = {n["name"] for n in original["nodes"]} | |
| 208 | + restored_names = {n["name"] for n in restored_dict["nodes"]} | |
| 209 | + assert original_names == restored_names | |
| 210 | + | |
| 211 | + def test_round_trip_with_sources(self): | |
| 212 | + kg = KnowledgeGraph() | |
| 213 | + kg.register_source( | |
| 214 | + { | |
| 215 | + "source_id": "src1", | |
| 216 | + "source_type": "video", | |
| 217 | + "title": "Test Video", | |
| 218 | + "ingested_at": "2025-01-01T00:00:00", | |
| 219 | + } | |
| 220 | + ) | |
| 221 | + data = kg.to_dict() | |
| 222 | + assert "sources" in data | |
| 223 | + assert data["sources"][0]["source_id"] == "src1" | |
| 224 | + | |
| 225 | + kg2 = KnowledgeGraph.from_dict(data) | |
| 226 | + sources = kg2._store.get_sources() | |
| 227 | + assert len(sources) == 1 | |
| 228 | + assert sources[0]["source_id"] == "src1" | |
| 229 | + | |
| 230 | + def test_from_dict_with_db_path(self, tmp_path): | |
| 231 | + data = { | |
| 232 | + "nodes": [ | |
| 233 | + {"name": "TestEntity", "type": "concept", "descriptions": ["A test"]}, | |
| 234 | + ], | |
| 235 | + "relationships": [], | |
| 236 | + } | |
| 237 | + db_path = tmp_path / "restored.db" | |
| 238 | + kg = KnowledgeGraph.from_dict(data, db_path=db_path) | |
| 239 | + assert kg._store.has_entity("TestEntity") | |
| 240 | + assert db_path.exists() | |
| 241 | + | |
| 242 | + | |
| 243 | +class TestSave: | |
| 244 | + def test_save_json(self, tmp_path, kg_with_provider, mock_pm): | |
| 245 | + kg_with_provider.add_content("Alice uses Python", "source1") | |
| 246 | + path = tmp_path / "graph.json" | |
| 247 | + result = kg_with_provider.save(path) | |
| 248 | + | |
| 249 | + assert result == path | |
| 250 | + assert path.exists() | |
| 251 | + data = json.loads(path.read_text()) | |
| 252 | + assert "nodes" in data | |
| 253 | + assert "relationships" in data | |
| 254 | + | |
| 255 | + def test_save_db(self, tmp_path, kg_with_provider, mock_pm): | |
| 256 | + kg_with_provider.add_content("Alice uses Python", "source1") | |
| 257 | + path = tmp_path / "graph.db" | |
| 258 | + result = kg_with_provider.save(path) | |
| 259 | + | |
| 260 | + assert result == path | |
| 261 | + assert path.exists() | |
| 262 | + | |
| 263 | + def test_save_no_suffix_defaults_to_db(self, tmp_path, kg_with_provider, mock_pm): | |
| 264 | + kg_with_provider.add_content("Alice uses Python", "source1") | |
| 265 | + path = tmp_path / "graph" | |
| 266 | + result = kg_with_provider.save(path) | |
| 267 | + assert result.suffix == ".db" | |
| 268 | + assert result.exists() | |
| 269 | + | |
| 270 | + def test_save_creates_parent_dirs(self, tmp_path, kg_with_provider, mock_pm): | |
| 271 | + kg_with_provider.add_content("Alice uses Python", "source1") | |
| 272 | + path = tmp_path / "nested" / "dir" / "graph.json" | |
| 273 | + result = kg_with_provider.save(path) | |
| 274 | + assert result.exists() | |
| 275 | + | |
| 276 | + def test_save_unknown_suffix_falls_back_to_json(self, tmp_path): | |
| 277 | + kg = KnowledgeGraph() | |
| 278 | + kg._store.merge_entity("TestNode", "concept", ["test"]) | |
| 279 | + path = tmp_path / "graph.xyz" | |
| 280 | + result = kg.save(path) | |
| 281 | + assert result.exists() | |
| 282 | + # Should be valid JSON | |
| 283 | + data = json.loads(path.read_text()) | |
| 284 | + assert "nodes" in data | |
| 285 | + | |
| 286 | + | |
| 287 | +class TestMerge: | |
| 288 | + def test_merge_disjoint(self): | |
| 289 | + kg1 = KnowledgeGraph() | |
| 290 | + kg1._store.merge_entity("Alice", "person", ["Developer"]) | |
| 291 | + | |
| 292 | + kg2 = KnowledgeGraph() | |
| 293 | + kg2._store.merge_entity("Bob", "person", ["Manager"]) | |
| 294 | + | |
| 295 | + kg1.merge(kg2) | |
| 296 | + assert kg1._store.has_entity("Alice") | |
| 297 | + assert kg1._store.has_entity("Bob") | |
| 298 | + assert kg1._store.get_entity_count() == 2 | |
| 299 | + | |
| 300 | + def test_merge_overlapping_entities_descriptions_merged(self): | |
| 301 | + kg1 = KnowledgeGraph() | |
| 302 | + kg1._store.merge_entity("Python", "concept", ["A language"]) | |
| 303 | + | |
| 304 | + kg2 = KnowledgeGraph() | |
| 305 | + kg2._store.merge_entity("Python", "technology", ["Programming language"]) | |
| 306 | + | |
| 307 | + kg1.merge(kg2) | |
| 308 | + entity = kg1._store.get_entity("Python") | |
| 309 | + # Descriptions from both should be present | |
| 310 | + descs = entity["descriptions"] | |
| 311 | + if isinstance(descs, set): | |
| 312 | + descs = list(descs) | |
| 313 | + assert "A language" in descs | |
| 314 | + assert "Programming language" in descs | |
| 315 | + | |
| 316 | + def test_merge_overlapping_entities_with_sqlite(self, tmp_path): | |
| 317 | + """SQLiteStore does update type on merge_entity, so type resolution works there.""" | |
| 318 | + kg1 = KnowledgeGraph(db_path=tmp_path / "kg1.db") | |
| 319 | + kg1._store.merge_entity("Python", "concept", ["A language"]) | |
| 320 | + | |
| 321 | + kg2 = KnowledgeGraph(db_path=tmp_path / "kg2.db") | |
| 322 | + kg2._store.merge_entity("Python", "technology", ["Programming language"]) | |
| 323 | + | |
| 324 | + kg1.merge(kg2) | |
| 325 | + entity = kg1._store.get_entity("Python") | |
| 326 | + # SQLiteStore overwrites type — merge resolves to more specific | |
| 327 | + # (The merge method computes the resolved type and passes it to merge_entity, | |
| 328 | + # but InMemoryStore ignores type for existing entities while SQLiteStore does not) | |
| 329 | + assert entity is not None | |
| 330 | + assert kg1._store.get_entity_count() == 1 | |
| 331 | + | |
| 332 | + def test_merge_fuzzy_match(self): | |
| 333 | + kg1 = KnowledgeGraph() | |
| 334 | + kg1._store.merge_entity("JavaScript", "technology", ["A language"]) | |
| 335 | + | |
| 336 | + kg2 = KnowledgeGraph() | |
| 337 | + kg2._store.merge_entity("Javascript", "technology", ["Web language"]) | |
| 338 | + | |
| 339 | + kg1.merge(kg2) | |
| 340 | + # Should fuzzy-match and merge, not create two entities | |
| 341 | + assert kg1._store.get_entity_count() == 1 | |
| 342 | + entity = kg1._store.get_entity("JavaScript") | |
| 343 | + assert entity is not None | |
| 344 | + | |
| 345 | + def test_merge_relationships(self): | |
| 346 | + kg1 = KnowledgeGraph() | |
| 347 | + kg1._store.merge_entity("Alice", "person", []) | |
| 348 | + | |
| 349 | + kg2 = KnowledgeGraph() | |
| 350 | + kg2._store.merge_entity("Bob", "person", []) | |
| 351 | + kg2._store.add_relationship("Alice", "Bob", "collaborates_with") | |
| 352 | + | |
| 353 | + kg1.merge(kg2) | |
| 354 | + rels = kg1._store.get_all_relationships() | |
| 355 | + assert len(rels) == 1 | |
| 356 | + assert rels[0]["type"] == "collaborates_with" | |
| 357 | + | |
| 358 | + def test_merge_sources(self): | |
| 359 | + kg1 = KnowledgeGraph() | |
| 360 | + kg2 = KnowledgeGraph() | |
| 361 | + kg2.register_source( | |
| 362 | + { | |
| 363 | + "source_id": "vid2", | |
| 364 | + "source_type": "video", | |
| 365 | + "title": "Video 2", | |
| 366 | + "ingested_at": "2025-01-01T00:00:00", | |
| 367 | + } | |
| 368 | + ) | |
| 369 | + kg1.merge(kg2) | |
| 370 | + sources = kg1._store.get_sources() | |
| 371 | + assert len(sources) == 1 | |
| 372 | + assert sources[0]["source_id"] == "vid2" | |
| 373 | + | |
| 374 | + def test_merge_type_specificity_with_sqlite(self, tmp_path): | |
| 375 | + """Type specificity resolution works with SQLiteStore which updates type.""" | |
| 376 | + kg1 = KnowledgeGraph(db_path=tmp_path / "kg1.db") | |
| 377 | + kg1._store.merge_entity("React", "concept", []) | |
| 378 | + | |
| 379 | + kg2 = KnowledgeGraph(db_path=tmp_path / "kg2.db") | |
| 380 | + kg2._store.merge_entity("React", "technology", []) | |
| 381 | + | |
| 382 | + kg1.merge(kg2) | |
| 383 | + entity = kg1._store.get_entity("React") | |
| 384 | + assert entity is not None | |
| 385 | + assert kg1._store.get_entity_count() == 1 | |
| 386 | + | |
| 387 | + | |
| 388 | +class TestRegisterSource: | |
| 389 | + def test_register_and_retrieve(self): | |
| 390 | + kg = KnowledgeGraph() | |
| 391 | + source = { | |
| 392 | + "source_id": "src123", | |
| 393 | + "source_type": "video", | |
| 394 | + "title": "Meeting Recording", | |
| 395 | + "path": "/tmp/meeting.mp4", | |
| 396 | + "ingested_at": "2025-06-01T10:00:00", | |
| 397 | + } | |
| 398 | + kg.register_source(source) | |
| 399 | + sources = kg._store.get_sources() | |
| 400 | + assert len(sources) == 1 | |
| 401 | + assert sources[0]["source_id"] == "src123" | |
| 402 | + assert sources[0]["title"] == "Meeting Recording" | |
| 403 | + | |
| 404 | + def test_register_multiple_sources(self): | |
| 405 | + kg = KnowledgeGraph() | |
| 406 | + for i in range(3): | |
| 407 | + kg.register_source( | |
| 408 | + { | |
| 409 | + "source_id": f"src{i}", | |
| 410 | + "source_type": "video", | |
| 411 | + "title": f"Video {i}", | |
| 412 | + "ingested_at": "2025-01-01", | |
| 413 | + } | |
| 414 | + ) | |
| 415 | + assert len(kg._store.get_sources()) == 3 | |
| 416 | + | |
| 417 | + | |
| 418 | +class TestClassifyForPlanning: | |
| 419 | + @patch("video_processor.integrators.knowledge_graph.TaxonomyClassifier", create=True) | |
| 420 | + def test_classify_calls_taxonomy(self, mock_cls): | |
| 421 | + """classify_for_planning should delegate to TaxonomyClassifier.""" | |
| 422 | + mock_ |
| --- a/tests/test_knowledge_graph.py | |
| +++ b/tests/test_knowledge_graph.py | |
| @@ -0,0 +1,422 @@ | |
| --- a/tests/test_knowledge_graph.py | |
| +++ b/tests/test_knowledge_graph.py | |
| @@ -0,0 +1,422 @@ | |
| 1 | """Tests for the KnowledgeGraph class.""" |
| 2 | |
| 3 | import json |
| 4 | from unittest.mock import MagicMock, patch |
| 5 | |
| 6 | import pytest |
| 7 | |
| 8 | from video_processor.integrators.knowledge_graph import KnowledgeGraph |
| 9 | |
| 10 | |
| 11 | @pytest.fixture |
| 12 | def mock_pm(): |
| 13 | """A mock ProviderManager that returns predictable JSON from chat().""" |
| 14 | pm = MagicMock() |
| 15 | pm.chat.return_value = json.dumps( |
| 16 | { |
| 17 | "entities": [ |
| 18 | {"name": "Python", "type": "technology", "description": "A programming language"}, |
| 19 | {"name": "Alice", "type": "person", "description": "Lead developer"}, |
| 20 | ], |
| 21 | "relationships": [ |
| 22 | {"source": "Alice", "target": "Python", "type": "uses"}, |
| 23 | ], |
| 24 | } |
| 25 | ) |
| 26 | return pm |
| 27 | |
| 28 | |
| 29 | @pytest.fixture |
| 30 | def kg_no_provider(): |
| 31 | """KnowledgeGraph with no provider (in-memory store).""" |
| 32 | return KnowledgeGraph() |
| 33 | |
| 34 | |
| 35 | @pytest.fixture |
| 36 | def kg_with_provider(mock_pm): |
| 37 | """KnowledgeGraph with a mock provider (in-memory store).""" |
| 38 | return KnowledgeGraph(provider_manager=mock_pm) |
| 39 | |
| 40 | |
| 41 | class TestCreation: |
| 42 | def test_create_without_db_path(self): |
| 43 | kg = KnowledgeGraph() |
| 44 | assert kg.pm is None |
| 45 | assert kg._store.get_entity_count() == 0 |
| 46 | assert kg._store.get_relationship_count() == 0 |
| 47 | |
| 48 | def test_create_with_db_path(self, tmp_path): |
| 49 | db_path = tmp_path / "test.db" |
| 50 | kg = KnowledgeGraph(db_path=db_path) |
| 51 | assert kg._store.get_entity_count() == 0 |
| 52 | assert db_path.exists() |
| 53 | |
| 54 | def test_create_with_provider(self, mock_pm): |
| 55 | kg = KnowledgeGraph(provider_manager=mock_pm) |
| 56 | assert kg.pm is mock_pm |
| 57 | |
| 58 | |
| 59 | class TestProcessTranscript: |
| 60 | def test_process_transcript_extracts_entities(self, kg_with_provider, mock_pm): |
| 61 | transcript = { |
| 62 | "segments": [ |
| 63 | {"text": "Alice is using Python for the project", "start": 0.0, "speaker": "Alice"}, |
| 64 | {"text": "It works great for data processing", "start": 5.0}, |
| 65 | ] |
| 66 | } |
| 67 | kg_with_provider.process_transcript(transcript) |
| 68 | |
| 69 | # The mock returns Python and Alice as entities |
| 70 | nodes = kg_with_provider.nodes |
| 71 | assert "Python" in nodes |
| 72 | assert "Alice" in nodes |
| 73 | assert nodes["Python"]["type"] == "technology" |
| 74 | |
| 75 | def test_process_transcript_registers_speakers(self, kg_with_provider): |
| 76 | transcript = { |
| 77 | "segments": [ |
| 78 | {"text": "Hello everyone", "start": 0.0, "speaker": "Bob"}, |
| 79 | ] |
| 80 | } |
| 81 | kg_with_provider.process_transcript(transcript) |
| 82 | assert kg_with_provider._store.has_entity("Bob") |
| 83 | |
| 84 | def test_process_transcript_missing_segments(self, kg_with_provider): |
| 85 | """Should log warning and return without error.""" |
| 86 | kg_with_provider.process_transcript({}) |
| 87 | assert kg_with_provider._store.get_entity_count() == 0 |
| 88 | |
| 89 | def test_process_transcript_empty_text_skipped(self, kg_with_provider, mock_pm): |
| 90 | transcript = { |
| 91 | "segments": [ |
| 92 | {"text": " ", "start": 0.0}, |
| 93 | ] |
| 94 | } |
| 95 | kg_with_provider.process_transcript(transcript) |
| 96 | # chat should not be called for empty batches (speaker registration may still happen) |
| 97 | mock_pm.chat.assert_not_called() |
| 98 | |
| 99 | def test_process_transcript_batching(self, kg_with_provider, mock_pm): |
| 100 | """With batch_size=2, 5 segments should produce 3 batches.""" |
| 101 | segments = [{"text": f"Segment {i}", "start": float(i)} for i in range(5)] |
| 102 | transcript = {"segments": segments} |
| 103 | kg_with_provider.process_transcript(transcript, batch_size=2) |
| 104 | assert mock_pm.chat.call_count == 3 |
| 105 | |
| 106 | |
| 107 | class TestProcessDiagrams: |
| 108 | def test_process_diagrams_with_text(self, kg_with_provider, mock_pm): |
| 109 | diagrams = [ |
| 110 | {"text_content": "Architecture shows Python microservices", "frame_index": 0}, |
| 111 | ] |
| 112 | kg_with_provider.process_diagrams(diagrams) |
| 113 | |
| 114 | # Should have called chat once for the text content |
| 115 | assert mock_pm.chat.call_count == 1 |
| 116 | # diagram_0 entity should exist |
| 117 | assert kg_with_provider._store.has_entity("diagram_0") |
| 118 | |
| 119 | def test_process_diagrams_without_text(self, kg_with_provider, mock_pm): |
| 120 | diagrams = [ |
| 121 | {"text_content": "", "frame_index": 5}, |
| 122 | ] |
| 123 | kg_with_provider.process_diagrams(diagrams) |
| 124 | # No chat call for empty text |
| 125 | mock_pm.chat.assert_not_called() |
| 126 | # But diagram entity still created |
| 127 | assert kg_with_provider._store.has_entity("diagram_0") |
| 128 | |
| 129 | def test_process_multiple_diagrams(self, kg_with_provider, mock_pm): |
| 130 | diagrams = [ |
| 131 | {"text_content": "Diagram A content", "frame_index": 0}, |
| 132 | {"text_content": "Diagram B content", "frame_index": 10}, |
| 133 | ] |
| 134 | kg_with_provider.process_diagrams(diagrams) |
| 135 | assert kg_with_provider._store.has_entity("diagram_0") |
| 136 | assert kg_with_provider._store.has_entity("diagram_1") |
| 137 | |
| 138 | |
| 139 | class Testcess_screenshots(screenshots) |
| 140 | # LLM extraction from text_content |
| 141 | mock_pm.chat.assert_called() |
| 142 | # Explicitly listed entities should be added |
| 143 | assert kg_with_provider._store.has_entity("Flask") |
| 144 | assert kg_with_provider._store.has_entity("Python") |
| 145 | |
| 146 | def test_process_screenshots_without_text(self, kg_with_provider, mock_pm): |
| 147 | screenshots = [ |
| 148 | { |
| 149 | "text_content": "", |
| 150 | "content_type": "other", |
| 151 | "entities": ["Docker"], |
| 152 | "frame_index": 5, |
| 153 | }, |
| 154 | ] |
| 155 | kg_with_provider.process_screenshots(screenshots) |
| 156 | # No chat call for empty text |
| 157 | mock_pm.chat.assert_not_called() |
| 158 | # But explicit entities still added |
| 159 | assert kg_with_provider._store.has_entity("Docker") |
| 160 | |
| 161 | def test_process_screenshots_empty_entities(self, kg_with_provider): |
| 162 | screenshots = [ |
| 163 | { |
| 164 | "text_content": "", |
| 165 | "content_type": "slide", |
| 166 | "entities": [], |
| 167 | "frame_index": 0, |
| 168 | }, |
| 169 | ] |
| 170 | kg_with_provider.process_screenshots(screenshots) |
| 171 | # No crash, no entities added |
| 172 | |
| 173 | def test_process_screenshots_filters_short_names(self, kg_with_provider): |
| 174 | screenshots = [ |
| 175 | { |
| 176 | "text_content": "", |
| 177 | "entities": ["A", "Go", "Python"], |
| 178 | "frame_index": 0, |
| 179 | }, |
| 180 | ] |
| 181 | kg_with_provider.process_screenshots(screenshots) |
| 182 | # "A" is too short (< 2 chars), filtered out |
| 183 | assert not kg_with_provider._store.has_entity("A") |
| 184 | assert kg_with_provider._store.has_entity("Go") |
| 185 | assert kg_with_provider._store.has_entity("Python") |
| 186 | |
| 187 | |
| 188 | class TestToDictFromDict: |
| 189 | def test_round_trip_empty(self): |
| 190 | kg = KnowledgeGraph() |
| 191 | data = kg.to_dict() |
| 192 | kg2 = KnowledgeGraph.from_dict(data) |
| 193 | assert kg2._store.get_entity_count() == 0 |
| 194 | assert kg2._store.get_relationship_count() == 0 |
| 195 | |
| 196 | def test_round_trip_with_entities(self, kg_with_provider, mock_pm): |
| 197 | # Add some content to populate the graph |
| 198 | kg_with_provider.add_content("Alice uses Python", "test_source") |
| 199 | original = kg_with_provider.to_dict() |
| 200 | |
| 201 | restored = KnowledgeGraph.from_dict(original) |
| 202 | restored_dict = restored.to_dict() |
| 203 | |
| 204 | assert len(restored_dict["nodes"]) == len(original["nodes"]) |
| 205 | assert len(restored_dict["relationships"]) == len(original["relationships"]) |
| 206 | |
| 207 | original_names = {n["name"] for n in original["nodes"]} |
| 208 | restored_names = {n["name"] for n in restored_dict["nodes"]} |
| 209 | assert original_names == restored_names |
| 210 | |
| 211 | def test_round_trip_with_sources(self): |
| 212 | kg = KnowledgeGraph() |
| 213 | kg.register_source( |
| 214 | { |
| 215 | "source_id": "src1", |
| 216 | "source_type": "video", |
| 217 | "title": "Test Video", |
| 218 | "ingested_at": "2025-01-01T00:00:00", |
| 219 | } |
| 220 | ) |
| 221 | data = kg.to_dict() |
| 222 | assert "sources" in data |
| 223 | assert data["sources"][0]["source_id"] == "src1" |
| 224 | |
| 225 | kg2 = KnowledgeGraph.from_dict(data) |
| 226 | sources = kg2._store.get_sources() |
| 227 | assert len(sources) == 1 |
| 228 | assert sources[0]["source_id"] == "src1" |
| 229 | |
| 230 | def test_from_dict_with_db_path(self, tmp_path): |
| 231 | data = { |
| 232 | "nodes": [ |
| 233 | {"name": "TestEntity", "type": "concept", "descriptions": ["A test"]}, |
| 234 | ], |
| 235 | "relationships": [], |
| 236 | } |
| 237 | db_path = tmp_path / "restored.db" |
| 238 | kg = KnowledgeGraph.from_dict(data, db_path=db_path) |
| 239 | assert kg._store.has_entity("TestEntity") |
| 240 | assert db_path.exists() |
| 241 | |
| 242 | |
| 243 | class TestSave: |
| 244 | def test_save_json(self, tmp_path, kg_with_provider, mock_pm): |
| 245 | kg_with_provider.add_content("Alice uses Python", "source1") |
| 246 | path = tmp_path / "graph.json" |
| 247 | result = kg_with_provider.save(path) |
| 248 | |
| 249 | assert result == path |
| 250 | assert path.exists() |
| 251 | data = json.loads(path.read_text()) |
| 252 | assert "nodes" in data |
| 253 | assert "relationships" in data |
| 254 | |
| 255 | def test_save_db(self, tmp_path, kg_with_provider, mock_pm): |
| 256 | kg_with_provider.add_content("Alice uses Python", "source1") |
| 257 | path = tmp_path / "graph.db" |
| 258 | result = kg_with_provider.save(path) |
| 259 | |
| 260 | assert result == path |
| 261 | assert path.exists() |
| 262 | |
| 263 | def test_save_no_suffix_defaults_to_db(self, tmp_path, kg_with_provider, mock_pm): |
| 264 | kg_with_provider.add_content("Alice uses Python", "source1") |
| 265 | path = tmp_path / "graph" |
| 266 | result = kg_with_provider.save(path) |
| 267 | assert result.suffix == ".db" |
| 268 | assert result.exists() |
| 269 | |
| 270 | def test_save_creates_parent_dirs(self, tmp_path, kg_with_provider, mock_pm): |
| 271 | kg_with_provider.add_content("Alice uses Python", "source1") |
| 272 | path = tmp_path / "nested" / "dir" / "graph.json" |
| 273 | result = kg_with_provider.save(path) |
| 274 | assert result.exists() |
| 275 | |
| 276 | def test_save_unknown_suffix_falls_back_to_json(self, tmp_path): |
| 277 | kg = KnowledgeGraph() |
| 278 | kg._store.merge_entity("TestNode", "concept", ["test"]) |
| 279 | path = tmp_path / "graph.xyz" |
| 280 | result = kg.save(path) |
| 281 | assert result.exists() |
| 282 | # Should be valid JSON |
| 283 | data = json.loads(path.read_text()) |
| 284 | assert "nodes" in data |
| 285 | |
| 286 | |
| 287 | class TestMerge: |
| 288 | def test_merge_disjoint(self): |
| 289 | kg1 = KnowledgeGraph() |
| 290 | kg1._store.merge_entity("Alice", "person", ["Developer"]) |
| 291 | |
| 292 | kg2 = KnowledgeGraph() |
| 293 | kg2._store.merge_entity("Bob", "person", ["Manager"]) |
| 294 | |
| 295 | kg1.merge(kg2) |
| 296 | assert kg1._store.has_entity("Alice") |
| 297 | assert kg1._store.has_entity("Bob") |
| 298 | assert kg1._store.get_entity_count() == 2 |
| 299 | |
| 300 | def test_merge_overlapping_entities_descriptions_merged(self): |
| 301 | kg1 = KnowledgeGraph() |
| 302 | kg1._store.merge_entity("Python", "concept", ["A language"]) |
| 303 | |
| 304 | kg2 = KnowledgeGraph() |
| 305 | kg2._store.merge_entity("Python", "technology", ["Programming language"]) |
| 306 | |
| 307 | kg1.merge(kg2) |
| 308 | entity = kg1._store.get_entity("Python") |
| 309 | # Descriptions from both should be present |
| 310 | descs = entity["descriptions"] |
| 311 | if isinstance(descs, set): |
| 312 | descs = list(descs) |
| 313 | assert "A language" in descs |
| 314 | assert "Programming language" in descs |
| 315 | |
| 316 | def test_merge_overlapping_entities_with_sqlite(self, tmp_path): |
| 317 | """SQLiteStore does update type on merge_entity, so type resolution works there.""" |
| 318 | kg1 = KnowledgeGraph(db_path=tmp_path / "kg1.db") |
| 319 | kg1._store.merge_entity("Python", "concept", ["A language"]) |
| 320 | |
| 321 | kg2 = KnowledgeGraph(db_path=tmp_path / "kg2.db") |
| 322 | kg2._store.merge_entity("Python", "technology", ["Programming language"]) |
| 323 | |
| 324 | kg1.merge(kg2) |
| 325 | entity = kg1._store.get_entity("Python") |
| 326 | # SQLiteStore overwrites type — merge resolves to more specific |
| 327 | # (The merge method computes the resolved type and passes it to merge_entity, |
| 328 | # but InMemoryStore ignores type for existing entities while SQLiteStore does not) |
| 329 | assert entity is not None |
| 330 | assert kg1._store.get_entity_count() == 1 |
| 331 | |
| 332 | def test_merge_fuzzy_match(self): |
| 333 | kg1 = KnowledgeGraph() |
| 334 | kg1._store.merge_entity("JavaScript", "technology", ["A language"]) |
| 335 | |
| 336 | kg2 = KnowledgeGraph() |
| 337 | kg2._store.merge_entity("Javascript", "technology", ["Web language"]) |
| 338 | |
| 339 | kg1.merge(kg2) |
| 340 | # Should fuzzy-match and merge, not create two entities |
| 341 | assert kg1._store.get_entity_count() == 1 |
| 342 | entity = kg1._store.get_entity("JavaScript") |
| 343 | assert entity is not None |
| 344 | |
| 345 | def test_merge_relationships(self): |
| 346 | kg1 = KnowledgeGraph() |
| 347 | kg1._store.merge_entity("Alice", "person", []) |
| 348 | |
| 349 | kg2 = KnowledgeGraph() |
| 350 | kg2._store.merge_entity("Bob", "person", []) |
| 351 | kg2._store.add_relationship("Alice", "Bob", "collaborates_with") |
| 352 | |
| 353 | kg1.merge(kg2) |
| 354 | rels = kg1._store.get_all_relationships() |
| 355 | assert len(rels) == 1 |
| 356 | assert rels[0]["type"] == "collaborates_with" |
| 357 | |
| 358 | def test_merge_sources(self): |
| 359 | kg1 = KnowledgeGraph() |
| 360 | kg2 = KnowledgeGraph() |
| 361 | kg2.register_source( |
| 362 | { |
| 363 | "source_id": "vid2", |
| 364 | "source_type": "video", |
| 365 | "title": "Video 2", |
| 366 | "ingested_at": "2025-01-01T00:00:00", |
| 367 | } |
| 368 | ) |
| 369 | kg1.merge(kg2) |
| 370 | sources = kg1._store.get_sources() |
| 371 | assert len(sources) == 1 |
| 372 | assert sources[0]["source_id"] == "vid2" |
| 373 | |
| 374 | def test_merge_type_specificity_with_sqlite(self, tmp_path): |
| 375 | """Type specificity resolution works with SQLiteStore which updates type.""" |
| 376 | kg1 = KnowledgeGraph(db_path=tmp_path / "kg1.db") |
| 377 | kg1._store.merge_entity("React", "concept", []) |
| 378 | |
| 379 | kg2 = KnowledgeGraph(db_path=tmp_path / "kg2.db") |
| 380 | kg2._store.merge_entity("React", "technology", []) |
| 381 | |
| 382 | kg1.merge(kg2) |
| 383 | entity = kg1._store.get_entity("React") |
| 384 | assert entity is not None |
| 385 | assert kg1._store.get_entity_count() == 1 |
| 386 | |
| 387 | |
| 388 | class TestRegisterSource: |
| 389 | def test_register_and_retrieve(self): |
| 390 | kg = KnowledgeGraph() |
| 391 | source = { |
| 392 | "source_id": "src123", |
| 393 | "source_type": "video", |
| 394 | "title": "Meeting Recording", |
| 395 | "path": "/tmp/meeting.mp4", |
| 396 | "ingested_at": "2025-06-01T10:00:00", |
| 397 | } |
| 398 | kg.register_source(source) |
| 399 | sources = kg._store.get_sources() |
| 400 | assert len(sources) == 1 |
| 401 | assert sources[0]["source_id"] == "src123" |
| 402 | assert sources[0]["title"] == "Meeting Recording" |
| 403 | |
| 404 | def test_register_multiple_sources(self): |
| 405 | kg = KnowledgeGraph() |
| 406 | for i in range(3): |
| 407 | kg.register_source( |
| 408 | { |
| 409 | "source_id": f"src{i}", |
| 410 | "source_type": "video", |
| 411 | "title": f"Video {i}", |
| 412 | "ingested_at": "2025-01-01", |
| 413 | } |
| 414 | ) |
| 415 | assert len(kg._store.get_sources()) == 3 |
| 416 | |
| 417 | |
| 418 | class TestClassifyForPlanning: |
| 419 | @patch("video_processor.integrators.knowledge_graph.TaxonomyClassifier", create=True) |
| 420 | def test_classify_calls_taxonomy(self, mock_cls): |
| 421 | """classify_for_planning should delegate to TaxonomyClassifier.""" |
| 422 | mock_ |
+433
-2
| --- tests/test_pipeline.py | ||
| +++ tests/test_pipeline.py | ||
| @@ -1,11 +1,19 @@ | ||
| 1 | 1 | """Tests for the core video processing pipeline.""" |
| 2 | 2 | |
| 3 | 3 | import json |
| 4 | -from unittest.mock import MagicMock | |
| 4 | +from pathlib import Path | |
| 5 | +from unittest.mock import MagicMock, patch | |
| 5 | 6 | |
| 6 | -from video_processor.pipeline import _extract_action_items, _extract_key_points, _format_srt_time | |
| 7 | +import pytest | |
| 8 | + | |
| 9 | +from video_processor.pipeline import ( | |
| 10 | + _extract_action_items, | |
| 11 | + _extract_key_points, | |
| 12 | + _format_srt_time, | |
| 13 | + process_single_video, | |
| 14 | +) | |
| 7 | 15 | |
| 8 | 16 | |
| 9 | 17 | class TestFormatSrtTime: |
| 10 | 18 | def test_zero(self): |
| 11 | 19 | assert _format_srt_time(0) == "00:00:00,000" |
| @@ -99,5 +107,428 @@ | ||
| 99 | 107 | def test_handles_error(self): |
| 100 | 108 | pm = MagicMock() |
| 101 | 109 | pm.chat.side_effect = Exception("API down") |
| 102 | 110 | result = _extract_action_items(pm, "text") |
| 103 | 111 | assert result == [] |
| 112 | + | |
| 113 | + | |
| 114 | +# --------------------------------------------------------------------------- | |
| 115 | +# process_single_video tests (heavily mocked) | |
| 116 | +# --------------------------------------------------------------------------- | |
| 117 | + | |
| 118 | + | |
| 119 | +def _make_mock_pm(): | |
| 120 | + """Build a mock ProviderManager with usage tracker and predictable responses.""" | |
| 121 | + pm = MagicMock() | |
| 122 | + | |
| 123 | + # Usage tracker stub | |
| 124 | + pm.usage = MagicMock() | |
| 125 | + pm.usage.start_step = MagicMock() | |
| 126 | + pm.usage.end_step = MagicMock() | |
| 127 | + | |
| 128 | + # transcribe_audio returns a simple transcript | |
| 129 | + pm.transcribe_audio.return_value = { | |
| 130 | + "text": "Alice discussed the Python deployment strategy with Bob.", | |
| 131 | + "segments": [ | |
| 132 | + {"start": 0.0, "end": 5.0, "text": "Alice discussed the Python deployment strategy."}, | |
| 133 | + {"start": 5.0, "end": 10.0, "text": "Bob agreed on the timeline."}, | |
| 134 | + ], | |
| 135 | + "duration": 10.0, | |
| 136 | + "language": "en", | |
| 137 | + "provider": "mock", | |
| 138 | + "model": "mock-whisper", | |
| 139 | + } | |
| 140 | + | |
| 141 | + # chat returns predictable JSON depending on the call | |
| 142 | + def _chat_side_effect(messages, **kwargs): | |
| 143 | + content = messages[0]["content"] if messages else "" | |
| 144 | + if "key points" in content.lower(): | |
| 145 | + return json.dumps( | |
| 146 | + [{"point": "Deployment strategy discussed", "topic": "DevOps", "details": "Python"}] | |
| 147 | + ) | |
| 148 | + if "action items" in content.lower(): | |
| 149 | + return json.dumps( | |
| 150 | + [{"action": "Deploy to production", "assignee": "Bob", "priority": "high"}] | |
| 151 | + ) | |
| 152 | + # Default: entity extraction for knowledge graph | |
| 153 | + return json.dumps( | |
| 154 | + { | |
| 155 | + "entities": [ | |
| 156 | + {"name": "Python", "type": "technology", "description": "Programming language"}, | |
| 157 | + {"name": "Alice", "type": "person", "description": "Engineer"}, | |
| 158 | + ], | |
| 159 | + "relationships": [ | |
| 160 | + {"source": "Alice", "target": "Python", "type": "uses"}, | |
| 161 | + ], | |
| 162 | + } | |
| 163 | + ) | |
| 164 | + | |
| 165 | + pm.chat.side_effect = _chat_side_effect | |
| 166 | + pm.get_models_used.return_value = {"chat": "mock-gpt", "transcription": "mock-whisper"} | |
| 167 | + return pm | |
| 168 | + | |
| 169 | + | |
| 170 | +def _make_tqdm_passthrough(mock_tqdm): | |
| 171 | + """Configure mock tqdm to pass through iterables while supporting .set_description() etc.""" | |
| 172 | + | |
| 173 | + def _tqdm_side_effect(iterable, **kw): | |
| 174 | + wrapper = MagicMock() | |
| 175 | + wrapper.__iter__ = lambda self: iter(iterable) | |
| 176 | + return wrapper | |
| 177 | + | |
| 178 | + mock_tqdm.side_effect = _tqdm_side_effect | |
| 179 | + | |
| 180 | + | |
| 181 | +def _create_fake_video(path: Path) -> Path: | |
| 182 | + """Create a tiny file that stands in for a video (all extractors are mocked).""" | |
| 183 | + path.parent.mkdir(parents=True, exist_ok=True) | |
| 184 | + path.write_bytes(b"\x00" * 64) | |
| 185 | + return path | |
| 186 | + | |
| 187 | + | |
| 188 | +class TestProcessSingleVideo: | |
| 189 | + """Integration-level tests for process_single_video with heavy mocking.""" | |
| 190 | + | |
| 191 | + @pytest.fixture | |
| 192 | + def setup(self, tmp_path): | |
| 193 | + """Create fake video, output dir, and mock PM.""" | |
| 194 | + video_path = _create_fake_video(tmp_path / "input" / "meeting.mp4") | |
| 195 | + output_dir = tmp_path / "output" | |
| 196 | + pm = _make_mock_pm() | |
| 197 | + return video_path, output_dir, pm | |
| 198 | + | |
| 199 | + @patch("video_processor.pipeline.export_all_formats") | |
| 200 | + @patch("video_processor.pipeline.PlanGenerator") | |
| 201 | + @patch("video_processor.pipeline.DiagramAnalyzer") | |
| 202 | + @patch("video_processor.pipeline.AudioExtractor") | |
| 203 | + @patch("video_processor.pipeline.filter_people_frames") | |
| 204 | + @patch("video_processor.pipeline.save_frames") | |
| 205 | + @patch("video_processor.pipeline.extract_frames") | |
| 206 | + @patch("video_processor.pipeline.tqdm") | |
| 207 | + def test_returns_manifest( | |
| 208 | + self, | |
| 209 | + mock_tqdm, | |
| 210 | + mock_extract_frames, | |
| 211 | + mock_save_frames, | |
| 212 | + mock_filter_people, | |
| 213 | + mock_audio_extractor_cls, | |
| 214 | + mock_diagram_analyzer_cls, | |
| 215 | + mock_plan_gen_cls, | |
| 216 | + mock_export, | |
| 217 | + setup, | |
| 218 | + ): | |
| 219 | + video_path, output_dir, pm = setup | |
| 220 | + | |
| 221 | + # tqdm pass-through | |
| 222 | + _make_tqdm_passthrough(mock_tqdm) | |
| 223 | + | |
| 224 | + # Frame extraction mocks | |
| 225 | + mock_extract_frames.return_value = [b"fake_frame_1", b"fake_frame_2"] | |
| 226 | + mock_filter_people.return_value = ([b"fake_frame_1", b"fake_frame_2"], 0) | |
| 227 | + | |
| 228 | + frames_dir = output_dir / "frames" | |
| 229 | + frames_dir.mkdir(parents=True, exist_ok=True) | |
| 230 | + frame_paths = [] | |
| 231 | + for i in range(2): | |
| 232 | + fp = frames_dir / f"frame_{i:04d}.jpg" | |
| 233 | + fp.write_bytes(b"\xff") | |
| 234 | + frame_paths.append(fp) | |
| 235 | + mock_save_frames.return_value = frame_paths | |
| 236 | + | |
| 237 | + # Audio extractor mock | |
| 238 | + audio_ext = MagicMock() | |
| 239 | + audio_ext.extract_audio.return_value = output_dir / "audio" / "meeting.wav" | |
| 240 | + audio_ext.get_audio_properties.return_value = {"duration": 10.0} | |
| 241 | + mock_audio_extractor_cls.return_value = audio_ext | |
| 242 | + | |
| 243 | + # Diagram analyzer mock | |
| 244 | + diag_analyzer = MagicMock() | |
| 245 | + diag_analyzer.process_frames.return_value = ([], []) | |
| 246 | + mock_diagram_analyzer_cls.return_value = diag_analyzer | |
| 247 | + | |
| 248 | + # Plan generator mock | |
| 249 | + plan_gen = MagicMock() | |
| 250 | + mock_plan_gen_cls.return_value = plan_gen | |
| 251 | + | |
| 252 | + # export_all_formats returns the manifest it receives | |
| 253 | + mock_export.side_effect = lambda out_dir, manifest: manifest | |
| 254 | + | |
| 255 | + manifest = process_single_video( | |
| 256 | + input_path=video_path, | |
| 257 | + output_dir=output_dir, | |
| 258 | + provider_manager=pm, | |
| 259 | + depth="standard", | |
| 260 | + ) | |
| 261 | + | |
| 262 | + from video_processor.models import VideoManifest | |
| 263 | + | |
| 264 | + assert isinstance(manifest, VideoManifest) | |
| 265 | + assert manifest.video.title == "Analysis of meeting" | |
| 266 | + assert manifest.stats.frames_extracted == 2 | |
| 267 | + assert manifest.transcript_json == "transcript/transcript.json" | |
| 268 | + assert manifest.knowledge_graph_json == "results/knowledge_graph.json" | |
| 269 | + | |
| 270 | + @patch("video_processor.pipeline.export_all_formats") | |
| 271 | + @patch("video_processor.pipeline.PlanGenerator") | |
| 272 | + @patch("video_processor.pipeline.DiagramAnalyzer") | |
| 273 | + @patch("video_processor.pipeline.AudioExtractor") | |
| 274 | + @patch("video_processor.pipeline.filter_people_frames") | |
| 275 | + @patch("video_processor.pipeline.save_frames") | |
| 276 | + @patch("video_processor.pipeline.extract_frames") | |
| 277 | + @patch("video_processor.pipeline.tqdm") | |
| 278 | + def test_creates_output_directories( | |
| 279 | + self, | |
| 280 | + mock_tqdm, | |
| 281 | + mock_extract_frames, | |
| 282 | + mock_save_frames, | |
| 283 | + mock_filter_people, | |
| 284 | + mock_audio_extractor_cls, | |
| 285 | + mock_diagram_analyzer_cls, | |
| 286 | + mock_plan_gen_cls, | |
| 287 | + mock_export, | |
| 288 | + setup, | |
| 289 | + ): | |
| 290 | + video_path, output_dir, pm = setup | |
| 291 | + | |
| 292 | + _make_tqdm_passthrough(mock_tqdm) | |
| 293 | + mock_extract_frames.return_value = [] | |
| 294 | + mock_filter_people.return_value = ([], 0) | |
| 295 | + mock_save_frames.return_value = [] | |
| 296 | + | |
| 297 | + audio_ext = MagicMock() | |
| 298 | + audio_ext.extract_audio.return_value = output_dir / "audio" / "meeting.wav" | |
| 299 | + audio_ext.get_audio_properties.return_value = {"duration": 5.0} | |
| 300 | + mock_audio_extractor_cls.return_value = audio_ext | |
| 301 | + | |
| 302 | + diag_analyzer = MagicMock() | |
| 303 | + diag_analyzer.process_frames.return_value = ([], []) | |
| 304 | + mock_diagram_analyzer_cls.return_value = diag_analyzer | |
| 305 | + | |
| 306 | + plan_gen = MagicMock() | |
| 307 | + mock_plan_gen_cls.return_value = plan_gen | |
| 308 | + | |
| 309 | + mock_export.side_effect = lambda out_dir, manifest: manifest | |
| 310 | + | |
| 311 | + process_single_video( | |
| 312 | + input_path=video_path, | |
| 313 | + output_dir=output_dir, | |
| 314 | + provider_manager=pm, | |
| 315 | + ) | |
| 316 | + | |
| 317 | + # Verify standard output directories were created | |
| 318 | + assert (output_dir / "transcript").is_dir() | |
| 319 | + assert (output_dir / "frames").is_dir() | |
| 320 | + assert (output_dir / "results").is_dir() | |
| 321 | + | |
| 322 | + @patch("video_processor.pipeline.export_all_formats") | |
| 323 | + @patch("video_processor.pipeline.PlanGenerator") | |
| 324 | + @patch("video_processor.pipeline.DiagramAnalyzer") | |
| 325 | + @patch("video_processor.pipeline.AudioExtractor") | |
| 326 | + @patch("video_processor.pipeline.filter_people_frames") | |
| 327 | + @patch("video_processor.pipeline.save_frames") | |
| 328 | + @patch("video_processor.pipeline.extract_frames") | |
| 329 | + @patch("video_processor.pipeline.tqdm") | |
| 330 | + def test_resume_existing_frames( | |
| 331 | + self, | |
| 332 | + mock_tqdm, | |
| 333 | + mock_extract_frames, | |
| 334 | + mock_save_frames, | |
| 335 | + mock_filter_people, | |
| 336 | + mock_audio_extractor_cls, | |
| 337 | + mock_diagram_analyzer_cls, | |
| 338 | + mock_plan_gen_cls, | |
| 339 | + mock_export, | |
| 340 | + setup, | |
| 341 | + ): | |
| 342 | + """When frames already exist on disk, extraction should be skipped.""" | |
| 343 | + video_path, output_dir, pm = setup | |
| 344 | + | |
| 345 | + _make_tqdm_passthrough(mock_tqdm) | |
| 346 | + | |
| 347 | + # Pre-create frames directory with existing frames | |
| 348 | + frames_dir = output_dir / "frames" | |
| 349 | + frames_dir.mkdir(parents=True, exist_ok=True) | |
| 350 | + for i in range(3): | |
| 351 | + (frames_dir / f"frame_{i:04d}.jpg").write_bytes(b"\xff") | |
| 352 | + | |
| 353 | + audio_ext = MagicMock() | |
| 354 | + audio_ext.extract_audio.return_value = output_dir / "audio" / "meeting.wav" | |
| 355 | + audio_ext.get_audio_properties.return_value = {"duration": 10.0} | |
| 356 | + mock_audio_extractor_cls.return_value = audio_ext | |
| 357 | + | |
| 358 | + diag_analyzer = MagicMock() | |
| 359 | + diag_analyzer.process_frames.return_value = ([], []) | |
| 360 | + mock_diagram_analyzer_cls.return_value = diag_analyzer | |
| 361 | + | |
| 362 | + plan_gen = MagicMock() | |
| 363 | + mock_plan_gen_cls.return_value = plan_gen | |
| 364 | + mock_export.side_effect = lambda out_dir, manifest: manifest | |
| 365 | + | |
| 366 | + manifest = process_single_video( | |
| 367 | + input_path=video_path, | |
| 368 | + output_dir=output_dir, | |
| 369 | + provider_manager=pm, | |
| 370 | + ) | |
| 371 | + | |
| 372 | + # extract_frames should NOT have been called (resume path) | |
| 373 | + mock_extract_frames.assert_not_called() | |
| 374 | + assert manifest.stats.frames_extracted == 3 | |
| 375 | + | |
| 376 | + @patch("video_processor.pipeline.export_all_formats") | |
| 377 | + @patch("video_processor.pipeline.PlanGenerator") | |
| 378 | + @patch("video_processor.pipeline.DiagramAnalyzer") | |
| 379 | + @patch("video_processor.pipeline.AudioExtractor") | |
| 380 | + @patch("video_processor.pipeline.filter_people_frames") | |
| 381 | + @patch("video_processor.pipeline.save_frames") | |
| 382 | + @patch("video_processor.pipeline.extract_frames") | |
| 383 | + @patch("video_processor.pipeline.tqdm") | |
| 384 | + def test_resume_existing_transcript( | |
| 385 | + self, | |
| 386 | + mock_tqdm, | |
| 387 | + mock_extract_frames, | |
| 388 | + mock_save_frames, | |
| 389 | + mock_filter_people, | |
| 390 | + mock_audio_extractor_cls, | |
| 391 | + mock_diagram_analyzer_cls, | |
| 392 | + mock_plan_gen_cls, | |
| 393 | + mock_export, | |
| 394 | + setup, | |
| 395 | + ): | |
| 396 | + """When transcript exists on disk, transcription should be skipped.""" | |
| 397 | + video_path, output_dir, pm = setup | |
| 398 | + | |
| 399 | + _make_tqdm_passthrough(mock_tqdm) | |
| 400 | + mock_extract_frames.return_value = [] | |
| 401 | + mock_filter_people.return_value = ([], 0) | |
| 402 | + mock_save_frames.return_value = [] | |
| 403 | + | |
| 404 | + audio_ext = MagicMock() | |
| 405 | + audio_ext.extract_audio.return_value = output_dir / "audio" / "meeting.wav" | |
| 406 | + audio_ext.get_audio_properties.return_value = {"duration": 10.0} | |
| 407 | + mock_audio_extractor_cls.return_value = audio_ext | |
| 408 | + | |
| 409 | + # Pre-create transcript file | |
| 410 | + transcript_dir = output_dir / "transcript" | |
| 411 | + transcript_dir.mkdir(parents=True, exist_ok=True) | |
| 412 | + transcript_data = { | |
| 413 | + "text": "Pre-existing transcript text.", | |
| 414 | + "segments": [{"start": 0.0, "end": 5.0, "text": "Pre-existing transcript text."}], | |
| 415 | + "duration": 5.0, | |
| 416 | + } | |
| 417 | + (transcript_dir / "transcript.json").write_text(json.dumps(transcript_data)) | |
| 418 | + | |
| 419 | + diag_analyzer = MagicMock() | |
| 420 | + diag_analyzer.process_frames.return_value = ([], []) | |
| 421 | + mock_diagram_analyzer_cls.return_value = diag_analyzer | |
| 422 | + | |
| 423 | + plan_gen = MagicMock() | |
| 424 | + mock_plan_gen_cls.return_value = plan_gen | |
| 425 | + mock_export.side_effect = lambda out_dir, manifest: manifest | |
| 426 | + | |
| 427 | + process_single_video( | |
| 428 | + input_path=video_path, | |
| 429 | + output_dir=output_dir, | |
| 430 | + provider_manager=pm, | |
| 431 | + ) | |
| 432 | + | |
| 433 | + # transcribe_audio should NOT have been called (resume path) | |
| 434 | + pm.transcribe_audio.assert_not_called() | |
| 435 | + | |
| 436 | + @patch("video_processor.pipeline.export_all_formats") | |
| 437 | + @patch("video_processor.pipeline.PlanGenerator") | |
| 438 | + @patch("video_processor.pipeline.DiagramAnalyzer") | |
| 439 | + @patch("video_processor.pipeline.AudioExtractor") | |
| 440 | + @patch("video_processor.pipeline.filter_people_frames") | |
| 441 | + @patch("video_processor.pipeline.save_frames") | |
| 442 | + @patch("video_processor.pipeline.extract_frames") | |
| 443 | + @patch("video_processor.pipeline.tqdm") | |
| 444 | + def test_custom_title( | |
| 445 | + self, | |
| 446 | + mock_tqdm, | |
| 447 | + mock_extract_frames, | |
| 448 | + mock_save_frames, | |
| 449 | + mock_filter_people, | |
| 450 | + mock_audio_extractor_cls, | |
| 451 | + mock_diagram_analyzer_cls, | |
| 452 | + mock_plan_gen_cls, | |
| 453 | + mock_export, | |
| 454 | + setup, | |
| 455 | + ): | |
| 456 | + video_path, output_dir, pm = setup | |
| 457 | + | |
| 458 | + _make_tqdm_passthrough(mock_tqdm) | |
| 459 | + mock_extract_frames.return_value = [] | |
| 460 | + mock_filter_people.return_value = ([], 0) | |
| 461 | + mock_save_frames.return_value = [] | |
| 462 | + | |
| 463 | + audio_ext = MagicMock() | |
| 464 | + audio_ext.extract_audio.return_value = output_dir / "audio" / "meeting.wav" | |
| 465 | + audio_ext.get_audio_properties.return_value = {"duration": 5.0} | |
| 466 | + mock_audio_extractor_cls.return_value = audio_ext | |
| 467 | + | |
| 468 | + diag_analyzer = MagicMock() | |
| 469 | + diag_analyzer.process_frames.return_value = ([], []) | |
| 470 | + mock_diagram_analyzer_cls.return_value = diag_analyzer | |
| 471 | + | |
| 472 | + plan_gen = MagicMock() | |
| 473 | + mock_plan_gen_cls.return_value = plan_gen | |
| 474 | + mock_export.side_effect = lambda out_dir, manifest: manifest | |
| 475 | + | |
| 476 | + manifest = process_single_video( | |
| 477 | + input_path=video_path, | |
| 478 | + output_dir=output_dir, | |
| 479 | + provider_manager=pm, | |
| 480 | + title="My Custom Title", | |
| 481 | + ) | |
| 482 | + | |
| 483 | + assert manifest.video.title == "My Custom Title" | |
| 484 | + | |
| 485 | + @patch("video_processor.pipeline.export_all_formats") | |
| 486 | + @patch("video_processor.pipeline.PlanGenerator") | |
| 487 | + @patch("video_processor.pipeline.DiagramAnalyzer") | |
| 488 | + @patch("video_processor.pipeline.AudioExtractor") | |
| 489 | + @patch("video_processor.pipeline.filter_people_frames") | |
| 490 | + @patch("video_processor.pipeline.save_frames") | |
| 491 | + @patch("video_processor.pipeline.extract_frames") | |
| 492 | + @patch("video_processor.pipeline.tqdm") | |
| 493 | + def test_key_points_and_action_items_extracted( | |
| 494 | + self, | |
| 495 | + mock_tqdm, | |
| 496 | + mock_extract_frames, | |
| 497 | + mock_save_frames, | |
| 498 | + mock_filter_people, | |
| 499 | + mock_audio_extractor_cls, | |
| 500 | + mock_diagram_analyzer_cls, | |
| 501 | + mock_plan_gen_cls, | |
| 502 | + mock_export, | |
| 503 | + setup, | |
| 504 | + ): | |
| 505 | + video_path, output_dir, pm = setup | |
| 506 | + | |
| 507 | + _make_tqdm_passthrough(mock_tqdm) | |
| 508 | + mock_extract_frames.return_value = [] | |
| 509 | + mock_filter_people.return_value = ([], 0) | |
| 510 | + mock_save_frames.return_value = [] | |
| 511 | + | |
| 512 | + audio_ext = MagicMock() | |
| 513 | + audio_ext.extract_audio.return_value = output_dir / "audio" / "meeting.wav" | |
| 514 | + audio_ext.get_audio_properties.return_value = {"duration": 10.0} | |
| 515 | + mock_audio_extractor_cls.return_value = audio_ext | |
| 516 | + | |
| 517 | + diag_analyzer = MagicMock() | |
| 518 | + diag_analyzer.process_frames.return_value = ([], []) | |
| 519 | + mock_diagram_analyzer_cls.return_value = diag_analyzer | |
| 520 | + | |
| 521 | + plan_gen = MagicMock() | |
| 522 | + mock_plan_gen_cls.return_value = plan_gen | |
| 523 | + mock_export.side_effect = lambda out_dir, manifest: manifest | |
| 524 | + | |
| 525 | + manifest = process_single_video( | |
| 526 | + input_path=video_path, | |
| 527 | + output_dir=output_dir, | |
| 528 | + provider_manager=pm, | |
| 529 | + ) | |
| 530 | + | |
| 531 | + assert len(manifest.key_points) == 1 | |
| 532 | + assert manifest.key_points[0].point == "Deployment strategy discussed" | |
| 533 | + assert len(manifest.action_items) == 1 | |
| 534 | + assert manifest.action_items[0].action == "Deploy to production" | |
| 104 | 535 |
| --- tests/test_pipeline.py | |
| +++ tests/test_pipeline.py | |
| @@ -1,11 +1,19 @@ | |
| 1 | """Tests for the core video processing pipeline.""" |
| 2 | |
| 3 | import json |
| 4 | from unittest.mock import MagicMock |
| 5 | |
| 6 | from video_processor.pipeline import _extract_action_items, _extract_key_points, _format_srt_time |
| 7 | |
| 8 | |
| 9 | class TestFormatSrtTime: |
| 10 | def test_zero(self): |
| 11 | assert _format_srt_time(0) == "00:00:00,000" |
| @@ -99,5 +107,428 @@ | |
| 99 | def test_handles_error(self): |
| 100 | pm = MagicMock() |
| 101 | pm.chat.side_effect = Exception("API down") |
| 102 | result = _extract_action_items(pm, "text") |
| 103 | assert result == [] |
| 104 |
| --- tests/test_pipeline.py | |
| +++ tests/test_pipeline.py | |
| @@ -1,11 +1,19 @@ | |
| 1 | """Tests for the core video processing pipeline.""" |
| 2 | |
| 3 | import json |
| 4 | from pathlib import Path |
| 5 | from unittest.mock import MagicMock, patch |
| 6 | |
| 7 | import pytest |
| 8 | |
| 9 | from video_processor.pipeline import ( |
| 10 | _extract_action_items, |
| 11 | _extract_key_points, |
| 12 | _format_srt_time, |
| 13 | process_single_video, |
| 14 | ) |
| 15 | |
| 16 | |
| 17 | class TestFormatSrtTime: |
| 18 | def test_zero(self): |
| 19 | assert _format_srt_time(0) == "00:00:00,000" |
| @@ -99,5 +107,428 @@ | |
| 107 | def test_handles_error(self): |
| 108 | pm = MagicMock() |
| 109 | pm.chat.side_effect = Exception("API down") |
| 110 | result = _extract_action_items(pm, "text") |
| 111 | assert result == [] |
| 112 | |
| 113 | |
| 114 | # --------------------------------------------------------------------------- |
| 115 | # process_single_video tests (heavily mocked) |
| 116 | # --------------------------------------------------------------------------- |
| 117 | |
| 118 | |
| 119 | def _make_mock_pm(): |
| 120 | """Build a mock ProviderManager with usage tracker and predictable responses.""" |
| 121 | pm = MagicMock() |
| 122 | |
| 123 | # Usage tracker stub |
| 124 | pm.usage = MagicMock() |
| 125 | pm.usage.start_step = MagicMock() |
| 126 | pm.usage.end_step = MagicMock() |
| 127 | |
| 128 | # transcribe_audio returns a simple transcript |
| 129 | pm.transcribe_audio.return_value = { |
| 130 | "text": "Alice discussed the Python deployment strategy with Bob.", |
| 131 | "segments": [ |
| 132 | {"start": 0.0, "end": 5.0, "text": "Alice discussed the Python deployment strategy."}, |
| 133 | {"start": 5.0, "end": 10.0, "text": "Bob agreed on the timeline."}, |
| 134 | ], |
| 135 | "duration": 10.0, |
| 136 | "language": "en", |
| 137 | "provider": "mock", |
| 138 | "model": "mock-whisper", |
| 139 | } |
| 140 | |
| 141 | # chat returns predictable JSON depending on the call |
| 142 | def _chat_side_effect(messages, **kwargs): |
| 143 | content = messages[0]["content"] if messages else "" |
| 144 | if "key points" in content.lower(): |
| 145 | return json.dumps( |
| 146 | [{"point": "Deployment strategy discussed", "topic": "DevOps", "details": "Python"}] |
| 147 | ) |
| 148 | if "action items" in content.lower(): |
| 149 | return json.dumps( |
| 150 | [{"action": "Deploy to production", "assignee": "Bob", "priority": "high"}] |
| 151 | ) |
| 152 | # Default: entity extraction for knowledge graph |
| 153 | return json.dumps( |
| 154 | { |
| 155 | "entities": [ |
| 156 | {"name": "Python", "type": "technology", "description": "Programming language"}, |
| 157 | {"name": "Alice", "type": "person", "description": "Engineer"}, |
| 158 | ], |
| 159 | "relationships": [ |
| 160 | {"source": "Alice", "target": "Python", "type": "uses"}, |
| 161 | ], |
| 162 | } |
| 163 | ) |
| 164 | |
| 165 | pm.chat.side_effect = _chat_side_effect |
| 166 | pm.get_models_used.return_value = {"chat": "mock-gpt", "transcription": "mock-whisper"} |
| 167 | return pm |
| 168 | |
| 169 | |
| 170 | def _make_tqdm_passthrough(mock_tqdm): |
| 171 | """Configure mock tqdm to pass through iterables while supporting .set_description() etc.""" |
| 172 | |
| 173 | def _tqdm_side_effect(iterable, **kw): |
| 174 | wrapper = MagicMock() |
| 175 | wrapper.__iter__ = lambda self: iter(iterable) |
| 176 | return wrapper |
| 177 | |
| 178 | mock_tqdm.side_effect = _tqdm_side_effect |
| 179 | |
| 180 | |
| 181 | def _create_fake_video(path: Path) -> Path: |
| 182 | """Create a tiny file that stands in for a video (all extractors are mocked).""" |
| 183 | path.parent.mkdir(parents=True, exist_ok=True) |
| 184 | path.write_bytes(b"\x00" * 64) |
| 185 | return path |
| 186 | |
| 187 | |
| 188 | class TestProcessSingleVideo: |
| 189 | """Integration-level tests for process_single_video with heavy mocking.""" |
| 190 | |
| 191 | @pytest.fixture |
| 192 | def setup(self, tmp_path): |
| 193 | """Create fake video, output dir, and mock PM.""" |
| 194 | video_path = _create_fake_video(tmp_path / "input" / "meeting.mp4") |
| 195 | output_dir = tmp_path / "output" |
| 196 | pm = _make_mock_pm() |
| 197 | return video_path, output_dir, pm |
| 198 | |
| 199 | @patch("video_processor.pipeline.export_all_formats") |
| 200 | @patch("video_processor.pipeline.PlanGenerator") |
| 201 | @patch("video_processor.pipeline.DiagramAnalyzer") |
| 202 | @patch("video_processor.pipeline.AudioExtractor") |
| 203 | @patch("video_processor.pipeline.filter_people_frames") |
| 204 | @patch("video_processor.pipeline.save_frames") |
| 205 | @patch("video_processor.pipeline.extract_frames") |
| 206 | @patch("video_processor.pipeline.tqdm") |
| 207 | def test_returns_manifest( |
| 208 | self, |
| 209 | mock_tqdm, |
| 210 | mock_extract_frames, |
| 211 | mock_save_frames, |
| 212 | mock_filter_people, |
| 213 | mock_audio_extractor_cls, |
| 214 | mock_diagram_analyzer_cls, |
| 215 | mock_plan_gen_cls, |
| 216 | mock_export, |
| 217 | setup, |
| 218 | ): |
| 219 | video_path, output_dir, pm = setup |
| 220 | |
| 221 | # tqdm pass-through |
| 222 | _make_tqdm_passthrough(mock_tqdm) |
| 223 | |
| 224 | # Frame extraction mocks |
| 225 | mock_extract_frames.return_value = [b"fake_frame_1", b"fake_frame_2"] |
| 226 | mock_filter_people.return_value = ([b"fake_frame_1", b"fake_frame_2"], 0) |
| 227 | |
| 228 | frames_dir = output_dir / "frames" |
| 229 | frames_dir.mkdir(parents=True, exist_ok=True) |
| 230 | frame_paths = [] |
| 231 | for i in range(2): |
| 232 | fp = frames_dir / f"frame_{i:04d}.jpg" |
| 233 | fp.write_bytes(b"\xff") |
| 234 | frame_paths.append(fp) |
| 235 | mock_save_frames.return_value = frame_paths |
| 236 | |
| 237 | # Audio extractor mock |
| 238 | audio_ext = MagicMock() |
| 239 | audio_ext.extract_audio.return_value = output_dir / "audio" / "meeting.wav" |
| 240 | audio_ext.get_audio_properties.return_value = {"duration": 10.0} |
| 241 | mock_audio_extractor_cls.return_value = audio_ext |
| 242 | |
| 243 | # Diagram analyzer mock |
| 244 | diag_analyzer = MagicMock() |
| 245 | diag_analyzer.process_frames.return_value = ([], []) |
| 246 | mock_diagram_analyzer_cls.return_value = diag_analyzer |
| 247 | |
| 248 | # Plan generator mock |
| 249 | plan_gen = MagicMock() |
| 250 | mock_plan_gen_cls.return_value = plan_gen |
| 251 | |
| 252 | # export_all_formats returns the manifest it receives |
| 253 | mock_export.side_effect = lambda out_dir, manifest: manifest |
| 254 | |
| 255 | manifest = process_single_video( |
| 256 | input_path=video_path, |
| 257 | output_dir=output_dir, |
| 258 | provider_manager=pm, |
| 259 | depth="standard", |
| 260 | ) |
| 261 | |
| 262 | from video_processor.models import VideoManifest |
| 263 | |
| 264 | assert isinstance(manifest, VideoManifest) |
| 265 | assert manifest.video.title == "Analysis of meeting" |
| 266 | assert manifest.stats.frames_extracted == 2 |
| 267 | assert manifest.transcript_json == "transcript/transcript.json" |
| 268 | assert manifest.knowledge_graph_json == "results/knowledge_graph.json" |
| 269 | |
| 270 | @patch("video_processor.pipeline.export_all_formats") |
| 271 | @patch("video_processor.pipeline.PlanGenerator") |
| 272 | @patch("video_processor.pipeline.DiagramAnalyzer") |
| 273 | @patch("video_processor.pipeline.AudioExtractor") |
| 274 | @patch("video_processor.pipeline.filter_people_frames") |
| 275 | @patch("video_processor.pipeline.save_frames") |
| 276 | @patch("video_processor.pipeline.extract_frames") |
| 277 | @patch("video_processor.pipeline.tqdm") |
| 278 | def test_creates_output_directories( |
| 279 | self, |
| 280 | mock_tqdm, |
| 281 | mock_extract_frames, |
| 282 | mock_save_frames, |
| 283 | mock_filter_people, |
| 284 | mock_audio_extractor_cls, |
| 285 | mock_diagram_analyzer_cls, |
| 286 | mock_plan_gen_cls, |
| 287 | mock_export, |
| 288 | setup, |
| 289 | ): |
| 290 | video_path, output_dir, pm = setup |
| 291 | |
| 292 | _make_tqdm_passthrough(mock_tqdm) |
| 293 | mock_extract_frames.return_value = [] |
| 294 | mock_filter_people.return_value = ([], 0) |
| 295 | mock_save_frames.return_value = [] |
| 296 | |
| 297 | audio_ext = MagicMock() |
| 298 | audio_ext.extract_audio.return_value = output_dir / "audio" / "meeting.wav" |
| 299 | audio_ext.get_audio_properties.return_value = {"duration": 5.0} |
| 300 | mock_audio_extractor_cls.return_value = audio_ext |
| 301 | |
| 302 | diag_analyzer = MagicMock() |
| 303 | diag_analyzer.process_frames.return_value = ([], []) |
| 304 | mock_diagram_analyzer_cls.return_value = diag_analyzer |
| 305 | |
| 306 | plan_gen = MagicMock() |
| 307 | mock_plan_gen_cls.return_value = plan_gen |
| 308 | |
| 309 | mock_export.side_effect = lambda out_dir, manifest: manifest |
| 310 | |
| 311 | process_single_video( |
| 312 | input_path=video_path, |
| 313 | output_dir=output_dir, |
| 314 | provider_manager=pm, |
| 315 | ) |
| 316 | |
| 317 | # Verify standard output directories were created |
| 318 | assert (output_dir / "transcript").is_dir() |
| 319 | assert (output_dir / "frames").is_dir() |
| 320 | assert (output_dir / "results").is_dir() |
| 321 | |
| 322 | @patch("video_processor.pipeline.export_all_formats") |
| 323 | @patch("video_processor.pipeline.PlanGenerator") |
| 324 | @patch("video_processor.pipeline.DiagramAnalyzer") |
| 325 | @patch("video_processor.pipeline.AudioExtractor") |
| 326 | @patch("video_processor.pipeline.filter_people_frames") |
| 327 | @patch("video_processor.pipeline.save_frames") |
| 328 | @patch("video_processor.pipeline.extract_frames") |
| 329 | @patch("video_processor.pipeline.tqdm") |
| 330 | def test_resume_existing_frames( |
| 331 | self, |
| 332 | mock_tqdm, |
| 333 | mock_extract_frames, |
| 334 | mock_save_frames, |
| 335 | mock_filter_people, |
| 336 | mock_audio_extractor_cls, |
| 337 | mock_diagram_analyzer_cls, |
| 338 | mock_plan_gen_cls, |
| 339 | mock_export, |
| 340 | setup, |
| 341 | ): |
| 342 | """When frames already exist on disk, extraction should be skipped.""" |
| 343 | video_path, output_dir, pm = setup |
| 344 | |
| 345 | _make_tqdm_passthrough(mock_tqdm) |
| 346 | |
| 347 | # Pre-create frames directory with existing frames |
| 348 | frames_dir = output_dir / "frames" |
| 349 | frames_dir.mkdir(parents=True, exist_ok=True) |
| 350 | for i in range(3): |
| 351 | (frames_dir / f"frame_{i:04d}.jpg").write_bytes(b"\xff") |
| 352 | |
| 353 | audio_ext = MagicMock() |
| 354 | audio_ext.extract_audio.return_value = output_dir / "audio" / "meeting.wav" |
| 355 | audio_ext.get_audio_properties.return_value = {"duration": 10.0} |
| 356 | mock_audio_extractor_cls.return_value = audio_ext |
| 357 | |
| 358 | diag_analyzer = MagicMock() |
| 359 | diag_analyzer.process_frames.return_value = ([], []) |
| 360 | mock_diagram_analyzer_cls.return_value = diag_analyzer |
| 361 | |
| 362 | plan_gen = MagicMock() |
| 363 | mock_plan_gen_cls.return_value = plan_gen |
| 364 | mock_export.side_effect = lambda out_dir, manifest: manifest |
| 365 | |
| 366 | manifest = process_single_video( |
| 367 | input_path=video_path, |
| 368 | output_dir=output_dir, |
| 369 | provider_manager=pm, |
| 370 | ) |
| 371 | |
| 372 | # extract_frames should NOT have been called (resume path) |
| 373 | mock_extract_frames.assert_not_called() |
| 374 | assert manifest.stats.frames_extracted == 3 |
| 375 | |
| 376 | @patch("video_processor.pipeline.export_all_formats") |
| 377 | @patch("video_processor.pipeline.PlanGenerator") |
| 378 | @patch("video_processor.pipeline.DiagramAnalyzer") |
| 379 | @patch("video_processor.pipeline.AudioExtractor") |
| 380 | @patch("video_processor.pipeline.filter_people_frames") |
| 381 | @patch("video_processor.pipeline.save_frames") |
| 382 | @patch("video_processor.pipeline.extract_frames") |
| 383 | @patch("video_processor.pipeline.tqdm") |
| 384 | def test_resume_existing_transcript( |
| 385 | self, |
| 386 | mock_tqdm, |
| 387 | mock_extract_frames, |
| 388 | mock_save_frames, |
| 389 | mock_filter_people, |
| 390 | mock_audio_extractor_cls, |
| 391 | mock_diagram_analyzer_cls, |
| 392 | mock_plan_gen_cls, |
| 393 | mock_export, |
| 394 | setup, |
| 395 | ): |
| 396 | """When transcript exists on disk, transcription should be skipped.""" |
| 397 | video_path, output_dir, pm = setup |
| 398 | |
| 399 | _make_tqdm_passthrough(mock_tqdm) |
| 400 | mock_extract_frames.return_value = [] |
| 401 | mock_filter_people.return_value = ([], 0) |
| 402 | mock_save_frames.return_value = [] |
| 403 | |
| 404 | audio_ext = MagicMock() |
| 405 | audio_ext.extract_audio.return_value = output_dir / "audio" / "meeting.wav" |
| 406 | audio_ext.get_audio_properties.return_value = {"duration": 10.0} |
| 407 | mock_audio_extractor_cls.return_value = audio_ext |
| 408 | |
| 409 | # Pre-create transcript file |
| 410 | transcript_dir = output_dir / "transcript" |
| 411 | transcript_dir.mkdir(parents=True, exist_ok=True) |
| 412 | transcript_data = { |
| 413 | "text": "Pre-existing transcript text.", |
| 414 | "segments": [{"start": 0.0, "end": 5.0, "text": "Pre-existing transcript text."}], |
| 415 | "duration": 5.0, |
| 416 | } |
| 417 | (transcript_dir / "transcript.json").write_text(json.dumps(transcript_data)) |
| 418 | |
| 419 | diag_analyzer = MagicMock() |
| 420 | diag_analyzer.process_frames.return_value = ([], []) |
| 421 | mock_diagram_analyzer_cls.return_value = diag_analyzer |
| 422 | |
| 423 | plan_gen = MagicMock() |
| 424 | mock_plan_gen_cls.return_value = plan_gen |
| 425 | mock_export.side_effect = lambda out_dir, manifest: manifest |
| 426 | |
| 427 | process_single_video( |
| 428 | input_path=video_path, |
| 429 | output_dir=output_dir, |
| 430 | provider_manager=pm, |
| 431 | ) |
| 432 | |
| 433 | # transcribe_audio should NOT have been called (resume path) |
| 434 | pm.transcribe_audio.assert_not_called() |
| 435 | |
| 436 | @patch("video_processor.pipeline.export_all_formats") |
| 437 | @patch("video_processor.pipeline.PlanGenerator") |
| 438 | @patch("video_processor.pipeline.DiagramAnalyzer") |
| 439 | @patch("video_processor.pipeline.AudioExtractor") |
| 440 | @patch("video_processor.pipeline.filter_people_frames") |
| 441 | @patch("video_processor.pipeline.save_frames") |
| 442 | @patch("video_processor.pipeline.extract_frames") |
| 443 | @patch("video_processor.pipeline.tqdm") |
| 444 | def test_custom_title( |
| 445 | self, |
| 446 | mock_tqdm, |
| 447 | mock_extract_frames, |
| 448 | mock_save_frames, |
| 449 | mock_filter_people, |
| 450 | mock_audio_extractor_cls, |
| 451 | mock_diagram_analyzer_cls, |
| 452 | mock_plan_gen_cls, |
| 453 | mock_export, |
| 454 | setup, |
| 455 | ): |
| 456 | video_path, output_dir, pm = setup |
| 457 | |
| 458 | _make_tqdm_passthrough(mock_tqdm) |
| 459 | mock_extract_frames.return_value = [] |
| 460 | mock_filter_people.return_value = ([], 0) |
| 461 | mock_save_frames.return_value = [] |
| 462 | |
| 463 | audio_ext = MagicMock() |
| 464 | audio_ext.extract_audio.return_value = output_dir / "audio" / "meeting.wav" |
| 465 | audio_ext.get_audio_properties.return_value = {"duration": 5.0} |
| 466 | mock_audio_extractor_cls.return_value = audio_ext |
| 467 | |
| 468 | diag_analyzer = MagicMock() |
| 469 | diag_analyzer.process_frames.return_value = ([], []) |
| 470 | mock_diagram_analyzer_cls.return_value = diag_analyzer |
| 471 | |
| 472 | plan_gen = MagicMock() |
| 473 | mock_plan_gen_cls.return_value = plan_gen |
| 474 | mock_export.side_effect = lambda out_dir, manifest: manifest |
| 475 | |
| 476 | manifest = process_single_video( |
| 477 | input_path=video_path, |
| 478 | output_dir=output_dir, |
| 479 | provider_manager=pm, |
| 480 | title="My Custom Title", |
| 481 | ) |
| 482 | |
| 483 | assert manifest.video.title == "My Custom Title" |
| 484 | |
| 485 | @patch("video_processor.pipeline.export_all_formats") |
| 486 | @patch("video_processor.pipeline.PlanGenerator") |
| 487 | @patch("video_processor.pipeline.DiagramAnalyzer") |
| 488 | @patch("video_processor.pipeline.AudioExtractor") |
| 489 | @patch("video_processor.pipeline.filter_people_frames") |
| 490 | @patch("video_processor.pipeline.save_frames") |
| 491 | @patch("video_processor.pipeline.extract_frames") |
| 492 | @patch("video_processor.pipeline.tqdm") |
| 493 | def test_key_points_and_action_items_extracted( |
| 494 | self, |
| 495 | mock_tqdm, |
| 496 | mock_extract_frames, |
| 497 | mock_save_frames, |
| 498 | mock_filter_people, |
| 499 | mock_audio_extractor_cls, |
| 500 | mock_diagram_analyzer_cls, |
| 501 | mock_plan_gen_cls, |
| 502 | mock_export, |
| 503 | setup, |
| 504 | ): |
| 505 | video_path, output_dir, pm = setup |
| 506 | |
| 507 | _make_tqdm_passthrough(mock_tqdm) |
| 508 | mock_extract_frames.return_value = [] |
| 509 | mock_filter_people.return_value = ([], 0) |
| 510 | mock_save_frames.return_value = [] |
| 511 | |
| 512 | audio_ext = MagicMock() |
| 513 | audio_ext.extract_audio.return_value = output_dir / "audio" / "meeting.wav" |
| 514 | audio_ext.get_audio_properties.return_value = {"duration": 10.0} |
| 515 | mock_audio_extractor_cls.return_value = audio_ext |
| 516 | |
| 517 | diag_analyzer = MagicMock() |
| 518 | diag_analyzer.process_frames.return_value = ([], []) |
| 519 | mock_diagram_analyzer_cls.return_value = diag_analyzer |
| 520 | |
| 521 | plan_gen = MagicMock() |
| 522 | mock_plan_gen_cls.return_value = plan_gen |
| 523 | mock_export.side_effect = lambda out_dir, manifest: manifest |
| 524 | |
| 525 | manifest = process_single_video( |
| 526 | input_path=video_path, |
| 527 | output_dir=output_dir, |
| 528 | provider_manager=pm, |
| 529 | ) |
| 530 | |
| 531 | assert len(manifest.key_points) == 1 |
| 532 | assert manifest.key_points[0].point == "Deployment strategy discussed" |
| 533 | assert len(manifest.action_items) == 1 |
| 534 | assert manifest.action_items[0].action == "Deploy to production" |
| 535 |
+286
-33
| --- tests/test_providers.py | ||
| +++ tests/test_providers.py | ||
| @@ -1,13 +1,23 @@ | ||
| 1 | 1 | """Tests for the provider abstraction layer.""" |
| 2 | 2 | |
| 3 | +import importlib | |
| 3 | 4 | from unittest.mock import MagicMock, patch |
| 4 | 5 | |
| 5 | 6 | import pytest |
| 6 | 7 | |
| 7 | -from video_processor.providers.base import BaseProvider, ModelInfo | |
| 8 | +from video_processor.providers.base import ( | |
| 9 | + BaseProvider, | |
| 10 | + ModelInfo, | |
| 11 | + OpenAICompatibleProvider, | |
| 12 | + ProviderRegistry, | |
| 13 | +) | |
| 8 | 14 | from video_processor.providers.manager import ProviderManager |
| 15 | + | |
| 16 | +# --------------------------------------------------------------------------- | |
| 17 | +# ModelInfo | |
| 18 | +# --------------------------------------------------------------------------- | |
| 9 | 19 | |
| 10 | 20 | |
| 11 | 21 | class TestModelInfo: |
| 12 | 22 | def test_basic(self): |
| 13 | 23 | m = ModelInfo(id="gpt-4o", provider="openai", capabilities=["chat", "vision"]) |
| @@ -22,14 +32,97 @@ | ||
| 22 | 32 | capabilities=["chat", "vision"], |
| 23 | 33 | ) |
| 24 | 34 | restored = ModelInfo.model_validate_json(m.model_dump_json()) |
| 25 | 35 | assert restored == m |
| 26 | 36 | |
| 37 | + def test_defaults(self): | |
| 38 | + m = ModelInfo(id="x", provider="y") | |
| 39 | + assert m.display_name == "" | |
| 40 | + assert m.capabilities == [] | |
| 41 | + | |
| 42 | + | |
| 43 | +# --------------------------------------------------------------------------- | |
| 44 | +# ProviderRegistry | |
| 45 | +# --------------------------------------------------------------------------- | |
| 46 | + | |
| 47 | + | |
| 48 | +class TestProviderRegistry: | |
| 49 | + """Test ProviderRegistry class methods. | |
| 50 | + | |
| 51 | + We save and restore the internal _providers dict around each test so that | |
| 52 | + registrations from one test don't leak into another. | |
| 53 | + """ | |
| 54 | + | |
| 55 | + @pytest.fixture(autouse=True) | |
| 56 | + def _save_restore_registry(self): | |
| 57 | + original = dict(ProviderRegistry._providers) | |
| 58 | + yield | |
| 59 | + ProviderRegistry._providers = original | |
| 60 | + | |
| 61 | + def test_register_and_get(self): | |
| 62 | + dummy_cls = type("Dummy", (), {}) | |
| 63 | + ProviderRegistry.register("test_prov", dummy_cls, env_var="TEST_KEY") | |
| 64 | + assert ProviderRegistry.get("test_prov") is dummy_cls | |
| 65 | + | |
| 66 | + def test_get_unknown_raises(self): | |
| 67 | + with pytest.raises(ValueError, match="Unknown provider"): | |
| 68 | + ProviderRegistry.get("nonexistent_provider_xyz") | |
| 69 | + | |
| 70 | + def test_get_by_model_prefix(self): | |
| 71 | + dummy_cls = type("Dummy", (), {}) | |
| 72 | + ProviderRegistry.register("myprov", dummy_cls, model_prefixes=["mymodel-"]) | |
| 73 | + assert ProviderRegistry.get_by_model("mymodel-7b") == "myprov" | |
| 74 | + assert ProviderRegistry.get_by_model("othermodel-7b") is None | |
| 75 | + | |
| 76 | + def test_get_by_model_returns_none_for_no_match(self): | |
| 77 | + assert ProviderRegistry.get_by_model("totally_unknown_model_xyz") is None | |
| 78 | + | |
| 79 | + def test_available_with_env_var(self): | |
| 80 | + dummy_cls = type("Dummy", (), {}) | |
| 81 | + ProviderRegistry.register("envprov", dummy_cls, env_var="ENVPROV_KEY") | |
| 82 | + # Not in env -> should not appear | |
| 83 | + with patch.dict("os.environ", {}, clear=True): | |
| 84 | + avail = ProviderRegistry.available() | |
| 85 | + assert "envprov" not in avail | |
| 86 | + | |
| 87 | + # In env -> should appear | |
| 88 | + with patch.dict("os.environ", {"ENVPROV_KEY": "secret"}): | |
| 89 | + avail = ProviderRegistry.available() | |
| 90 | + assert "envprov" in avail | |
| 91 | + | |
| 92 | + def test_available_no_env_var_required(self): | |
| 93 | + dummy_cls = type("Dummy", (), {}) | |
| 94 | + ProviderRegistry.register("noenvprov", dummy_cls, env_var="") | |
| 95 | + avail = ProviderRegistry.available() | |
| 96 | + assert "noenvprov" in avail | |
| 97 | + | |
| 98 | + def test_all_registered(self): | |
| 99 | + dummy_cls = type("Dummy", (), {}) | |
| 100 | + ProviderRegistry.register("regprov", dummy_cls, env_var="X", default_models={"chat": "m1"}) | |
| 101 | + all_reg = ProviderRegistry.all_registered() | |
| 102 | + assert "regprov" in all_reg | |
| 103 | + assert all_reg["regprov"]["class"] is dummy_cls | |
| 104 | + | |
| 105 | + def test_get_default_models(self): | |
| 106 | + dummy_cls = type("Dummy", (), {}) | |
| 107 | + ProviderRegistry.register( | |
| 108 | + "defprov", dummy_cls, default_models={"chat": "c1", "vision": "v1"} | |
| 109 | + ) | |
| 110 | + defaults = ProviderRegistry.get_default_models("defprov") | |
| 111 | + assert defaults == {"chat": "c1", "vision": "v1"} | |
| 112 | + | |
| 113 | + def test_get_default_models_unknown(self): | |
| 114 | + assert ProviderRegistry.get_default_models("unknown_prov_xyz") == {} | |
| 115 | + | |
| 116 | + | |
| 117 | +# --------------------------------------------------------------------------- | |
| 118 | +# ProviderManager | |
| 119 | +# --------------------------------------------------------------------------- | |
| 120 | + | |
| 27 | 121 | |
| 28 | 122 | class TestProviderManager: |
| 29 | 123 | def _make_mock_provider(self, name="openai"): |
| 30 | - """Create a mock provider.""" | |
| 31 | 124 | provider = MagicMock(spec=BaseProvider) |
| 32 | 125 | provider.provider_name = name |
| 33 | 126 | provider.chat.return_value = "test response" |
| 34 | 127 | provider.analyze_image.return_value = "image analysis" |
| 35 | 128 | provider.transcribe_audio.return_value = { |
| @@ -53,18 +146,58 @@ | ||
| 53 | 146 | def test_init_forced_provider(self): |
| 54 | 147 | mgr = ProviderManager(provider="gemini") |
| 55 | 148 | assert mgr.vision_model == "gemini-2.5-flash" |
| 56 | 149 | assert mgr.chat_model == "gemini-2.5-flash" |
| 57 | 150 | assert mgr.transcription_model == "gemini-2.5-flash" |
| 151 | + | |
| 152 | + def test_init_forced_provider_ollama(self): | |
| 153 | + mgr = ProviderManager(provider="ollama") | |
| 154 | + assert mgr.vision_model == "" | |
| 155 | + assert mgr.chat_model == "" | |
| 156 | + assert mgr.transcription_model == "" | |
| 157 | + | |
| 158 | + def test_init_no_overrides(self): | |
| 159 | + mgr = ProviderManager() | |
| 160 | + assert mgr.vision_model is None | |
| 161 | + assert mgr.chat_model is None | |
| 162 | + assert mgr.transcription_model is None | |
| 163 | + assert mgr.auto is True | |
| 164 | + | |
| 165 | + def test_default_for_provider_gemini(self): | |
| 166 | + result = ProviderManager._default_for_provider("gemini", "vision") | |
| 167 | + assert result == "gemini-2.5-flash" | |
| 168 | + | |
| 169 | + def test_default_for_provider_openai(self): | |
| 170 | + result = ProviderManager._default_for_provider("openai", "chat") | |
| 171 | + assert isinstance(result, str) | |
| 172 | + assert len(result) > 0 | |
| 173 | + | |
| 174 | + def test_default_for_provider_unknown(self): | |
| 175 | + result = ProviderManager._default_for_provider("nonexistent_xyz", "chat") | |
| 176 | + assert result == "" | |
| 58 | 177 | |
| 59 | 178 | def test_provider_for_model(self): |
| 60 | 179 | mgr = ProviderManager() |
| 61 | 180 | assert mgr._provider_for_model("gpt-4o") == "openai" |
| 62 | 181 | assert mgr._provider_for_model("claude-sonnet-4-5-20250929") == "anthropic" |
| 63 | 182 | assert mgr._provider_for_model("gemini-2.5-flash") == "gemini" |
| 64 | 183 | assert mgr._provider_for_model("whisper-1") == "openai" |
| 65 | 184 | |
| 185 | + def test_provider_for_model_ollama_via_discovery(self): | |
| 186 | + mgr = ProviderManager() | |
| 187 | + mgr._available_models = [ | |
| 188 | + ModelInfo(id="llama3.2:latest", provider="ollama", capabilities=["chat"]), | |
| 189 | + ] | |
| 190 | + assert mgr._provider_for_model("llama3.2:latest") == "ollama" | |
| 191 | + | |
| 192 | + def test_provider_for_model_ollama_fuzzy_tag(self): | |
| 193 | + mgr = ProviderManager() | |
| 194 | + mgr._available_models = [ | |
| 195 | + ModelInfo(id="llama3.2:latest", provider="ollama", capabilities=["chat"]), | |
| 196 | + ] | |
| 197 | + assert mgr._provider_for_model("llama3.2") == "ollama" | |
| 198 | + | |
| 66 | 199 | @patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}) |
| 67 | 200 | def test_chat_routes_to_provider(self): |
| 68 | 201 | mgr = ProviderManager(chat_model="gpt-4o") |
| 69 | 202 | mock_prov = self._make_mock_provider("openai") |
| 70 | 203 | mgr._providers["openai"] = mock_prov |
| @@ -97,36 +230,126 @@ | ||
| 97 | 230 | mgr = ProviderManager( |
| 98 | 231 | vision_model="gpt-4o", |
| 99 | 232 | chat_model="claude-sonnet-4-5-20250929", |
| 100 | 233 | transcription_model="whisper-1", |
| 101 | 234 | ) |
| 102 | - # Pre-fill providers so _resolve_model doesn't try to instantiate real ones | |
| 103 | 235 | for name in ["openai", "anthropic"]: |
| 104 | 236 | mgr._providers[name] = self._make_mock_provider(name) |
| 105 | 237 | |
| 106 | 238 | used = mgr.get_models_used() |
| 107 | 239 | assert "vision" in used |
| 108 | - assert "openai/gpt-4o" == used["vision"] | |
| 109 | - assert "anthropic/claude-sonnet-4-5-20250929" == used["chat"] | |
| 240 | + assert used["vision"] == "openai/gpt-4o" | |
| 241 | + assert used["chat"] == "anthropic/claude-sonnet-4-5-20250929" | |
| 242 | + | |
| 243 | + def test_track_records_usage(self): | |
| 244 | + mgr = ProviderManager(chat_model="gpt-4o") | |
| 245 | + mock_prov = self._make_mock_provider("openai") | |
| 246 | + mock_prov._last_usage = {"input_tokens": 10, "output_tokens": 20} | |
| 247 | + mgr._providers["openai"] = mock_prov | |
| 248 | + | |
| 249 | + mgr.chat([{"role": "user", "content": "hi"}]) | |
| 250 | + assert mgr.usage.total_input_tokens == 10 | |
| 251 | + assert mgr.usage.total_output_tokens == 20 | |
| 252 | + | |
| 253 | + | |
| 254 | +# --------------------------------------------------------------------------- | |
| 255 | +# OpenAICompatibleProvider | |
| 256 | +# --------------------------------------------------------------------------- | |
| 257 | + | |
| 258 | + | |
| 259 | +class TestOpenAICompatibleProvider: | |
| 260 | + @patch("openai.OpenAI") | |
| 261 | + def test_chat(self, mock_openai_cls): | |
| 262 | + mock_client = MagicMock() | |
| 263 | + mock_openai_cls.return_value = mock_client | |
| 264 | + | |
| 265 | + mock_choice = MagicMock() | |
| 266 | + mock_choice.message.content = "hello back" | |
| 267 | + mock_response = MagicMock() | |
| 268 | + mock_response.choices = [mock_choice] | |
| 269 | + mock_response.usage.prompt_tokens = 5 | |
| 270 | + mock_response.usage.completion_tokens = 10 | |
| 271 | + mock_client.chat.completions.create.return_value = mock_response | |
| 272 | + | |
| 273 | + provider = OpenAICompatibleProvider(api_key="test", base_url="http://test") | |
| 274 | + result = provider.chat([{"role": "user", "content": "hi"}], model="test-model") | |
| 275 | + assert result == "hello back" | |
| 276 | + assert provider._last_usage == {"input_tokens": 5, "output_tokens": 10} | |
| 277 | + | |
| 278 | + @patch("openai.OpenAI") | |
| 279 | + def test_analyze_image(self, mock_openai_cls): | |
| 280 | + mock_client = MagicMock() | |
| 281 | + mock_openai_cls.return_value = mock_client | |
| 282 | + | |
| 283 | + mock_choice = MagicMock() | |
| 284 | + mock_choice.message.content = "a cat" | |
| 285 | + mock_response = MagicMock() | |
| 286 | + mock_response.choices = [mock_choice] | |
| 287 | + mock_response.usage.prompt_tokens = 100 | |
| 288 | + mock_response.usage.completion_tokens = 5 | |
| 289 | + mock_client.chat.completions.create.return_value = mock_response | |
| 290 | + | |
| 291 | + provider = OpenAICompatibleProvider(api_key="test", base_url="http://test") | |
| 292 | + result = provider.analyze_image(b"\x89PNG", "what is this?") | |
| 293 | + assert result == "a cat" | |
| 294 | + assert provider._last_usage["input_tokens"] == 100 | |
| 295 | + | |
| 296 | + @patch("openai.OpenAI") | |
| 297 | + def test_transcribe_raises(self, mock_openai_cls): | |
| 298 | + provider = OpenAICompatibleProvider(api_key="test", base_url="http://test") | |
| 299 | + with pytest.raises(NotImplementedError): | |
| 300 | + provider.transcribe_audio("/tmp/audio.wav") | |
| 301 | + | |
| 302 | + @patch("openai.OpenAI") | |
| 303 | + def test_list_models(self, mock_openai_cls): | |
| 304 | + mock_client = MagicMock() | |
| 305 | + mock_openai_cls.return_value = mock_client | |
| 306 | + | |
| 307 | + mock_model = MagicMock() | |
| 308 | + mock_model.id = "test-model-1" | |
| 309 | + mock_client.models.list.return_value = [mock_model] | |
| 310 | + | |
| 311 | + provider = OpenAICompatibleProvider(api_key="test", base_url="http://test") | |
| 312 | + provider.provider_name = "testprov" | |
| 313 | + models = provider.list_models() | |
| 314 | + assert len(models) == 1 | |
| 315 | + assert models[0].id == "test-model-1" | |
| 316 | + assert models[0].provider == "testprov" | |
| 317 | + | |
| 318 | + @patch("openai.OpenAI") | |
| 319 | + def test_list_models_handles_error(self, mock_openai_cls): | |
| 320 | + mock_client = MagicMock() | |
| 321 | + mock_openai_cls.return_value = mock_client | |
| 322 | + mock_client.models.list.side_effect = Exception("connection error") | |
| 323 | + | |
| 324 | + provider = OpenAICompatibleProvider(api_key="test", base_url="http://test") | |
| 325 | + models = provider.list_models() | |
| 326 | + assert models == [] | |
| 327 | + | |
| 328 | + | |
| 329 | +# --------------------------------------------------------------------------- | |
| 330 | +# Discovery | |
| 331 | +# --------------------------------------------------------------------------- | |
| 110 | 332 | |
| 111 | 333 | |
| 112 | 334 | class TestDiscovery: |
| 113 | 335 | @patch("video_processor.providers.discovery._cached_models", None) |
| 114 | 336 | @patch( |
| 115 | - "video_processor.providers.ollama_provider.OllamaProvider.is_available", return_value=False | |
| 337 | + "video_processor.providers.ollama_provider.OllamaProvider.is_available", | |
| 338 | + return_value=False, | |
| 116 | 339 | ) |
| 117 | 340 | @patch.dict("os.environ", {}, clear=True) |
| 118 | 341 | def test_discover_skips_missing_keys(self, mock_ollama): |
| 119 | 342 | from video_processor.providers.discovery import discover_available_models |
| 120 | 343 | |
| 121 | - # No API keys and no Ollama -> empty list, no errors | |
| 122 | 344 | models = discover_available_models(api_keys={"openai": "", "anthropic": "", "gemini": ""}) |
| 123 | 345 | assert models == [] |
| 124 | 346 | |
| 125 | 347 | @patch.dict("os.environ", {}, clear=True) |
| 126 | 348 | @patch( |
| 127 | - "video_processor.providers.ollama_provider.OllamaProvider.is_available", return_value=False | |
| 349 | + "video_processor.providers.ollama_provider.OllamaProvider.is_available", | |
| 350 | + return_value=False, | |
| 128 | 351 | ) |
| 129 | 352 | @patch("video_processor.providers.discovery._cached_models", None) |
| 130 | 353 | def test_discover_caches_results(self, mock_ollama): |
| 131 | 354 | from video_processor.providers import discovery |
| 132 | 355 | |
| @@ -136,13 +359,41 @@ | ||
| 136 | 359 | assert models == [] |
| 137 | 360 | # Second call should use cache |
| 138 | 361 | models2 = discovery.discover_available_models(api_keys={"openai": "key"}) |
| 139 | 362 | assert models2 == [] # Still cached empty result |
| 140 | 363 | |
| 141 | - # Force refresh | |
| 364 | + discovery.clear_discovery_cache() | |
| 365 | + | |
| 366 | + @patch("video_processor.providers.discovery._cached_models", None) | |
| 367 | + @patch( | |
| 368 | + "video_processor.providers.ollama_provider.OllamaProvider.is_available", | |
| 369 | + return_value=False, | |
| 370 | + ) | |
| 371 | + @patch.dict("os.environ", {}, clear=True) | |
| 372 | + def test_force_refresh_clears_cache(self, mock_ollama): | |
| 373 | + from video_processor.providers import discovery | |
| 374 | + | |
| 375 | + # Warm the cache | |
| 376 | + discovery.discover_available_models(api_keys={"openai": "", "anthropic": "", "gemini": ""}) | |
| 377 | + # Force refresh should re-run | |
| 378 | + models = discovery.discover_available_models( | |
| 379 | + api_keys={"openai": "", "anthropic": "", "gemini": ""}, | |
| 380 | + force_refresh=True, | |
| 381 | + ) | |
| 382 | + assert models == [] | |
| 383 | + | |
| 384 | + def test_clear_discovery_cache(self): | |
| 385 | + from video_processor.providers import discovery | |
| 386 | + | |
| 387 | + discovery._cached_models = [ModelInfo(id="x", provider="y")] | |
| 142 | 388 | discovery.clear_discovery_cache() |
| 143 | - # Would try to connect with real key, so skip that test | |
| 389 | + assert discovery._cached_models is None | |
| 390 | + | |
| 391 | + | |
| 392 | +# --------------------------------------------------------------------------- | |
| 393 | +# OllamaProvider | |
| 394 | +# --------------------------------------------------------------------------- | |
| 144 | 395 | |
| 145 | 396 | |
| 146 | 397 | class TestOllamaProvider: |
| 147 | 398 | @patch("video_processor.providers.ollama_provider.requests") |
| 148 | 399 | def test_is_available_when_running(self, mock_requests): |
| @@ -189,35 +440,37 @@ | ||
| 189 | 440 | provider = OllamaProvider() |
| 190 | 441 | models = provider.list_models() |
| 191 | 442 | assert len(models) == 2 |
| 192 | 443 | assert models[0].provider == "ollama" |
| 193 | 444 | |
| 194 | - # llava should have vision capability | |
| 195 | 445 | llava = [m for m in models if "llava" in m.id][0] |
| 196 | 446 | assert "vision" in llava.capabilities |
| 197 | 447 | |
| 198 | - # llama should have only chat | |
| 199 | 448 | llama = [m for m in models if "llama" in m.id][0] |
| 200 | 449 | assert "chat" in llama.capabilities |
| 201 | 450 | assert "vision" not in llama.capabilities |
| 202 | 451 | |
| 203 | - def test_provider_for_model_ollama_via_discovery(self): | |
| 204 | - mgr = ProviderManager() | |
| 205 | - mgr._available_models = [ | |
| 206 | - ModelInfo(id="llama3.2:latest", provider="ollama", capabilities=["chat"]), | |
| 207 | - ] | |
| 208 | - assert mgr._provider_for_model("llama3.2:latest") == "ollama" | |
| 209 | - | |
| 210 | - def test_provider_for_model_ollama_fuzzy_tag(self): | |
| 211 | - mgr = ProviderManager() | |
| 212 | - mgr._available_models = [ | |
| 213 | - ModelInfo(id="llama3.2:latest", provider="ollama", capabilities=["chat"]), | |
| 214 | - ] | |
| 215 | - # Should match "llama3.2" to "llama3.2:latest" via prefix | |
| 216 | - assert mgr._provider_for_model("llama3.2") == "ollama" | |
| 217 | - | |
| 218 | - def test_init_forced_provider_ollama(self): | |
| 219 | - mgr = ProviderManager(provider="ollama") | |
| 220 | - # Ollama defaults are empty (resolved dynamically) | |
| 221 | - assert mgr.vision_model == "" | |
| 222 | - assert mgr.chat_model == "" | |
| 223 | - assert mgr.transcription_model == "" | |
| 452 | + | |
| 453 | +# --------------------------------------------------------------------------- | |
| 454 | +# Provider module imports | |
| 455 | +# --------------------------------------------------------------------------- | |
| 456 | + | |
| 457 | + | |
| 458 | +class TestProviderImports: | |
| 459 | + """Verify that all provider modules import without errors.""" | |
| 460 | + | |
| 461 | + PROVIDER_MODULES = [ | |
| 462 | + "video_processor.providers.openai_provider", | |
| 463 | + "video_processor.providers.anthropic_provider", | |
| 464 | + "video_processor.providers.gemini_provider", | |
| 465 | + "video_processor.providers.ollama_provider", | |
| 466 | + "video_processor.providers.azure_provider", | |
| 467 | + "video_processor.providers.together_provider", | |
| 468 | + "video_processor.providers.fireworks_provider", | |
| 469 | + "video_processor.providers.cerebras_provider", | |
| 470 | + "video_processor.providers.xai_provider", | |
| 471 | + ] | |
| 472 | + | |
| 473 | + @pytest.mark.parametrize("module_name", PROVIDER_MODULES) | |
| 474 | + def test_import(self, module_name): | |
| 475 | + mod = importlib.import_module(module_name) | |
| 476 | + assert mod is not None | |
| 224 | 477 | |
| 225 | 478 | ADDED tests/test_usage_tracker.py |
| --- tests/test_providers.py | |
| +++ tests/test_providers.py | |
| @@ -1,13 +1,23 @@ | |
| 1 | """Tests for the provider abstraction layer.""" |
| 2 | |
| 3 | from unittest.mock import MagicMock, patch |
| 4 | |
| 5 | import pytest |
| 6 | |
| 7 | from video_processor.providers.base import BaseProvider, ModelInfo |
| 8 | from video_processor.providers.manager import ProviderManager |
| 9 | |
| 10 | |
| 11 | class TestModelInfo: |
| 12 | def test_basic(self): |
| 13 | m = ModelInfo(id="gpt-4o", provider="openai", capabilities=["chat", "vision"]) |
| @@ -22,14 +32,97 @@ | |
| 22 | capabilities=["chat", "vision"], |
| 23 | ) |
| 24 | restored = ModelInfo.model_validate_json(m.model_dump_json()) |
| 25 | assert restored == m |
| 26 | |
| 27 | |
| 28 | class TestProviderManager: |
| 29 | def _make_mock_provider(self, name="openai"): |
| 30 | """Create a mock provider.""" |
| 31 | provider = MagicMock(spec=BaseProvider) |
| 32 | provider.provider_name = name |
| 33 | provider.chat.return_value = "test response" |
| 34 | provider.analyze_image.return_value = "image analysis" |
| 35 | provider.transcribe_audio.return_value = { |
| @@ -53,18 +146,58 @@ | |
| 53 | def test_init_forced_provider(self): |
| 54 | mgr = ProviderManager(provider="gemini") |
| 55 | assert mgr.vision_model == "gemini-2.5-flash" |
| 56 | assert mgr.chat_model == "gemini-2.5-flash" |
| 57 | assert mgr.transcription_model == "gemini-2.5-flash" |
| 58 | |
| 59 | def test_provider_for_model(self): |
| 60 | mgr = ProviderManager() |
| 61 | assert mgr._provider_for_model("gpt-4o") == "openai" |
| 62 | assert mgr._provider_for_model("claude-sonnet-4-5-20250929") == "anthropic" |
| 63 | assert mgr._provider_for_model("gemini-2.5-flash") == "gemini" |
| 64 | assert mgr._provider_for_model("whisper-1") == "openai" |
| 65 | |
| 66 | @patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}) |
| 67 | def test_chat_routes_to_provider(self): |
| 68 | mgr = ProviderManager(chat_model="gpt-4o") |
| 69 | mock_prov = self._make_mock_provider("openai") |
| 70 | mgr._providers["openai"] = mock_prov |
| @@ -97,36 +230,126 @@ | |
| 97 | mgr = ProviderManager( |
| 98 | vision_model="gpt-4o", |
| 99 | chat_model="claude-sonnet-4-5-20250929", |
| 100 | transcription_model="whisper-1", |
| 101 | ) |
| 102 | # Pre-fill providers so _resolve_model doesn't try to instantiate real ones |
| 103 | for name in ["openai", "anthropic"]: |
| 104 | mgr._providers[name] = self._make_mock_provider(name) |
| 105 | |
| 106 | used = mgr.get_models_used() |
| 107 | assert "vision" in used |
| 108 | assert "openai/gpt-4o" == used["vision"] |
| 109 | assert "anthropic/claude-sonnet-4-5-20250929" == used["chat"] |
| 110 | |
| 111 | |
| 112 | class TestDiscovery: |
| 113 | @patch("video_processor.providers.discovery._cached_models", None) |
| 114 | @patch( |
| 115 | "video_processor.providers.ollama_provider.OllamaProvider.is_available", return_value=False |
| 116 | ) |
| 117 | @patch.dict("os.environ", {}, clear=True) |
| 118 | def test_discover_skips_missing_keys(self, mock_ollama): |
| 119 | from video_processor.providers.discovery import discover_available_models |
| 120 | |
| 121 | # No API keys and no Ollama -> empty list, no errors |
| 122 | models = discover_available_models(api_keys={"openai": "", "anthropic": "", "gemini": ""}) |
| 123 | assert models == [] |
| 124 | |
| 125 | @patch.dict("os.environ", {}, clear=True) |
| 126 | @patch( |
| 127 | "video_processor.providers.ollama_provider.OllamaProvider.is_available", return_value=False |
| 128 | ) |
| 129 | @patch("video_processor.providers.discovery._cached_models", None) |
| 130 | def test_discover_caches_results(self, mock_ollama): |
| 131 | from video_processor.providers import discovery |
| 132 | |
| @@ -136,13 +359,41 @@ | |
| 136 | assert models == [] |
| 137 | # Second call should use cache |
| 138 | models2 = discovery.discover_available_models(api_keys={"openai": "key"}) |
| 139 | assert models2 == [] # Still cached empty result |
| 140 | |
| 141 | # Force refresh |
| 142 | discovery.clear_discovery_cache() |
| 143 | # Would try to connect with real key, so skip that test |
| 144 | |
| 145 | |
| 146 | class TestOllamaProvider: |
| 147 | @patch("video_processor.providers.ollama_provider.requests") |
| 148 | def test_is_available_when_running(self, mock_requests): |
| @@ -189,35 +440,37 @@ | |
| 189 | provider = OllamaProvider() |
| 190 | models = provider.list_models() |
| 191 | assert len(models) == 2 |
| 192 | assert models[0].provider == "ollama" |
| 193 | |
| 194 | # llava should have vision capability |
| 195 | llava = [m for m in models if "llava" in m.id][0] |
| 196 | assert "vision" in llava.capabilities |
| 197 | |
| 198 | # llama should have only chat |
| 199 | llama = [m for m in models if "llama" in m.id][0] |
| 200 | assert "chat" in llama.capabilities |
| 201 | assert "vision" not in llama.capabilities |
| 202 | |
| 203 | def test_provider_for_model_ollama_via_discovery(self): |
| 204 | mgr = ProviderManager() |
| 205 | mgr._available_models = [ |
| 206 | ModelInfo(id="llama3.2:latest", provider="ollama", capabilities=["chat"]), |
| 207 | ] |
| 208 | assert mgr._provider_for_model("llama3.2:latest") == "ollama" |
| 209 | |
| 210 | def test_provider_for_model_ollama_fuzzy_tag(self): |
| 211 | mgr = ProviderManager() |
| 212 | mgr._available_models = [ |
| 213 | ModelInfo(id="llama3.2:latest", provider="ollama", capabilities=["chat"]), |
| 214 | ] |
| 215 | # Should match "llama3.2" to "llama3.2:latest" via prefix |
| 216 | assert mgr._provider_for_model("llama3.2") == "ollama" |
| 217 | |
| 218 | def test_init_forced_provider_ollama(self): |
| 219 | mgr = ProviderManager(provider="ollama") |
| 220 | # Ollama defaults are empty (resolved dynamically) |
| 221 | assert mgr.vision_model == "" |
| 222 | assert mgr.chat_model == "" |
| 223 | assert mgr.transcription_model == "" |
| 224 | |
| 225 | DDED tests/test_usage_tracker.py |
| --- tests/test_providers.py | |
| +++ tests/test_providers.py | |
| @@ -1,13 +1,23 @@ | |
| 1 | """Tests for the provider abstraction layer.""" |
| 2 | |
| 3 | import importlib |
| 4 | from unittest.mock import MagicMock, patch |
| 5 | |
| 6 | import pytest |
| 7 | |
| 8 | from video_processor.providers.base import ( |
| 9 | BaseProvider, |
| 10 | ModelInfo, |
| 11 | OpenAICompatibleProvider, |
| 12 | ProviderRegistry, |
| 13 | ) |
| 14 | from video_processor.providers.manager import ProviderManager |
| 15 | |
| 16 | # --------------------------------------------------------------------------- |
| 17 | # ModelInfo |
| 18 | # --------------------------------------------------------------------------- |
| 19 | |
| 20 | |
| 21 | class TestModelInfo: |
| 22 | def test_basic(self): |
| 23 | m = ModelInfo(id="gpt-4o", provider="openai", capabilities=["chat", "vision"]) |
| @@ -22,14 +32,97 @@ | |
| 32 | capabilities=["chat", "vision"], |
| 33 | ) |
| 34 | restored = ModelInfo.model_validate_json(m.model_dump_json()) |
| 35 | assert restored == m |
| 36 | |
| 37 | def test_defaults(self): |
| 38 | m = ModelInfo(id="x", provider="y") |
| 39 | assert m.display_name == "" |
| 40 | assert m.capabilities == [] |
| 41 | |
| 42 | |
| 43 | # --------------------------------------------------------------------------- |
| 44 | # ProviderRegistry |
| 45 | # --------------------------------------------------------------------------- |
| 46 | |
| 47 | |
| 48 | class TestProviderRegistry: |
| 49 | """Test ProviderRegistry class methods. |
| 50 | |
| 51 | We save and restore the internal _providers dict around each test so that |
| 52 | registrations from one test don't leak into another. |
| 53 | """ |
| 54 | |
| 55 | @pytest.fixture(autouse=True) |
| 56 | def _save_restore_registry(self): |
| 57 | original = dict(ProviderRegistry._providers) |
| 58 | yield |
| 59 | ProviderRegistry._providers = original |
| 60 | |
| 61 | def test_register_and_get(self): |
| 62 | dummy_cls = type("Dummy", (), {}) |
| 63 | ProviderRegistry.register("test_prov", dummy_cls, env_var="TEST_KEY") |
| 64 | assert ProviderRegistry.get("test_prov") is dummy_cls |
| 65 | |
| 66 | def test_get_unknown_raises(self): |
| 67 | with pytest.raises(ValueError, match="Unknown provider"): |
| 68 | ProviderRegistry.get("nonexistent_provider_xyz") |
| 69 | |
| 70 | def test_get_by_model_prefix(self): |
| 71 | dummy_cls = type("Dummy", (), {}) |
| 72 | ProviderRegistry.register("myprov", dummy_cls, model_prefixes=["mymodel-"]) |
| 73 | assert ProviderRegistry.get_by_model("mymodel-7b") == "myprov" |
| 74 | assert ProviderRegistry.get_by_model("othermodel-7b") is None |
| 75 | |
| 76 | def test_get_by_model_returns_none_for_no_match(self): |
| 77 | assert ProviderRegistry.get_by_model("totally_unknown_model_xyz") is None |
| 78 | |
| 79 | def test_available_with_env_var(self): |
| 80 | dummy_cls = type("Dummy", (), {}) |
| 81 | ProviderRegistry.register("envprov", dummy_cls, env_var="ENVPROV_KEY") |
| 82 | # Not in env -> should not appear |
| 83 | with patch.dict("os.environ", {}, clear=True): |
| 84 | avail = ProviderRegistry.available() |
| 85 | assert "envprov" not in avail |
| 86 | |
| 87 | # In env -> should appear |
| 88 | with patch.dict("os.environ", {"ENVPROV_KEY": "secret"}): |
| 89 | avail = ProviderRegistry.available() |
| 90 | assert "envprov" in avail |
| 91 | |
| 92 | def test_available_no_env_var_required(self): |
| 93 | dummy_cls = type("Dummy", (), {}) |
| 94 | ProviderRegistry.register("noenvprov", dummy_cls, env_var="") |
| 95 | avail = ProviderRegistry.available() |
| 96 | assert "noenvprov" in avail |
| 97 | |
| 98 | def test_all_registered(self): |
| 99 | dummy_cls = type("Dummy", (), {}) |
| 100 | ProviderRegistry.register("regprov", dummy_cls, env_var="X", default_models={"chat": "m1"}) |
| 101 | all_reg = ProviderRegistry.all_registered() |
| 102 | assert "regprov" in all_reg |
| 103 | assert all_reg["regprov"]["class"] is dummy_cls |
| 104 | |
| 105 | def test_get_default_models(self): |
| 106 | dummy_cls = type("Dummy", (), {}) |
| 107 | ProviderRegistry.register( |
| 108 | "defprov", dummy_cls, default_models={"chat": "c1", "vision": "v1"} |
| 109 | ) |
| 110 | defaults = ProviderRegistry.get_default_models("defprov") |
| 111 | assert defaults == {"chat": "c1", "vision": "v1"} |
| 112 | |
| 113 | def test_get_default_models_unknown(self): |
| 114 | assert ProviderRegistry.get_default_models("unknown_prov_xyz") == {} |
| 115 | |
| 116 | |
| 117 | # --------------------------------------------------------------------------- |
| 118 | # ProviderManager |
| 119 | # --------------------------------------------------------------------------- |
| 120 | |
| 121 | |
| 122 | class TestProviderManager: |
| 123 | def _make_mock_provider(self, name="openai"): |
| 124 | provider = MagicMock(spec=BaseProvider) |
| 125 | provider.provider_name = name |
| 126 | provider.chat.return_value = "test response" |
| 127 | provider.analyze_image.return_value = "image analysis" |
| 128 | provider.transcribe_audio.return_value = { |
| @@ -53,18 +146,58 @@ | |
| 146 | def test_init_forced_provider(self): |
| 147 | mgr = ProviderManager(provider="gemini") |
| 148 | assert mgr.vision_model == "gemini-2.5-flash" |
| 149 | assert mgr.chat_model == "gemini-2.5-flash" |
| 150 | assert mgr.transcription_model == "gemini-2.5-flash" |
| 151 | |
| 152 | def test_init_forced_provider_ollama(self): |
| 153 | mgr = ProviderManager(provider="ollama") |
| 154 | assert mgr.vision_model == "" |
| 155 | assert mgr.chat_model == "" |
| 156 | assert mgr.transcription_model == "" |
| 157 | |
| 158 | def test_init_no_overrides(self): |
| 159 | mgr = ProviderManager() |
| 160 | assert mgr.vision_model is None |
| 161 | assert mgr.chat_model is None |
| 162 | assert mgr.transcription_model is None |
| 163 | assert mgr.auto is True |
| 164 | |
| 165 | def test_default_for_provider_gemini(self): |
| 166 | result = ProviderManager._default_for_provider("gemini", "vision") |
| 167 | assert result == "gemini-2.5-flash" |
| 168 | |
| 169 | def test_default_for_provider_openai(self): |
| 170 | result = ProviderManager._default_for_provider("openai", "chat") |
| 171 | assert isinstance(result, str) |
| 172 | assert len(result) > 0 |
| 173 | |
| 174 | def test_default_for_provider_unknown(self): |
| 175 | result = ProviderManager._default_for_provider("nonexistent_xyz", "chat") |
| 176 | assert result == "" |
| 177 | |
| 178 | def test_provider_for_model(self): |
| 179 | mgr = ProviderManager() |
| 180 | assert mgr._provider_for_model("gpt-4o") == "openai" |
| 181 | assert mgr._provider_for_model("claude-sonnet-4-5-20250929") == "anthropic" |
| 182 | assert mgr._provider_for_model("gemini-2.5-flash") == "gemini" |
| 183 | assert mgr._provider_for_model("whisper-1") == "openai" |
| 184 | |
| 185 | def test_provider_for_model_ollama_via_discovery(self): |
| 186 | mgr = ProviderManager() |
| 187 | mgr._available_models = [ |
| 188 | ModelInfo(id="llama3.2:latest", provider="ollama", capabilities=["chat"]), |
| 189 | ] |
| 190 | assert mgr._provider_for_model("llama3.2:latest") == "ollama" |
| 191 | |
| 192 | def test_provider_for_model_ollama_fuzzy_tag(self): |
| 193 | mgr = ProviderManager() |
| 194 | mgr._available_models = [ |
| 195 | ModelInfo(id="llama3.2:latest", provider="ollama", capabilities=["chat"]), |
| 196 | ] |
| 197 | assert mgr._provider_for_model("llama3.2") == "ollama" |
| 198 | |
| 199 | @patch.dict("os.environ", {"OPENAI_API_KEY": "test-key"}) |
| 200 | def test_chat_routes_to_provider(self): |
| 201 | mgr = ProviderManager(chat_model="gpt-4o") |
| 202 | mock_prov = self._make_mock_provider("openai") |
| 203 | mgr._providers["openai"] = mock_prov |
| @@ -97,36 +230,126 @@ | |
| 230 | mgr = ProviderManager( |
| 231 | vision_model="gpt-4o", |
| 232 | chat_model="claude-sonnet-4-5-20250929", |
| 233 | transcription_model="whisper-1", |
| 234 | ) |
| 235 | for name in ["openai", "anthropic"]: |
| 236 | mgr._providers[name] = self._make_mock_provider(name) |
| 237 | |
| 238 | used = mgr.get_models_used() |
| 239 | assert "vision" in used |
| 240 | assert used["vision"] == "openai/gpt-4o" |
| 241 | assert used["chat"] == "anthropic/claude-sonnet-4-5-20250929" |
| 242 | |
| 243 | def test_track_records_usage(self): |
| 244 | mgr = ProviderManager(chat_model="gpt-4o") |
| 245 | mock_prov = self._make_mock_provider("openai") |
| 246 | mock_prov._last_usage = {"input_tokens": 10, "output_tokens": 20} |
| 247 | mgr._providers["openai"] = mock_prov |
| 248 | |
| 249 | mgr.chat([{"role": "user", "content": "hi"}]) |
| 250 | assert mgr.usage.total_input_tokens == 10 |
| 251 | assert mgr.usage.total_output_tokens == 20 |
| 252 | |
| 253 | |
| 254 | # --------------------------------------------------------------------------- |
| 255 | # OpenAICompatibleProvider |
| 256 | # --------------------------------------------------------------------------- |
| 257 | |
| 258 | |
| 259 | class TestOpenAICompatibleProvider: |
| 260 | @patch("openai.OpenAI") |
| 261 | def test_chat(self, mock_openai_cls): |
| 262 | mock_client = MagicMock() |
| 263 | mock_openai_cls.return_value = mock_client |
| 264 | |
| 265 | mock_choice = MagicMock() |
| 266 | mock_choice.message.content = "hello back" |
| 267 | mock_response = MagicMock() |
| 268 | mock_response.choices = [mock_choice] |
| 269 | mock_response.usage.prompt_tokens = 5 |
| 270 | mock_response.usage.completion_tokens = 10 |
| 271 | mock_client.chat.completions.create.return_value = mock_response |
| 272 | |
| 273 | provider = OpenAICompatibleProvider(api_key="test", base_url="http://test") |
| 274 | result = provider.chat([{"role": "user", "content": "hi"}], model="test-model") |
| 275 | assert result == "hello back" |
| 276 | assert provider._last_usage == {"input_tokens": 5, "output_tokens": 10} |
| 277 | |
| 278 | @patch("openai.OpenAI") |
| 279 | def test_analyze_image(self, mock_openai_cls): |
| 280 | mock_client = MagicMock() |
| 281 | mock_openai_cls.return_value = mock_client |
| 282 | |
| 283 | mock_choice = MagicMock() |
| 284 | mock_choice.message.content = "a cat" |
| 285 | mock_response = MagicMock() |
| 286 | mock_response.choices = [mock_choice] |
| 287 | mock_response.usage.prompt_tokens = 100 |
| 288 | mock_response.usage.completion_tokens = 5 |
| 289 | mock_client.chat.completions.create.return_value = mock_response |
| 290 | |
| 291 | provider = OpenAICompatibleProvider(api_key="test", base_url="http://test") |
| 292 | result = provider.analyze_image(b"\x89PNG", "what is this?") |
| 293 | assert result == "a cat" |
| 294 | assert provider._last_usage["input_tokens"] == 100 |
| 295 | |
| 296 | @patch("openai.OpenAI") |
| 297 | def test_transcribe_raises(self, mock_openai_cls): |
| 298 | provider = OpenAICompatibleProvider(api_key="test", base_url="http://test") |
| 299 | with pytest.raises(NotImplementedError): |
| 300 | provider.transcribe_audio("/tmp/audio.wav") |
| 301 | |
| 302 | @patch("openai.OpenAI") |
| 303 | def test_list_models(self, mock_openai_cls): |
| 304 | mock_client = MagicMock() |
| 305 | mock_openai_cls.return_value = mock_client |
| 306 | |
| 307 | mock_model = MagicMock() |
| 308 | mock_model.id = "test-model-1" |
| 309 | mock_client.models.list.return_value = [mock_model] |
| 310 | |
| 311 | provider = OpenAICompatibleProvider(api_key="test", base_url="http://test") |
| 312 | provider.provider_name = "testprov" |
| 313 | models = provider.list_models() |
| 314 | assert len(models) == 1 |
| 315 | assert models[0].id == "test-model-1" |
| 316 | assert models[0].provider == "testprov" |
| 317 | |
| 318 | @patch("openai.OpenAI") |
| 319 | def test_list_models_handles_error(self, mock_openai_cls): |
| 320 | mock_client = MagicMock() |
| 321 | mock_openai_cls.return_value = mock_client |
| 322 | mock_client.models.list.side_effect = Exception("connection error") |
| 323 | |
| 324 | provider = OpenAICompatibleProvider(api_key="test", base_url="http://test") |
| 325 | models = provider.list_models() |
| 326 | assert models == [] |
| 327 | |
| 328 | |
| 329 | # --------------------------------------------------------------------------- |
| 330 | # Discovery |
| 331 | # --------------------------------------------------------------------------- |
| 332 | |
| 333 | |
| 334 | class TestDiscovery: |
| 335 | @patch("video_processor.providers.discovery._cached_models", None) |
| 336 | @patch( |
| 337 | "video_processor.providers.ollama_provider.OllamaProvider.is_available", |
| 338 | return_value=False, |
| 339 | ) |
| 340 | @patch.dict("os.environ", {}, clear=True) |
| 341 | def test_discover_skips_missing_keys(self, mock_ollama): |
| 342 | from video_processor.providers.discovery import discover_available_models |
| 343 | |
| 344 | models = discover_available_models(api_keys={"openai": "", "anthropic": "", "gemini": ""}) |
| 345 | assert models == [] |
| 346 | |
| 347 | @patch.dict("os.environ", {}, clear=True) |
| 348 | @patch( |
| 349 | "video_processor.providers.ollama_provider.OllamaProvider.is_available", |
| 350 | return_value=False, |
| 351 | ) |
| 352 | @patch("video_processor.providers.discovery._cached_models", None) |
| 353 | def test_discover_caches_results(self, mock_ollama): |
| 354 | from video_processor.providers import discovery |
| 355 | |
| @@ -136,13 +359,41 @@ | |
| 359 | assert models == [] |
| 360 | # Second call should use cache |
| 361 | models2 = discovery.discover_available_models(api_keys={"openai": "key"}) |
| 362 | assert models2 == [] # Still cached empty result |
| 363 | |
| 364 | discovery.clear_discovery_cache() |
| 365 | |
| 366 | @patch("video_processor.providers.discovery._cached_models", None) |
| 367 | @patch( |
| 368 | "video_processor.providers.ollama_provider.OllamaProvider.is_available", |
| 369 | return_value=False, |
| 370 | ) |
| 371 | @patch.dict("os.environ", {}, clear=True) |
| 372 | def test_force_refresh_clears_cache(self, mock_ollama): |
| 373 | from video_processor.providers import discovery |
| 374 | |
| 375 | # Warm the cache |
| 376 | discovery.discover_available_models(api_keys={"openai": "", "anthropic": "", "gemini": ""}) |
| 377 | # Force refresh should re-run |
| 378 | models = discovery.discover_available_models( |
| 379 | api_keys={"openai": "", "anthropic": "", "gemini": ""}, |
| 380 | force_refresh=True, |
| 381 | ) |
| 382 | assert models == [] |
| 383 | |
| 384 | def test_clear_discovery_cache(self): |
| 385 | from video_processor.providers import discovery |
| 386 | |
| 387 | discovery._cached_models = [ModelInfo(id="x", provider="y")] |
| 388 | discovery.clear_discovery_cache() |
| 389 | assert discovery._cached_models is None |
| 390 | |
| 391 | |
| 392 | # --------------------------------------------------------------------------- |
| 393 | # OllamaProvider |
| 394 | # --------------------------------------------------------------------------- |
| 395 | |
| 396 | |
| 397 | class TestOllamaProvider: |
| 398 | @patch("video_processor.providers.ollama_provider.requests") |
| 399 | def test_is_available_when_running(self, mock_requests): |
| @@ -189,35 +440,37 @@ | |
| 440 | provider = OllamaProvider() |
| 441 | models = provider.list_models() |
| 442 | assert len(models) == 2 |
| 443 | assert models[0].provider == "ollama" |
| 444 | |
| 445 | llava = [m for m in models if "llava" in m.id][0] |
| 446 | assert "vision" in llava.capabilities |
| 447 | |
| 448 | llama = [m for m in models if "llama" in m.id][0] |
| 449 | assert "chat" in llama.capabilities |
| 450 | assert "vision" not in llama.capabilities |
| 451 | |
| 452 | |
| 453 | # --------------------------------------------------------------------------- |
| 454 | # Provider module imports |
| 455 | # --------------------------------------------------------------------------- |
| 456 | |
| 457 | |
| 458 | class TestProviderImports: |
| 459 | """Verify that all provider modules import without errors.""" |
| 460 | |
| 461 | PROVIDER_MODULES = [ |
| 462 | "video_processor.providers.openai_provider", |
| 463 | "video_processor.providers.anthropic_provider", |
| 464 | "video_processor.providers.gemini_provider", |
| 465 | "video_processor.providers.ollama_provider", |
| 466 | "video_processor.providers.azure_provider", |
| 467 | "video_processor.providers.together_provider", |
| 468 | "video_processor.providers.fireworks_provider", |
| 469 | "video_processor.providers.cerebras_provider", |
| 470 | "video_processor.providers.xai_provider", |
| 471 | ] |
| 472 | |
| 473 | @pytest.mark.parametrize("module_name", PROVIDER_MODULES) |
| 474 | def test_import(self, module_name): |
| 475 | mod = importlib.import_module(module_name) |
| 476 | assert mod is not None |
| 477 | |
| 478 | DDED tests/test_usage_tracker.py |
+198
| --- a/tests/test_usage_tracker.py | ||
| +++ b/tests/test_usage_tracker.py | ||
| @@ -0,0 +1,198 @@ | ||
| 1 | +"""Tests for the UsageTracker class.""" | |
| 2 | + | |
| 3 | +import time | |
| 4 | + | |
| 5 | +from video_processor.utils.usage_tracker import ModelUsage, StepTiming, UsageTracker, _fmt_duration | |
| 6 | + | |
| 7 | + | |
| 8 | +class TestModelUsage: | |
| 9 | + def test_total_tokens(self): | |
| 10 | + mu = ModelUsage(provider="openai", model="gpt-4o", input_tokens=100, output_tokens=50) | |
| 11 | + assert mu.total_tokens == 150 | |
| 12 | + | |
| 13 | + def test_estimated_cost_known_model(self): | |
| 14 | + mu = ModelUsage( | |
| 15 | + provider="openai", | |
| 16 | + model="gpt-4o", | |
| 17 | + input_tokens=1_000_000, | |
| 18 | + output_tokens=500_000, | |
| 19 | + ) | |
| 20 | + # gpt-4o: input $2.50/M, output $10.00/M | |
| 21 | + expected = 1_000_000 * 2.50 / 1_000_000 + 500_000 * 10.00 / 1_000_000 | |
| 22 | + assert abs(mu.estimated_cost - expected) < 0.001 | |
| 23 | + | |
| 24 | + def test_estimated_cost_unknown_model(self): | |
| 25 | + mu = ModelUsage( | |
| 26 | + provider="local", | |
| 27 | + model="my-custom-model", | |
| 28 | + input_tokens=1000, | |
| 29 | + output_tokens=500, | |
| 30 | + ) | |
| 31 | + assert mu.estimated_cost == 0.0 | |
| 32 | + | |
| 33 | + def test_estimated_cost_whisper(self): | |
| 34 | + mu = ModelUsage( | |
| 35 | + provider="openai", | |
| 36 | + model="whisper-1", | |
| 37 | + audio_minutes=10.0, | |
| 38 | + ) | |
| 39 | + # whisper-1: $0.006/min | |
| 40 | + assert abs(mu.estimated_cost - 0.06) < 0.001 | |
| 41 | + | |
| 42 | + def test_estimated_cost_partial_match(self): | |
| 43 | + mu = ModelUsage( | |
| 44 | + provider="openai", | |
| 45 | + model="gpt-4o-2024-08-06", | |
| 46 | + input_tokens=1_000_000, | |
| 47 | + output_tokens=0, | |
| 48 | + ) | |
| 49 | + # Should partial-match to gpt-4o | |
| 50 | + assert mu.estimated_cost > 0 | |
| 51 | + | |
| 52 | + def test_calls_default_zero(self): | |
| 53 | + mu = ModelUsage() | |
| 54 | + assert mu.calls == 0 | |
| 55 | + assert mu.total_tokens == 0 | |
| 56 | + assert mu.estimated_cost == 0.0 | |
| 57 | + | |
| 58 | + | |
| 59 | +class TestStepTiming: | |
| 60 | + def test_duration_with_times(self): | |
| 61 | + st = StepTiming(name="test", start_time=100.0, end_time=105.5) | |
| 62 | + assert abs(st.duration - 5.5) < 0.001 | |
| 63 | + | |
| 64 | + def test_duration_no_end_time(self): | |
| 65 | + st = StepTiming(name="test", start_time=100.0) | |
| 66 | + assert st.duration == 0.0 | |
| 67 | + | |
| 68 | + def test_duration_no_start_time(self): | |
| 69 | + st = StepTiming(name="test") | |
| 70 | + assert st.duration == 0.0 | |
| 71 | + | |
| 72 | + | |
| 73 | +class TestUsageTracker: | |
| 74 | + def test_record_single_call(self): | |
| 75 | + tracker = UsageTracker() | |
| 76 | + tracker.record("openai", "gpt-4o", input_tokens=500, output_tokens=200) | |
| 77 | + assert tracker.total_api_calls == 1 | |
| 78 | + assert tracker.total_input_tokens == 500 | |
| 79 | + assert tracker.total_output_tokens == 200 | |
| 80 | + assert tracker.total_tokens == 700 | |
| 81 | + | |
| 82 | + def test_record_multiple_calls_same_model(self): | |
| 83 | + tracker = UsageTracker() | |
| 84 | + tracker.record("openai", "gpt-4o", input_tokens=100, output_tokens=50) | |
| 85 | + tracker.record("openai", "gpt-4o", input_tokens=200, output_tokens=100) | |
| 86 | + assert tracker.total_api_calls == 2 | |
| 87 | + assert tracker.total_input_tokens == 300 | |
| 88 | + assert tracker.total_output_tokens == 150 | |
| 89 | + | |
| 90 | + def test_record_multiple_models(self): | |
| 91 | + tracker = UsageTracker() | |
| 92 | + tracker.record("openai", "gpt-4o", input_tokens=100, output_tokens=50) | |
| 93 | + tracker.record( | |
| 94 | + "anthropic", "claude-sonnet-4-5-20250929", input_tokens=200, output_tokens=100 | |
| 95 | + ) | |
| 96 | + assert tracker.total_api_calls == 2 | |
| 97 | + assert tracker.total_input_tokens == 300 | |
| 98 | + assert len(tracker._models) == 2 | |
| 99 | + | |
| 100 | + def test_total_cost(self): | |
| 101 | + tracker = UsageTracker() | |
| 102 | + tracker.record("openai", "gpt-4o", input_tokens=1_000_000, output_tokens=500_000) | |
| 103 | + cost = tracker.total_cost | |
| 104 | + assert cost > 0 | |
| 105 | + | |
| 106 | + def test_start_and_end_step(self): | |
| 107 | + tracker = UsageTracker() | |
| 108 | + tracker.start_step("Frame extraction") | |
| 109 | + time.sleep(0.01) | |
| 110 | + tracker.end_step() | |
| 111 | + | |
| 112 | + assert len(tracker._steps) == 1 | |
| 113 | + assert tracker._steps[0].name == "Frame extraction" | |
| 114 | + assert tracker._steps[0].duration > 0 | |
| 115 | + | |
| 116 | + def test_start_step_auto_closes_previous(self): | |
| 117 | + tracker = UsageTracker() | |
| 118 | + tracker.start_step("Step 1") | |
| 119 | + time.sleep(0.01) | |
| 120 | + tracker.start_step("Step 2") | |
| 121 | + # Step 1 should have been auto-closed | |
| 122 | + assert len(tracker._steps) == 1 | |
| 123 | + assert tracker._steps[0].name == "Step 1" | |
| 124 | + assert tracker._steps[0].duration > 0 | |
| 125 | + # Step 2 is current | |
| 126 | + assert tracker._current_step.name == "Step 2" | |
| 127 | + | |
| 128 | + def test_end_step_when_none(self): | |
| 129 | + tracker = UsageTracker() | |
| 130 | + tracker.end_step() # Should not raise | |
| 131 | + assert len(tracker._steps) == 0 | |
| 132 | + | |
| 133 | + def test_total_duration(self): | |
| 134 | + tracker = UsageTracker() | |
| 135 | + time.sleep(0.01) | |
| 136 | + assert tracker.total_duration > 0 | |
| 137 | + | |
| 138 | + def test_format_summary_empty(self): | |
| 139 | + tracker = UsageTracker() | |
| 140 | + summary = tracker.format_summary() | |
| 141 | + assert "PROCESSING SUMMARY" in summary | |
| 142 | + assert "Total time" in summary | |
| 143 | + | |
| 144 | + def test_format_summary_with_usage(self): | |
| 145 | + tracker = UsageTracker() | |
| 146 | + tracker.record("openai", "gpt-4o", input_tokens=1000, output_tokens=500) | |
| 147 | + tracker.start_step("Analysis") | |
| 148 | + tracker.end_step() | |
| 149 | + | |
| 150 | + summary = tracker.format_summary() | |
| 151 | + assert "API Calls" in summary | |
| 152 | + assert "Tokens" in summary | |
| 153 | + assert "gpt-4o" in summary | |
| 154 | + assert "Analysis" in summary | |
| 155 | + | |
| 156 | + def test_format_summary_with_audio(self): | |
| 157 | + tracker = UsageTracker() | |
| 158 | + tracker.record("openai", "whisper-1", audio_minutes=5.0) | |
| 159 | + summary = tracker.format_summary() | |
| 160 | + assert "whisper" in summary | |
| 161 | + assert "5.0m" in summary | |
| 162 | + | |
| 163 | + def test_format_summary_cost_display(self): | |
| 164 | + tracker = UsageTracker() | |
| 165 | + tracker.record("openai", "gpt-4o", input_tokens=1_000_000, output_tokens=500_000) | |
| 166 | + summary = tracker.format_summary() | |
| 167 | + assert "Estimated total cost: $" in summary | |
| 168 | + | |
| 169 | + def test_format_summary_step_percentages(self): | |
| 170 | + tracker = UsageTracker() | |
| 171 | + # Manually create steps with known timings | |
| 172 | + tracker._steps = [ | |
| 173 | + StepTiming(name="Step A", start_time=0.0, end_time=1.0), | |
| 174 | + StepTiming(name="Step B", start_time=1.0, end_time=3.0), | |
| 175 | + ] | |
| 176 | + summary = tracker.format_summary() | |
| 177 | + assert "Step A" in summary | |
| 178 | + assert "Step B" in summary | |
| 179 | + assert "%" in summary | |
| 180 | + | |
| 181 | + | |
| 182 | +class TestFmtDuration: | |
| 183 | + def test_seconds(self): | |
| 184 | + assert _fmt_duration(5.3) == "5.3s" | |
| 185 | + | |
| 186 | + def test_minutes(self): | |
| 187 | + result = _fmt_duration(90.0) | |
| 188 | + assert result == "1m 30s" | |
| 189 | + | |
| 190 | + def test_hours(self): | |
| 191 | + result = _fmt_duration(3661.0) | |
| 192 | + assert result == "1h 1m 1s" | |
| 193 | + | |
| 194 | + def test_zero(self): | |
| 195 | + assert _fmt_duration(0.0) == "0.0s" | |
| 196 | + | |
| 197 | + def test_just_under_minute(self): | |
| 198 | + assert _fmt_duration(59.9) == "59.9s" |
| --- a/tests/test_usage_tracker.py | |
| +++ b/tests/test_usage_tracker.py | |
| @@ -0,0 +1,198 @@ | |
| --- a/tests/test_usage_tracker.py | |
| +++ b/tests/test_usage_tracker.py | |
| @@ -0,0 +1,198 @@ | |
| 1 | """Tests for the UsageTracker class.""" |
| 2 | |
| 3 | import time |
| 4 | |
| 5 | from video_processor.utils.usage_tracker import ModelUsage, StepTiming, UsageTracker, _fmt_duration |
| 6 | |
| 7 | |
| 8 | class TestModelUsage: |
| 9 | def test_total_tokens(self): |
| 10 | mu = ModelUsage(provider="openai", model="gpt-4o", input_tokens=100, output_tokens=50) |
| 11 | assert mu.total_tokens == 150 |
| 12 | |
| 13 | def test_estimated_cost_known_model(self): |
| 14 | mu = ModelUsage( |
| 15 | provider="openai", |
| 16 | model="gpt-4o", |
| 17 | input_tokens=1_000_000, |
| 18 | output_tokens=500_000, |
| 19 | ) |
| 20 | # gpt-4o: input $2.50/M, output $10.00/M |
| 21 | expected = 1_000_000 * 2.50 / 1_000_000 + 500_000 * 10.00 / 1_000_000 |
| 22 | assert abs(mu.estimated_cost - expected) < 0.001 |
| 23 | |
| 24 | def test_estimated_cost_unknown_model(self): |
| 25 | mu = ModelUsage( |
| 26 | provider="local", |
| 27 | model="my-custom-model", |
| 28 | input_tokens=1000, |
| 29 | output_tokens=500, |
| 30 | ) |
| 31 | assert mu.estimated_cost == 0.0 |
| 32 | |
| 33 | def test_estimated_cost_whisper(self): |
| 34 | mu = ModelUsage( |
| 35 | provider="openai", |
| 36 | model="whisper-1", |
| 37 | audio_minutes=10.0, |
| 38 | ) |
| 39 | # whisper-1: $0.006/min |
| 40 | assert abs(mu.estimated_cost - 0.06) < 0.001 |
| 41 | |
| 42 | def test_estimated_cost_partial_match(self): |
| 43 | mu = ModelUsage( |
| 44 | provider="openai", |
| 45 | model="gpt-4o-2024-08-06", |
| 46 | input_tokens=1_000_000, |
| 47 | output_tokens=0, |
| 48 | ) |
| 49 | # Should partial-match to gpt-4o |
| 50 | assert mu.estimated_cost > 0 |
| 51 | |
| 52 | def test_calls_default_zero(self): |
| 53 | mu = ModelUsage() |
| 54 | assert mu.calls == 0 |
| 55 | assert mu.total_tokens == 0 |
| 56 | assert mu.estimated_cost == 0.0 |
| 57 | |
| 58 | |
| 59 | class TestStepTiming: |
| 60 | def test_duration_with_times(self): |
| 61 | st = StepTiming(name="test", start_time=100.0, end_time=105.5) |
| 62 | assert abs(st.duration - 5.5) < 0.001 |
| 63 | |
| 64 | def test_duration_no_end_time(self): |
| 65 | st = StepTiming(name="test", start_time=100.0) |
| 66 | assert st.duration == 0.0 |
| 67 | |
| 68 | def test_duration_no_start_time(self): |
| 69 | st = StepTiming(name="test") |
| 70 | assert st.duration == 0.0 |
| 71 | |
| 72 | |
| 73 | class TestUsageTracker: |
| 74 | def test_record_single_call(self): |
| 75 | tracker = UsageTracker() |
| 76 | tracker.record("openai", "gpt-4o", input_tokens=500, output_tokens=200) |
| 77 | assert tracker.total_api_calls == 1 |
| 78 | assert tracker.total_input_tokens == 500 |
| 79 | assert tracker.total_output_tokens == 200 |
| 80 | assert tracker.total_tokens == 700 |
| 81 | |
| 82 | def test_record_multiple_calls_same_model(self): |
| 83 | tracker = UsageTracker() |
| 84 | tracker.record("openai", "gpt-4o", input_tokens=100, output_tokens=50) |
| 85 | tracker.record("openai", "gpt-4o", input_tokens=200, output_tokens=100) |
| 86 | assert tracker.total_api_calls == 2 |
| 87 | assert tracker.total_input_tokens == 300 |
| 88 | assert tracker.total_output_tokens == 150 |
| 89 | |
| 90 | def test_record_multiple_models(self): |
| 91 | tracker = UsageTracker() |
| 92 | tracker.record("openai", "gpt-4o", input_tokens=100, output_tokens=50) |
| 93 | tracker.record( |
| 94 | "anthropic", "claude-sonnet-4-5-20250929", input_tokens=200, output_tokens=100 |
| 95 | ) |
| 96 | assert tracker.total_api_calls == 2 |
| 97 | assert tracker.total_input_tokens == 300 |
| 98 | assert len(tracker._models) == 2 |
| 99 | |
| 100 | def test_total_cost(self): |
| 101 | tracker = UsageTracker() |
| 102 | tracker.record("openai", "gpt-4o", input_tokens=1_000_000, output_tokens=500_000) |
| 103 | cost = tracker.total_cost |
| 104 | assert cost > 0 |
| 105 | |
| 106 | def test_start_and_end_step(self): |
| 107 | tracker = UsageTracker() |
| 108 | tracker.start_step("Frame extraction") |
| 109 | time.sleep(0.01) |
| 110 | tracker.end_step() |
| 111 | |
| 112 | assert len(tracker._steps) == 1 |
| 113 | assert tracker._steps[0].name == "Frame extraction" |
| 114 | assert tracker._steps[0].duration > 0 |
| 115 | |
| 116 | def test_start_step_auto_closes_previous(self): |
| 117 | tracker = UsageTracker() |
| 118 | tracker.start_step("Step 1") |
| 119 | time.sleep(0.01) |
| 120 | tracker.start_step("Step 2") |
| 121 | # Step 1 should have been auto-closed |
| 122 | assert len(tracker._steps) == 1 |
| 123 | assert tracker._steps[0].name == "Step 1" |
| 124 | assert tracker._steps[0].duration > 0 |
| 125 | # Step 2 is current |
| 126 | assert tracker._current_step.name == "Step 2" |
| 127 | |
| 128 | def test_end_step_when_none(self): |
| 129 | tracker = UsageTracker() |
| 130 | tracker.end_step() # Should not raise |
| 131 | assert len(tracker._steps) == 0 |
| 132 | |
| 133 | def test_total_duration(self): |
| 134 | tracker = UsageTracker() |
| 135 | time.sleep(0.01) |
| 136 | assert tracker.total_duration > 0 |
| 137 | |
| 138 | def test_format_summary_empty(self): |
| 139 | tracker = UsageTracker() |
| 140 | summary = tracker.format_summary() |
| 141 | assert "PROCESSING SUMMARY" in summary |
| 142 | assert "Total time" in summary |
| 143 | |
| 144 | def test_format_summary_with_usage(self): |
| 145 | tracker = UsageTracker() |
| 146 | tracker.record("openai", "gpt-4o", input_tokens=1000, output_tokens=500) |
| 147 | tracker.start_step("Analysis") |
| 148 | tracker.end_step() |
| 149 | |
| 150 | summary = tracker.format_summary() |
| 151 | assert "API Calls" in summary |
| 152 | assert "Tokens" in summary |
| 153 | assert "gpt-4o" in summary |
| 154 | assert "Analysis" in summary |
| 155 | |
| 156 | def test_format_summary_with_audio(self): |
| 157 | tracker = UsageTracker() |
| 158 | tracker.record("openai", "whisper-1", audio_minutes=5.0) |
| 159 | summary = tracker.format_summary() |
| 160 | assert "whisper" in summary |
| 161 | assert "5.0m" in summary |
| 162 | |
| 163 | def test_format_summary_cost_display(self): |
| 164 | tracker = UsageTracker() |
| 165 | tracker.record("openai", "gpt-4o", input_tokens=1_000_000, output_tokens=500_000) |
| 166 | summary = tracker.format_summary() |
| 167 | assert "Estimated total cost: $" in summary |
| 168 | |
| 169 | def test_format_summary_step_percentages(self): |
| 170 | tracker = UsageTracker() |
| 171 | # Manually create steps with known timings |
| 172 | tracker._steps = [ |
| 173 | StepTiming(name="Step A", start_time=0.0, end_time=1.0), |
| 174 | StepTiming(name="Step B", start_time=1.0, end_time=3.0), |
| 175 | ] |
| 176 | summary = tracker.format_summary() |
| 177 | assert "Step A" in summary |
| 178 | assert "Step B" in summary |
| 179 | assert "%" in summary |
| 180 | |
| 181 | |
| 182 | class TestFmtDuration: |
| 183 | def test_seconds(self): |
| 184 | assert _fmt_duration(5.3) == "5.3s" |
| 185 | |
| 186 | def test_minutes(self): |
| 187 | result = _fmt_duration(90.0) |
| 188 | assert result == "1m 30s" |
| 189 | |
| 190 | def test_hours(self): |
| 191 | result = _fmt_duration(3661.0) |
| 192 | assert result == "1h 1m 1s" |
| 193 | |
| 194 | def test_zero(self): |
| 195 | assert _fmt_duration(0.0) == "0.0s" |
| 196 | |
| 197 | def test_just_under_minute(self): |
| 198 | assert _fmt_duration(59.9) == "59.9s" |
| --- video_processor/integrators/graph_store.py | ||
| +++ video_processor/integrators/graph_store.py | ||
| @@ -188,10 +188,12 @@ | ||
| 188 | 188 | ) -> None: |
| 189 | 189 | key = name.lower() |
| 190 | 190 | if key in self._nodes: |
| 191 | 191 | if descriptions: |
| 192 | 192 | self._nodes[key]["descriptions"].update(descriptions) |
| 193 | + if entity_type and entity_type != self._nodes[key]["type"]: | |
| 194 | + self._nodes[key]["type"] = entity_type | |
| 193 | 195 | else: |
| 194 | 196 | self._nodes[key] = { |
| 195 | 197 | "id": name, |
| 196 | 198 | "name": name, |
| 197 | 199 | "type": entity_type, |
| @@ -411,12 +413,12 @@ | ||
| 411 | 413 | |
| 412 | 414 | if row: |
| 413 | 415 | existing = json.loads(row[0]) |
| 414 | 416 | merged = list(set(existing + descriptions)) |
| 415 | 417 | self._conn.execute( |
| 416 | - "UPDATE entities SET descriptions = ? WHERE name_lower = ?", | |
| 417 | - (json.dumps(merged), name_lower), | |
| 418 | + "UPDATE entities SET descriptions = ?, type = ? WHERE name_lower = ?", | |
| 419 | + (json.dumps(merged), entity_type, name_lower), | |
| 418 | 420 | ) |
| 419 | 421 | else: |
| 420 | 422 | self._conn.execute( |
| 421 | 423 | "INSERT INTO entities (name, name_lower, type, descriptions, source) " |
| 422 | 424 | "VALUES (?, ?, ?, ?, ?)", |
| 423 | 425 |
| --- video_processor/integrators/graph_store.py | |
| +++ video_processor/integrators/graph_store.py | |
| @@ -188,10 +188,12 @@ | |
| 188 | ) -> None: |
| 189 | key = name.lower() |
| 190 | if key in self._nodes: |
| 191 | if descriptions: |
| 192 | self._nodes[key]["descriptions"].update(descriptions) |
| 193 | else: |
| 194 | self._nodes[key] = { |
| 195 | "id": name, |
| 196 | "name": name, |
| 197 | "type": entity_type, |
| @@ -411,12 +413,12 @@ | |
| 411 | |
| 412 | if row: |
| 413 | existing = json.loads(row[0]) |
| 414 | merged = list(set(existing + descriptions)) |
| 415 | self._conn.execute( |
| 416 | "UPDATE entities SET descriptions = ? WHERE name_lower = ?", |
| 417 | (json.dumps(merged), name_lower), |
| 418 | ) |
| 419 | else: |
| 420 | self._conn.execute( |
| 421 | "INSERT INTO entities (name, name_lower, type, descriptions, source) " |
| 422 | "VALUES (?, ?, ?, ?, ?)", |
| 423 |
| --- video_processor/integrators/graph_store.py | |
| +++ video_processor/integrators/graph_store.py | |
| @@ -188,10 +188,12 @@ | |
| 188 | ) -> None: |
| 189 | key = name.lower() |
| 190 | if key in self._nodes: |
| 191 | if descriptions: |
| 192 | self._nodes[key]["descriptions"].update(descriptions) |
| 193 | if entity_type and entity_type != self._nodes[key]["type"]: |
| 194 | self._nodes[key]["type"] = entity_type |
| 195 | else: |
| 196 | self._nodes[key] = { |
| 197 | "id": name, |
| 198 | "name": name, |
| 199 | "type": entity_type, |
| @@ -411,12 +413,12 @@ | |
| 413 | |
| 414 | if row: |
| 415 | existing = json.loads(row[0]) |
| 416 | merged = list(set(existing + descriptions)) |
| 417 | self._conn.execute( |
| 418 | "UPDATE entities SET descriptions = ?, type = ? WHERE name_lower = ?", |
| 419 | (json.dumps(merged), entity_type, name_lower), |
| 420 | ) |
| 421 | else: |
| 422 | self._conn.execute( |
| 423 | "INSERT INTO entities (name, name_lower, type, descriptions, source) " |
| 424 | "VALUES (?, ?, ?, ?, ?)", |
| 425 |