ScuttleBot

scuttlebot / internal / llm / gemini.go
Source Blame History 143 lines
5ac549c… lmata 1 package llm
5ac549c… lmata 2
5ac549c… lmata 3 import (
5ac549c… lmata 4 "bytes"
5ac549c… lmata 5 "context"
5ac549c… lmata 6 "encoding/json"
5ac549c… lmata 7 "fmt"
5ac549c… lmata 8 "io"
5ac549c… lmata 9 "net/http"
5ac549c… lmata 10 "strings"
5ac549c… lmata 11 )
5ac549c… lmata 12
5ac549c… lmata 13 const geminiAPIBase = "https://generativelanguage.googleapis.com"
5ac549c… lmata 14
5ac549c… lmata 15 type geminiProvider struct {
5ac549c… lmata 16 apiKey string
5ac549c… lmata 17 model string
5ac549c… lmata 18 baseURL string
5ac549c… lmata 19 http *http.Client
5ac549c… lmata 20 }
5ac549c… lmata 21
5ac549c… lmata 22 func newGeminiProvider(cfg BackendConfig, hc *http.Client) *geminiProvider {
5ac549c… lmata 23 model := cfg.Model
5ac549c… lmata 24 if model == "" {
5ac549c… lmata 25 model = "gemini-1.5-flash"
5ac549c… lmata 26 }
5ac549c… lmata 27 baseURL := cfg.BaseURL
5ac549c… lmata 28 if baseURL == "" {
5ac549c… lmata 29 baseURL = geminiAPIBase
5ac549c… lmata 30 }
5ac549c… lmata 31 return &geminiProvider{
5ac549c… lmata 32 apiKey: cfg.APIKey,
5ac549c… lmata 33 model: model,
5ac549c… lmata 34 baseURL: baseURL,
5ac549c… lmata 35 http: hc,
5ac549c… lmata 36 }
5ac549c… lmata 37 }
5ac549c… lmata 38
5ac549c… lmata 39 func (p *geminiProvider) Summarize(ctx context.Context, prompt string) (string, error) {
5ac549c… lmata 40 url := fmt.Sprintf("%s/v1beta/models/%s:generateContent?key=%s", p.baseURL, p.model, p.apiKey)
5ac549c… lmata 41 body, _ := json.Marshal(map[string]any{
5ac549c… lmata 42 "contents": []map[string]any{
5ac549c… lmata 43 {
5ac549c… lmata 44 "parts": []map[string]string{
5ac549c… lmata 45 {"text": prompt},
5ac549c… lmata 46 },
5ac549c… lmata 47 },
5ac549c… lmata 48 },
5ac549c… lmata 49 "generationConfig": map[string]any{
5ac549c… lmata 50 "maxOutputTokens": 512,
5ac549c… lmata 51 },
5ac549c… lmata 52 })
5ac549c… lmata 53
5ac549c… lmata 54 req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
5ac549c… lmata 55 if err != nil {
5ac549c… lmata 56 return "", err
5ac549c… lmata 57 }
5ac549c… lmata 58 req.Header.Set("Content-Type", "application/json")
5ac549c… lmata 59
5ac549c… lmata 60 resp, err := p.http.Do(req)
5ac549c… lmata 61 if err != nil {
5ac549c… lmata 62 return "", fmt.Errorf("gemini request: %w", err)
5ac549c… lmata 63 }
5ac549c… lmata 64 defer resp.Body.Close()
5ac549c… lmata 65
5ac549c… lmata 66 data, _ := io.ReadAll(resp.Body)
5ac549c… lmata 67 if resp.StatusCode != http.StatusOK {
5ac549c… lmata 68 return "", fmt.Errorf("gemini error %d: %s", resp.StatusCode, string(data))
5ac549c… lmata 69 }
5ac549c… lmata 70
5ac549c… lmata 71 var result struct {
5ac549c… lmata 72 Candidates []struct {
5ac549c… lmata 73 Content struct {
5ac549c… lmata 74 Parts []struct {
5ac549c… lmata 75 Text string `json:"text"`
5ac549c… lmata 76 } `json:"parts"`
5ac549c… lmata 77 } `json:"content"`
5ac549c… lmata 78 } `json:"candidates"`
5ac549c… lmata 79 }
5ac549c… lmata 80 if err := json.Unmarshal(data, &result); err != nil {
5ac549c… lmata 81 return "", fmt.Errorf("gemini parse: %w", err)
5ac549c… lmata 82 }
5ac549c… lmata 83 if len(result.Candidates) == 0 || len(result.Candidates[0].Content.Parts) == 0 {
5ac549c… lmata 84 return "", fmt.Errorf("gemini returned no candidates")
5ac549c… lmata 85 }
5ac549c… lmata 86 return result.Candidates[0].Content.Parts[0].Text, nil
5ac549c… lmata 87 }
5ac549c… lmata 88
5ac549c… lmata 89 func (p *geminiProvider) DiscoverModels(ctx context.Context) ([]ModelInfo, error) {
5ac549c… lmata 90 url := fmt.Sprintf("%s/v1beta/models?key=%s", p.baseURL, p.apiKey)
5ac549c… lmata 91 req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
5ac549c… lmata 92 if err != nil {
5ac549c… lmata 93 return nil, err
5ac549c… lmata 94 }
5ac549c… lmata 95
5ac549c… lmata 96 resp, err := p.http.Do(req)
5ac549c… lmata 97 if err != nil {
5ac549c… lmata 98 return nil, fmt.Errorf("gemini models request: %w", err)
5ac549c… lmata 99 }
5ac549c… lmata 100 defer resp.Body.Close()
5ac549c… lmata 101
5ac549c… lmata 102 data, _ := io.ReadAll(resp.Body)
5ac549c… lmata 103 if resp.StatusCode != http.StatusOK {
5ac549c… lmata 104 return nil, fmt.Errorf("gemini models error %d: %s", resp.StatusCode, string(data))
5ac549c… lmata 105 }
5ac549c… lmata 106
5ac549c… lmata 107 var result struct {
5ac549c… lmata 108 Models []struct {
1066004… lmata 109 Name string `json:"name"`
1066004… lmata 110 DisplayName string `json:"displayName"`
1066004… lmata 111 Description string `json:"description"`
5ac549c… lmata 112 SupportedMethods []string `json:"supportedGenerationMethods"`
5ac549c… lmata 113 } `json:"models"`
5ac549c… lmata 114 }
5ac549c… lmata 115 if err := json.Unmarshal(data, &result); err != nil {
5ac549c… lmata 116 return nil, fmt.Errorf("gemini models parse: %w", err)
5ac549c… lmata 117 }
5ac549c… lmata 118
5ac549c… lmata 119 var models []ModelInfo
5ac549c… lmata 120 for _, m := range result.Models {
5ac549c… lmata 121 // Only include models that support content generation.
5ac549c… lmata 122 if !supportsGenerate(m.SupportedMethods) {
5ac549c… lmata 123 continue
5ac549c… lmata 124 }
5ac549c… lmata 125 // Name is "models/gemini-1.5-flash" — strip the prefix.
5ac549c… lmata 126 id := strings.TrimPrefix(m.Name, "models/")
5ac549c… lmata 127 models = append(models, ModelInfo{
5ac549c… lmata 128 ID: id,
5ac549c… lmata 129 Name: m.DisplayName,
5ac549c… lmata 130 Description: m.Description,
5ac549c… lmata 131 })
5ac549c… lmata 132 }
5ac549c… lmata 133 return models, nil
5ac549c… lmata 134 }
5ac549c… lmata 135
5ac549c… lmata 136 func supportsGenerate(methods []string) bool {
5ac549c… lmata 137 for _, m := range methods {
5ac549c… lmata 138 if m == "generateContent" {
5ac549c… lmata 139 return true
5ac549c… lmata 140 }
5ac549c… lmata 141 }
5ac549c… lmata 142 return false
5ac549c… lmata 143 }

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button