ScuttleBot

scuttlebot / internal / mcp / mcp_test.go
Blame History Raw 376 lines
1
package mcp_test
2
3
import (
4
"bytes"
5
"context"
6
"encoding/json"
7
"fmt"
8
"net/http"
9
"net/http/httptest"
10
"sync"
11
"testing"
12
13
"github.com/conflicthq/scuttlebot/internal/mcp"
14
"github.com/conflicthq/scuttlebot/internal/registry"
15
"log/slog"
16
"os"
17
)
18
19
var testLog = slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
20
21
const testToken = "test-mcp-token"
22
23
// --- mocks ---
24
25
type tokenSet map[string]struct{}
26
27
func (t tokenSet) ValidToken(tok string) bool {
28
_, ok := t[tok]
29
return ok
30
}
31
32
type mockProvisioner struct {
33
mu sync.Mutex
34
accounts map[string]string
35
}
36
37
func newMock() *mockProvisioner {
38
return &mockProvisioner{accounts: make(map[string]string)}
39
}
40
41
func (m *mockProvisioner) RegisterAccount(name, pass string) error {
42
m.mu.Lock()
43
defer m.mu.Unlock()
44
if _, ok := m.accounts[name]; ok {
45
return fmt.Errorf("ACCOUNT_EXISTS")
46
}
47
m.accounts[name] = pass
48
return nil
49
}
50
51
func (m *mockProvisioner) ChangePassword(name, pass string) error {
52
m.mu.Lock()
53
defer m.mu.Unlock()
54
if _, ok := m.accounts[name]; !ok {
55
return fmt.Errorf("ACCOUNT_DOES_NOT_EXIST")
56
}
57
m.accounts[name] = pass
58
return nil
59
}
60
61
type mockChannelLister struct {
62
channels []mcp.ChannelInfo
63
}
64
65
func (m *mockChannelLister) ListChannels() ([]mcp.ChannelInfo, error) {
66
return m.channels, nil
67
}
68
69
type mockSender struct {
70
sent []string
71
}
72
73
func (m *mockSender) Send(_ context.Context, channel, msgType string, _ any) error {
74
m.sent = append(m.sent, channel+"/"+msgType)
75
return nil
76
}
77
78
type mockHistory struct {
79
entries map[string][]mcp.HistoryEntry
80
}
81
82
func (m *mockHistory) Query(channel string, limit int) ([]mcp.HistoryEntry, error) {
83
entries := m.entries[channel]
84
if len(entries) > limit {
85
entries = entries[len(entries)-limit:]
86
}
87
return entries, nil
88
}
89
90
// --- test server setup ---
91
92
func newTestServer(t *testing.T) *httptest.Server {
93
t.Helper()
94
reg := registry.New(newMock(), []byte("test-signing-key"))
95
channels := &mockChannelLister{channels: []mcp.ChannelInfo{
96
{Name: "#fleet", Topic: "main coordination", Count: 3},
97
{Name: "#task.abc", Count: 1},
98
}}
99
sender := &mockSender{}
100
hist := &mockHistory{entries: map[string][]mcp.HistoryEntry{
101
"#fleet": {
102
{Nick: "agent-01", MessageType: "task.create", MessageID: "01HX", Raw: `{"v":1}`},
103
},
104
}}
105
srv := mcp.New(reg, channels, tokenSet{testToken: {}}, testLog).
106
WithSender(sender).
107
WithHistory(hist)
108
return httptest.NewServer(srv.Handler())
109
}
110
111
func rpc(t *testing.T, srv *httptest.Server, method string, params any, token string) map[string]any {
112
t.Helper()
113
body := map[string]any{
114
"jsonrpc": "2.0",
115
"id": 1,
116
"method": method,
117
}
118
if params != nil {
119
body["params"] = params
120
}
121
data, _ := json.Marshal(body)
122
req, _ := http.NewRequest("POST", srv.URL+"/mcp", bytes.NewReader(data))
123
req.Header.Set("Content-Type", "application/json")
124
if token != "" {
125
req.Header.Set("Authorization", "Bearer "+token)
126
}
127
resp, err := http.DefaultClient.Do(req)
128
if err != nil {
129
t.Fatalf("request: %v", err)
130
}
131
defer resp.Body.Close()
132
133
var result map[string]any
134
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
135
t.Fatalf("decode: %v", err)
136
}
137
return result
138
}
139
140
// --- tests ---
141
142
func TestAuthRequired(t *testing.T) {
143
srv := newTestServer(t)
144
defer srv.Close()
145
146
resp := rpc(t, srv, "initialize", nil, "") // no token
147
if resp["error"] == nil {
148
t.Error("expected error for missing auth, got none")
149
}
150
}
151
152
func TestAuthInvalid(t *testing.T) {
153
srv := newTestServer(t)
154
defer srv.Close()
155
156
resp := rpc(t, srv, "initialize", nil, "wrong-token")
157
if resp["error"] == nil {
158
t.Error("expected error for invalid token")
159
}
160
}
161
162
func TestInitialize(t *testing.T) {
163
srv := newTestServer(t)
164
defer srv.Close()
165
166
resp := rpc(t, srv, "initialize", map[string]any{
167
"protocolVersion": "2024-11-05",
168
"capabilities": map[string]any{},
169
"clientInfo": map[string]any{"name": "test", "version": "1"},
170
}, testToken)
171
172
result, ok := resp["result"].(map[string]any)
173
if !ok {
174
t.Fatalf("no result: %v", resp)
175
}
176
if result["protocolVersion"] == nil {
177
t.Error("missing protocolVersion in initialize response")
178
}
179
}
180
181
func TestToolsList(t *testing.T) {
182
srv := newTestServer(t)
183
defer srv.Close()
184
185
resp := rpc(t, srv, "tools/list", nil, testToken)
186
result, _ := resp["result"].(map[string]any)
187
tools, _ := result["tools"].([]any)
188
if len(tools) == 0 {
189
t.Error("expected at least one tool")
190
}
191
// Check all expected tool names are present.
192
want := map[string]bool{
193
"get_status": false, "list_channels": false,
194
"register_agent": false, "send_message": false, "get_history": false,
195
}
196
for _, tool := range tools {
197
m, _ := tool.(map[string]any)
198
if name, ok := m["name"].(string); ok {
199
want[name] = true
200
}
201
}
202
for name, found := range want {
203
if !found {
204
t.Errorf("tool %q missing from tools/list", name)
205
}
206
}
207
}
208
209
func TestToolGetStatus(t *testing.T) {
210
srv := newTestServer(t)
211
defer srv.Close()
212
213
resp := rpc(t, srv, "tools/call", map[string]any{
214
"name": "get_status",
215
"arguments": map[string]any{},
216
}, testToken)
217
218
if resp["error"] != nil {
219
t.Fatalf("unexpected rpc error: %v", resp["error"])
220
}
221
result := toolText(t, resp)
222
if result == "" {
223
t.Error("expected non-empty status text")
224
}
225
}
226
227
func TestToolListChannels(t *testing.T) {
228
srv := newTestServer(t)
229
defer srv.Close()
230
231
resp := rpc(t, srv, "tools/call", map[string]any{
232
"name": "list_channels",
233
"arguments": map[string]any{},
234
}, testToken)
235
236
text := toolText(t, resp)
237
if !contains(text, "#fleet") {
238
t.Errorf("expected #fleet in channel list, got: %s", text)
239
}
240
}
241
242
func TestToolRegisterAgent(t *testing.T) {
243
srv := newTestServer(t)
244
defer srv.Close()
245
246
resp := rpc(t, srv, "tools/call", map[string]any{
247
"name": "register_agent",
248
"arguments": map[string]any{
249
"nick": "mcp-agent",
250
"type": "worker",
251
"channels": []any{"#fleet"},
252
},
253
}, testToken)
254
255
if isToolError(resp) {
256
t.Fatalf("unexpected tool error: %s", toolText(t, resp))
257
}
258
text := toolText(t, resp)
259
if !contains(text, "mcp-agent") {
260
t.Errorf("expected nick in response, got: %s", text)
261
}
262
if !contains(text, "password") {
263
t.Errorf("expected password in response, got: %s", text)
264
}
265
}
266
267
func TestToolRegisterAgentMissingNick(t *testing.T) {
268
srv := newTestServer(t)
269
defer srv.Close()
270
271
resp := rpc(t, srv, "tools/call", map[string]any{
272
"name": "register_agent",
273
"arguments": map[string]any{},
274
}, testToken)
275
276
if !isToolError(resp) {
277
t.Error("expected tool error for missing nick")
278
}
279
}
280
281
func TestToolSendMessage(t *testing.T) {
282
srv := newTestServer(t)
283
defer srv.Close()
284
285
resp := rpc(t, srv, "tools/call", map[string]any{
286
"name": "send_message",
287
"arguments": map[string]any{
288
"channel": "#fleet",
289
"type": "task.update",
290
"payload": map[string]any{"status": "done"},
291
},
292
}, testToken)
293
294
if isToolError(resp) {
295
t.Fatalf("unexpected tool error: %s", toolText(t, resp))
296
}
297
}
298
299
func TestToolGetHistory(t *testing.T) {
300
srv := newTestServer(t)
301
defer srv.Close()
302
303
resp := rpc(t, srv, "tools/call", map[string]any{
304
"name": "get_history",
305
"arguments": map[string]any{
306
"channel": "#fleet",
307
"limit": float64(10),
308
},
309
}, testToken)
310
311
if isToolError(resp) {
312
t.Fatalf("unexpected tool error: %s", toolText(t, resp))
313
}
314
text := toolText(t, resp)
315
if !contains(text, "#fleet") {
316
t.Errorf("expected #fleet in history, got: %s", text)
317
}
318
}
319
320
func TestUnknownTool(t *testing.T) {
321
srv := newTestServer(t)
322
defer srv.Close()
323
324
resp := rpc(t, srv, "tools/call", map[string]any{
325
"name": "no_such_tool",
326
"arguments": map[string]any{},
327
}, testToken)
328
329
if resp["error"] == nil {
330
t.Error("expected rpc error for unknown tool")
331
}
332
}
333
334
func TestUnknownMethod(t *testing.T) {
335
srv := newTestServer(t)
336
defer srv.Close()
337
338
resp := rpc(t, srv, "no_such_method", nil, testToken)
339
if resp["error"] == nil {
340
t.Error("expected rpc error for unknown method")
341
}
342
}
343
344
// --- helpers ---
345
346
func toolText(t *testing.T, resp map[string]any) string {
347
t.Helper()
348
result, _ := resp["result"].(map[string]any)
349
content, _ := result["content"].([]any)
350
if len(content) == 0 {
351
return ""
352
}
353
item, _ := content[0].(map[string]any)
354
text, _ := item["text"].(string)
355
return text
356
}
357
358
func isToolError(resp map[string]any) bool {
359
result, _ := resp["result"].(map[string]any)
360
isErr, _ := result["isError"].(bool)
361
return isErr
362
}
363
364
func contains(s, sub string) bool {
365
return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub))
366
}
367
368
func containsStr(s, sub string) bool {
369
for i := 0; i <= len(s)-len(sub); i++ {
370
if s[i:i+len(sub)] == sub {
371
return true
372
}
373
}
374
return false
375
}
376

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button