|
1
|
package oracle_test |
|
2
|
|
|
3
|
import ( |
|
4
|
"context" |
|
5
|
"errors" |
|
6
|
"testing" |
|
7
|
|
|
8
|
"github.com/conflicthq/scuttlebot/internal/bots/oracle" |
|
9
|
) |
|
10
|
|
|
11
|
// --- mock history --- |
|
12
|
|
|
13
|
type mockHistory struct { |
|
14
|
entries map[string][]oracle.HistoryEntry |
|
15
|
} |
|
16
|
|
|
17
|
func (m *mockHistory) Query(channel string, limit int) ([]oracle.HistoryEntry, error) { |
|
18
|
entries := m.entries[channel] |
|
19
|
if len(entries) > limit { |
|
20
|
entries = entries[:limit] |
|
21
|
} |
|
22
|
return entries, nil |
|
23
|
} |
|
24
|
|
|
25
|
func newHistory(channel string, entries []oracle.HistoryEntry) *mockHistory { |
|
26
|
return &mockHistory{entries: map[string][]oracle.HistoryEntry{channel: entries}} |
|
27
|
} |
|
28
|
|
|
29
|
// --- ParseCommand tests --- |
|
30
|
|
|
31
|
func TestParseCommandValid(t *testing.T) { |
|
32
|
tests := []struct { |
|
33
|
input string |
|
34
|
channel string |
|
35
|
limit int |
|
36
|
format oracle.Format |
|
37
|
}{ |
|
38
|
{"summarize #fleet", "#fleet", 50, oracle.FormatTOON}, |
|
39
|
{"summarize #fleet last=20", "#fleet", 20, oracle.FormatTOON}, |
|
40
|
{"summarize #fleet last=100 format=json", "#fleet", 100, oracle.FormatJSON}, |
|
41
|
{"summarize #project.test format=toon last=10", "#project.test", 10, oracle.FormatTOON}, |
|
42
|
} |
|
43
|
for _, tt := range tests { |
|
44
|
t.Run(tt.input, func(t *testing.T) { |
|
45
|
req, err := oracle.ParseCommand(tt.input) |
|
46
|
if err != nil { |
|
47
|
t.Fatalf("unexpected error: %v", err) |
|
48
|
} |
|
49
|
if req.Channel != tt.channel { |
|
50
|
t.Errorf("Channel: got %q, want %q", req.Channel, tt.channel) |
|
51
|
} |
|
52
|
if req.Limit != tt.limit { |
|
53
|
t.Errorf("Limit: got %d, want %d", req.Limit, tt.limit) |
|
54
|
} |
|
55
|
if req.Format != tt.format { |
|
56
|
t.Errorf("Format: got %q, want %q", req.Format, tt.format) |
|
57
|
} |
|
58
|
}) |
|
59
|
} |
|
60
|
} |
|
61
|
|
|
62
|
func TestParseCommandInvalid(t *testing.T) { |
|
63
|
tests := []struct { |
|
64
|
input string |
|
65
|
}{ |
|
66
|
{"summarize"}, // missing channel |
|
67
|
{""}, // empty |
|
68
|
{"do-something #fleet"}, // unknown command |
|
69
|
{"summarize fleet"}, // missing # |
|
70
|
{"summarize #fleet last=notanumber"}, // bad last |
|
71
|
{"summarize #fleet format=xml"}, // unknown format |
|
72
|
{"summarize #fleet last=-5"}, // negative |
|
73
|
} |
|
74
|
for _, tt := range tests { |
|
75
|
t.Run(tt.input, func(t *testing.T) { |
|
76
|
if _, err := oracle.ParseCommand(tt.input); err == nil { |
|
77
|
t.Errorf("expected error for %q, got nil", tt.input) |
|
78
|
} |
|
79
|
}) |
|
80
|
} |
|
81
|
} |
|
82
|
|
|
83
|
func TestParseCommandLimitCap(t *testing.T) { |
|
84
|
req, err := oracle.ParseCommand("summarize #fleet last=9999") |
|
85
|
if err != nil { |
|
86
|
t.Fatalf("unexpected error: %v", err) |
|
87
|
} |
|
88
|
if req.Limit > 200 { |
|
89
|
t.Errorf("limit should be capped at 200, got %d", req.Limit) |
|
90
|
} |
|
91
|
} |
|
92
|
|
|
93
|
// --- Bot construction --- |
|
94
|
|
|
95
|
func TestBotName(t *testing.T) { |
|
96
|
b := oracle.New("localhost:6667", "pass", nil, |
|
97
|
newHistory("#fleet", nil), |
|
98
|
&oracle.StubProvider{Response: "summary"}, |
|
99
|
nil, |
|
100
|
) |
|
101
|
if b.Name() != "oracle" { |
|
102
|
t.Errorf("Name(): got %q", b.Name()) |
|
103
|
} |
|
104
|
} |
|
105
|
|
|
106
|
// --- StubProvider --- |
|
107
|
|
|
108
|
func TestStubProviderReturnsResponse(t *testing.T) { |
|
109
|
p := &oracle.StubProvider{Response: "the fleet is idle"} |
|
110
|
summary, err := p.Summarize(context.TODO(), "prompt") |
|
111
|
if err != nil { |
|
112
|
t.Fatalf("unexpected error: %v", err) |
|
113
|
} |
|
114
|
if summary != "the fleet is idle" { |
|
115
|
t.Errorf("got %q", summary) |
|
116
|
} |
|
117
|
} |
|
118
|
|
|
119
|
func TestStubProviderReturnsError(t *testing.T) { |
|
120
|
p := &oracle.StubProvider{Err: errors.New("llm unavailable")} |
|
121
|
_, err := p.Summarize(context.TODO(), "prompt") |
|
122
|
if err == nil { |
|
123
|
t.Error("expected error") |
|
124
|
} |
|
125
|
} |
|
126
|
|
|
127
|
// --- HistoryFetcher --- |
|
128
|
|
|
129
|
func TestHistoryFetcherReturnsEntries(t *testing.T) { |
|
130
|
h := newHistory("#fleet", []oracle.HistoryEntry{ |
|
131
|
{Nick: "agent-01", MessageType: "task.create", Raw: `{"v":1}`}, |
|
132
|
{Nick: "human", Raw: "looks good"}, |
|
133
|
}) |
|
134
|
entries, err := h.Query("#fleet", 10) |
|
135
|
if err != nil { |
|
136
|
t.Fatalf("unexpected error: %v", err) |
|
137
|
} |
|
138
|
if len(entries) != 2 { |
|
139
|
t.Errorf("expected 2 entries, got %d", len(entries)) |
|
140
|
} |
|
141
|
} |
|
142
|
|
|
143
|
func TestHistoryFetcherEmptyChannel(t *testing.T) { |
|
144
|
h := newHistory("#fleet", nil) |
|
145
|
entries, err := h.Query("#empty", 10) |
|
146
|
if err != nil { |
|
147
|
t.Fatalf("unexpected error: %v", err) |
|
148
|
} |
|
149
|
if len(entries) != 0 { |
|
150
|
t.Errorf("expected 0 entries, got %d", len(entries)) |
|
151
|
} |
|
152
|
} |
|
153
|
|
|
154
|
func TestHistoryFetcherLimitRespected(t *testing.T) { |
|
155
|
entries := make([]oracle.HistoryEntry, 100) |
|
156
|
for i := range entries { |
|
157
|
entries[i] = oracle.HistoryEntry{Nick: "a", Raw: "msg"} |
|
158
|
} |
|
159
|
h := newHistory("#fleet", entries) |
|
160
|
got, _ := h.Query("#fleet", 10) |
|
161
|
if len(got) != 10 { |
|
162
|
t.Errorf("expected 10 entries (limit), got %d", len(got)) |
|
163
|
} |
|
164
|
} |
|
165
|
|