ScuttleBot

scuttlebot / internal / mcp / mcp_test.go
Source Blame History 375 lines
550b35e… lmata 1 package mcp_test
550b35e… lmata 2
550b35e… lmata 3 import (
550b35e… lmata 4 "bytes"
550b35e… lmata 5 "context"
550b35e… lmata 6 "encoding/json"
550b35e… lmata 7 "fmt"
550b35e… lmata 8 "net/http"
550b35e… lmata 9 "net/http/httptest"
550b35e… lmata 10 "sync"
550b35e… lmata 11 "testing"
550b35e… lmata 12
550b35e… lmata 13 "github.com/conflicthq/scuttlebot/internal/mcp"
550b35e… lmata 14 "github.com/conflicthq/scuttlebot/internal/registry"
550b35e… lmata 15 "log/slog"
550b35e… lmata 16 "os"
550b35e… lmata 17 )
550b35e… lmata 18
550b35e… lmata 19 var testLog = slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError}))
550b35e… lmata 20
550b35e… lmata 21 const testToken = "test-mcp-token"
550b35e… lmata 22
550b35e… lmata 23 // --- mocks ---
550b35e… lmata 24
68677f9… noreply 25 type tokenSet map[string]struct{}
68677f9… noreply 26
68677f9… noreply 27 func (t tokenSet) ValidToken(tok string) bool {
68677f9… noreply 28 _, ok := t[tok]
68677f9… noreply 29 return ok
68677f9… noreply 30 }
68677f9… noreply 31
550b35e… lmata 32 type mockProvisioner struct {
550b35e… lmata 33 mu sync.Mutex
550b35e… lmata 34 accounts map[string]string
550b35e… lmata 35 }
550b35e… lmata 36
550b35e… lmata 37 func newMock() *mockProvisioner {
550b35e… lmata 38 return &mockProvisioner{accounts: make(map[string]string)}
550b35e… lmata 39 }
550b35e… lmata 40
550b35e… lmata 41 func (m *mockProvisioner) RegisterAccount(name, pass string) error {
550b35e… lmata 42 m.mu.Lock()
550b35e… lmata 43 defer m.mu.Unlock()
550b35e… lmata 44 if _, ok := m.accounts[name]; ok {
550b35e… lmata 45 return fmt.Errorf("ACCOUNT_EXISTS")
550b35e… lmata 46 }
550b35e… lmata 47 m.accounts[name] = pass
550b35e… lmata 48 return nil
550b35e… lmata 49 }
550b35e… lmata 50
550b35e… lmata 51 func (m *mockProvisioner) ChangePassword(name, pass string) error {
550b35e… lmata 52 m.mu.Lock()
550b35e… lmata 53 defer m.mu.Unlock()
550b35e… lmata 54 if _, ok := m.accounts[name]; !ok {
550b35e… lmata 55 return fmt.Errorf("ACCOUNT_DOES_NOT_EXIST")
550b35e… lmata 56 }
550b35e… lmata 57 m.accounts[name] = pass
550b35e… lmata 58 return nil
550b35e… lmata 59 }
550b35e… lmata 60
550b35e… lmata 61 type mockChannelLister struct {
550b35e… lmata 62 channels []mcp.ChannelInfo
550b35e… lmata 63 }
550b35e… lmata 64
550b35e… lmata 65 func (m *mockChannelLister) ListChannels() ([]mcp.ChannelInfo, error) {
550b35e… lmata 66 return m.channels, nil
550b35e… lmata 67 }
550b35e… lmata 68
550b35e… lmata 69 type mockSender struct {
550b35e… lmata 70 sent []string
550b35e… lmata 71 }
550b35e… lmata 72
550b35e… lmata 73 func (m *mockSender) Send(_ context.Context, channel, msgType string, _ any) error {
550b35e… lmata 74 m.sent = append(m.sent, channel+"/"+msgType)
550b35e… lmata 75 return nil
550b35e… lmata 76 }
550b35e… lmata 77
550b35e… lmata 78 type mockHistory struct {
550b35e… lmata 79 entries map[string][]mcp.HistoryEntry
550b35e… lmata 80 }
550b35e… lmata 81
550b35e… lmata 82 func (m *mockHistory) Query(channel string, limit int) ([]mcp.HistoryEntry, error) {
550b35e… lmata 83 entries := m.entries[channel]
550b35e… lmata 84 if len(entries) > limit {
550b35e… lmata 85 entries = entries[len(entries)-limit:]
550b35e… lmata 86 }
550b35e… lmata 87 return entries, nil
550b35e… lmata 88 }
550b35e… lmata 89
550b35e… lmata 90 // --- test server setup ---
550b35e… lmata 91
550b35e… lmata 92 func newTestServer(t *testing.T) *httptest.Server {
550b35e… lmata 93 t.Helper()
550b35e… lmata 94 reg := registry.New(newMock(), []byte("test-signing-key"))
550b35e… lmata 95 channels := &mockChannelLister{channels: []mcp.ChannelInfo{
550b35e… lmata 96 {Name: "#fleet", Topic: "main coordination", Count: 3},
550b35e… lmata 97 {Name: "#task.abc", Count: 1},
550b35e… lmata 98 }}
550b35e… lmata 99 sender := &mockSender{}
550b35e… lmata 100 hist := &mockHistory{entries: map[string][]mcp.HistoryEntry{
550b35e… lmata 101 "#fleet": {
550b35e… lmata 102 {Nick: "agent-01", MessageType: "task.create", MessageID: "01HX", Raw: `{"v":1}`},
550b35e… lmata 103 },
550b35e… lmata 104 }}
68677f9… noreply 105 srv := mcp.New(reg, channels, tokenSet{testToken: {}}, testLog).
550b35e… lmata 106 WithSender(sender).
550b35e… lmata 107 WithHistory(hist)
550b35e… lmata 108 return httptest.NewServer(srv.Handler())
550b35e… lmata 109 }
550b35e… lmata 110
550b35e… lmata 111 func rpc(t *testing.T, srv *httptest.Server, method string, params any, token string) map[string]any {
550b35e… lmata 112 t.Helper()
550b35e… lmata 113 body := map[string]any{
550b35e… lmata 114 "jsonrpc": "2.0",
550b35e… lmata 115 "id": 1,
550b35e… lmata 116 "method": method,
550b35e… lmata 117 }
550b35e… lmata 118 if params != nil {
550b35e… lmata 119 body["params"] = params
550b35e… lmata 120 }
550b35e… lmata 121 data, _ := json.Marshal(body)
550b35e… lmata 122 req, _ := http.NewRequest("POST", srv.URL+"/mcp", bytes.NewReader(data))
550b35e… lmata 123 req.Header.Set("Content-Type", "application/json")
550b35e… lmata 124 if token != "" {
550b35e… lmata 125 req.Header.Set("Authorization", "Bearer "+token)
550b35e… lmata 126 }
550b35e… lmata 127 resp, err := http.DefaultClient.Do(req)
550b35e… lmata 128 if err != nil {
550b35e… lmata 129 t.Fatalf("request: %v", err)
550b35e… lmata 130 }
550b35e… lmata 131 defer resp.Body.Close()
550b35e… lmata 132
550b35e… lmata 133 var result map[string]any
550b35e… lmata 134 if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
550b35e… lmata 135 t.Fatalf("decode: %v", err)
550b35e… lmata 136 }
550b35e… lmata 137 return result
550b35e… lmata 138 }
550b35e… lmata 139
550b35e… lmata 140 // --- tests ---
550b35e… lmata 141
550b35e… lmata 142 func TestAuthRequired(t *testing.T) {
550b35e… lmata 143 srv := newTestServer(t)
550b35e… lmata 144 defer srv.Close()
550b35e… lmata 145
550b35e… lmata 146 resp := rpc(t, srv, "initialize", nil, "") // no token
550b35e… lmata 147 if resp["error"] == nil {
550b35e… lmata 148 t.Error("expected error for missing auth, got none")
550b35e… lmata 149 }
550b35e… lmata 150 }
550b35e… lmata 151
550b35e… lmata 152 func TestAuthInvalid(t *testing.T) {
550b35e… lmata 153 srv := newTestServer(t)
550b35e… lmata 154 defer srv.Close()
550b35e… lmata 155
550b35e… lmata 156 resp := rpc(t, srv, "initialize", nil, "wrong-token")
550b35e… lmata 157 if resp["error"] == nil {
550b35e… lmata 158 t.Error("expected error for invalid token")
550b35e… lmata 159 }
550b35e… lmata 160 }
550b35e… lmata 161
550b35e… lmata 162 func TestInitialize(t *testing.T) {
550b35e… lmata 163 srv := newTestServer(t)
550b35e… lmata 164 defer srv.Close()
550b35e… lmata 165
550b35e… lmata 166 resp := rpc(t, srv, "initialize", map[string]any{
550b35e… lmata 167 "protocolVersion": "2024-11-05",
550b35e… lmata 168 "capabilities": map[string]any{},
550b35e… lmata 169 "clientInfo": map[string]any{"name": "test", "version": "1"},
550b35e… lmata 170 }, testToken)
550b35e… lmata 171
550b35e… lmata 172 result, ok := resp["result"].(map[string]any)
550b35e… lmata 173 if !ok {
550b35e… lmata 174 t.Fatalf("no result: %v", resp)
550b35e… lmata 175 }
550b35e… lmata 176 if result["protocolVersion"] == nil {
550b35e… lmata 177 t.Error("missing protocolVersion in initialize response")
550b35e… lmata 178 }
550b35e… lmata 179 }
550b35e… lmata 180
550b35e… lmata 181 func TestToolsList(t *testing.T) {
550b35e… lmata 182 srv := newTestServer(t)
550b35e… lmata 183 defer srv.Close()
550b35e… lmata 184
550b35e… lmata 185 resp := rpc(t, srv, "tools/list", nil, testToken)
550b35e… lmata 186 result, _ := resp["result"].(map[string]any)
550b35e… lmata 187 tools, _ := result["tools"].([]any)
550b35e… lmata 188 if len(tools) == 0 {
550b35e… lmata 189 t.Error("expected at least one tool")
550b35e… lmata 190 }
550b35e… lmata 191 // Check all expected tool names are present.
550b35e… lmata 192 want := map[string]bool{
550b35e… lmata 193 "get_status": false, "list_channels": false,
550b35e… lmata 194 "register_agent": false, "send_message": false, "get_history": false,
550b35e… lmata 195 }
550b35e… lmata 196 for _, tool := range tools {
550b35e… lmata 197 m, _ := tool.(map[string]any)
550b35e… lmata 198 if name, ok := m["name"].(string); ok {
550b35e… lmata 199 want[name] = true
550b35e… lmata 200 }
550b35e… lmata 201 }
550b35e… lmata 202 for name, found := range want {
550b35e… lmata 203 if !found {
550b35e… lmata 204 t.Errorf("tool %q missing from tools/list", name)
550b35e… lmata 205 }
550b35e… lmata 206 }
550b35e… lmata 207 }
550b35e… lmata 208
550b35e… lmata 209 func TestToolGetStatus(t *testing.T) {
550b35e… lmata 210 srv := newTestServer(t)
550b35e… lmata 211 defer srv.Close()
550b35e… lmata 212
550b35e… lmata 213 resp := rpc(t, srv, "tools/call", map[string]any{
550b35e… lmata 214 "name": "get_status",
550b35e… lmata 215 "arguments": map[string]any{},
550b35e… lmata 216 }, testToken)
550b35e… lmata 217
550b35e… lmata 218 if resp["error"] != nil {
550b35e… lmata 219 t.Fatalf("unexpected rpc error: %v", resp["error"])
550b35e… lmata 220 }
550b35e… lmata 221 result := toolText(t, resp)
550b35e… lmata 222 if result == "" {
550b35e… lmata 223 t.Error("expected non-empty status text")
550b35e… lmata 224 }
550b35e… lmata 225 }
550b35e… lmata 226
550b35e… lmata 227 func TestToolListChannels(t *testing.T) {
550b35e… lmata 228 srv := newTestServer(t)
550b35e… lmata 229 defer srv.Close()
550b35e… lmata 230
550b35e… lmata 231 resp := rpc(t, srv, "tools/call", map[string]any{
550b35e… lmata 232 "name": "list_channels",
550b35e… lmata 233 "arguments": map[string]any{},
550b35e… lmata 234 }, testToken)
550b35e… lmata 235
550b35e… lmata 236 text := toolText(t, resp)
550b35e… lmata 237 if !contains(text, "#fleet") {
550b35e… lmata 238 t.Errorf("expected #fleet in channel list, got: %s", text)
550b35e… lmata 239 }
550b35e… lmata 240 }
550b35e… lmata 241
550b35e… lmata 242 func TestToolRegisterAgent(t *testing.T) {
550b35e… lmata 243 srv := newTestServer(t)
550b35e… lmata 244 defer srv.Close()
550b35e… lmata 245
550b35e… lmata 246 resp := rpc(t, srv, "tools/call", map[string]any{
550b35e… lmata 247 "name": "register_agent",
550b35e… lmata 248 "arguments": map[string]any{
550b35e… lmata 249 "nick": "mcp-agent",
550b35e… lmata 250 "type": "worker",
550b35e… lmata 251 "channels": []any{"#fleet"},
550b35e… lmata 252 },
550b35e… lmata 253 }, testToken)
550b35e… lmata 254
550b35e… lmata 255 if isToolError(resp) {
550b35e… lmata 256 t.Fatalf("unexpected tool error: %s", toolText(t, resp))
550b35e… lmata 257 }
550b35e… lmata 258 text := toolText(t, resp)
550b35e… lmata 259 if !contains(text, "mcp-agent") {
550b35e… lmata 260 t.Errorf("expected nick in response, got: %s", text)
550b35e… lmata 261 }
550b35e… lmata 262 if !contains(text, "password") {
550b35e… lmata 263 t.Errorf("expected password in response, got: %s", text)
550b35e… lmata 264 }
550b35e… lmata 265 }
550b35e… lmata 266
550b35e… lmata 267 func TestToolRegisterAgentMissingNick(t *testing.T) {
550b35e… lmata 268 srv := newTestServer(t)
550b35e… lmata 269 defer srv.Close()
550b35e… lmata 270
550b35e… lmata 271 resp := rpc(t, srv, "tools/call", map[string]any{
550b35e… lmata 272 "name": "register_agent",
550b35e… lmata 273 "arguments": map[string]any{},
550b35e… lmata 274 }, testToken)
550b35e… lmata 275
550b35e… lmata 276 if !isToolError(resp) {
550b35e… lmata 277 t.Error("expected tool error for missing nick")
550b35e… lmata 278 }
550b35e… lmata 279 }
550b35e… lmata 280
550b35e… lmata 281 func TestToolSendMessage(t *testing.T) {
550b35e… lmata 282 srv := newTestServer(t)
550b35e… lmata 283 defer srv.Close()
550b35e… lmata 284
550b35e… lmata 285 resp := rpc(t, srv, "tools/call", map[string]any{
550b35e… lmata 286 "name": "send_message",
550b35e… lmata 287 "arguments": map[string]any{
550b35e… lmata 288 "channel": "#fleet",
550b35e… lmata 289 "type": "task.update",
550b35e… lmata 290 "payload": map[string]any{"status": "done"},
550b35e… lmata 291 },
550b35e… lmata 292 }, testToken)
550b35e… lmata 293
550b35e… lmata 294 if isToolError(resp) {
550b35e… lmata 295 t.Fatalf("unexpected tool error: %s", toolText(t, resp))
550b35e… lmata 296 }
550b35e… lmata 297 }
550b35e… lmata 298
550b35e… lmata 299 func TestToolGetHistory(t *testing.T) {
550b35e… lmata 300 srv := newTestServer(t)
550b35e… lmata 301 defer srv.Close()
550b35e… lmata 302
550b35e… lmata 303 resp := rpc(t, srv, "tools/call", map[string]any{
550b35e… lmata 304 "name": "get_history",
550b35e… lmata 305 "arguments": map[string]any{
550b35e… lmata 306 "channel": "#fleet",
550b35e… lmata 307 "limit": float64(10),
550b35e… lmata 308 },
550b35e… lmata 309 }, testToken)
550b35e… lmata 310
550b35e… lmata 311 if isToolError(resp) {
550b35e… lmata 312 t.Fatalf("unexpected tool error: %s", toolText(t, resp))
550b35e… lmata 313 }
550b35e… lmata 314 text := toolText(t, resp)
550b35e… lmata 315 if !contains(text, "#fleet") {
550b35e… lmata 316 t.Errorf("expected #fleet in history, got: %s", text)
550b35e… lmata 317 }
550b35e… lmata 318 }
550b35e… lmata 319
550b35e… lmata 320 func TestUnknownTool(t *testing.T) {
550b35e… lmata 321 srv := newTestServer(t)
550b35e… lmata 322 defer srv.Close()
550b35e… lmata 323
550b35e… lmata 324 resp := rpc(t, srv, "tools/call", map[string]any{
550b35e… lmata 325 "name": "no_such_tool",
550b35e… lmata 326 "arguments": map[string]any{},
550b35e… lmata 327 }, testToken)
550b35e… lmata 328
550b35e… lmata 329 if resp["error"] == nil {
550b35e… lmata 330 t.Error("expected rpc error for unknown tool")
550b35e… lmata 331 }
550b35e… lmata 332 }
550b35e… lmata 333
550b35e… lmata 334 func TestUnknownMethod(t *testing.T) {
550b35e… lmata 335 srv := newTestServer(t)
550b35e… lmata 336 defer srv.Close()
550b35e… lmata 337
550b35e… lmata 338 resp := rpc(t, srv, "no_such_method", nil, testToken)
550b35e… lmata 339 if resp["error"] == nil {
550b35e… lmata 340 t.Error("expected rpc error for unknown method")
550b35e… lmata 341 }
550b35e… lmata 342 }
550b35e… lmata 343
550b35e… lmata 344 // --- helpers ---
550b35e… lmata 345
550b35e… lmata 346 func toolText(t *testing.T, resp map[string]any) string {
550b35e… lmata 347 t.Helper()
550b35e… lmata 348 result, _ := resp["result"].(map[string]any)
550b35e… lmata 349 content, _ := result["content"].([]any)
550b35e… lmata 350 if len(content) == 0 {
550b35e… lmata 351 return ""
550b35e… lmata 352 }
550b35e… lmata 353 item, _ := content[0].(map[string]any)
550b35e… lmata 354 text, _ := item["text"].(string)
550b35e… lmata 355 return text
550b35e… lmata 356 }
550b35e… lmata 357
550b35e… lmata 358 func isToolError(resp map[string]any) bool {
550b35e… lmata 359 result, _ := resp["result"].(map[string]any)
550b35e… lmata 360 isErr, _ := result["isError"].(bool)
550b35e… lmata 361 return isErr
550b35e… lmata 362 }
550b35e… lmata 363
550b35e… lmata 364 func contains(s, sub string) bool {
550b35e… lmata 365 return len(s) >= len(sub) && (s == sub || len(s) > 0 && containsStr(s, sub))
550b35e… lmata 366 }
550b35e… lmata 367
550b35e… lmata 368 func containsStr(s, sub string) bool {
550b35e… lmata 369 for i := 0; i <= len(s)-len(sub); i++ {
550b35e… lmata 370 if s[i:i+len(sub)] == sub {
550b35e… lmata 371 return true
550b35e… lmata 372 }
550b35e… lmata 373 }
550b35e… lmata 374 return false
550b35e… lmata 375 }

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button