|
5ac549c…
|
lmata
|
1 |
package llm |
|
5ac549c…
|
lmata
|
2 |
|
|
5ac549c…
|
lmata
|
3 |
import ( |
|
5ac549c…
|
lmata
|
4 |
"bytes" |
|
5ac549c…
|
lmata
|
5 |
"context" |
|
5ac549c…
|
lmata
|
6 |
"crypto/hmac" |
|
5ac549c…
|
lmata
|
7 |
"crypto/sha256" |
|
5ac549c…
|
lmata
|
8 |
"encoding/hex" |
|
5ac549c…
|
lmata
|
9 |
"encoding/json" |
|
5ac549c…
|
lmata
|
10 |
"fmt" |
|
5ac549c…
|
lmata
|
11 |
"io" |
|
5ac549c…
|
lmata
|
12 |
"net/http" |
|
5ac549c…
|
lmata
|
13 |
"os" |
|
5ac549c…
|
lmata
|
14 |
"sort" |
|
5ac549c…
|
lmata
|
15 |
"strings" |
|
5ac549c…
|
lmata
|
16 |
"sync" |
|
5ac549c…
|
lmata
|
17 |
"time" |
|
5ac549c…
|
lmata
|
18 |
) |
|
5ac549c…
|
lmata
|
19 |
|
|
5ac549c…
|
lmata
|
20 |
// awsCreds holds a resolved set of AWS credentials (static or temporary). |
|
5ac549c…
|
lmata
|
21 |
type awsCreds struct { |
|
5ac549c…
|
lmata
|
22 |
KeyID string |
|
5ac549c…
|
lmata
|
23 |
SecretKey string |
|
1066004…
|
lmata
|
24 |
SessionToken string // non-empty for temporary credentials from IAM roles |
|
1066004…
|
lmata
|
25 |
Expiry time.Time // zero for static credentials |
|
5ac549c…
|
lmata
|
26 |
} |
|
5ac549c…
|
lmata
|
27 |
|
|
5ac549c…
|
lmata
|
28 |
// credCache caches resolved credentials to avoid hitting the metadata endpoint |
|
5ac549c…
|
lmata
|
29 |
// on every request. Refreshes when credentials are within 30s of expiry. |
|
5ac549c…
|
lmata
|
30 |
type credCache struct { |
|
5ac549c…
|
lmata
|
31 |
mu sync.Mutex |
|
5ac549c…
|
lmata
|
32 |
creds *awsCreds |
|
5ac549c…
|
lmata
|
33 |
} |
|
5ac549c…
|
lmata
|
34 |
|
|
5ac549c…
|
lmata
|
35 |
func (c *credCache) get() *awsCreds { |
|
5ac549c…
|
lmata
|
36 |
c.mu.Lock() |
|
5ac549c…
|
lmata
|
37 |
defer c.mu.Unlock() |
|
5ac549c…
|
lmata
|
38 |
if c.creds == nil { |
|
5ac549c…
|
lmata
|
39 |
return nil |
|
5ac549c…
|
lmata
|
40 |
} |
|
5ac549c…
|
lmata
|
41 |
if c.creds.Expiry.IsZero() { |
|
5ac549c…
|
lmata
|
42 |
return c.creds // static creds never expire |
|
5ac549c…
|
lmata
|
43 |
} |
|
5ac549c…
|
lmata
|
44 |
if time.Now().Before(c.creds.Expiry.Add(-30 * time.Second)) { |
|
5ac549c…
|
lmata
|
45 |
return c.creds |
|
5ac549c…
|
lmata
|
46 |
} |
|
5ac549c…
|
lmata
|
47 |
return nil // expired or about to expire |
|
5ac549c…
|
lmata
|
48 |
} |
|
5ac549c…
|
lmata
|
49 |
|
|
5ac549c…
|
lmata
|
50 |
func (c *credCache) set(creds *awsCreds) { |
|
5ac549c…
|
lmata
|
51 |
c.mu.Lock() |
|
5ac549c…
|
lmata
|
52 |
defer c.mu.Unlock() |
|
5ac549c…
|
lmata
|
53 |
c.creds = creds |
|
5ac549c…
|
lmata
|
54 |
} |
|
5ac549c…
|
lmata
|
55 |
|
|
5ac549c…
|
lmata
|
56 |
type bedrockProvider struct { |
|
5ac549c…
|
lmata
|
57 |
region string |
|
5ac549c…
|
lmata
|
58 |
modelID string |
|
5ac549c…
|
lmata
|
59 |
baseURL string // for testing |
|
5ac549c…
|
lmata
|
60 |
cfg BackendConfig |
|
5ac549c…
|
lmata
|
61 |
cache credCache |
|
5ac549c…
|
lmata
|
62 |
http *http.Client |
|
5ac549c…
|
lmata
|
63 |
} |
|
5ac549c…
|
lmata
|
64 |
|
|
5ac549c…
|
lmata
|
65 |
func newBedrockProvider(cfg BackendConfig, hc *http.Client) (*bedrockProvider, error) { |
|
5ac549c…
|
lmata
|
66 |
if cfg.Region == "" { |
|
5ac549c…
|
lmata
|
67 |
return nil, fmt.Errorf("llm: bedrock requires region") |
|
5ac549c…
|
lmata
|
68 |
} |
|
5ac549c…
|
lmata
|
69 |
model := cfg.Model |
|
5ac549c…
|
lmata
|
70 |
if model == "" { |
|
5ac549c…
|
lmata
|
71 |
model = "anthropic.claude-3-5-sonnet-20241022-v2:0" |
|
5ac549c…
|
lmata
|
72 |
} |
|
5ac549c…
|
lmata
|
73 |
return &bedrockProvider{ |
|
5ac549c…
|
lmata
|
74 |
region: cfg.Region, |
|
5ac549c…
|
lmata
|
75 |
modelID: model, |
|
5ac549c…
|
lmata
|
76 |
baseURL: cfg.BaseURL, |
|
5ac549c…
|
lmata
|
77 |
cfg: cfg, |
|
5ac549c…
|
lmata
|
78 |
http: hc, |
|
5ac549c…
|
lmata
|
79 |
}, nil |
|
5ac549c…
|
lmata
|
80 |
} |
|
5ac549c…
|
lmata
|
81 |
|
|
5ac549c…
|
lmata
|
82 |
// Summarize calls the Bedrock Converse API, which provides a unified interface |
|
5ac549c…
|
lmata
|
83 |
// across all Bedrock-hosted models. |
|
5ac549c…
|
lmata
|
84 |
func (p *bedrockProvider) Summarize(ctx context.Context, prompt string) (string, error) { |
|
5ac549c…
|
lmata
|
85 |
url := p.baseURL |
|
5ac549c…
|
lmata
|
86 |
if url == "" { |
|
5ac549c…
|
lmata
|
87 |
url = fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com", p.region) |
|
5ac549c…
|
lmata
|
88 |
} |
|
5ac549c…
|
lmata
|
89 |
url = fmt.Sprintf("%s/model/%s/converse", url, p.modelID) |
|
5ac549c…
|
lmata
|
90 |
|
|
5ac549c…
|
lmata
|
91 |
body, _ := json.Marshal(map[string]any{ |
|
5ac549c…
|
lmata
|
92 |
"messages": []map[string]any{ |
|
5ac549c…
|
lmata
|
93 |
{ |
|
5ac549c…
|
lmata
|
94 |
"role": "user", |
|
5ac549c…
|
lmata
|
95 |
"content": []map[string]string{ |
|
5ac549c…
|
lmata
|
96 |
{"type": "text", "text": prompt}, |
|
5ac549c…
|
lmata
|
97 |
}, |
|
5ac549c…
|
lmata
|
98 |
}, |
|
5ac549c…
|
lmata
|
99 |
}, |
|
5ac549c…
|
lmata
|
100 |
"inferenceConfig": map[string]any{ |
|
5ac549c…
|
lmata
|
101 |
"maxTokens": 512, |
|
5ac549c…
|
lmata
|
102 |
}, |
|
5ac549c…
|
lmata
|
103 |
}) |
|
5ac549c…
|
lmata
|
104 |
|
|
5ac549c…
|
lmata
|
105 |
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) |
|
5ac549c…
|
lmata
|
106 |
if err != nil { |
|
5ac549c…
|
lmata
|
107 |
return "", err |
|
5ac549c…
|
lmata
|
108 |
} |
|
5ac549c…
|
lmata
|
109 |
req.Header.Set("Content-Type", "application/json") |
|
5ac549c…
|
lmata
|
110 |
if err := p.signRequest(ctx, req, body); err != nil { |
|
5ac549c…
|
lmata
|
111 |
return "", fmt.Errorf("bedrock sign: %w", err) |
|
5ac549c…
|
lmata
|
112 |
} |
|
5ac549c…
|
lmata
|
113 |
|
|
5ac549c…
|
lmata
|
114 |
resp, err := p.http.Do(req) |
|
5ac549c…
|
lmata
|
115 |
if err != nil { |
|
5ac549c…
|
lmata
|
116 |
return "", fmt.Errorf("bedrock request: %w", err) |
|
5ac549c…
|
lmata
|
117 |
} |
|
5ac549c…
|
lmata
|
118 |
defer resp.Body.Close() |
|
5ac549c…
|
lmata
|
119 |
|
|
5ac549c…
|
lmata
|
120 |
data, _ := io.ReadAll(resp.Body) |
|
5ac549c…
|
lmata
|
121 |
if resp.StatusCode != http.StatusOK { |
|
5ac549c…
|
lmata
|
122 |
return "", fmt.Errorf("bedrock error %d: %s", resp.StatusCode, string(data)) |
|
5ac549c…
|
lmata
|
123 |
} |
|
5ac549c…
|
lmata
|
124 |
|
|
5ac549c…
|
lmata
|
125 |
var result struct { |
|
5ac549c…
|
lmata
|
126 |
Output struct { |
|
5ac549c…
|
lmata
|
127 |
Message struct { |
|
5ac549c…
|
lmata
|
128 |
Content []struct { |
|
5ac549c…
|
lmata
|
129 |
Text string `json:"text"` |
|
5ac549c…
|
lmata
|
130 |
} `json:"content"` |
|
5ac549c…
|
lmata
|
131 |
} `json:"message"` |
|
5ac549c…
|
lmata
|
132 |
} `json:"output"` |
|
5ac549c…
|
lmata
|
133 |
} |
|
5ac549c…
|
lmata
|
134 |
if err := json.Unmarshal(data, &result); err != nil { |
|
5ac549c…
|
lmata
|
135 |
return "", fmt.Errorf("bedrock parse: %w", err) |
|
5ac549c…
|
lmata
|
136 |
} |
|
5ac549c…
|
lmata
|
137 |
if len(result.Output.Message.Content) == 0 { |
|
5ac549c…
|
lmata
|
138 |
return "", fmt.Errorf("bedrock returned no content") |
|
5ac549c…
|
lmata
|
139 |
} |
|
5ac549c…
|
lmata
|
140 |
return result.Output.Message.Content[0].Text, nil |
|
5ac549c…
|
lmata
|
141 |
} |
|
5ac549c…
|
lmata
|
142 |
|
|
5ac549c…
|
lmata
|
143 |
// DiscoverModels lists Bedrock foundation models available in the configured region. |
|
5ac549c…
|
lmata
|
144 |
func (p *bedrockProvider) DiscoverModels(ctx context.Context) ([]ModelInfo, error) { |
|
5ac549c…
|
lmata
|
145 |
url := p.baseURL |
|
5ac549c…
|
lmata
|
146 |
if url == "" { |
|
5ac549c…
|
lmata
|
147 |
url = fmt.Sprintf("https://bedrock.%s.amazonaws.com", p.region) |
|
5ac549c…
|
lmata
|
148 |
} |
|
5ac549c…
|
lmata
|
149 |
url = fmt.Sprintf("%s/foundation-models", url) |
|
5ac549c…
|
lmata
|
150 |
req, err := http.NewRequestWithContext(ctx, "GET", url, nil) |
|
5ac549c…
|
lmata
|
151 |
if err != nil { |
|
5ac549c…
|
lmata
|
152 |
return nil, err |
|
5ac549c…
|
lmata
|
153 |
} |
|
5ac549c…
|
lmata
|
154 |
if err := p.signRequest(ctx, req, nil); err != nil { |
|
5ac549c…
|
lmata
|
155 |
return nil, fmt.Errorf("bedrock sign: %w", err) |
|
5ac549c…
|
lmata
|
156 |
} |
|
5ac549c…
|
lmata
|
157 |
|
|
5ac549c…
|
lmata
|
158 |
resp, err := p.http.Do(req) |
|
5ac549c…
|
lmata
|
159 |
if err != nil { |
|
5ac549c…
|
lmata
|
160 |
return nil, fmt.Errorf("bedrock models request: %w", err) |
|
5ac549c…
|
lmata
|
161 |
} |
|
5ac549c…
|
lmata
|
162 |
defer resp.Body.Close() |
|
5ac549c…
|
lmata
|
163 |
|
|
5ac549c…
|
lmata
|
164 |
data, _ := io.ReadAll(resp.Body) |
|
5ac549c…
|
lmata
|
165 |
if resp.StatusCode != http.StatusOK { |
|
5ac549c…
|
lmata
|
166 |
return nil, fmt.Errorf("bedrock models error %d: %s", resp.StatusCode, string(data)) |
|
5ac549c…
|
lmata
|
167 |
} |
|
5ac549c…
|
lmata
|
168 |
|
|
5ac549c…
|
lmata
|
169 |
var result struct { |
|
5ac549c…
|
lmata
|
170 |
ModelSummaries []struct { |
|
5ac549c…
|
lmata
|
171 |
ModelID string `json:"modelId"` |
|
5ac549c…
|
lmata
|
172 |
ModelName string `json:"modelName"` |
|
5ac549c…
|
lmata
|
173 |
} `json:"modelSummaries"` |
|
5ac549c…
|
lmata
|
174 |
} |
|
5ac549c…
|
lmata
|
175 |
if err := json.Unmarshal(data, &result); err != nil { |
|
5ac549c…
|
lmata
|
176 |
return nil, fmt.Errorf("bedrock models parse: %w", err) |
|
5ac549c…
|
lmata
|
177 |
} |
|
5ac549c…
|
lmata
|
178 |
|
|
5ac549c…
|
lmata
|
179 |
models := make([]ModelInfo, len(result.ModelSummaries)) |
|
5ac549c…
|
lmata
|
180 |
for i, m := range result.ModelSummaries { |
|
5ac549c…
|
lmata
|
181 |
models[i] = ModelInfo{ID: m.ModelID, Name: m.ModelName} |
|
5ac549c…
|
lmata
|
182 |
} |
|
5ac549c…
|
lmata
|
183 |
return models, nil |
|
5ac549c…
|
lmata
|
184 |
} |
|
5ac549c…
|
lmata
|
185 |
|
|
5ac549c…
|
lmata
|
186 |
// signRequest resolves credentials (with caching) and applies SigV4 headers. |
|
5ac549c…
|
lmata
|
187 |
func (p *bedrockProvider) signRequest(ctx context.Context, r *http.Request, body []byte) error { |
|
5ac549c…
|
lmata
|
188 |
creds := p.cache.get() |
|
5ac549c…
|
lmata
|
189 |
if creds == nil { |
|
5ac549c…
|
lmata
|
190 |
var err error |
|
5ac549c…
|
lmata
|
191 |
creds, err = resolveAWSCreds(ctx, p.cfg, p.http) |
|
5ac549c…
|
lmata
|
192 |
if err != nil { |
|
5ac549c…
|
lmata
|
193 |
return fmt.Errorf("resolve credentials: %w", err) |
|
5ac549c…
|
lmata
|
194 |
} |
|
5ac549c…
|
lmata
|
195 |
p.cache.set(creds) |
|
5ac549c…
|
lmata
|
196 |
} |
|
5ac549c…
|
lmata
|
197 |
return signSigV4(r, body, creds, p.region, "bedrock") |
|
5ac549c…
|
lmata
|
198 |
} |
|
5ac549c…
|
lmata
|
199 |
|
|
5ac549c…
|
lmata
|
200 |
// --- AWS credential resolution chain --- |
|
5ac549c…
|
lmata
|
201 |
|
|
5ac549c…
|
lmata
|
202 |
// resolveAWSCreds resolves credentials using the standard AWS chain: |
|
5ac549c…
|
lmata
|
203 |
// 1. Static credentials in BackendConfig (AWSKeyID + AWSSecretKey) |
|
5ac549c…
|
lmata
|
204 |
// 2. AWS_ACCESS_KEY_ID / AWS_SECRET_ACCESS_KEY / AWS_SESSION_TOKEN env vars |
|
5ac549c…
|
lmata
|
205 |
// 3. ECS task role via AWS_CONTAINER_CREDENTIALS_RELATIVE_URI or _FULL_URI |
|
5ac549c…
|
lmata
|
206 |
// 4. EC2/EKS instance profile via IMDSv2 |
|
5ac549c…
|
lmata
|
207 |
func resolveAWSCreds(ctx context.Context, cfg BackendConfig, hc *http.Client) (*awsCreds, error) { |
|
5ac549c…
|
lmata
|
208 |
// 1. Static config credentials. |
|
5ac549c…
|
lmata
|
209 |
if cfg.AWSKeyID != "" && cfg.AWSSecretKey != "" { |
|
5ac549c…
|
lmata
|
210 |
return &awsCreds{KeyID: cfg.AWSKeyID, SecretKey: cfg.AWSSecretKey}, nil |
|
5ac549c…
|
lmata
|
211 |
} |
|
5ac549c…
|
lmata
|
212 |
|
|
5ac549c…
|
lmata
|
213 |
// 2. Environment variables. |
|
5ac549c…
|
lmata
|
214 |
if id := os.Getenv("AWS_ACCESS_KEY_ID"); id != "" { |
|
5ac549c…
|
lmata
|
215 |
return &awsCreds{ |
|
5ac549c…
|
lmata
|
216 |
KeyID: id, |
|
5ac549c…
|
lmata
|
217 |
SecretKey: os.Getenv("AWS_SECRET_ACCESS_KEY"), |
|
5ac549c…
|
lmata
|
218 |
SessionToken: os.Getenv("AWS_SESSION_TOKEN"), |
|
5ac549c…
|
lmata
|
219 |
}, nil |
|
5ac549c…
|
lmata
|
220 |
} |
|
5ac549c…
|
lmata
|
221 |
|
|
5ac549c…
|
lmata
|
222 |
// 3. ECS container credentials. |
|
5ac549c…
|
lmata
|
223 |
if rel := os.Getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"); rel != "" { |
|
5ac549c…
|
lmata
|
224 |
return fetchContainerCreds(ctx, "http://169.254.170.2"+rel, "", hc) |
|
5ac549c…
|
lmata
|
225 |
} |
|
5ac549c…
|
lmata
|
226 |
if full := os.Getenv("AWS_CONTAINER_CREDENTIALS_FULL_URI"); full != "" { |
|
5ac549c…
|
lmata
|
227 |
token := os.Getenv("AWS_CONTAINER_AUTHORIZATION_TOKEN") |
|
5ac549c…
|
lmata
|
228 |
return fetchContainerCreds(ctx, full, token, hc) |
|
5ac549c…
|
lmata
|
229 |
} |
|
5ac549c…
|
lmata
|
230 |
|
|
5ac549c…
|
lmata
|
231 |
// 4. EC2 / EKS instance metadata (IMDSv2). |
|
5ac549c…
|
lmata
|
232 |
return fetchIMDSCreds(ctx, hc) |
|
5ac549c…
|
lmata
|
233 |
} |
|
5ac549c…
|
lmata
|
234 |
|
|
5ac549c…
|
lmata
|
235 |
// fetchContainerCreds fetches temporary credentials from the ECS task metadata endpoint. |
|
5ac549c…
|
lmata
|
236 |
func fetchContainerCreds(ctx context.Context, url, token string, hc *http.Client) (*awsCreds, error) { |
|
5ac549c…
|
lmata
|
237 |
req, err := http.NewRequestWithContext(ctx, "GET", url, nil) |
|
5ac549c…
|
lmata
|
238 |
if err != nil { |
|
5ac549c…
|
lmata
|
239 |
return nil, fmt.Errorf("bedrock ecs creds: %w", err) |
|
5ac549c…
|
lmata
|
240 |
} |
|
5ac549c…
|
lmata
|
241 |
if token != "" { |
|
5ac549c…
|
lmata
|
242 |
req.Header.Set("Authorization", token) |
|
5ac549c…
|
lmata
|
243 |
} |
|
5ac549c…
|
lmata
|
244 |
return parseTempCreds(hc, req, "ECS container credentials") |
|
5ac549c…
|
lmata
|
245 |
} |
|
5ac549c…
|
lmata
|
246 |
|
|
5ac549c…
|
lmata
|
247 |
// fetchIMDSCreds fetches temporary credentials via EC2 IMDSv2 (also works for EKS). |
|
5ac549c…
|
lmata
|
248 |
func fetchIMDSCreds(ctx context.Context, hc *http.Client) (*awsCreds, error) { |
|
5ac549c…
|
lmata
|
249 |
const imdsBase = "http://169.254.169.254/latest" |
|
5ac549c…
|
lmata
|
250 |
|
|
5ac549c…
|
lmata
|
251 |
// Step 1: obtain IMDSv2 session token. |
|
5ac549c…
|
lmata
|
252 |
tokenReq, err := http.NewRequestWithContext(ctx, "PUT", imdsBase+"/api/token", nil) |
|
5ac549c…
|
lmata
|
253 |
if err != nil { |
|
5ac549c…
|
lmata
|
254 |
return nil, fmt.Errorf("bedrock imds token request: %w", err) |
|
5ac549c…
|
lmata
|
255 |
} |
|
5ac549c…
|
lmata
|
256 |
tokenReq.Header.Set("X-aws-ec2-metadata-token-ttl-seconds", "21600") |
|
5ac549c…
|
lmata
|
257 |
tokenResp, err := hc.Do(tokenReq) |
|
5ac549c…
|
lmata
|
258 |
if err != nil { |
|
5ac549c…
|
lmata
|
259 |
return nil, fmt.Errorf("bedrock imds: not running on EC2/EKS or IMDS unreachable: %w", err) |
|
5ac549c…
|
lmata
|
260 |
} |
|
5ac549c…
|
lmata
|
261 |
defer tokenResp.Body.Close() |
|
5ac549c…
|
lmata
|
262 |
tokenBytes, _ := io.ReadAll(tokenResp.Body) |
|
5ac549c…
|
lmata
|
263 |
if tokenResp.StatusCode != http.StatusOK { |
|
5ac549c…
|
lmata
|
264 |
return nil, fmt.Errorf("bedrock imds: token request failed (%d)", tokenResp.StatusCode) |
|
5ac549c…
|
lmata
|
265 |
} |
|
5ac549c…
|
lmata
|
266 |
imdsToken := strings.TrimSpace(string(tokenBytes)) |
|
5ac549c…
|
lmata
|
267 |
|
|
5ac549c…
|
lmata
|
268 |
// Step 2: get the IAM role name. |
|
5ac549c…
|
lmata
|
269 |
roleReq, _ := http.NewRequestWithContext(ctx, "GET", imdsBase+"/meta-data/iam/security-credentials/", nil) |
|
5ac549c…
|
lmata
|
270 |
roleReq.Header.Set("X-aws-ec2-metadata-token", imdsToken) |
|
5ac549c…
|
lmata
|
271 |
roleResp, err := hc.Do(roleReq) |
|
5ac549c…
|
lmata
|
272 |
if err != nil { |
|
5ac549c…
|
lmata
|
273 |
return nil, fmt.Errorf("bedrock imds: get role name: %w", err) |
|
5ac549c…
|
lmata
|
274 |
} |
|
5ac549c…
|
lmata
|
275 |
defer roleResp.Body.Close() |
|
5ac549c…
|
lmata
|
276 |
roleBytes, _ := io.ReadAll(roleResp.Body) |
|
5ac549c…
|
lmata
|
277 |
if roleResp.StatusCode != http.StatusOK { |
|
5ac549c…
|
lmata
|
278 |
return nil, fmt.Errorf("bedrock imds: no IAM role attached to instance") |
|
5ac549c…
|
lmata
|
279 |
} |
|
5ac549c…
|
lmata
|
280 |
role := strings.TrimSpace(string(roleBytes)) |
|
5ac549c…
|
lmata
|
281 |
|
|
5ac549c…
|
lmata
|
282 |
// Step 3: fetch credentials for the role. |
|
5ac549c…
|
lmata
|
283 |
credsReq, _ := http.NewRequestWithContext(ctx, "GET", imdsBase+"/meta-data/iam/security-credentials/"+role, nil) |
|
5ac549c…
|
lmata
|
284 |
credsReq.Header.Set("X-aws-ec2-metadata-token", imdsToken) |
|
5ac549c…
|
lmata
|
285 |
return parseTempCreds(hc, credsReq, "EC2 instance metadata") |
|
5ac549c…
|
lmata
|
286 |
} |
|
5ac549c…
|
lmata
|
287 |
|
|
5ac549c…
|
lmata
|
288 |
func parseTempCreds(hc *http.Client, req *http.Request, source string) (*awsCreds, error) { |
|
5ac549c…
|
lmata
|
289 |
resp, err := hc.Do(req) |
|
5ac549c…
|
lmata
|
290 |
if err != nil { |
|
5ac549c…
|
lmata
|
291 |
return nil, fmt.Errorf("bedrock %s: %w", source, err) |
|
5ac549c…
|
lmata
|
292 |
} |
|
5ac549c…
|
lmata
|
293 |
defer resp.Body.Close() |
|
5ac549c…
|
lmata
|
294 |
data, _ := io.ReadAll(resp.Body) |
|
5ac549c…
|
lmata
|
295 |
if resp.StatusCode != http.StatusOK { |
|
5ac549c…
|
lmata
|
296 |
return nil, fmt.Errorf("bedrock %s error %d: %s", source, resp.StatusCode, string(data)) |
|
5ac549c…
|
lmata
|
297 |
} |
|
5ac549c…
|
lmata
|
298 |
var result struct { |
|
5ac549c…
|
lmata
|
299 |
AccessKeyID string `json:"AccessKeyId"` |
|
5ac549c…
|
lmata
|
300 |
SecretAccessKey string `json:"SecretAccessKey"` |
|
5ac549c…
|
lmata
|
301 |
Token string `json:"Token"` |
|
5ac549c…
|
lmata
|
302 |
Expiration string `json:"Expiration"` |
|
5ac549c…
|
lmata
|
303 |
} |
|
5ac549c…
|
lmata
|
304 |
if err := json.Unmarshal(data, &result); err != nil { |
|
5ac549c…
|
lmata
|
305 |
return nil, fmt.Errorf("bedrock %s parse: %w", source, err) |
|
5ac549c…
|
lmata
|
306 |
} |
|
5ac549c…
|
lmata
|
307 |
creds := &awsCreds{ |
|
5ac549c…
|
lmata
|
308 |
KeyID: result.AccessKeyID, |
|
5ac549c…
|
lmata
|
309 |
SecretKey: result.SecretAccessKey, |
|
5ac549c…
|
lmata
|
310 |
SessionToken: result.Token, |
|
5ac549c…
|
lmata
|
311 |
} |
|
5ac549c…
|
lmata
|
312 |
if result.Expiration != "" { |
|
5ac549c…
|
lmata
|
313 |
if t, err := time.Parse(time.RFC3339, result.Expiration); err == nil { |
|
5ac549c…
|
lmata
|
314 |
creds.Expiry = t |
|
5ac549c…
|
lmata
|
315 |
} |
|
5ac549c…
|
lmata
|
316 |
} |
|
5ac549c…
|
lmata
|
317 |
return creds, nil |
|
5ac549c…
|
lmata
|
318 |
} |
|
5ac549c…
|
lmata
|
319 |
|
|
5ac549c…
|
lmata
|
320 |
// --- SigV4 signing --- |
|
5ac549c…
|
lmata
|
321 |
|
|
5ac549c…
|
lmata
|
322 |
// signSigV4 adds AWS Signature Version 4 authentication headers to r. |
|
5ac549c…
|
lmata
|
323 |
// Both bedrock.*.amazonaws.com and bedrock-runtime.*.amazonaws.com use service "bedrock". |
|
5ac549c…
|
lmata
|
324 |
func signSigV4(r *http.Request, body []byte, creds *awsCreds, region, service string) error { |
|
5ac549c…
|
lmata
|
325 |
now := time.Now().UTC() |
|
5ac549c…
|
lmata
|
326 |
dateTime := now.Format("20060102T150405Z") |
|
5ac549c…
|
lmata
|
327 |
date := now.Format("20060102") |
|
5ac549c…
|
lmata
|
328 |
|
|
5ac549c…
|
lmata
|
329 |
var bodyBytes []byte |
|
5ac549c…
|
lmata
|
330 |
if body != nil { |
|
5ac549c…
|
lmata
|
331 |
bodyBytes = body |
|
5ac549c…
|
lmata
|
332 |
} |
|
5ac549c…
|
lmata
|
333 |
bodyHash := sha256Hex(bodyBytes) |
|
5ac549c…
|
lmata
|
334 |
|
|
5ac549c…
|
lmata
|
335 |
r.Header.Set("x-amz-date", dateTime) |
|
5ac549c…
|
lmata
|
336 |
r.Header.Set("x-amz-content-sha256", bodyHash) |
|
5ac549c…
|
lmata
|
337 |
if creds.SessionToken != "" { |
|
5ac549c…
|
lmata
|
338 |
r.Header.Set("x-amz-security-token", creds.SessionToken) |
|
5ac549c…
|
lmata
|
339 |
} |
|
5ac549c…
|
lmata
|
340 |
if r.Host == "" { |
|
5ac549c…
|
lmata
|
341 |
r.Host = r.URL.Host |
|
5ac549c…
|
lmata
|
342 |
} |
|
5ac549c…
|
lmata
|
343 |
|
|
5ac549c…
|
lmata
|
344 |
canonHeaders, signedHeaders := buildHeaders(r) |
|
5ac549c…
|
lmata
|
345 |
path := r.URL.Path |
|
5ac549c…
|
lmata
|
346 |
if path == "" { |
|
5ac549c…
|
lmata
|
347 |
path = "/" |
|
5ac549c…
|
lmata
|
348 |
} |
|
5ac549c…
|
lmata
|
349 |
canonReq := strings.Join([]string{ |
|
5ac549c…
|
lmata
|
350 |
r.Method, |
|
5ac549c…
|
lmata
|
351 |
path, |
|
5ac549c…
|
lmata
|
352 |
r.URL.RawQuery, |
|
5ac549c…
|
lmata
|
353 |
canonHeaders, |
|
5ac549c…
|
lmata
|
354 |
signedHeaders, |
|
5ac549c…
|
lmata
|
355 |
bodyHash, |
|
5ac549c…
|
lmata
|
356 |
}, "\n") |
|
5ac549c…
|
lmata
|
357 |
|
|
5ac549c…
|
lmata
|
358 |
credScope := strings.Join([]string{date, region, service, "aws4_request"}, "/") |
|
5ac549c…
|
lmata
|
359 |
strToSign := strings.Join([]string{ |
|
5ac549c…
|
lmata
|
360 |
"AWS4-HMAC-SHA256", |
|
5ac549c…
|
lmata
|
361 |
dateTime, |
|
5ac549c…
|
lmata
|
362 |
credScope, |
|
5ac549c…
|
lmata
|
363 |
sha256Hex([]byte(canonReq)), |
|
5ac549c…
|
lmata
|
364 |
}, "\n") |
|
5ac549c…
|
lmata
|
365 |
|
|
5ac549c…
|
lmata
|
366 |
sigKey := deriveSigningKey(creds.SecretKey, date, region, service) |
|
5ac549c…
|
lmata
|
367 |
sig := hex.EncodeToString(hmacSHA256(sigKey, []byte(strToSign))) |
|
5ac549c…
|
lmata
|
368 |
|
|
5ac549c…
|
lmata
|
369 |
r.Header.Set("Authorization", fmt.Sprintf( |
|
5ac549c…
|
lmata
|
370 |
"AWS4-HMAC-SHA256 Credential=%s/%s,SignedHeaders=%s,Signature=%s", |
|
5ac549c…
|
lmata
|
371 |
creds.KeyID, credScope, signedHeaders, sig, |
|
5ac549c…
|
lmata
|
372 |
)) |
|
5ac549c…
|
lmata
|
373 |
return nil |
|
5ac549c…
|
lmata
|
374 |
} |
|
5ac549c…
|
lmata
|
375 |
|
|
5ac549c…
|
lmata
|
376 |
func buildHeaders(r *http.Request) (canonical, signed string) { |
|
5ac549c…
|
lmata
|
377 |
seen := map[string]bool{} |
|
5ac549c…
|
lmata
|
378 |
var names []string |
|
5ac549c…
|
lmata
|
379 |
for k := range r.Header { |
|
5ac549c…
|
lmata
|
380 |
lk := strings.ToLower(k) |
|
5ac549c…
|
lmata
|
381 |
if !seen[lk] { |
|
5ac549c…
|
lmata
|
382 |
seen[lk] = true |
|
5ac549c…
|
lmata
|
383 |
names = append(names, lk) |
|
5ac549c…
|
lmata
|
384 |
} |
|
5ac549c…
|
lmata
|
385 |
} |
|
5ac549c…
|
lmata
|
386 |
if !seen["host"] { |
|
5ac549c…
|
lmata
|
387 |
names = append(names, "host") |
|
5ac549c…
|
lmata
|
388 |
} |
|
5ac549c…
|
lmata
|
389 |
sort.Strings(names) |
|
5ac549c…
|
lmata
|
390 |
|
|
5ac549c…
|
lmata
|
391 |
var sb strings.Builder |
|
5ac549c…
|
lmata
|
392 |
for _, h := range names { |
|
5ac549c…
|
lmata
|
393 |
if h == "host" { |
|
5ac549c…
|
lmata
|
394 |
sb.WriteString("host:" + r.Host + "\n") |
|
5ac549c…
|
lmata
|
395 |
} else { |
|
5ac549c…
|
lmata
|
396 |
vals := r.Header[http.CanonicalHeaderKey(h)] |
|
5ac549c…
|
lmata
|
397 |
sb.WriteString(h + ":" + strings.TrimSpace(strings.Join(vals, ",")) + "\n") |
|
5ac549c…
|
lmata
|
398 |
} |
|
5ac549c…
|
lmata
|
399 |
} |
|
5ac549c…
|
lmata
|
400 |
return sb.String(), strings.Join(names, ";") |
|
5ac549c…
|
lmata
|
401 |
} |
|
5ac549c…
|
lmata
|
402 |
|
|
5ac549c…
|
lmata
|
403 |
func sha256Hex(data []byte) string { |
|
5ac549c…
|
lmata
|
404 |
h := sha256.Sum256(data) |
|
5ac549c…
|
lmata
|
405 |
return hex.EncodeToString(h[:]) |
|
5ac549c…
|
lmata
|
406 |
} |
|
5ac549c…
|
lmata
|
407 |
|
|
5ac549c…
|
lmata
|
408 |
func hmacSHA256(key, data []byte) []byte { |
|
5ac549c…
|
lmata
|
409 |
h := hmac.New(sha256.New, key) |
|
5ac549c…
|
lmata
|
410 |
h.Write(data) |
|
5ac549c…
|
lmata
|
411 |
return h.Sum(nil) |
|
5ac549c…
|
lmata
|
412 |
} |
|
5ac549c…
|
lmata
|
413 |
|
|
5ac549c…
|
lmata
|
414 |
func deriveSigningKey(secret, date, region, service string) []byte { |
|
5ac549c…
|
lmata
|
415 |
kDate := hmacSHA256([]byte("AWS4"+secret), []byte(date)) |
|
5ac549c…
|
lmata
|
416 |
kRegion := hmacSHA256(kDate, []byte(region)) |
|
5ac549c…
|
lmata
|
417 |
kService := hmacSHA256(kRegion, []byte(service)) |
|
5ac549c…
|
lmata
|
418 |
return hmacSHA256(kService, []byte("aws4_request")) |
|
5ac549c…
|
lmata
|
419 |
} |