ScuttleBot

scuttlebot / internal / llm / bedrock_test.go
Source Blame History 87 lines
5ac549c… lmata 1 package llm
5ac549c… lmata 2
5ac549c… lmata 3 import (
5ac549c… lmata 4 "context"
5ac549c… lmata 5 "encoding/json"
5ac549c… lmata 6 "net/http"
5ac549c… lmata 7 "net/http/httptest"
5ac549c… lmata 8 "testing"
5ac549c… lmata 9 )
5ac549c… lmata 10
5ac549c… lmata 11 func TestBedrockSummarize(t *testing.T) {
5ac549c… lmata 12 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
5ac549c… lmata 13 if r.Method != "POST" {
5ac549c… lmata 14 t.Errorf("expected POST request, got %s", r.Method)
5ac549c… lmata 15 }
5ac549c… lmata 16 // Path: /model/{modelID}/converse
5ac549c… lmata 17 if r.URL.Path != "/model/test-model/converse" {
5ac549c… lmata 18 t.Errorf("unexpected path: %s", r.URL.Path)
5ac549c… lmata 19 }
5ac549c… lmata 20
5ac549c… lmata 21 resp := map[string]any{
5ac549c… lmata 22 "output": map[string]any{
5ac549c… lmata 23 "message": map[string]any{
5ac549c… lmata 24 "content": []map[string]any{
5ac549c… lmata 25 {"text": "bedrock response"},
5ac549c… lmata 26 },
5ac549c… lmata 27 },
5ac549c… lmata 28 },
5ac549c… lmata 29 }
5ac549c… lmata 30 _ = json.NewEncoder(w).Encode(resp)
5ac549c… lmata 31 }))
5ac549c… lmata 32 defer srv.Close()
5ac549c… lmata 33
5ac549c… lmata 34 p, _ := newBedrockProvider(BackendConfig{
5ac549c… lmata 35 Backend: "bedrock",
5ac549c… lmata 36 Region: "us-east-1",
5ac549c… lmata 37 Model: "test-model",
5ac549c… lmata 38 BaseURL: srv.URL,
5ac549c… lmata 39 AWSKeyID: "test-key",
5ac549c… lmata 40 AWSSecretKey: "test-secret",
5ac549c… lmata 41 }, srv.Client())
5ac549c… lmata 42
5ac549c… lmata 43 got, err := p.Summarize(context.Background(), "test prompt")
5ac549c… lmata 44 if err != nil {
5ac549c… lmata 45 t.Fatalf("Summarize failed: %v", err)
5ac549c… lmata 46 }
5ac549c… lmata 47 if got != "bedrock response" {
5ac549c… lmata 48 t.Errorf("got %q, want %q", got, "bedrock response")
5ac549c… lmata 49 }
5ac549c… lmata 50 }
5ac549c… lmata 51
5ac549c… lmata 52 func TestBedrockDiscoverModels(t *testing.T) {
5ac549c… lmata 53 srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
5ac549c… lmata 54 if r.Method != "GET" {
5ac549c… lmata 55 t.Errorf("expected GET request, got %s", r.Method)
5ac549c… lmata 56 }
5ac549c… lmata 57 if r.URL.Path != "/foundation-models" {
5ac549c… lmata 58 t.Errorf("unexpected path: %s", r.URL.Path)
5ac549c… lmata 59 }
5ac549c… lmata 60
5ac549c… lmata 61 resp := map[string]any{
5ac549c… lmata 62 "modelSummaries": []map[string]any{
5ac549c… lmata 63 {"modelId": "m1", "modelName": "Model 1"},
5ac549c… lmata 64 {"modelId": "m2", "modelName": "Model 2"},
5ac549c… lmata 65 },
5ac549c… lmata 66 }
5ac549c… lmata 67 _ = json.NewEncoder(w).Encode(resp)
5ac549c… lmata 68 }))
5ac549c… lmata 69 defer srv.Close()
5ac549c… lmata 70
5ac549c… lmata 71 p, _ := newBedrockProvider(BackendConfig{
5ac549c… lmata 72 Backend: "bedrock",
5ac549c… lmata 73 Region: "us-east-1",
5ac549c… lmata 74 BaseURL: srv.URL,
5ac549c… lmata 75 AWSKeyID: "test-key",
5ac549c… lmata 76 AWSSecretKey: "test-secret",
5ac549c… lmata 77 }, srv.Client())
5ac549c… lmata 78
5ac549c… lmata 79 models, err := p.DiscoverModels(context.Background())
5ac549c… lmata 80 if err != nil {
5ac549c… lmata 81 t.Fatalf("DiscoverModels failed: %v", err)
5ac549c… lmata 82 }
5ac549c… lmata 83
5ac549c… lmata 84 if len(models) != 2 {
5ac549c… lmata 85 t.Errorf("got %d models, want 2", len(models))
5ac549c… lmata 86 }
5ac549c… lmata 87 }

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button