ScuttleBot

scuttlebot / internal / llm / openai.go
Source Blame History 139 lines
5ac549c… lmata 1 package llm
5ac549c… lmata 2
5ac549c… lmata 3 import (
5ac549c… lmata 4 "bytes"
5ac549c… lmata 5 "context"
5ac549c… lmata 6 "encoding/json"
5ac549c… lmata 7 "fmt"
5ac549c… lmata 8 "io"
5ac549c… lmata 9 "net/http"
5ac549c… lmata 10 "strings"
5ac549c… lmata 11 )
5ac549c… lmata 12
5ac549c… lmata 13 // openAIProvider implements Provider and ModelDiscoverer for any OpenAI-compatible API.
5ac549c… lmata 14 type openAIProvider struct {
5ac549c… lmata 15 baseURL string
5ac549c… lmata 16 apiKey string
5ac549c… lmata 17 model string
5ac549c… lmata 18 http *http.Client
5ac549c… lmata 19 }
5ac549c… lmata 20
5ac549c… lmata 21 func newOpenAIProvider(apiKey, baseURL, model string, hc *http.Client) *openAIProvider {
5ac549c… lmata 22 return &openAIProvider{
5ac549c… lmata 23 baseURL: baseURL,
5ac549c… lmata 24 apiKey: apiKey,
5ac549c… lmata 25 model: model,
5ac549c… lmata 26 http: hc,
5ac549c… lmata 27 }
5ac549c… lmata 28 }
5ac549c… lmata 29
5ac549c… lmata 30 func (p *openAIProvider) Summarize(ctx context.Context, prompt string) (string, error) {
5ac549c… lmata 31 text, status, data, err := p.summarizeWithTokenField(ctx, prompt, "max_tokens")
5ac549c… lmata 32 if err == nil {
5ac549c… lmata 33 return text, nil
5ac549c… lmata 34 }
5ac549c… lmata 35 if shouldRetryWithMaxCompletionTokens(status, data) {
5ac549c… lmata 36 text, _, _, err := p.summarizeWithTokenField(ctx, prompt, "max_completion_tokens")
5ac549c… lmata 37 return text, err
5ac549c… lmata 38 }
5ac549c… lmata 39 return "", err
5ac549c… lmata 40 }
5ac549c… lmata 41
5ac549c… lmata 42 func (p *openAIProvider) summarizeWithTokenField(ctx context.Context, prompt, tokenField string) (string, int, []byte, error) {
5ac549c… lmata 43 body, _ := json.Marshal(map[string]any{
5ac549c… lmata 44 "model": p.model,
5ac549c… lmata 45 "messages": []map[string]string{
5ac549c… lmata 46 {"role": "user", "content": prompt},
5ac549c… lmata 47 },
5ac549c… lmata 48 tokenField: 512,
5ac549c… lmata 49 })
5ac549c… lmata 50 req, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/chat/completions", bytes.NewReader(body))
5ac549c… lmata 51 if err != nil {
5ac549c… lmata 52 return "", 0, nil, err
5ac549c… lmata 53 }
5ac549c… lmata 54 if p.apiKey != "" {
5ac549c… lmata 55 req.Header.Set("Authorization", "Bearer "+p.apiKey)
5ac549c… lmata 56 }
5ac549c… lmata 57 req.Header.Set("Content-Type", "application/json")
5ac549c… lmata 58
5ac549c… lmata 59 resp, err := p.http.Do(req)
5ac549c… lmata 60 if err != nil {
5ac549c… lmata 61 return "", 0, nil, fmt.Errorf("openai request: %w", err)
5ac549c… lmata 62 }
5ac549c… lmata 63 defer resp.Body.Close()
5ac549c… lmata 64
5ac549c… lmata 65 data, _ := io.ReadAll(resp.Body)
5ac549c… lmata 66 if resp.StatusCode != http.StatusOK {
5ac549c… lmata 67 return "", resp.StatusCode, data, fmt.Errorf("openai error %d: %s", resp.StatusCode, string(data))
5ac549c… lmata 68 }
5ac549c… lmata 69
5ac549c… lmata 70 var result struct {
5ac549c… lmata 71 Choices []struct {
5ac549c… lmata 72 Message struct {
5ac549c… lmata 73 Content string `json:"content"`
5ac549c… lmata 74 } `json:"message"`
5ac549c… lmata 75 } `json:"choices"`
5ac549c… lmata 76 }
5ac549c… lmata 77 if err := json.Unmarshal(data, &result); err != nil {
5ac549c… lmata 78 return "", resp.StatusCode, data, fmt.Errorf("openai parse: %w", err)
5ac549c… lmata 79 }
5ac549c… lmata 80 if len(result.Choices) == 0 {
5ac549c… lmata 81 return "", resp.StatusCode, data, fmt.Errorf("openai returned no choices")
5ac549c… lmata 82 }
5ac549c… lmata 83 return result.Choices[0].Message.Content, resp.StatusCode, data, nil
5ac549c… lmata 84 }
5ac549c… lmata 85
5ac549c… lmata 86 func shouldRetryWithMaxCompletionTokens(status int, data []byte) bool {
5ac549c… lmata 87 if status != http.StatusBadRequest {
5ac549c… lmata 88 return false
5ac549c… lmata 89 }
5ac549c… lmata 90 var result struct {
5ac549c… lmata 91 Error struct {
5ac549c… lmata 92 Message string `json:"message"`
5ac549c… lmata 93 Param string `json:"param"`
5ac549c… lmata 94 } `json:"error"`
5ac549c… lmata 95 }
5ac549c… lmata 96 if err := json.Unmarshal(data, &result); err == nil {
5ac549c… lmata 97 if result.Error.Param == "max_tokens" && strings.Contains(strings.ToLower(result.Error.Message), "not supported") {
5ac549c… lmata 98 return true
5ac549c… lmata 99 }
5ac549c… lmata 100 }
5ac549c… lmata 101 lower := strings.ToLower(string(data))
5ac549c… lmata 102 return strings.Contains(lower, "unsupported parameter") && strings.Contains(lower, "max_tokens")
5ac549c… lmata 103 }
5ac549c… lmata 104
5ac549c… lmata 105 func (p *openAIProvider) DiscoverModels(ctx context.Context) ([]ModelInfo, error) {
5ac549c… lmata 106 req, err := http.NewRequestWithContext(ctx, "GET", p.baseURL+"/models", nil)
5ac549c… lmata 107 if err != nil {
5ac549c… lmata 108 return nil, err
5ac549c… lmata 109 }
5ac549c… lmata 110 if p.apiKey != "" {
5ac549c… lmata 111 req.Header.Set("Authorization", "Bearer "+p.apiKey)
5ac549c… lmata 112 }
5ac549c… lmata 113
5ac549c… lmata 114 resp, err := p.http.Do(req)
5ac549c… lmata 115 if err != nil {
5ac549c… lmata 116 return nil, fmt.Errorf("models request: %w", err)
5ac549c… lmata 117 }
5ac549c… lmata 118 defer resp.Body.Close()
5ac549c… lmata 119
5ac549c… lmata 120 data, _ := io.ReadAll(resp.Body)
5ac549c… lmata 121 if resp.StatusCode != http.StatusOK {
5ac549c… lmata 122 return nil, fmt.Errorf("models error %d: %s", resp.StatusCode, string(data))
5ac549c… lmata 123 }
5ac549c… lmata 124
5ac549c… lmata 125 var result struct {
5ac549c… lmata 126 Data []struct {
5ac549c… lmata 127 ID string `json:"id"`
5ac549c… lmata 128 } `json:"data"`
5ac549c… lmata 129 }
5ac549c… lmata 130 if err := json.Unmarshal(data, &result); err != nil {
5ac549c… lmata 131 return nil, fmt.Errorf("models parse: %w", err)
5ac549c… lmata 132 }
5ac549c… lmata 133
5ac549c… lmata 134 models := make([]ModelInfo, len(result.Data))
5ac549c… lmata 135 for i, m := range result.Data {
5ac549c… lmata 136 models[i] = ModelInfo{ID: m.ID}
5ac549c… lmata 137 }
5ac549c… lmata 138 return models, nil
5ac549c… lmata 139 }

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button