ScuttleBot

scuttlebot / internal / llm / bedrock.go
Blame History Raw 420 lines
1
package llm
2
3
import (
4
"bytes"
5
"context"
6
"crypto/hmac"
7
"crypto/sha256"
8
"encoding/hex"
9
"encoding/json"
10
"fmt"
11
"io"
12
"net/http"
13
"os"
14
"sort"
15
"strings"
16
"sync"
17
"time"
18
)
19
20
// awsCreds holds a resolved set of AWS credentials (static or temporary).
21
type awsCreds struct {
22
KeyID string
23
SecretKey string
24
SessionToken string // non-empty for temporary credentials from IAM roles
25
Expiry time.Time // zero for static credentials
26
}
27
28
// credCache caches resolved credentials to avoid hitting the metadata endpoint
29
// on every request. Refreshes when credentials are within 30s of expiry.
30
type credCache struct {
31
mu sync.Mutex
32
creds *awsCreds
33
}
34
35
func (c *credCache) get() *awsCreds {
36
c.mu.Lock()
37
defer c.mu.Unlock()
38
if c.creds == nil {
39
return nil
40
}
41
if c.creds.Expiry.IsZero() {
42
return c.creds // static creds never expire
43
}
44
if time.Now().Before(c.creds.Expiry.Add(-30 * time.Second)) {
45
return c.creds
46
}
47
return nil // expired or about to expire
48
}
49
50
func (c *credCache) set(creds *awsCreds) {
51
c.mu.Lock()
52
defer c.mu.Unlock()
53
c.creds = creds
54
}
55
56
type bedrockProvider struct {
57
region string
58
modelID string
59
baseURL string // for testing
60
cfg BackendConfig
61
cache credCache
62
http *http.Client
63
}
64
65
func newBedrockProvider(cfg BackendConfig, hc *http.Client) (*bedrockProvider, error) {
66
if cfg.Region == "" {
67
return nil, fmt.Errorf("llm: bedrock requires region")
68
}
69
model := cfg.Model
70
if model == "" {
71
model = "anthropic.claude-3-5-sonnet-20241022-v2:0"
72
}
73
return &bedrockProvider{
74
region: cfg.Region,
75
modelID: model,
76
baseURL: cfg.BaseURL,
77
cfg: cfg,
78
http: hc,
79
}, nil
80
}
81
82
// Summarize calls the Bedrock Converse API, which provides a unified interface
83
// across all Bedrock-hosted models.
84
func (p *bedrockProvider) Summarize(ctx context.Context, prompt string) (string, error) {
85
url := p.baseURL
86
if url == "" {
87
url = fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com", p.region)
88
}
89
url = fmt.Sprintf("%s/model/%s/converse", url, p.modelID)
90
91
body, _ := json.Marshal(map[string]any{
92
"messages": []map[string]any{
93
{
94
"role": "user",
95
"content": []map[string]string{
96
{"type": "text", "text": prompt},
97
},
98
},
99
},
100
"inferenceConfig": map[string]any{
101
"maxTokens": 512,
102
},
103
})
104
105
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
106
if err != nil {
107
return "", err
108
}
109
req.Header.Set("Content-Type", "application/json")
110
if err := p.signRequest(ctx, req, body); err != nil {
111
return "", fmt.Errorf("bedrock sign: %w", err)
112
}
113
114
resp, err := p.http.Do(req)
115
if err != nil {
116
return "", fmt.Errorf("bedrock request: %w", err)
117
}
118
defer resp.Body.Close()
119
120
data, _ := io.ReadAll(resp.Body)
121
if resp.StatusCode != http.StatusOK {
122
return "", fmt.Errorf("bedrock error %d: %s", resp.StatusCode, string(data))
123
}
124
125
var result struct {
126
Output struct {
127
Message struct {
128
Content []struct {
129
Text string `json:"text"`
130
} `json:"content"`
131
} `json:"message"`
132
} `json:"output"`
133
}
134
if err := json.Unmarshal(data, &result); err != nil {
135
return "", fmt.Errorf("bedrock parse: %w", err)
136
}
137
if len(result.Output.Message.Content) == 0 {
138
return "", fmt.Errorf("bedrock returned no content")
139
}
140
return result.Output.Message.Content[0].Text, nil
141
}
142
143
// DiscoverModels lists Bedrock foundation models available in the configured region.
144
func (p *bedrockProvider) DiscoverModels(ctx context.Context) ([]ModelInfo, error) {
145
url := p.baseURL
146
if url == "" {
147
url = fmt.Sprintf("https://bedrock.%s.amazonaws.com", p.region)
148
}
149
url = fmt.Sprintf("%s/foundation-models", url)
150
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
151
if err != nil {
152
return nil, err
153
}
154
if err := p.signRequest(ctx, req, nil); err != nil {
155
return nil, fmt.Errorf("bedrock sign: %w", err)
156
}
157
158
resp, err := p.http.Do(req)
159
if err != nil {
160
return nil, fmt.Errorf("bedrock models request: %w", err)
161
}
162
defer resp.Body.Close()
163
164
data, _ := io.ReadAll(resp.Body)
165
if resp.StatusCode != http.StatusOK {
166
return nil, fmt.Errorf("bedrock models error %d: %s", resp.StatusCode, string(data))
167
}
168
169
var result struct {
170
ModelSummaries []struct {
171
ModelID string `json:"modelId"`
172
ModelName string `json:"modelName"`
173
} `json:"modelSummaries"`
174
}
175
if err := json.Unmarshal(data, &result); err != nil {
176
return nil, fmt.Errorf("bedrock models parse: %w", err)
177
}
178
179
models := make([]ModelInfo, len(result.ModelSummaries))
180
for i, m := range result.ModelSummaries {
181
models[i] = ModelInfo{ID: m.ModelID, Name: m.ModelName}
182
}
183
return models, nil
184
}
185
186
// signRequest resolves credentials (with caching) and applies SigV4 headers.
187
func (p *bedrockProvider) signRequest(ctx context.Context, r *http.Request, body []byte) error {
188
creds := p.cache.get()
189
if creds == nil {
190
var err error
191
creds, err = resolveAWSCreds(ctx, p.cfg, p.http)
192
if err != nil {
193
return fmt.Errorf("resolve credentials: %w", err)
194
}
195
p.cache.set(creds)
196
}
197
return signSigV4(r, body, creds, p.region, "bedrock")
198
}
199
200
// --- AWS credential resolution chain ---
201
202
// resolveAWSCreds resolves credentials using the standard AWS chain:
203
// 1. Static credentials in BackendConfig (AWSKeyID + AWSSecretKey)
204
// 2. AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY / AWS_SESSION_TOKEN env vars
205
// 3. ECS task role via AWS_CONTAINER_CREDENTIALS_RELATIVE_URI or _FULL_URI
206
// 4. EC2/EKS instance profile via IMDSv2
207
func resolveAWSCreds(ctx context.Context, cfg BackendConfig, hc *http.Client) (*awsCreds, error) {
208
// 1. Static config credentials.
209
if cfg.AWSKeyID != "" && cfg.AWSSecretKey != "" {
210
return &awsCreds{KeyID: cfg.AWSKeyID, SecretKey: cfg.AWSSecretKey}, nil
211
}
212
213
// 2. Environment variables.
214
if id := os.Getenv("AWS_ACCESS_KEY_ID"); id != "" {
215
return &awsCreds{
216
KeyID: id,
217
SecretKey: os.Getenv("AWS_SECRET_ACCESS_KEY"),
218
SessionToken: os.Getenv("AWS_SESSION_TOKEN"),
219
}, nil
220
}
221
222
// 3. ECS container credentials.
223
if rel := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"); rel != "" {
224
return fetchContainerCreds(ctx, "http://169.254.170.2"+rel, "", hc)
225
}
226
if full := os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI"); full != "" {
227
token := os.Getenv("AWS_CONTAINER_AUTHORIZATION_TOKEN")
228
return fetchContainerCreds(ctx, full, token, hc)
229
}
230
231
// 4. EC2 / EKS instance metadata (IMDSv2).
232
return fetchIMDSCreds(ctx, hc)
233
}
234
235
// fetchContainerCreds fetches temporary credentials from the ECS task metadata endpoint.
236
func fetchContainerCreds(ctx context.Context, url, token string, hc *http.Client) (*awsCreds, error) {
237
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
238
if err != nil {
239
return nil, fmt.Errorf("bedrock ecs creds: %w", err)
240
}
241
if token != "" {
242
req.Header.Set("Authorization", token)
243
}
244
return parseTempCreds(hc, req, "ECS container credentials")
245
}
246
247
// fetchIMDSCreds fetches temporary credentials via EC2 IMDSv2 (also works for EKS).
248
func fetchIMDSCreds(ctx context.Context, hc *http.Client) (*awsCreds, error) {
249
const imdsBase = "http://169.254.169.254/latest"
250
251
// Step 1: obtain IMDSv2 session token.
252
tokenReq, err := http.NewRequestWithContext(ctx, "PUT", imdsBase+"/api/token", nil)
253
if err != nil {
254
return nil, fmt.Errorf("bedrock imds token request: %w", err)
255
}
256
tokenReq.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "21600")
257
tokenResp, err := hc.Do(tokenReq)
258
if err != nil {
259
return nil, fmt.Errorf("bedrock imds: not running on EC2/EKS or IMDS unreachable: %w", err)
260
}
261
defer tokenResp.Body.Close()
262
tokenBytes, _ := io.ReadAll(tokenResp.Body)
263
if tokenResp.StatusCode != http.StatusOK {
264
return nil, fmt.Errorf("bedrock imds: token request failed (%d)", tokenResp.StatusCode)
265
}
266
imdsToken := strings.TrimSpace(string(tokenBytes))
267
268
// Step 2: get the IAM role name.
269
roleReq, _ := http.NewRequestWithContext(ctx, "GET", imdsBase+"/meta-data/iam/security-credentials/", nil)
270
roleReq.Header.Set("X-aws-ec2-metadata-token", imdsToken)
271
roleResp, err := hc.Do(roleReq)
272
if err != nil {
273
return nil, fmt.Errorf("bedrock imds: get role name: %w", err)
274
}
275
defer roleResp.Body.Close()
276
roleBytes, _ := io.ReadAll(roleResp.Body)
277
if roleResp.StatusCode != http.StatusOK {
278
return nil, fmt.Errorf("bedrock imds: no IAM role attached to instance")
279
}
280
role := strings.TrimSpace(string(roleBytes))
281
282
// Step 3: fetch credentials for the role.
283
credsReq, _ := http.NewRequestWithContext(ctx, "GET", imdsBase+"/meta-data/iam/security-credentials/"+role, nil)
284
credsReq.Header.Set("X-aws-ec2-metadata-token", imdsToken)
285
return parseTempCreds(hc, credsReq, "EC2 instance metadata")
286
}
287
288
func parseTempCreds(hc *http.Client, req *http.Request, source string) (*awsCreds, error) {
289
resp, err := hc.Do(req)
290
if err != nil {
291
return nil, fmt.Errorf("bedrock %s: %w", source, err)
292
}
293
defer resp.Body.Close()
294
data, _ := io.ReadAll(resp.Body)
295
if resp.StatusCode != http.StatusOK {
296
return nil, fmt.Errorf("bedrock %s error %d: %s", source, resp.StatusCode, string(data))
297
}
298
var result struct {
299
AccessKeyID string `json:"AccessKeyId"`
300
SecretAccessKey string `json:"SecretAccessKey"`
301
Token string `json:"Token"`
302
Expiration string `json:"Expiration"`
303
}
304
if err := json.Unmarshal(data, &result); err != nil {
305
return nil, fmt.Errorf("bedrock %s parse: %w", source, err)
306
}
307
creds := &awsCreds{
308
KeyID: result.AccessKeyID,
309
SecretKey: result.SecretAccessKey,
310
SessionToken: result.Token,
311
}
312
if result.Expiration != "" {
313
if t, err := time.Parse(time.RFC3339, result.Expiration); err == nil {
314
creds.Expiry = t
315
}
316
}
317
return creds, nil
318
}
319
320
// --- SigV4 signing ---
321
322
// signSigV4 adds AWS Signature Version 4 authentication headers to r.
323
// Both bedrock.*.amazonaws.com and bedrock-runtime.*.amazonaws.com use service "bedrock".
324
func signSigV4(r *http.Request, body []byte, creds *awsCreds, region, service string) error {
325
now := time.Now().UTC()
326
dateTime := now.Format("20060102T150405Z")
327
date := now.Format("20060102")
328
329
var bodyBytes []byte
330
if body != nil {
331
bodyBytes = body
332
}
333
bodyHash := sha256Hex(bodyBytes)
334
335
r.Header.Set("x-amz-date", dateTime)
336
r.Header.Set("x-amz-content-sha256", bodyHash)
337
if creds.SessionToken != "" {
338
r.Header.Set("x-amz-security-token", creds.SessionToken)
339
}
340
if r.Host == "" {
341
r.Host = r.URL.Host
342
}
343
344
canonHeaders, signedHeaders := buildHeaders(r)
345
path := r.URL.Path
346
if path == "" {
347
path = "/"
348
}
349
canonReq := strings.Join([]string{
350
r.Method,
351
path,
352
r.URL.RawQuery,
353
canonHeaders,
354
signedHeaders,
355
bodyHash,
356
}, "\n")
357
358
credScope := strings.Join([]string{date, region, service, "aws4_request"}, "/")
359
strToSign := strings.Join([]string{
360
"AWS4-HMAC-SHA256",
361
dateTime,
362
credScope,
363
sha256Hex([]byte(canonReq)),
364
}, "\n")
365
366
sigKey := deriveSigningKey(creds.SecretKey, date, region, service)
367
sig := hex.EncodeToString(hmacSHA256(sigKey, []byte(strToSign)))
368
369
r.Header.Set("Authorization", fmt.Sprintf(
370
"AWS4-HMAC-SHA256 Credential=%s/%s,SignedHeaders=%s,Signature=%s",
371
creds.KeyID, credScope, signedHeaders, sig,
372
))
373
return nil
374
}
375
376
func buildHeaders(r *http.Request) (canonical, signed string) {
377
seen := map[string]bool{}
378
var names []string
379
for k := range r.Header {
380
lk := strings.ToLower(k)
381
if !seen[lk] {
382
seen[lk] = true
383
names = append(names, lk)
384
}
385
}
386
if !seen["host"] {
387
names = append(names, "host")
388
}
389
sort.Strings(names)
390
391
var sb strings.Builder
392
for _, h := range names {
393
if h == "host" {
394
sb.WriteString("host:" + r.Host + "\n")
395
} else {
396
vals := r.Header[http.CanonicalHeaderKey(h)]
397
sb.WriteString(h + ":" + strings.TrimSpace(strings.Join(vals, ",")) + "\n")
398
}
399
}
400
return sb.String(), strings.Join(names, ";")
401
}
402
403
func sha256Hex(data []byte) string {
404
h := sha256.Sum256(data)
405
return hex.EncodeToString(h[:])
406
}
407
408
func hmacSHA256(key, data []byte) []byte {
409
h := hmac.New(sha256.New, key)
410
h.Write(data)
411
return h.Sum(nil)
412
}
413
414
func deriveSigningKey(secret, date, region, service string) []byte {
415
kDate := hmacSHA256([]byte("AWS4"+secret), []byte(date))
416
kRegion := hmacSHA256(kDate, []byte(region))
417
kService := hmacSHA256(kRegion, []byte(service))
418
return hmacSHA256(kService, []byte("aws4_request"))
419
}
420

Keyboard Shortcuts

Open search /
Next entry (timeline) j
Previous entry (timeline) k
Open focused entry Enter
Show this help ?
Toggle theme Top nav button