|
1
|
package llm |
|
2
|
|
|
3
|
import ( |
|
4
|
"bytes" |
|
5
|
"context" |
|
6
|
"encoding/json" |
|
7
|
"fmt" |
|
8
|
"io" |
|
9
|
"net/http" |
|
10
|
"strings" |
|
11
|
) |
|
12
|
|
|
13
|
// openAIProvider implements Provider and ModelDiscoverer for any OpenAI-compatible API. |
|
14
|
type openAIProvider struct { |
|
15
|
baseURL string |
|
16
|
apiKey string |
|
17
|
model string |
|
18
|
http *http.Client |
|
19
|
} |
|
20
|
|
|
21
|
func newOpenAIProvider(apiKey, baseURL, model string, hc *http.Client) *openAIProvider { |
|
22
|
return &openAIProvider{ |
|
23
|
baseURL: baseURL, |
|
24
|
apiKey: apiKey, |
|
25
|
model: model, |
|
26
|
http: hc, |
|
27
|
} |
|
28
|
} |
|
29
|
|
|
30
|
func (p *openAIProvider) Summarize(ctx context.Context, prompt string) (string, error) { |
|
31
|
text, status, data, err := p.summarizeWithTokenField(ctx, prompt, "max_tokens") |
|
32
|
if err == nil { |
|
33
|
return text, nil |
|
34
|
} |
|
35
|
if shouldRetryWithMaxCompletionTokens(status, data) { |
|
36
|
text, _, _, err := p.summarizeWithTokenField(ctx, prompt, "max_completion_tokens") |
|
37
|
return text, err |
|
38
|
} |
|
39
|
return "", err |
|
40
|
} |
|
41
|
|
|
42
|
func (p *openAIProvider) summarizeWithTokenField(ctx context.Context, prompt, tokenField string) (string, int, []byte, error) { |
|
43
|
body, _ := json.Marshal(map[string]any{ |
|
44
|
"model": p.model, |
|
45
|
"messages": []map[string]string{ |
|
46
|
{"role": "user", "content": prompt}, |
|
47
|
}, |
|
48
|
tokenField: 512, |
|
49
|
}) |
|
50
|
req, err := http.NewRequestWithContext(ctx, "POST", p.baseURL+"/chat/completions", bytes.NewReader(body)) |
|
51
|
if err != nil { |
|
52
|
return "", 0, nil, err |
|
53
|
} |
|
54
|
if p.apiKey != "" { |
|
55
|
req.Header.Set("Authorization", "Bearer "+p.apiKey) |
|
56
|
} |
|
57
|
req.Header.Set("Content-Type", "application/json") |
|
58
|
|
|
59
|
resp, err := p.http.Do(req) |
|
60
|
if err != nil { |
|
61
|
return "", 0, nil, fmt.Errorf("openai request: %w", err) |
|
62
|
} |
|
63
|
defer resp.Body.Close() |
|
64
|
|
|
65
|
data, _ := io.ReadAll(resp.Body) |
|
66
|
if resp.StatusCode != http.StatusOK { |
|
67
|
return "", resp.StatusCode, data, fmt.Errorf("openai error %d: %s", resp.StatusCode, string(data)) |
|
68
|
} |
|
69
|
|
|
70
|
var result struct { |
|
71
|
Choices []struct { |
|
72
|
Message struct { |
|
73
|
Content string `json:"content"` |
|
74
|
} `json:"message"` |
|
75
|
} `json:"choices"` |
|
76
|
} |
|
77
|
if err := json.Unmarshal(data, &result); err != nil { |
|
78
|
return "", resp.StatusCode, data, fmt.Errorf("openai parse: %w", err) |
|
79
|
} |
|
80
|
if len(result.Choices) == 0 { |
|
81
|
return "", resp.StatusCode, data, fmt.Errorf("openai returned no choices") |
|
82
|
} |
|
83
|
return result.Choices[0].Message.Content, resp.StatusCode, data, nil |
|
84
|
} |
|
85
|
|
|
86
|
func shouldRetryWithMaxCompletionTokens(status int, data []byte) bool { |
|
87
|
if status != http.StatusBadRequest { |
|
88
|
return false |
|
89
|
} |
|
90
|
var result struct { |
|
91
|
Error struct { |
|
92
|
Message string `json:"message"` |
|
93
|
Param string `json:"param"` |
|
94
|
} `json:"error"` |
|
95
|
} |
|
96
|
if err := json.Unmarshal(data, &result); err == nil { |
|
97
|
if result.Error.Param == "max_tokens" && strings.Contains(strings.ToLower(result.Error.Message), "not supported") { |
|
98
|
return true |
|
99
|
} |
|
100
|
} |
|
101
|
lower := strings.ToLower(string(data)) |
|
102
|
return strings.Contains(lower, "unsupported parameter") && strings.Contains(lower, "max_tokens") |
|
103
|
} |
|
104
|
|
|
105
|
func (p *openAIProvider) DiscoverModels(ctx context.Context) ([]ModelInfo, error) { |
|
106
|
req, err := http.NewRequestWithContext(ctx, "GET", p.baseURL+"/models", nil) |
|
107
|
if err != nil { |
|
108
|
return nil, err |
|
109
|
} |
|
110
|
if p.apiKey != "" { |
|
111
|
req.Header.Set("Authorization", "Bearer "+p.apiKey) |
|
112
|
} |
|
113
|
|
|
114
|
resp, err := p.http.Do(req) |
|
115
|
if err != nil { |
|
116
|
return nil, fmt.Errorf("models request: %w", err) |
|
117
|
} |
|
118
|
defer resp.Body.Close() |
|
119
|
|
|
120
|
data, _ := io.ReadAll(resp.Body) |
|
121
|
if resp.StatusCode != http.StatusOK { |
|
122
|
return nil, fmt.Errorf("models error %d: %s", resp.StatusCode, string(data)) |
|
123
|
} |
|
124
|
|
|
125
|
var result struct { |
|
126
|
Data []struct { |
|
127
|
ID string `json:"id"` |
|
128
|
} `json:"data"` |
|
129
|
} |
|
130
|
if err := json.Unmarshal(data, &result); err != nil { |
|
131
|
return nil, fmt.Errorf("models parse: %w", err) |
|
132
|
} |
|
133
|
|
|
134
|
models := make([]ModelInfo, len(result.Data)) |
|
135
|
for i, m := range result.Data { |
|
136
|
models[i] = ModelInfo{ID: m.ID} |
|
137
|
} |
|
138
|
return models, nil |
|
139
|
} |
|
140
|
|