Skip to content

Commit ee2c4b4

Browse files
authored
feat: support IAM role-based authentication for AWS Bedrock (#265)
1 parent ab88767 commit ee2c4b4

3 files changed

Lines changed: 123 additions & 18 deletions

File tree

config/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ type Anthropic struct {
2626
type AWSBedrock struct {
2727
Region string
2828
AccessKey, AccessKeySecret string
29+
SessionToken string
2930
Model, SmallFastModel string
3031
// If set, requests will be sent to this URL instead of the default AWS Bedrock endpoint
3132
// (https://bedrock-runtime.{region}.amazonaws.com).

intercept/messages/base.go

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -264,42 +264,60 @@ func (i *interceptionBase) withBody() option.RequestOption {
264264
return option.WithRequestBody("application/json", []byte(i.reqPayload))
265265
}
266266

267+
// withAWSBedrockOptions returns request options for authenticating with AWS Bedrock.
268+
//
269+
// When both AccessKey and AccessKeySecret are set in the aibridge config, they are
270+
// used directly as static credentials (with an optional SessionToken for temporary credentials).
271+
// Otherwise, the AWS SDK default credential chain resolves credentials (environment variables,
272+
// shared config/credentials files, IAM roles, IRSA, SSO, IMDS, etc.).
267273
func (*interceptionBase) withAWSBedrockOptions(ctx context.Context, cfg *aibconfig.AWSBedrock) ([]option.RequestOption, error) {
268274
if cfg == nil {
269275
return nil, xerrors.New("nil config given")
270276
}
271277
if cfg.Region == "" && cfg.BaseURL == "" {
272278
return nil, xerrors.New("region or base url required")
273279
}
274-
if cfg.AccessKey == "" {
275-
return nil, xerrors.New("access key required")
276-
}
277-
if cfg.AccessKeySecret == "" {
278-
return nil, xerrors.New("access key secret required")
279-
}
280280
if cfg.Model == "" {
281281
return nil, xerrors.New("model required")
282282
}
283283
if cfg.SmallFastModel == "" {
284284
return nil, xerrors.New("small fast model required")
285285
}
286286

287-
opts := []func(*config.LoadOptions) error{
287+
loadOpts := []func(*config.LoadOptions) error{
288288
config.WithRegion(cfg.Region),
289-
config.WithCredentialsProvider(
289+
}
290+
291+
// Use static credentials when explicitly provided, otherwise fall back to the SDK default credential chain.
292+
switch {
293+
// Both set: use static credentials directly.
294+
case cfg.AccessKey != "" && cfg.AccessKeySecret != "":
295+
loadOpts = append(loadOpts, config.WithCredentialsProvider(
290296
credentials.NewStaticCredentialsProvider(
291297
cfg.AccessKey,
292298
cfg.AccessKeySecret,
293-
"",
299+
cfg.SessionToken, // optional
294300
),
295-
),
301+
))
302+
// Only one set: misconfiguration.
303+
case cfg.AccessKey != "" || cfg.AccessKeySecret != "":
304+
return nil, xerrors.New("both access key and access key secret must be provided together")
305+
// Neither set: SDK default credential chain resolves credentials.
306+
default:
296307
}
297308

298-
awsCfg, err := config.LoadDefaultConfig(ctx, opts...)
309+
awsCfg, err := config.LoadDefaultConfig(ctx, loadOpts...)
299310
if err != nil {
300311
return nil, xerrors.Errorf("failed to load AWS Bedrock config: %w", err)
301312
}
302313

314+
// Fail fast: ensure credentials can be resolved before making any requests.
315+
// awsCfg already carries the credentials provider, and the Bedrock middleware
316+
// will call Retrieve on it when signing each request.
317+
if _, err := awsCfg.Credentials.Retrieve(ctx); err != nil {
318+
return nil, xerrors.Errorf("no AWS credentials found: %w", err)
319+
}
320+
303321
var out []option.RequestOption
304322
out = append(out, bedrock.WithConfig(awsCfg))
305323

intercept/messages/base_test.go

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ func TestAWSBedrockValidation(t *testing.T) {
8484
expectError bool
8585
errorMsg string
8686
}{
87-
// Valid cases.
87+
// Valid cases: static credentials.
8888
{
89-
name: "valid with region",
89+
name: "static credentials with region",
9090
cfg: &config.AWSBedrock{
9191
Region: "us-east-1",
9292
AccessKey: "test-key",
@@ -96,7 +96,7 @@ func TestAWSBedrockValidation(t *testing.T) {
9696
},
9797
},
9898
{
99-
name: "valid with base url",
99+
name: "static credentials with base url",
100100
cfg: &config.AWSBedrock{
101101
BaseURL: "http://bedrock.internal",
102102
AccessKey: "test-key",
@@ -111,7 +111,7 @@ func TestAWSBedrockValidation(t *testing.T) {
111111
// which is internal to the anthropic SDK.
112112
//
113113
// See TestAWSBedrockIntegration which validates this.
114-
name: "valid with base url & region",
114+
name: "static credentials with base url & region",
115115
cfg: &config.AWSBedrock{
116116
Region: "us-east-1",
117117
AccessKey: "test-key",
@@ -120,6 +120,17 @@ func TestAWSBedrockValidation(t *testing.T) {
120120
SmallFastModel: "test-small-model",
121121
},
122122
},
123+
{
124+
name: "static credentials with session token",
125+
cfg: &config.AWSBedrock{
126+
Region: "us-east-1",
127+
AccessKey: "test-key",
128+
AccessKeySecret: "test-secret",
129+
SessionToken: "test-session-token",
130+
Model: "test-model",
131+
SmallFastModel: "test-small-model",
132+
},
133+
},
123134
// Invalid cases.
124135
{
125136
name: "missing region & base url",
@@ -137,13 +148,12 @@ func TestAWSBedrockValidation(t *testing.T) {
137148
name: "missing access key",
138149
cfg: &config.AWSBedrock{
139150
Region: "us-east-1",
140-
AccessKey: "",
141151
AccessKeySecret: "test-secret",
142152
Model: "test-model",
143153
SmallFastModel: "test-small-model",
144154
},
145155
expectError: true,
146-
errorMsg: "access key required",
156+
errorMsg: "both access key and access key secret must be provided together",
147157
},
148158
{
149159
name: "missing access key secret",
@@ -155,7 +165,7 @@ func TestAWSBedrockValidation(t *testing.T) {
155165
SmallFastModel: "test-small-model",
156166
},
157167
expectError: true,
158-
errorMsg: "access key secret required",
168+
errorMsg: "both access key and access key secret must be provided together",
159169
},
160170
{
161171
name: "missing model",
@@ -213,6 +223,82 @@ func TestAWSBedrockValidation(t *testing.T) {
213223
}
214224
}
215225

226+
// TestAWSBedrockCredentialChain tests credential resolution via the AWS SDK default credential chain.
227+
// NOTE: Cannot use t.Parallel() here because subtests use t.Setenv which requires sequential execution.
228+
func TestAWSBedrockCredentialChain(t *testing.T) {
229+
tests := []struct {
230+
name string
231+
cfg *config.AWSBedrock
232+
envVars map[string]string
233+
expectError bool
234+
errorMsg string
235+
}{
236+
{
237+
name: "temporary credentials via env",
238+
cfg: &config.AWSBedrock{
239+
Region: "us-east-1",
240+
Model: "test-model",
241+
SmallFastModel: "test-small-model",
242+
},
243+
envVars: map[string]string{
244+
"AWS_ACCESS_KEY_ID": "test-key",
245+
"AWS_SECRET_ACCESS_KEY": "test-secret",
246+
},
247+
},
248+
{
249+
name: "temporary credentials with session token via env",
250+
cfg: &config.AWSBedrock{
251+
Region: "us-east-1",
252+
Model: "test-model",
253+
SmallFastModel: "test-small-model",
254+
},
255+
envVars: map[string]string{
256+
"AWS_ACCESS_KEY_ID": "test-key",
257+
"AWS_SECRET_ACCESS_KEY": "test-secret",
258+
"AWS_SESSION_TOKEN": "test-session-token",
259+
},
260+
},
261+
{
262+
// When static credentials are not provided and no environment credentials are set,
263+
// the SDK default credential chain fails to resolve credentials.
264+
name: "error when no credential source is configured",
265+
cfg: &config.AWSBedrock{
266+
Region: "us-east-1",
267+
Model: "test-model",
268+
SmallFastModel: "test-small-model",
269+
},
270+
envVars: map[string]string{
271+
"AWS_ACCESS_KEY_ID": "",
272+
"AWS_SECRET_ACCESS_KEY": "",
273+
"AWS_SESSION_TOKEN": "",
274+
"AWS_PROFILE": "",
275+
"AWS_SHARED_CREDENTIALS_FILE": "/dev/null",
276+
"AWS_CONFIG_FILE": "/dev/null",
277+
},
278+
expectError: true,
279+
errorMsg: "no AWS credentials found",
280+
},
281+
}
282+
283+
for _, tt := range tests {
284+
t.Run(tt.name, func(t *testing.T) {
285+
for key, val := range tt.envVars {
286+
t.Setenv(key, val)
287+
}
288+
base := &interceptionBase{}
289+
opts, err := base.withAWSBedrockOptions(context.Background(), tt.cfg)
290+
291+
if tt.expectError {
292+
require.Error(t, err)
293+
require.Contains(t, err.Error(), tt.errorMsg)
294+
} else {
295+
require.NotEmpty(t, opts)
296+
require.NoError(t, err)
297+
}
298+
})
299+
}
300+
}
301+
216302
func TestAccumulateUsage(t *testing.T) {
217303
t.Parallel()
218304

0 commit comments

Comments
 (0)