diff --git a/pkg/model/provider/anthropic/adapter.go b/pkg/model/provider/anthropic/adapter.go index 3eb776fa2..7133fdee5 100644 --- a/pkg/model/provider/anthropic/adapter.go +++ b/pkg/model/provider/anthropic/adapter.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "strconv" "strings" @@ -19,19 +18,16 @@ import ( // streamAdapter adapts the Anthropic stream to our interface type streamAdapter struct { - stream *ssestream.Stream[anthropic.MessageStreamEventUnion] - trackUsage bool - toolCall bool - toolID string - // For single retry on context length error - retryFn func() *streamAdapter - retried bool + retryableStream[anthropic.MessageStreamEventUnion] + trackUsage bool + toolCall bool + toolID string getResponseTrailer func() http.Header } func (c *Client) newStreamAdapter(stream *ssestream.Stream[anthropic.MessageStreamEventUnion], trackUsage bool) *streamAdapter { return &streamAdapter{ - stream: stream, + retryableStream: retryableStream[anthropic.MessageStreamEventUnion]{stream: stream}, trackUsage: trackUsage, getResponseTrailer: c.getResponseTrailer, } @@ -72,21 +68,9 @@ func isContextLengthError(err error) bool { // Recv gets the next completion chunk func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) { - if !a.stream.Next() { - err := a.stream.Err() - // Single retry on context length error - if err != nil && !a.retried && a.retryFn != nil && isContextLengthError(err) { - a.retried = true - if retry := a.retryFn(); retry != nil { - a.stream.Close() - a.stream = retry.stream - return a.Recv() - } - } - if err != nil { - return chat.MessageStreamResponse{}, err - } - return chat.MessageStreamResponse{}, io.EOF + ok, err := a.next() + if !ok { + return chat.MessageStreamResponse{}, err } event := a.stream.Current() @@ -192,7 +176,5 @@ func parseHeaderInt64(headerValue string) int64 { // Close closes the stream func (a *streamAdapter) Close() { - if a.stream != nil { - a.stream.Close() - } + a.stream.Close() } diff --git a/pkg/model/provider/anthropic/beta_adapter.go b/pkg/model/provider/anthropic/beta_adapter.go index 57e13d14a..ca884b72b 100644 --- a/pkg/model/provider/anthropic/beta_adapter.go +++ b/pkg/model/provider/anthropic/beta_adapter.go @@ -2,7 +2,6 @@ package anthropic import ( "fmt" - "io" "log/slog" "net/http" @@ -15,20 +14,17 @@ import ( // betaStreamAdapter adapts the Anthropic Beta stream to our interface type betaStreamAdapter struct { - stream *ssestream.Stream[anthropic.BetaRawMessageStreamEventUnion] - trackUsage bool - toolCall bool - toolID string - // For single retry on context length error - retryFn func() *betaStreamAdapter - retried bool + retryableStream[anthropic.BetaRawMessageStreamEventUnion] + trackUsage bool + toolCall bool + toolID string getResponseTrailer func() http.Header } // newBetaStreamAdapter creates a new Beta stream adapter func (c *Client) newBetaStreamAdapter(stream *ssestream.Stream[anthropic.BetaRawMessageStreamEventUnion], trackUsage bool) *betaStreamAdapter { return &betaStreamAdapter{ - stream: stream, + retryableStream: retryableStream[anthropic.BetaRawMessageStreamEventUnion]{stream: stream}, trackUsage: trackUsage, getResponseTrailer: c.getResponseTrailer, } @@ -36,21 +32,9 @@ func (c *Client) newBetaStreamAdapter(stream *ssestream.Stream[anthropic.BetaRaw // Recv gets the next completion chunk from the Beta stream func (a *betaStreamAdapter) Recv() (chat.MessageStreamResponse, error) { - if !a.stream.Next() { - err := a.stream.Err() - // Single retry on context length error - if err != nil && !a.retried && a.retryFn != nil && isContextLengthError(err) { - a.retried = true - if retry := a.retryFn(); retry != nil { - a.stream.Close() - a.stream = retry.stream - return a.Recv() - } - } - if err != nil { - return chat.MessageStreamResponse{}, err - } - return chat.MessageStreamResponse{}, io.EOF + ok, err := a.next() + if !ok { + return chat.MessageStreamResponse{}, err } event := a.stream.Current() @@ -137,7 +121,5 @@ func (a *betaStreamAdapter) Recv() (chat.MessageStreamResponse, error) { // Close closes the Beta stream func (a *betaStreamAdapter) Close() { - if a.stream != nil { - a.stream.Close() - } + a.stream.Close() } diff --git a/pkg/model/provider/anthropic/beta_client.go b/pkg/model/provider/anthropic/beta_client.go index fa412817f..55b5aa274 100644 --- a/pkg/model/provider/anthropic/beta_client.go +++ b/pkg/model/provider/anthropic/beta_client.go @@ -128,7 +128,7 @@ func (c *Client) createBetaStream( ad := c.newBetaStreamAdapter(stream, trackUsage) // Set up single retry for context length errors - ad.retryFn = func() *betaStreamAdapter { + ad.retryFn = func() *ssestream.Stream[anthropic.BetaRawMessageStreamEventUnion] { used, err := countAnthropicTokensBeta(ctx, client, anthropic.Model(c.ModelConfig.Model), converted, sys, allTools) if err != nil { slog.Warn("Failed to count tokens for retry, skipping", "error", err) @@ -142,7 +142,7 @@ func (c *Client) createBetaStream( slog.Warn("Retrying with clamped max_tokens after context length error", "original", maxTokens, "clamped", newMaxTokens, "used", used) retryParams := params retryParams.MaxTokens = newMaxTokens - return c.newBetaStreamAdapter(client.Beta.Messages.NewStreaming(ctx, retryParams), trackUsage) + return client.Beta.Messages.NewStreaming(ctx, retryParams) } slog.Debug("Anthropic Beta API chat completion stream created successfully", "model", c.ModelConfig.Model) diff --git a/pkg/model/provider/anthropic/client.go b/pkg/model/provider/anthropic/client.go index 8bafe1cec..10e05b701 100644 --- a/pkg/model/provider/anthropic/client.go +++ b/pkg/model/provider/anthropic/client.go @@ -15,6 +15,7 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/option" "github.com/anthropics/anthropic-sdk-go/packages/param" + "github.com/anthropics/anthropic-sdk-go/packages/ssestream" "github.com/docker/docker-agent/pkg/chat" "github.com/docker/docker-agent/pkg/config/latest" @@ -349,7 +350,7 @@ func (c *Client) CreateChatCompletionStream( ad := c.newStreamAdapter(stream, trackUsage) // Set up single retry for context length errors - ad.retryFn = func() *streamAdapter { + ad.retryFn = func() *ssestream.Stream[anthropic.MessageStreamEventUnion] { used, err := countAnthropicTokens(ctx, client, anthropic.Model(c.ModelConfig.Model), converted, sys, allTools) if err != nil { slog.Warn("Failed to count tokens for retry, skipping", "error", err) @@ -363,7 +364,7 @@ func (c *Client) CreateChatCompletionStream( slog.Warn("Retrying with clamped max_tokens after context length error", "original max_tokens", maxTokens, "clamped max_tokens", newMaxTokens, "used tokens", used) retryParams := params retryParams.MaxTokens = newMaxTokens - return c.newStreamAdapter(client.Messages.NewStreaming(ctx, retryParams, betaHeader), trackUsage) + return client.Messages.NewStreaming(ctx, retryParams, betaHeader) } slog.Debug("Anthropic chat completion stream created successfully", "model", c.ModelConfig.Model) diff --git a/pkg/model/provider/anthropic/retry.go b/pkg/model/provider/anthropic/retry.go new file mode 100644 index 000000000..e9c495c41 --- /dev/null +++ b/pkg/model/provider/anthropic/retry.go @@ -0,0 +1,46 @@ +package anthropic + +import ( + "io" + + "github.com/anthropics/anthropic-sdk-go/packages/ssestream" +) + +// retryableStream wraps an ssestream.Stream and adds a single-retry mechanism +// for context length errors. Both the standard and Beta stream adapters embed +// this to share the retry logic. +type retryableStream[T any] struct { + stream *ssestream.Stream[T] + // retryFn, when non-nil, is called once on a context-length error. + // It should return a new stream to use, or nil to skip retrying. + retryFn func() *ssestream.Stream[T] + retried bool +} + +// next moves the stream forward. If the stream is exhausted it returns +// (false, io.EOF). If it encounters an error it attempts a single retry when +// the error is a context-length error and a retryFn is configured. +// On success it returns (true, nil). +func (r *retryableStream[T]) next() (bool, error) { + if r.stream.Next() { + return true, nil + } + + err := r.stream.Err() + if err != nil && !r.retried && r.retryFn != nil && isContextLengthError(err) { + r.retried = true + if newStream := r.retryFn(); newStream != nil { + r.stream.Close() + r.stream = newStream + ok, err := r.next() + if !ok && err != nil { + r.stream.Close() // Clean up on retry failure + } + return ok, err + } + } + if err != nil { + return false, err + } + return false, io.EOF +}