Skip to content

Commit bb7fced

Browse files
authored
feat: make MaxRetries configurable in AI providers
Fixes: #271
1 parent c0d28d0 commit bb7fced

8 files changed

Lines changed: 56 additions & 8 deletions

File tree

config/config.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ type Anthropic struct {
2121
// with a access token. When set, the access token is used for upstream
2222
// LLM requests instead of the API key.
2323
BYOKBearerToken string
24+
// MaxRetries controls the number of automatic retries the SDK will perform
25+
// on transient errors. If nil, the SDK default (2) is used.
26+
// Set to 0 to disable retries entirely.
27+
MaxRetries *int
2428
}
2529

2630
type AWSBedrock struct {
@@ -43,6 +47,10 @@ type OpenAI struct {
4347
CircuitBreaker *CircuitBreaker
4448
SendActorHeaders bool
4549
ExtraHeaders map[string]string
50+
// MaxRetries controls the number of automatic retries the SDK will perform
51+
// on transient errors. If nil, the SDK default (2) is used.
52+
// Set to 0 to disable retries entirely.
53+
MaxRetries *int
4654
}
4755

4856
type Copilot struct {
@@ -51,6 +59,10 @@ type Copilot struct {
5159
BaseURL string
5260
APIDumpDir string
5361
CircuitBreaker *CircuitBreaker
62+
// MaxRetries controls the number of automatic retries the SDK will perform
63+
// on transient errors. If nil, the SDK default (2) is used.
64+
// Set to 0 to disable retries entirely.
65+
MaxRetries *int
5466
}
5567

5668
// CircuitBreaker holds configuration for circuit breakers.

intercept/chatcompletions/base.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ type interceptionBase struct {
4747

4848
func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService {
4949
opts := []option.RequestOption{option.WithAPIKey(i.cfg.Key), option.WithBaseURL(i.cfg.BaseURL)}
50+
if i.cfg.MaxRetries != nil {
51+
opts = append(opts, option.WithMaxRetries(*i.cfg.MaxRetries))
52+
}
5053

5154
// Add extra headers if configured.
5255
// Some providers require additional headers that are not added by the SDK.

intercept/messages/base.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,9 @@ func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...optio
219219
opts = append(opts, option.WithAPIKey(i.cfg.Key))
220220
}
221221
opts = append(opts, option.WithBaseURL(i.cfg.BaseURL))
222+
if i.cfg.MaxRetries != nil {
223+
opts = append(opts, option.WithMaxRetries(*i.cfg.MaxRetries))
224+
}
222225

223226
// Add extra headers if configured.
224227
// Some providers require additional headers that are not added by the SDK.

intercept/responses/base.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ type responsesInterceptionBase struct {
5555

5656
func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService {
5757
opts := []option.RequestOption{option.WithBaseURL(i.cfg.BaseURL), option.WithAPIKey(i.cfg.Key)}
58+
if i.cfg.MaxRetries != nil {
59+
opts = append(opts, option.WithMaxRetries(*i.cfg.MaxRetries))
60+
}
5861

5962
// Add extra headers if configured.
6063
// Some providers require additional headers that are not added by the SDK.

internal/integrationtest/responses_test.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -596,9 +596,6 @@ func TestResponsesParallelToolsOverwritten(t *testing.T) {
596596
}
597597
}
598598

599-
// TODO set MaxRetries to speed up this test
600-
// option.WithMaxRetries(0), in base responses interceptor
601-
// https://github.com/coder/aibridge/issues/115
602599
func TestClientAndConnectionError(t *testing.T) {
603600
t.Parallel()
604601

@@ -642,7 +639,11 @@ func TestClientAndConnectionError(t *testing.T) {
642639
t.Cleanup(cancel)
643640

644641
// tc.addr may be an intentionally invalid URL; use withCustomProvider.
645-
bridgeServer := newBridgeTestServer(ctx, t, tc.addr, withCustomProvider(provider.NewOpenAI(openAICfg(tc.addr, apiKey))))
642+
// MaxRetries is set to 0 to disable SDK retries and speed up the test.
643+
cfg := openAICfg(tc.addr, apiKey)
644+
maxRetries := 0
645+
cfg.MaxRetries = &maxRetries
646+
bridgeServer := newBridgeTestServer(ctx, t, tc.addr, withCustomProvider(provider.NewOpenAI(cfg)))
646647

647648
reqBytes := responsesRequestBytes(t, tc.streaming)
648649
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)
@@ -660,9 +661,6 @@ func TestClientAndConnectionError(t *testing.T) {
660661
}
661662
}
662663

663-
// TODO set MaxRetries to speed up this test
664-
// option.WithMaxRetries(0), in base responses interceptor
665-
// https://github.com/coder/aibridge/issues/115
666664
func TestUpstreamError(t *testing.T) {
667665
t.Parallel()
668666

@@ -721,7 +719,11 @@ func TestUpstreamError(t *testing.T) {
721719
}))
722720
t.Cleanup(upstream.Close)
723721

724-
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL)
722+
// MaxRetries is set to 0 to disable SDK retries and speed up the test.
723+
cfg := openAICfg(upstream.URL, apiKey)
724+
maxRetries := 0
725+
cfg.MaxRetries = &maxRetries
726+
bridgeServer := newBridgeTestServer(ctx, t, upstream.URL, withCustomProvider(provider.NewOpenAI(cfg)))
725727

726728
reqBytes := responsesRequestBytes(t, tc.streaming)
727729
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathOpenAIResponses, reqBytes)

provider/anthropic.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"io"
66
"net/http"
77
"os"
8+
"strconv"
89
"strings"
910

1011
"github.com/google/uuid"
@@ -62,6 +63,13 @@ func NewAnthropic(cfg config.Anthropic, bedrockCfg *config.AWSBedrock) *Anthropi
6263
if cfg.APIDumpDir == "" {
6364
cfg.APIDumpDir = os.Getenv("BRIDGE_DUMP_DIR")
6465
}
66+
if cfg.MaxRetries == nil {
67+
if v := os.Getenv("ANTHROPIC_MAX_RETRIES"); v != "" {
68+
if n, err := strconv.Atoi(v); err == nil {
69+
cfg.MaxRetries = &n
70+
}
71+
}
72+
}
6573
if cfg.CircuitBreaker != nil {
6674
cfg.CircuitBreaker.IsFailure = anthropicIsFailure
6775
cfg.CircuitBreaker.OpenErrorResponse = anthropicOpenErrorResponse

provider/copilot.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"io"
77
"net/http"
88
"os"
9+
"strconv"
910
"strings"
1011

1112
"github.com/google/uuid"
@@ -63,6 +64,13 @@ func NewCopilot(cfg config.Copilot) *Copilot {
6364
if cfg.APIDumpDir == "" {
6465
cfg.APIDumpDir = os.Getenv("BRIDGE_DUMP_DIR")
6566
}
67+
if cfg.MaxRetries == nil {
68+
if v := os.Getenv("COPILOT_MAX_RETRIES"); v != "" {
69+
if n, err := strconv.Atoi(v); err == nil {
70+
cfg.MaxRetries = &n
71+
}
72+
}
73+
}
6674
if cfg.CircuitBreaker != nil {
6775
cfg.CircuitBreaker.OpenErrorResponse = copilotOpenErrorResponse
6876
}
@@ -145,6 +153,7 @@ func (p *Copilot) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trac
145153
APIDumpDir: p.cfg.APIDumpDir,
146154
CircuitBreaker: p.cfg.CircuitBreaker,
147155
ExtraHeaders: extractCopilotHeaders(r),
156+
MaxRetries: p.cfg.MaxRetries,
148157
}
149158

150159
cred := intercept.NewCredentialInfo(intercept.CredentialKindBYOK, key)

provider/openai.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"io"
77
"net/http"
88
"os"
9+
"strconv"
910
"strings"
1011

1112
"github.com/google/uuid"
@@ -51,6 +52,13 @@ func NewOpenAI(cfg config.OpenAI) *OpenAI {
5152
if cfg.APIDumpDir == "" {
5253
cfg.APIDumpDir = os.Getenv("BRIDGE_DUMP_DIR")
5354
}
55+
if cfg.MaxRetries == nil {
56+
if v := os.Getenv("OPENAI_MAX_RETRIES"); v != "" {
57+
if n, err := strconv.Atoi(v); err == nil {
58+
cfg.MaxRetries = &n
59+
}
60+
}
61+
}
5462
if cfg.CircuitBreaker != nil {
5563
cfg.CircuitBreaker.OpenErrorResponse = openAIOpenErrorResponse
5664
}

0 commit comments

Comments
 (0)