|
1
|
package llm |
|
2
|
|
|
3
|
import ( |
|
4
|
"bytes" |
|
5
|
"context" |
|
6
|
"crypto/hmac" |
|
7
|
"crypto/sha256" |
|
8
|
"encoding/hex" |
|
9
|
"encoding/json" |
|
10
|
"fmt" |
|
11
|
"io" |
|
12
|
"net/http" |
|
13
|
"os" |
|
14
|
"sort" |
|
15
|
"strings" |
|
16
|
"sync" |
|
17
|
"time" |
|
18
|
) |
|
19
|
|
|
20
|
// awsCreds holds a resolved set of AWS credentials (static or temporary). |
|
21
|
type awsCreds struct { |
|
22
|
KeyID string |
|
23
|
SecretKey string |
|
24
|
SessionToken string // non-empty for temporary credentials from IAM roles |
|
25
|
Expiry time.Time // zero for static credentials |
|
26
|
} |
|
27
|
|
|
28
|
// credCache caches resolved credentials to avoid hitting the metadata endpoint |
|
29
|
// on every request. Refreshes when credentials are within 30s of expiry. |
|
30
|
type credCache struct { |
|
31
|
mu sync.Mutex |
|
32
|
creds *awsCreds |
|
33
|
} |
|
34
|
|
|
35
|
func (c *credCache) get() *awsCreds { |
|
36
|
c.mu.Lock() |
|
37
|
defer c.mu.Unlock() |
|
38
|
if c.creds == nil { |
|
39
|
return nil |
|
40
|
} |
|
41
|
if c.creds.Expiry.IsZero() { |
|
42
|
return c.creds // static creds never expire |
|
43
|
} |
|
44
|
if time.Now().Before(c.creds.Expiry.Add(-30 * time.Second)) { |
|
45
|
return c.creds |
|
46
|
} |
|
47
|
return nil // expired or about to expire |
|
48
|
} |
|
49
|
|
|
50
|
func (c *credCache) set(creds *awsCreds) { |
|
51
|
c.mu.Lock() |
|
52
|
defer c.mu.Unlock() |
|
53
|
c.creds = creds |
|
54
|
} |
|
55
|
|
|
56
|
type bedrockProvider struct { |
|
57
|
region string |
|
58
|
modelID string |
|
59
|
baseURL string // for testing |
|
60
|
cfg BackendConfig |
|
61
|
cache credCache |
|
62
|
http *http.Client |
|
63
|
} |
|
64
|
|
|
65
|
func newBedrockProvider(cfg BackendConfig, hc *http.Client) (*bedrockProvider, error) { |
|
66
|
if cfg.Region == "" { |
|
67
|
return nil, fmt.Errorf("llm: bedrock requires region") |
|
68
|
} |
|
69
|
model := cfg.Model |
|
70
|
if model == "" { |
|
71
|
model = "anthropic.claude-3-5-sonnet-20241022-v2:0" |
|
72
|
} |
|
73
|
return &bedrockProvider{ |
|
74
|
region: cfg.Region, |
|
75
|
modelID: model, |
|
76
|
baseURL: cfg.BaseURL, |
|
77
|
cfg: cfg, |
|
78
|
http: hc, |
|
79
|
}, nil |
|
80
|
} |
|
81
|
|
|
82
|
// Summarize calls the Bedrock Converse API, which provides a unified interface |
|
83
|
// across all Bedrock-hosted models. |
|
84
|
func (p *bedrockProvider) Summarize(ctx context.Context, prompt string) (string, error) { |
|
85
|
url := p.baseURL |
|
86
|
if url == "" { |
|
87
|
url = fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com", p.region) |
|
88
|
} |
|
89
|
url = fmt.Sprintf("%s/model/%s/converse", url, p.modelID) |
|
90
|
|
|
91
|
body, _ := json.Marshal(map[string]any{ |
|
92
|
"messages": []map[string]any{ |
|
93
|
{ |
|
94
|
"role": "user", |
|
95
|
"content": []map[string]string{ |
|
96
|
{"type": "text", "text": prompt}, |
|
97
|
}, |
|
98
|
}, |
|
99
|
}, |
|
100
|
"inferenceConfig": map[string]any{ |
|
101
|
"maxTokens": 512, |
|
102
|
}, |
|
103
|
}) |
|
104
|
|
|
105
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) |
|
106
|
if err != nil { |
|
107
|
return "", err |
|
108
|
} |
|
109
|
req.Header.Set("Content-Type", "application/json") |
|
110
|
if err := p.signRequest(ctx, req, body); err != nil { |
|
111
|
return "", fmt.Errorf("bedrock sign: %w", err) |
|
112
|
} |
|
113
|
|
|
114
|
resp, err := p.http.Do(req) |
|
115
|
if err != nil { |
|
116
|
return "", fmt.Errorf("bedrock request: %w", err) |
|
117
|
} |
|
118
|
defer resp.Body.Close() |
|
119
|
|
|
120
|
data, _ := io.ReadAll(resp.Body) |
|
121
|
if resp.StatusCode != http.StatusOK { |
|
122
|
return "", fmt.Errorf("bedrock error %d: %s", resp.StatusCode, string(data)) |
|
123
|
} |
|
124
|
|
|
125
|
var result struct { |
|
126
|
Output struct { |
|
127
|
Message struct { |
|
128
|
Content []struct { |
|
129
|
Text string `json:"text"` |
|
130
|
} `json:"content"` |
|
131
|
} `json:"message"` |
|
132
|
} `json:"output"` |
|
133
|
} |
|
134
|
if err := json.Unmarshal(data, &result); err != nil { |
|
135
|
return "", fmt.Errorf("bedrock parse: %w", err) |
|
136
|
} |
|
137
|
if len(result.Output.Message.Content) == 0 { |
|
138
|
return "", fmt.Errorf("bedrock returned no content") |
|
139
|
} |
|
140
|
return result.Output.Message.Content[0].Text, nil |
|
141
|
} |
|
142
|
|
|
143
|
// DiscoverModels lists Bedrock foundation models available in the configured region. |
|
144
|
func (p *bedrockProvider) DiscoverModels(ctx context.Context) ([]ModelInfo, error) { |
|
145
|
url := p.baseURL |
|
146
|
if url == "" { |
|
147
|
url = fmt.Sprintf("https://bedrock.%s.amazonaws.com", p.region) |
|
148
|
} |
|
149
|
url = fmt.Sprintf("%s/foundation-models", url) |
|
150
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil) |
|
151
|
if err != nil { |
|
152
|
return nil, err |
|
153
|
} |
|
154
|
if err := p.signRequest(ctx, req, nil); err != nil { |
|
155
|
return nil, fmt.Errorf("bedrock sign: %w", err) |
|
156
|
} |
|
157
|
|
|
158
|
resp, err := p.http.Do(req) |
|
159
|
if err != nil { |
|
160
|
return nil, fmt.Errorf("bedrock models request: %w", err) |
|
161
|
} |
|
162
|
defer resp.Body.Close() |
|
163
|
|
|
164
|
data, _ := io.ReadAll(resp.Body) |
|
165
|
if resp.StatusCode != http.StatusOK { |
|
166
|
return nil, fmt.Errorf("bedrock models error %d: %s", resp.StatusCode, string(data)) |
|
167
|
} |
|
168
|
|
|
169
|
var result struct { |
|
170
|
ModelSummaries []struct { |
|
171
|
ModelID string `json:"modelId"` |
|
172
|
ModelName string `json:"modelName"` |
|
173
|
} `json:"modelSummaries"` |
|
174
|
} |
|
175
|
if err := json.Unmarshal(data, &result); err != nil { |
|
176
|
return nil, fmt.Errorf("bedrock models parse: %w", err) |
|
177
|
} |
|
178
|
|
|
179
|
models := make([]ModelInfo, len(result.ModelSummaries)) |
|
180
|
for i, m := range result.ModelSummaries { |
|
181
|
models[i] = ModelInfo{ID: m.ModelID, Name: m.ModelName} |
|
182
|
} |
|
183
|
return models, nil |
|
184
|
} |
|
185
|
|
|
186
|
// signRequest resolves credentials (with caching) and applies SigV4 headers. |
|
187
|
func (p *bedrockProvider) signRequest(ctx context.Context, r *http.Request, body []byte) error { |
|
188
|
creds := p.cache.get() |
|
189
|
if creds == nil { |
|
190
|
var err error |
|
191
|
creds, err = resolveAWSCreds(ctx, p.cfg, p.http) |
|
192
|
if err != nil { |
|
193
|
return fmt.Errorf("resolve credentials: %w", err) |
|
194
|
} |
|
195
|
p.cache.set(creds) |
|
196
|
} |
|
197
|
return signSigV4(r, body, creds, p.region, "bedrock") |
|
198
|
} |
|
199
|
|
|
200
|
// --- AWS credential resolution chain --- |
|
201
|
|
|
202
|
// resolveAWSCreds resolves credentials using the standard AWS chain: |
|
203
|
// 1. Static credentials in BackendConfig (AWSKeyID + AWSSecretKey) |
|
204
|
// 2. AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY / AWS_SESSION_TOKEN env vars |
|
205
|
// 3. ECS task role via AWS_CONTAINER_CREDENTIALS_RELATIVE_URI or _FULL_URI |
|
206
|
// 4. EC2/EKS instance profile via IMDSv2 |
|
207
|
func resolveAWSCreds(ctx context.Context, cfg BackendConfig, hc *http.Client) (*awsCreds, error) { |
|
208
|
// 1. Static config credentials. |
|
209
|
if cfg.AWSKeyID != "" && cfg.AWSSecretKey != "" { |
|
210
|
return &awsCreds{KeyID: cfg.AWSKeyID, SecretKey: cfg.AWSSecretKey}, nil |
|
211
|
} |
|
212
|
|
|
213
|
// 2. Environment variables. |
|
214
|
if id := os.Getenv("AWS_ACCESS_KEY_ID"); id != "" { |
|
215
|
return &awsCreds{ |
|
216
|
KeyID: id, |
|
217
|
SecretKey: os.Getenv("AWS_SECRET_ACCESS_KEY"), |
|
218
|
SessionToken: os.Getenv("AWS_SESSION_TOKEN"), |
|
219
|
}, nil |
|
220
|
} |
|
221
|
|
|
222
|
// 3. ECS container credentials. |
|
223
|
if rel := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"); rel != "" { |
|
224
|
return fetchContainerCreds(ctx, "http://169.254.170.2"+rel, "", hc) |
|
225
|
} |
|
226
|
if full := os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI"); full != "" { |
|
227
|
token := os.Getenv("AWS_CONTAINER_AUTHORIZATION_TOKEN") |
|
228
|
return fetchContainerCreds(ctx, full, token, hc) |
|
229
|
} |
|
230
|
|
|
231
|
// 4. EC2 / EKS instance metadata (IMDSv2). |
|
232
|
return fetchIMDSCreds(ctx, hc) |
|
233
|
} |
|
234
|
|
|
235
|
// fetchContainerCreds fetches temporary credentials from the ECS task metadata endpoint. |
|
236
|
func fetchContainerCreds(ctx context.Context, url, token string, hc *http.Client) (*awsCreds, error) { |
|
237
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil) |
|
238
|
if err != nil { |
|
239
|
return nil, fmt.Errorf("bedrock ecs creds: %w", err) |
|
240
|
} |
|
241
|
if token != "" { |
|
242
|
req.Header.Set("Authorization", token) |
|
243
|
} |
|
244
|
return parseTempCreds(hc, req, "ECS container credentials") |
|
245
|
} |
|
246
|
|
|
247
|
// fetchIMDSCreds fetches temporary credentials via EC2 IMDSv2 (also works for EKS). |
|
248
|
func fetchIMDSCreds(ctx context.Context, hc *http.Client) (*awsCreds, error) { |
|
249
|
const imdsBase = "http://169.254.169.254/latest" |
|
250
|
|
|
251
|
// Step 1: obtain IMDSv2 session token. |
|
252
|
tokenReq, err := http.NewRequestWithContext(ctx, "PUT", imdsBase+"/api/token", nil) |
|
253
|
if err != nil { |
|
254
|
return nil, fmt.Errorf("bedrock imds token request: %w", err) |
|
255
|
} |
|
256
|
tokenReq.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "21600") |
|
257
|
tokenResp, err := hc.Do(tokenReq) |
|
258
|
if err != nil { |
|
259
|
return nil, fmt.Errorf("bedrock imds: not running on EC2/EKS or IMDS unreachable: %w", err) |
|
260
|
} |
|
261
|
defer tokenResp.Body.Close() |
|
262
|
tokenBytes, _ := io.ReadAll(tokenResp.Body) |
|
263
|
if tokenResp.StatusCode != http.StatusOK { |
|
264
|
return nil, fmt.Errorf("bedrock imds: token request failed (%d)", tokenResp.StatusCode) |
|
265
|
} |
|
266
|
imdsToken := strings.TrimSpace(string(tokenBytes)) |
|
267
|
|
|
268
|
// Step 2: get the IAM role name. |
|
269
|
roleReq, _ := http.NewRequestWithContext(ctx, "GET", imdsBase+"/meta-data/iam/security-credentials/", nil) |
|
270
|
roleReq.Header.Set("X-aws-ec2-metadata-token", imdsToken) |
|
271
|
roleResp, err := hc.Do(roleReq) |
|
272
|
if err != nil { |
|
273
|
return nil, fmt.Errorf("bedrock imds: get role name: %w", err) |
|
274
|
} |
|
275
|
defer roleResp.Body.Close() |
|
276
|
roleBytes, _ := io.ReadAll(roleResp.Body) |
|
277
|
if roleResp.StatusCode != http.StatusOK { |
|
278
|
return nil, fmt.Errorf("bedrock imds: no IAM role attached to instance") |
|
279
|
} |
|
280
|
role := strings.TrimSpace(string(roleBytes)) |
|
281
|
|
|
282
|
// Step 3: fetch credentials for the role. |
|
283
|
credsReq, _ := http.NewRequestWithContext(ctx, "GET", imdsBase+"/meta-data/iam/security-credentials/"+role, nil) |
|
284
|
credsReq.Header.Set("X-aws-ec2-metadata-token", imdsToken) |
|
285
|
return parseTempCreds(hc, credsReq, "EC2 instance metadata") |
|
286
|
} |
|
287
|
|
|
288
|
func parseTempCreds(hc *http.Client, req *http.Request, source string) (*awsCreds, error) { |
|
289
|
resp, err := hc.Do(req) |
|
290
|
if err != nil { |
|
291
|
return nil, fmt.Errorf("bedrock %s: %w", source, err) |
|
292
|
} |
|
293
|
defer resp.Body.Close() |
|
294
|
data, _ := io.ReadAll(resp.Body) |
|
295
|
if resp.StatusCode != http.StatusOK { |
|
296
|
return nil, fmt.Errorf("bedrock %s error %d: %s", source, resp.StatusCode, string(data)) |
|
297
|
} |
|
298
|
var result struct { |
|
299
|
AccessKeyID string `json:"AccessKeyId"` |
|
300
|
SecretAccessKey string `json:"SecretAccessKey"` |
|
301
|
Token string `json:"Token"` |
|
302
|
Expiration string `json:"Expiration"` |
|
303
|
} |
|
304
|
if err := json.Unmarshal(data, &result); err != nil { |
|
305
|
return nil, fmt.Errorf("bedrock %s parse: %w", source, err) |
|
306
|
} |
|
307
|
creds := &awsCreds{ |
|
308
|
KeyID: result.AccessKeyID, |
|
309
|
SecretKey: result.SecretAccessKey, |
|
310
|
SessionToken: result.Token, |
|
311
|
} |
|
312
|
if result.Expiration != "" { |
|
313
|
if t, err := time.Parse(time.RFC3339, result.Expiration); err == nil { |
|
314
|
creds.Expiry = t |
|
315
|
} |
|
316
|
} |
|
317
|
return creds, nil |
|
318
|
} |
|
319
|
|
|
320
|
// --- SigV4 signing --- |
|
321
|
|
|
322
|
// signSigV4 adds AWS Signature Version 4 authentication headers to r. |
|
323
|
// Both bedrock.*.amazonaws.com and bedrock-runtime.*.amazonaws.com use service "bedrock". |
|
324
|
func signSigV4(r *http.Request, body []byte, creds *awsCreds, region, service string) error { |
|
325
|
now := time.Now().UTC() |
|
326
|
dateTime := now.Format("20060102T150405Z") |
|
327
|
date := now.Format("20060102") |
|
328
|
|
|
329
|
var bodyBytes []byte |
|
330
|
if body != nil { |
|
331
|
bodyBytes = body |
|
332
|
} |
|
333
|
bodyHash := sha256Hex(bodyBytes) |
|
334
|
|
|
335
|
r.Header.Set("x-amz-date", dateTime) |
|
336
|
r.Header.Set("x-amz-content-sha256", bodyHash) |
|
337
|
if creds.SessionToken != "" { |
|
338
|
r.Header.Set("x-amz-security-token", creds.SessionToken) |
|
339
|
} |
|
340
|
if r.Host == "" { |
|
341
|
r.Host = r.URL.Host |
|
342
|
} |
|
343
|
|
|
344
|
canonHeaders, signedHeaders := buildHeaders(r) |
|
345
|
path := r.URL.Path |
|
346
|
if path == "" { |
|
347
|
path = "/" |
|
348
|
} |
|
349
|
canonReq := strings.Join([]string{ |
|
350
|
r.Method, |
|
351
|
path, |
|
352
|
r.URL.RawQuery, |
|
353
|
canonHeaders, |
|
354
|
signedHeaders, |
|
355
|
bodyHash, |
|
356
|
}, "\n") |
|
357
|
|
|
358
|
credScope := strings.Join([]string{date, region, service, "aws4_request"}, "/") |
|
359
|
strToSign := strings.Join([]string{ |
|
360
|
"AWS4-HMAC-SHA256", |
|
361
|
dateTime, |
|
362
|
credScope, |
|
363
|
sha256Hex([]byte(canonReq)), |
|
364
|
}, "\n") |
|
365
|
|
|
366
|
sigKey := deriveSigningKey(creds.SecretKey, date, region, service) |
|
367
|
sig := hex.EncodeToString(hmacSHA256(sigKey, []byte(strToSign))) |
|
368
|
|
|
369
|
r.Header.Set("Authorization", fmt.Sprintf( |
|
370
|
"AWS4-HMAC-SHA256 Credential=%s/%s,SignedHeaders=%s,Signature=%s", |
|
371
|
creds.KeyID, credScope, signedHeaders, sig, |
|
372
|
)) |
|
373
|
return nil |
|
374
|
} |
|
375
|
|
|
376
|
func buildHeaders(r *http.Request) (canonical, signed string) { |
|
377
|
seen := map[string]bool{} |
|
378
|
var names []string |
|
379
|
for k := range r.Header { |
|
380
|
lk := strings.ToLower(k) |
|
381
|
if !seen[lk] { |
|
382
|
seen[lk] = true |
|
383
|
names = append(names, lk) |
|
384
|
} |
|
385
|
} |
|
386
|
if !seen["host"] { |
|
387
|
names = append(names, "host") |
|
388
|
} |
|
389
|
sort.Strings(names) |
|
390
|
|
|
391
|
var sb strings.Builder |
|
392
|
for _, h := range names { |
|
393
|
if h == "host" { |
|
394
|
sb.WriteString("host:" + r.Host + "\n") |
|
395
|
} else { |
|
396
|
vals := r.Header[http.CanonicalHeaderKey(h)] |
|
397
|
sb.WriteString(h + ":" + strings.TrimSpace(strings.Join(vals, ",")) + "\n") |
|
398
|
} |
|
399
|
} |
|
400
|
return sb.String(), strings.Join(names, ";") |
|
401
|
} |
|
402
|
|
|
403
|
func sha256Hex(data []byte) string { |
|
404
|
h := sha256.Sum256(data) |
|
405
|
return hex.EncodeToString(h[:]) |
|
406
|
} |
|
407
|
|
|
408
|
func hmacSHA256(key, data []byte) []byte { |
|
409
|
h := hmac.New(sha256.New, key) |
|
410
|
h.Write(data) |
|
411
|
return h.Sum(nil) |
|
412
|
} |
|
413
|
|
|
414
|
func deriveSigningKey(secret, date, region, service string) []byte { |
|
415
|
kDate := hmacSHA256([]byte("AWS4"+secret), []byte(date)) |
|
416
|
kRegion := hmacSHA256(kDate, []byte(region)) |
|
417
|
kService := hmacSHA256(kRegion, []byte(service)) |
|
418
|
return hmacSHA256(kService, []byte("aws4_request")) |
|
419
|
} |
|
420
|
|