ScuttleBot
| 5ac549c… | lmata | 1 | package llm |
| 5ac549c… | lmata | 2 | |
| 5ac549c… | lmata | 3 | import ( |
| 5ac549c… | lmata | 4 | "context" |
| 5ac549c… | lmata | 5 | "encoding/json" |
| 5ac549c… | lmata | 6 | "net/http" |
| 5ac549c… | lmata | 7 | "net/http/httptest" |
| 5ac549c… | lmata | 8 | "testing" |
| 5ac549c… | lmata | 9 | ) |
| 5ac549c… | lmata | 10 | |
| 5ac549c… | lmata | 11 | func TestBedrockSummarize(t *testing.T) { |
| 5ac549c… | lmata | 12 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 5ac549c… | lmata | 13 | if r.Method != "POST" { |
| 5ac549c… | lmata | 14 | t.Errorf("expected POST request, got %s", r.Method) |
| 5ac549c… | lmata | 15 | } |
| 5ac549c… | lmata | 16 | // Path: /model/{modelID}/converse |
| 5ac549c… | lmata | 17 | if r.URL.Path != "/model/test-model/converse" { |
| 5ac549c… | lmata | 18 | t.Errorf("unexpected path: %s", r.URL.Path) |
| 5ac549c… | lmata | 19 | } |
| 5ac549c… | lmata | 20 | |
| 5ac549c… | lmata | 21 | resp := map[string]any{ |
| 5ac549c… | lmata | 22 | "output": map[string]any{ |
| 5ac549c… | lmata | 23 | "message": map[string]any{ |
| 5ac549c… | lmata | 24 | "content": []map[string]any{ |
| 5ac549c… | lmata | 25 | {"text": "bedrock response"}, |
| 5ac549c… | lmata | 26 | }, |
| 5ac549c… | lmata | 27 | }, |
| 5ac549c… | lmata | 28 | }, |
| 5ac549c… | lmata | 29 | } |
| 5ac549c… | lmata | 30 | _ = json.NewEncoder(w).Encode(resp) |
| 5ac549c… | lmata | 31 | })) |
| 5ac549c… | lmata | 32 | defer srv.Close() |
| 5ac549c… | lmata | 33 | |
| 5ac549c… | lmata | 34 | p, _ := newBedrockProvider(BackendConfig{ |
| 5ac549c… | lmata | 35 | Backend: "bedrock", |
| 5ac549c… | lmata | 36 | Region: "us-east-1", |
| 5ac549c… | lmata | 37 | Model: "test-model", |
| 5ac549c… | lmata | 38 | BaseURL: srv.URL, |
| 5ac549c… | lmata | 39 | AWSKeyID: "test-key", |
| 5ac549c… | lmata | 40 | AWSSecretKey: "test-secret", |
| 5ac549c… | lmata | 41 | }, srv.Client()) |
| 5ac549c… | lmata | 42 | |
| 5ac549c… | lmata | 43 | got, err := p.Summarize(context.Background(), "test prompt") |
| 5ac549c… | lmata | 44 | if err != nil { |
| 5ac549c… | lmata | 45 | t.Fatalf("Summarize failed: %v", err) |
| 5ac549c… | lmata | 46 | } |
| 5ac549c… | lmata | 47 | if got != "bedrock response" { |
| 5ac549c… | lmata | 48 | t.Errorf("got %q, want %q", got, "bedrock response") |
| 5ac549c… | lmata | 49 | } |
| 5ac549c… | lmata | 50 | } |
| 5ac549c… | lmata | 51 | |
| 5ac549c… | lmata | 52 | func TestBedrockDiscoverModels(t *testing.T) { |
| 5ac549c… | lmata | 53 | srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
| 5ac549c… | lmata | 54 | if r.Method != "GET" { |
| 5ac549c… | lmata | 55 | t.Errorf("expected GET request, got %s", r.Method) |
| 5ac549c… | lmata | 56 | } |
| 5ac549c… | lmata | 57 | if r.URL.Path != "/foundation-models" { |
| 5ac549c… | lmata | 58 | t.Errorf("unexpected path: %s", r.URL.Path) |
| 5ac549c… | lmata | 59 | } |
| 5ac549c… | lmata | 60 | |
| 5ac549c… | lmata | 61 | resp := map[string]any{ |
| 5ac549c… | lmata | 62 | "modelSummaries": []map[string]any{ |
| 5ac549c… | lmata | 63 | {"modelId": "m1", "modelName": "Model 1"}, |
| 5ac549c… | lmata | 64 | {"modelId": "m2", "modelName": "Model 2"}, |
| 5ac549c… | lmata | 65 | }, |
| 5ac549c… | lmata | 66 | } |
| 5ac549c… | lmata | 67 | _ = json.NewEncoder(w).Encode(resp) |
| 5ac549c… | lmata | 68 | })) |
| 5ac549c… | lmata | 69 | defer srv.Close() |
| 5ac549c… | lmata | 70 | |
| 5ac549c… | lmata | 71 | p, _ := newBedrockProvider(BackendConfig{ |
| 5ac549c… | lmata | 72 | Backend: "bedrock", |
| 5ac549c… | lmata | 73 | Region: "us-east-1", |
| 5ac549c… | lmata | 74 | BaseURL: srv.URL, |
| 5ac549c… | lmata | 75 | AWSKeyID: "test-key", |
| 5ac549c… | lmata | 76 | AWSSecretKey: "test-secret", |
| 5ac549c… | lmata | 77 | }, srv.Client()) |
| 5ac549c… | lmata | 78 | |
| 5ac549c… | lmata | 79 | models, err := p.DiscoverModels(context.Background()) |
| 5ac549c… | lmata | 80 | if err != nil { |
| 5ac549c… | lmata | 81 | t.Fatalf("DiscoverModels failed: %v", err) |
| 5ac549c… | lmata | 82 | } |
| 5ac549c… | lmata | 83 | |
| 5ac549c… | lmata | 84 | if len(models) != 2 { |
| 5ac549c… | lmata | 85 | t.Errorf("got %d models, want 2", len(models)) |
| 5ac549c… | lmata | 86 | } |
| 5ac549c… | lmata | 87 | } |