|
49f7ee2…
|
lmata
|
1 |
package oracle_test |
|
49f7ee2…
|
lmata
|
2 |
|
|
49f7ee2…
|
lmata
|
3 |
import ( |
|
0e244d2…
|
lmata
|
4 |
"context" |
|
49f7ee2…
|
lmata
|
5 |
"errors" |
|
49f7ee2…
|
lmata
|
6 |
"testing" |
|
49f7ee2…
|
lmata
|
7 |
|
|
49f7ee2…
|
lmata
|
8 |
"github.com/conflicthq/scuttlebot/internal/bots/oracle" |
|
49f7ee2…
|
lmata
|
9 |
) |
|
49f7ee2…
|
lmata
|
10 |
|
|
49f7ee2…
|
lmata
|
11 |
// --- mock history --- |
|
49f7ee2…
|
lmata
|
12 |
|
|
49f7ee2…
|
lmata
|
13 |
type mockHistory struct { |
|
49f7ee2…
|
lmata
|
14 |
entries map[string][]oracle.HistoryEntry |
|
49f7ee2…
|
lmata
|
15 |
} |
|
49f7ee2…
|
lmata
|
16 |
|
|
49f7ee2…
|
lmata
|
17 |
func (m *mockHistory) Query(channel string, limit int) ([]oracle.HistoryEntry, error) { |
|
49f7ee2…
|
lmata
|
18 |
entries := m.entries[channel] |
|
49f7ee2…
|
lmata
|
19 |
if len(entries) > limit { |
|
49f7ee2…
|
lmata
|
20 |
entries = entries[:limit] |
|
49f7ee2…
|
lmata
|
21 |
} |
|
49f7ee2…
|
lmata
|
22 |
return entries, nil |
|
49f7ee2…
|
lmata
|
23 |
} |
|
49f7ee2…
|
lmata
|
24 |
|
|
49f7ee2…
|
lmata
|
25 |
func newHistory(channel string, entries []oracle.HistoryEntry) *mockHistory { |
|
49f7ee2…
|
lmata
|
26 |
return &mockHistory{entries: map[string][]oracle.HistoryEntry{channel: entries}} |
|
49f7ee2…
|
lmata
|
27 |
} |
|
49f7ee2…
|
lmata
|
28 |
|
|
49f7ee2…
|
lmata
|
29 |
// --- ParseCommand tests --- |
|
49f7ee2…
|
lmata
|
30 |
|
|
49f7ee2…
|
lmata
|
31 |
func TestParseCommandValid(t *testing.T) { |
|
49f7ee2…
|
lmata
|
32 |
tests := []struct { |
|
49f7ee2…
|
lmata
|
33 |
input string |
|
49f7ee2…
|
lmata
|
34 |
channel string |
|
49f7ee2…
|
lmata
|
35 |
limit int |
|
49f7ee2…
|
lmata
|
36 |
format oracle.Format |
|
49f7ee2…
|
lmata
|
37 |
}{ |
|
49f7ee2…
|
lmata
|
38 |
{"summarize #fleet", "#fleet", 50, oracle.FormatTOON}, |
|
49f7ee2…
|
lmata
|
39 |
{"summarize #fleet last=20", "#fleet", 20, oracle.FormatTOON}, |
|
49f7ee2…
|
lmata
|
40 |
{"summarize #fleet last=100 format=json", "#fleet", 100, oracle.FormatJSON}, |
|
49f7ee2…
|
lmata
|
41 |
{"summarize #project.test format=toon last=10", "#project.test", 10, oracle.FormatTOON}, |
|
49f7ee2…
|
lmata
|
42 |
} |
|
49f7ee2…
|
lmata
|
43 |
for _, tt := range tests { |
|
49f7ee2…
|
lmata
|
44 |
t.Run(tt.input, func(t *testing.T) { |
|
49f7ee2…
|
lmata
|
45 |
req, err := oracle.ParseCommand(tt.input) |
|
49f7ee2…
|
lmata
|
46 |
if err != nil { |
|
49f7ee2…
|
lmata
|
47 |
t.Fatalf("unexpected error: %v", err) |
|
49f7ee2…
|
lmata
|
48 |
} |
|
49f7ee2…
|
lmata
|
49 |
if req.Channel != tt.channel { |
|
49f7ee2…
|
lmata
|
50 |
t.Errorf("Channel: got %q, want %q", req.Channel, tt.channel) |
|
49f7ee2…
|
lmata
|
51 |
} |
|
49f7ee2…
|
lmata
|
52 |
if req.Limit != tt.limit { |
|
49f7ee2…
|
lmata
|
53 |
t.Errorf("Limit: got %d, want %d", req.Limit, tt.limit) |
|
49f7ee2…
|
lmata
|
54 |
} |
|
49f7ee2…
|
lmata
|
55 |
if req.Format != tt.format { |
|
49f7ee2…
|
lmata
|
56 |
t.Errorf("Format: got %q, want %q", req.Format, tt.format) |
|
49f7ee2…
|
lmata
|
57 |
} |
|
49f7ee2…
|
lmata
|
58 |
}) |
|
49f7ee2…
|
lmata
|
59 |
} |
|
49f7ee2…
|
lmata
|
60 |
} |
|
49f7ee2…
|
lmata
|
61 |
|
|
49f7ee2…
|
lmata
|
62 |
func TestParseCommandInvalid(t *testing.T) { |
|
49f7ee2…
|
lmata
|
63 |
tests := []struct { |
|
49f7ee2…
|
lmata
|
64 |
input string |
|
49f7ee2…
|
lmata
|
65 |
}{ |
|
1066004…
|
lmata
|
66 |
{"summarize"}, // missing channel |
|
1066004…
|
lmata
|
67 |
{""}, // empty |
|
1066004…
|
lmata
|
68 |
{"do-something #fleet"}, // unknown command |
|
1066004…
|
lmata
|
69 |
{"summarize fleet"}, // missing # |
|
1066004…
|
lmata
|
70 |
{"summarize #fleet last=notanumber"}, // bad last |
|
1066004…
|
lmata
|
71 |
{"summarize #fleet format=xml"}, // unknown format |
|
1066004…
|
lmata
|
72 |
{"summarize #fleet last=-5"}, // negative |
|
49f7ee2…
|
lmata
|
73 |
} |
|
49f7ee2…
|
lmata
|
74 |
for _, tt := range tests { |
|
49f7ee2…
|
lmata
|
75 |
t.Run(tt.input, func(t *testing.T) { |
|
49f7ee2…
|
lmata
|
76 |
if _, err := oracle.ParseCommand(tt.input); err == nil { |
|
49f7ee2…
|
lmata
|
77 |
t.Errorf("expected error for %q, got nil", tt.input) |
|
49f7ee2…
|
lmata
|
78 |
} |
|
49f7ee2…
|
lmata
|
79 |
}) |
|
49f7ee2…
|
lmata
|
80 |
} |
|
49f7ee2…
|
lmata
|
81 |
} |
|
49f7ee2…
|
lmata
|
82 |
|
|
49f7ee2…
|
lmata
|
83 |
func TestParseCommandLimitCap(t *testing.T) { |
|
49f7ee2…
|
lmata
|
84 |
req, err := oracle.ParseCommand("summarize #fleet last=9999") |
|
49f7ee2…
|
lmata
|
85 |
if err != nil { |
|
49f7ee2…
|
lmata
|
86 |
t.Fatalf("unexpected error: %v", err) |
|
49f7ee2…
|
lmata
|
87 |
} |
|
49f7ee2…
|
lmata
|
88 |
if req.Limit > 200 { |
|
49f7ee2…
|
lmata
|
89 |
t.Errorf("limit should be capped at 200, got %d", req.Limit) |
|
49f7ee2…
|
lmata
|
90 |
} |
|
49f7ee2…
|
lmata
|
91 |
} |
|
49f7ee2…
|
lmata
|
92 |
|
|
49f7ee2…
|
lmata
|
93 |
// --- Bot construction --- |
|
49f7ee2…
|
lmata
|
94 |
|
|
49f7ee2…
|
lmata
|
95 |
func TestBotName(t *testing.T) { |
|
3420a83…
|
lmata
|
96 |
b := oracle.New("localhost:6667", "pass", nil, |
|
49f7ee2…
|
lmata
|
97 |
newHistory("#fleet", nil), |
|
49f7ee2…
|
lmata
|
98 |
&oracle.StubProvider{Response: "summary"}, |
|
49f7ee2…
|
lmata
|
99 |
nil, |
|
49f7ee2…
|
lmata
|
100 |
) |
|
49f7ee2…
|
lmata
|
101 |
if b.Name() != "oracle" { |
|
49f7ee2…
|
lmata
|
102 |
t.Errorf("Name(): got %q", b.Name()) |
|
49f7ee2…
|
lmata
|
103 |
} |
|
49f7ee2…
|
lmata
|
104 |
} |
|
49f7ee2…
|
lmata
|
105 |
|
|
49f7ee2…
|
lmata
|
106 |
// --- StubProvider --- |
|
49f7ee2…
|
lmata
|
107 |
|
|
49f7ee2…
|
lmata
|
108 |
func TestStubProviderReturnsResponse(t *testing.T) { |
|
49f7ee2…
|
lmata
|
109 |
p := &oracle.StubProvider{Response: "the fleet is idle"} |
|
0e244d2…
|
lmata
|
110 |
summary, err := p.Summarize(context.TODO(), "prompt") |
|
49f7ee2…
|
lmata
|
111 |
if err != nil { |
|
49f7ee2…
|
lmata
|
112 |
t.Fatalf("unexpected error: %v", err) |
|
49f7ee2…
|
lmata
|
113 |
} |
|
49f7ee2…
|
lmata
|
114 |
if summary != "the fleet is idle" { |
|
49f7ee2…
|
lmata
|
115 |
t.Errorf("got %q", summary) |
|
49f7ee2…
|
lmata
|
116 |
} |
|
49f7ee2…
|
lmata
|
117 |
} |
|
49f7ee2…
|
lmata
|
118 |
|
|
49f7ee2…
|
lmata
|
119 |
func TestStubProviderReturnsError(t *testing.T) { |
|
49f7ee2…
|
lmata
|
120 |
p := &oracle.StubProvider{Err: errors.New("llm unavailable")} |
|
0e244d2…
|
lmata
|
121 |
_, err := p.Summarize(context.TODO(), "prompt") |
|
49f7ee2…
|
lmata
|
122 |
if err == nil { |
|
49f7ee2…
|
lmata
|
123 |
t.Error("expected error") |
|
49f7ee2…
|
lmata
|
124 |
} |
|
49f7ee2…
|
lmata
|
125 |
} |
|
49f7ee2…
|
lmata
|
126 |
|
|
49f7ee2…
|
lmata
|
127 |
// --- HistoryFetcher --- |
|
49f7ee2…
|
lmata
|
128 |
|
|
49f7ee2…
|
lmata
|
129 |
func TestHistoryFetcherReturnsEntries(t *testing.T) { |
|
49f7ee2…
|
lmata
|
130 |
h := newHistory("#fleet", []oracle.HistoryEntry{ |
|
49f7ee2…
|
lmata
|
131 |
{Nick: "agent-01", MessageType: "task.create", Raw: `{"v":1}`}, |
|
49f7ee2…
|
lmata
|
132 |
{Nick: "human", Raw: "looks good"}, |
|
49f7ee2…
|
lmata
|
133 |
}) |
|
49f7ee2…
|
lmata
|
134 |
entries, err := h.Query("#fleet", 10) |
|
49f7ee2…
|
lmata
|
135 |
if err != nil { |
|
49f7ee2…
|
lmata
|
136 |
t.Fatalf("unexpected error: %v", err) |
|
49f7ee2…
|
lmata
|
137 |
} |
|
49f7ee2…
|
lmata
|
138 |
if len(entries) != 2 { |
|
49f7ee2…
|
lmata
|
139 |
t.Errorf("expected 2 entries, got %d", len(entries)) |
|
49f7ee2…
|
lmata
|
140 |
} |
|
49f7ee2…
|
lmata
|
141 |
} |
|
49f7ee2…
|
lmata
|
142 |
|
|
49f7ee2…
|
lmata
|
143 |
func TestHistoryFetcherEmptyChannel(t *testing.T) { |
|
49f7ee2…
|
lmata
|
144 |
h := newHistory("#fleet", nil) |
|
49f7ee2…
|
lmata
|
145 |
entries, err := h.Query("#empty", 10) |
|
49f7ee2…
|
lmata
|
146 |
if err != nil { |
|
49f7ee2…
|
lmata
|
147 |
t.Fatalf("unexpected error: %v", err) |
|
49f7ee2…
|
lmata
|
148 |
} |
|
49f7ee2…
|
lmata
|
149 |
if len(entries) != 0 { |
|
49f7ee2…
|
lmata
|
150 |
t.Errorf("expected 0 entries, got %d", len(entries)) |
|
49f7ee2…
|
lmata
|
151 |
} |
|
49f7ee2…
|
lmata
|
152 |
} |
|
49f7ee2…
|
lmata
|
153 |
|
|
49f7ee2…
|
lmata
|
154 |
func TestHistoryFetcherLimitRespected(t *testing.T) { |
|
49f7ee2…
|
lmata
|
155 |
entries := make([]oracle.HistoryEntry, 100) |
|
49f7ee2…
|
lmata
|
156 |
for i := range entries { |
|
49f7ee2…
|
lmata
|
157 |
entries[i] = oracle.HistoryEntry{Nick: "a", Raw: "msg"} |
|
49f7ee2…
|
lmata
|
158 |
} |
|
49f7ee2…
|
lmata
|
159 |
h := newHistory("#fleet", entries) |
|
49f7ee2…
|
lmata
|
160 |
got, _ := h.Query("#fleet", 10) |
|
49f7ee2…
|
lmata
|
161 |
if len(got) != 10 { |
|
49f7ee2…
|
lmata
|
162 |
t.Errorf("expected 10 entries (limit), got %d", len(got)) |
|
49f7ee2…
|
lmata
|
163 |
} |
|
49f7ee2…
|
lmata
|
164 |
} |