ScuttleBot

scuttlebot / internal / llm / gemini.go
Blame History Raw 144 lines
1
package llm
2
3
import (
4
"bytes"
5
"context"
6
"encoding/json"
7
"fmt"
8
"io"
9
"net/http"
10
"strings"
11
)
12
13
const geminiAPIBase = "https://generativelanguage.googleapis.com"
14
15
type geminiProvider struct {
16
apiKey string
17
model string
18
baseURL string
19
http *http.Client
20
}
21
22
func newGeminiProvider(cfg BackendConfig, hc *http.Client) *geminiProvider {
23
model := cfg.Model
24
if model == "" {
25
model = "gemini-1.5-flash"
26
}
27
baseURL := cfg.BaseURL
28
if baseURL == "" {
29
baseURL = geminiAPIBase
30
}
31
return &geminiProvider{
32
apiKey: cfg.APIKey,
33
model: model,
34
baseURL: baseURL,
35
http: hc,
36
}
37
}
38
39
func (p *geminiProvider) Summarize(ctx context.Context, prompt string) (string, error) {
40
url := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", p.baseURL, p.model, p.apiKey)
41
body, _ := json.Marshal(map[string]any{
42
"contents": []map[string]any{
43
{
44
"parts": []map[string]string{
45
{"text": prompt},
46
},
47
},
48
},
49
"generationConfig": map[string]any{
50
"maxOutputTokens": 512,
51
},
52
})
53
54
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
55
if err != nil {
56
return "", err
57
}
58
req.Header.Set("Content-Type", "application/json")
59
60
resp, err := p.http.Do(req)
61
if err != nil {
62
return "", fmt.Errorf("gemini request: %w", err)
63
}
64
defer resp.Body.Close()
65
66
data, _ := io.ReadAll(resp.Body)
67
if resp.StatusCode != http.StatusOK {
68
return "", fmt.Errorf("gemini error %d: %s", resp.StatusCode, string(data))
69
}
70
71
var result struct {
72
Candidates []struct {
73
Content struct {
74
Parts []struct {
75
Text string `json:"text"`
76
} `json:"parts"`
77
} `json:"content"`
78
} `json:"candidates"`
79
}
80
if err := json.Unmarshal(data, &result); err != nil {
81
return "", fmt.Errorf("gemini parse: %w", err)
82
}
83
if len(result.Candidates) == 0 || len(result.Candidates[0].Content.Parts) == 0 {
84
return "", fmt.Errorf("gemini returned no candidates")
85
}
86
return result.Candidates[0].Content.Parts[0].Text, nil
87
}
88
89
func (p *geminiProvider) DiscoverModels(ctx context.Context) ([]ModelInfo, error) {
90
url := fmt.Sprintf("%s/v1beta/models?key=%s", p.baseURL, p.apiKey)
91
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
92
if err != nil {
93
return nil, err
94
}
95
96
resp, err := p.http.Do(req)
97
if err != nil {
98
return nil, fmt.Errorf("gemini models request: %w", err)
99
}
100
defer resp.Body.Close()
101
102
data, _ := io.ReadAll(resp.Body)
103
if resp.StatusCode != http.StatusOK {
104
return nil, fmt.Errorf("gemini models error %d: %s", resp.StatusCode, string(data))
105
}
106
107
var result struct {
108
Models []struct {
109
Name string `json:"name"`
110
DisplayName string `json:"displayName"`
111
Description string `json:"description"`
112
SupportedMethods []string `json:"supportedGenerationMethods"`
113
} `json:"models"`
114
}
115
if err := json.Unmarshal(data, &result); err != nil {
116
return nil, fmt.Errorf("gemini models parse: %w", err)
117
}
118
119
var models []ModelInfo
120
for _, m := range result.Models {
121
// Only include models that support content generation.
122
if !supportsGenerate(m.SupportedMethods) {
123
continue
124
}
125
// Name is "models/gemini-1.5-flash" — strip the prefix.
126
id := strings.TrimPrefix(m.Name, "models/")
127
models = append(models, ModelInfo{
128
ID: id,
129
Name: m.DisplayName,
130
Description: m.Description,
131
})
132
}
133
return models, nil
134
}
135
136
func supportsGenerate(methods []string) bool {
137
for _, m := range methods {
138
if m == "generateContent" {
139
return true
140
}
141
}
142
return false
143
}
144

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button