|
ccf32cc…
|
leo
|
1 |
"""Tests for prompt template management.""" |
|
ccf32cc…
|
leo
|
2 |
|
|
ccf32cc…
|
leo
|
3 |
from video_processor.utils.prompt_templates import ( |
|
ccf32cc…
|
leo
|
4 |
DEFAULT_TEMPLATES, |
|
ccf32cc…
|
leo
|
5 |
PromptTemplate, |
|
ccf32cc…
|
leo
|
6 |
default_prompt_manager, |
|
ccf32cc…
|
leo
|
7 |
) |
|
ccf32cc…
|
leo
|
8 |
|
|
ccf32cc…
|
leo
|
9 |
|
|
ccf32cc…
|
leo
|
10 |
class TestPromptTemplate: |
|
ccf32cc…
|
leo
|
11 |
def test_default_templates_loaded(self): |
|
ccf32cc…
|
leo
|
12 |
pm = PromptTemplate(default_templates=DEFAULT_TEMPLATES) |
|
ccf32cc…
|
leo
|
13 |
assert len(pm.templates) == 10 |
|
ccf32cc…
|
leo
|
14 |
|
|
ccf32cc…
|
leo
|
15 |
def test_all_expected_templates_exist(self): |
|
ccf32cc…
|
leo
|
16 |
expected = [ |
|
ccf32cc…
|
leo
|
17 |
"content_analysis", |
|
ccf32cc…
|
leo
|
18 |
"diagram_extraction", |
|
ccf32cc…
|
leo
|
19 |
"action_item_detection", |
|
ccf32cc…
|
leo
|
20 |
"content_summary", |
|
ccf32cc…
|
leo
|
21 |
"summary_generation", |
|
ccf32cc…
|
leo
|
22 |
"key_points_extraction", |
|
ccf32cc…
|
leo
|
23 |
"entity_extraction", |
|
ccf32cc…
|
leo
|
24 |
"relationship_extraction", |
|
ccf32cc…
|
leo
|
25 |
"diagram_analysis", |
|
ccf32cc…
|
leo
|
26 |
"mermaid_generation", |
|
ccf32cc…
|
leo
|
27 |
] |
|
ccf32cc…
|
leo
|
28 |
for name in expected: |
|
ccf32cc…
|
leo
|
29 |
assert name in DEFAULT_TEMPLATES, f"Missing template: {name}" |
|
ccf32cc…
|
leo
|
30 |
|
|
ccf32cc…
|
leo
|
31 |
def test_get_template(self): |
|
ccf32cc…
|
leo
|
32 |
pm = PromptTemplate(default_templates={"test": "Hello $name"}) |
|
ccf32cc…
|
leo
|
33 |
template = pm.get_template("test") |
|
ccf32cc…
|
leo
|
34 |
assert template is not None |
|
ccf32cc…
|
leo
|
35 |
|
|
ccf32cc…
|
leo
|
36 |
def test_get_missing_template(self): |
|
ccf32cc…
|
leo
|
37 |
pm = PromptTemplate(default_templates={}) |
|
ccf32cc…
|
leo
|
38 |
assert pm.get_template("nonexistent") is None |
|
ccf32cc…
|
leo
|
39 |
|
|
ccf32cc…
|
leo
|
40 |
def test_format_prompt(self): |
|
ccf32cc…
|
leo
|
41 |
pm = PromptTemplate(default_templates={"greet": "Hello $name, welcome to $place"}) |
|
ccf32cc…
|
leo
|
42 |
result = pm.format_prompt("greet", name="Alice", place="Wonderland") |
|
ccf32cc…
|
leo
|
43 |
assert "Alice" in result |
|
ccf32cc…
|
leo
|
44 |
assert "Wonderland" in result |
|
ccf32cc…
|
leo
|
45 |
|
|
ccf32cc…
|
leo
|
46 |
def test_format_missing_template(self): |
|
ccf32cc…
|
leo
|
47 |
pm = PromptTemplate(default_templates={}) |
|
ccf32cc…
|
leo
|
48 |
result = pm.format_prompt("nonexistent", key="value") |
|
ccf32cc…
|
leo
|
49 |
assert result is None |
|
ccf32cc…
|
leo
|
50 |
|
|
ccf32cc…
|
leo
|
51 |
def test_safe_substitute_missing_vars(self): |
|
ccf32cc…
|
leo
|
52 |
pm = PromptTemplate(default_templates={"test": "Hello $name and $other"}) |
|
ccf32cc…
|
leo
|
53 |
result = pm.format_prompt("test", name="Alice") |
|
ccf32cc…
|
leo
|
54 |
assert "Alice" in result |
|
ccf32cc…
|
leo
|
55 |
assert "$other" in result # safe_substitute keeps unresolved vars |
|
ccf32cc…
|
leo
|
56 |
|
|
ccf32cc…
|
leo
|
57 |
def test_add_template(self): |
|
ccf32cc…
|
leo
|
58 |
pm = PromptTemplate(default_templates={}) |
|
ccf32cc…
|
leo
|
59 |
pm.add_template("new", "New template: $var") |
|
ccf32cc…
|
leo
|
60 |
result = pm.format_prompt("new", var="value") |
|
ccf32cc…
|
leo
|
61 |
assert "value" in result |
|
ccf32cc…
|
leo
|
62 |
|
|
ccf32cc…
|
leo
|
63 |
def test_save_template_no_dir(self): |
|
ccf32cc…
|
leo
|
64 |
pm = PromptTemplate(default_templates={"test": "content"}) |
|
ccf32cc…
|
leo
|
65 |
assert pm.save_template("test") is False |
|
ccf32cc…
|
leo
|
66 |
|
|
ccf32cc…
|
leo
|
67 |
def test_save_template_missing_name(self): |
|
ccf32cc…
|
leo
|
68 |
pm = PromptTemplate(default_templates={}) |
|
ccf32cc…
|
leo
|
69 |
assert pm.save_template("nonexistent") is False |
|
ccf32cc…
|
leo
|
70 |
|
|
ccf32cc…
|
leo
|
71 |
def test_save_and_load_from_dir(self, tmp_path): |
|
ccf32cc…
|
leo
|
72 |
templates_dir = tmp_path / "templates" |
|
ccf32cc…
|
leo
|
73 |
templates_dir.mkdir() |
|
ccf32cc…
|
leo
|
74 |
(templates_dir / "custom.txt").write_text("Custom: $data") |
|
ccf32cc…
|
leo
|
75 |
|
|
ccf32cc…
|
leo
|
76 |
pm = PromptTemplate(templates_dir=templates_dir) |
|
ccf32cc…
|
leo
|
77 |
assert "custom" in pm.templates |
|
ccf32cc…
|
leo
|
78 |
result = pm.format_prompt("custom", data="hello") |
|
ccf32cc…
|
leo
|
79 |
assert "hello" in result |
|
ccf32cc…
|
leo
|
80 |
|
|
ccf32cc…
|
leo
|
81 |
def test_save_template_to_dir(self, tmp_path): |
|
ccf32cc…
|
leo
|
82 |
templates_dir = tmp_path / "templates" |
|
ccf32cc…
|
leo
|
83 |
pm = PromptTemplate( |
|
ccf32cc…
|
leo
|
84 |
templates_dir=templates_dir, |
|
ccf32cc…
|
leo
|
85 |
default_templates={"saveme": "Save this: $x"}, |
|
ccf32cc…
|
leo
|
86 |
) |
|
ccf32cc…
|
leo
|
87 |
result = pm.save_template("saveme") |
|
ccf32cc…
|
leo
|
88 |
assert result is True |
|
ccf32cc…
|
leo
|
89 |
assert (templates_dir / "saveme.txt").exists() |
|
ccf32cc…
|
leo
|
90 |
|
|
ccf32cc…
|
leo
|
91 |
|
|
ccf32cc…
|
leo
|
92 |
class TestDefaultPromptManager: |
|
ccf32cc…
|
leo
|
93 |
def test_is_initialized(self): |
|
ccf32cc…
|
leo
|
94 |
assert default_prompt_manager is not None |
|
ccf32cc…
|
leo
|
95 |
assert len(default_prompt_manager.templates) == 10 |
|
ccf32cc…
|
leo
|
96 |
|
|
ccf32cc…
|
leo
|
97 |
def test_entity_extraction_template_has_content_var(self): |
|
ccf32cc…
|
leo
|
98 |
result = default_prompt_manager.format_prompt( |
|
ccf32cc…
|
leo
|
99 |
"entity_extraction", content="some transcript" |
|
ccf32cc…
|
leo
|
100 |
) |
|
ccf32cc…
|
leo
|
101 |
assert "some transcript" in result |
|
ccf32cc…
|
leo
|
102 |
|
|
ccf32cc…
|
leo
|
103 |
def test_mermaid_generation_template(self): |
|
ccf32cc…
|
leo
|
104 |
result = default_prompt_manager.format_prompt( |
|
ccf32cc…
|
leo
|
105 |
"mermaid_generation", |
|
ccf32cc…
|
leo
|
106 |
diagram_type="flowchart", |
|
ccf32cc…
|
leo
|
107 |
text_content="A -> B", |
|
ccf32cc…
|
leo
|
108 |
semantic_analysis="Flow diagram", |
|
ccf32cc…
|
leo
|
109 |
) |
|
ccf32cc…
|
leo
|
110 |
assert "flowchart" in result |
|
ccf32cc…
|
leo
|
111 |
assert "A -> B" in result |