ScuttleBot

scuttlebot / internal / llm / bedrock_test.go
Blame History Raw 88 lines
1
package llm
2
3
import (
4
"context"
5
"encoding/json"
6
"net/http"
7
"net/http/httptest"
8
"testing"
9
)
10
11
func TestBedrockSummarize(t *testing.T) {
12
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
13
if r.Method != "POST" {
14
t.Errorf("expected POST request, got %s", r.Method)
15
}
16
// Path: /model/{modelID}/converse
17
if r.URL.Path != "/model/test-model/converse" {
18
t.Errorf("unexpected path: %s", r.URL.Path)
19
}
20
21
resp := map[string]any{
22
"output": map[string]any{
23
"message": map[string]any{
24
"content": []map[string]any{
25
{"text": "bedrock response"},
26
},
27
},
28
},
29
}
30
_ = json.NewEncoder(w).Encode(resp)
31
}))
32
defer srv.Close()
33
34
p, _ := newBedrockProvider(BackendConfig{
35
Backend: "bedrock",
36
Region: "us-east-1",
37
Model: "test-model",
38
BaseURL: srv.URL,
39
AWSKeyID: "test-key",
40
AWSSecretKey: "test-secret",
41
}, srv.Client())
42
43
got, err := p.Summarize(context.Background(), "test prompt")
44
if err != nil {
45
t.Fatalf("Summarize failed: %v", err)
46
}
47
if got != "bedrock response" {
48
t.Errorf("got %q, want %q", got, "bedrock response")
49
}
50
}
51
52
func TestBedrockDiscoverModels(t *testing.T) {
53
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
54
if r.Method != "GET" {
55
t.Errorf("expected GET request, got %s", r.Method)
56
}
57
if r.URL.Path != "/foundation-models" {
58
t.Errorf("unexpected path: %s", r.URL.Path)
59
}
60
61
resp := map[string]any{
62
"modelSummaries": []map[string]any{
63
{"modelId": "m1", "modelName": "Model 1"},
64
{"modelId": "m2", "modelName": "Model 2"},
65
},
66
}
67
_ = json.NewEncoder(w).Encode(resp)
68
}))
69
defer srv.Close()
70
71
p, _ := newBedrockProvider(BackendConfig{
72
Backend: "bedrock",
73
Region: "us-east-1",
74
BaseURL: srv.URL,
75
AWSKeyID: "test-key",
76
AWSSecretKey: "test-secret",
77
}, srv.Client())
78
79
models, err := p.DiscoverModels(context.Background())
80
if err != nil {
81
t.Fatalf("DiscoverModels failed: %v", err)
82
}
83
84
if len(models) != 2 {
85
t.Errorf("got %d models, want 2", len(models))
86
}
87
}
88

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button