ScuttleBot

scuttlebot / internal / llm / gemini_test.go
Blame History Raw 100 lines
1
package llm
2
3
import (
4
"context"
5
"encoding/json"
6
"net/http"
7
"net/http/httptest"
8
"testing"
9
)
10
11
func TestGeminiSummarize(t *testing.T) {
12
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
13
if r.Method != "POST" {
14
t.Errorf("expected POST request, got %s", r.Method)
15
}
16
if r.URL.Path != "/v1beta/models/gemini-1.5-flash:generateContent" {
17
t.Errorf("unexpected path: %s", r.URL.Path)
18
}
19
if r.URL.Query().Get("key") != "test-api-key" {
20
t.Errorf("expected api key test-api-key, got %s", r.URL.Query().Get("key"))
21
}
22
23
resp := map[string]any{
24
"candidates": []map[string]any{
25
{
26
"content": map[string]any{
27
"parts": []map[string]any{
28
{"text": "gemini response"},
29
},
30
},
31
},
32
},
33
}
34
_ = json.NewEncoder(w).Encode(resp)
35
}))
36
defer srv.Close()
37
38
p := newGeminiProvider(BackendConfig{
39
Backend: "gemini",
40
APIKey: "test-api-key",
41
BaseURL: srv.URL,
42
}, srv.Client())
43
44
got, err := p.Summarize(context.Background(), "test prompt")
45
if err != nil {
46
t.Fatalf("Summarize failed: %v", err)
47
}
48
if got != "gemini response" {
49
t.Errorf("got %q, want %q", got, "gemini response")
50
}
51
}
52
53
func TestGeminiDiscoverModels(t *testing.T) {
54
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
55
if r.Method != "GET" {
56
t.Errorf("expected GET request, got %s", r.Method)
57
}
58
if r.URL.Path != "/v1beta/models" {
59
t.Errorf("unexpected path: %s", r.URL.Path)
60
}
61
62
resp := map[string]any{
63
"models": []map[string]any{
64
{
65
"name": "models/gemini-1.5-flash",
66
"displayName": "Gemini 1.5 Flash",
67
"description": "Fast and versatile",
68
"supportedGenerationMethods": []string{"generateContent"},
69
},
70
{
71
"name": "models/other-model",
72
"displayName": "Other",
73
"description": "Other model",
74
"supportedGenerationMethods": []string{"other"},
75
},
76
},
77
}
78
_ = json.NewEncoder(w).Encode(resp)
79
}))
80
defer srv.Close()
81
82
p := newGeminiProvider(BackendConfig{
83
Backend: "gemini",
84
APIKey: "test-api-key",
85
BaseURL: srv.URL,
86
}, srv.Client())
87
88
models, err := p.DiscoverModels(context.Background())
89
if err != nil {
90
t.Fatalf("DiscoverModels failed: %v", err)
91
}
92
93
if len(models) != 1 {
94
t.Errorf("got %d models, want 1", len(models))
95
}
96
if models[0].ID != "gemini-1.5-flash" {
97
t.Errorf("got ID %q, want %q", models[0].ID, "gemini-1.5-flash")
98
}
99
}
100

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button