Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 9 additions & 27 deletions pkg/model/provider/anthropic/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
Expand All @@ -19,19 +18,16 @@ import (

// streamAdapter adapts the Anthropic stream to our interface
type streamAdapter struct {
stream *ssestream.Stream[anthropic.MessageStreamEventUnion]
trackUsage bool
toolCall bool
toolID string
// For single retry on context length error
retryFn func() *streamAdapter
retried bool
retryableStream[anthropic.MessageStreamEventUnion]
trackUsage bool
toolCall bool
toolID string
getResponseTrailer func() http.Header
}

func (c *Client) newStreamAdapter(stream *ssestream.Stream[anthropic.MessageStreamEventUnion], trackUsage bool) *streamAdapter {
return &streamAdapter{
stream: stream,
retryableStream: retryableStream[anthropic.MessageStreamEventUnion]{stream: stream},
trackUsage: trackUsage,
getResponseTrailer: c.getResponseTrailer,
}
Expand Down Expand Up @@ -72,21 +68,9 @@ func isContextLengthError(err error) bool {

// Recv gets the next completion chunk
func (a *streamAdapter) Recv() (chat.MessageStreamResponse, error) {
if !a.stream.Next() {
err := a.stream.Err()
// Single retry on context length error
if err != nil && !a.retried && a.retryFn != nil && isContextLengthError(err) {
a.retried = true
if retry := a.retryFn(); retry != nil {
a.stream.Close()
a.stream = retry.stream
return a.Recv()
}
}
if err != nil {
return chat.MessageStreamResponse{}, err
}
return chat.MessageStreamResponse{}, io.EOF
ok, err := a.next()
if !ok {
return chat.MessageStreamResponse{}, err
}

event := a.stream.Current()
Expand Down Expand Up @@ -192,7 +176,5 @@ func parseHeaderInt64(headerValue string) int64 {

// Close closes the stream
func (a *streamAdapter) Close() {
if a.stream != nil {
a.stream.Close()
}
a.stream.Close()
}
36 changes: 9 additions & 27 deletions pkg/model/provider/anthropic/beta_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package anthropic

import (
"fmt"
"io"
"log/slog"
"net/http"

Expand All @@ -15,42 +14,27 @@ import (

// betaStreamAdapter adapts the Anthropic Beta stream to our interface
type betaStreamAdapter struct {
stream *ssestream.Stream[anthropic.BetaRawMessageStreamEventUnion]
trackUsage bool
toolCall bool
toolID string
// For single retry on context length error
retryFn func() *betaStreamAdapter
retried bool
retryableStream[anthropic.BetaRawMessageStreamEventUnion]
trackUsage bool
toolCall bool
toolID string
getResponseTrailer func() http.Header
}

// newBetaStreamAdapter creates a new Beta stream adapter
func (c *Client) newBetaStreamAdapter(stream *ssestream.Stream[anthropic.BetaRawMessageStreamEventUnion], trackUsage bool) *betaStreamAdapter {
return &betaStreamAdapter{
stream: stream,
retryableStream: retryableStream[anthropic.BetaRawMessageStreamEventUnion]{stream: stream},
trackUsage: trackUsage,
getResponseTrailer: c.getResponseTrailer,
}
}

// Recv gets the next completion chunk from the Beta stream
func (a *betaStreamAdapter) Recv() (chat.MessageStreamResponse, error) {
if !a.stream.Next() {
err := a.stream.Err()
// Single retry on context length error
if err != nil && !a.retried && a.retryFn != nil && isContextLengthError(err) {
a.retried = true
if retry := a.retryFn(); retry != nil {
a.stream.Close()
a.stream = retry.stream
return a.Recv()
}
}
if err != nil {
return chat.MessageStreamResponse{}, err
}
return chat.MessageStreamResponse{}, io.EOF
ok, err := a.next()
if !ok {
return chat.MessageStreamResponse{}, err
}

event := a.stream.Current()
Expand Down Expand Up @@ -137,7 +121,5 @@ func (a *betaStreamAdapter) Recv() (chat.MessageStreamResponse, error) {

// Close closes the Beta stream
func (a *betaStreamAdapter) Close() {
if a.stream != nil {
a.stream.Close()
}
a.stream.Close()
}
4 changes: 2 additions & 2 deletions pkg/model/provider/anthropic/beta_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func (c *Client) createBetaStream(
ad := c.newBetaStreamAdapter(stream, trackUsage)

// Set up single retry for context length errors
ad.retryFn = func() *betaStreamAdapter {
ad.retryFn = func() *ssestream.Stream[anthropic.BetaRawMessageStreamEventUnion] {
used, err := countAnthropicTokensBeta(ctx, client, anthropic.Model(c.ModelConfig.Model), converted, sys, allTools)
if err != nil {
slog.Warn("Failed to count tokens for retry, skipping", "error", err)
Expand All @@ -142,7 +142,7 @@ func (c *Client) createBetaStream(
slog.Warn("Retrying with clamped max_tokens after context length error", "original", maxTokens, "clamped", newMaxTokens, "used", used)
retryParams := params
retryParams.MaxTokens = newMaxTokens
return c.newBetaStreamAdapter(client.Beta.Messages.NewStreaming(ctx, retryParams), trackUsage)
return client.Beta.Messages.NewStreaming(ctx, retryParams)
}

slog.Debug("Anthropic Beta API chat completion stream created successfully", "model", c.ModelConfig.Model)
Expand Down
5 changes: 3 additions & 2 deletions pkg/model/provider/anthropic/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/option"
"github.com/anthropics/anthropic-sdk-go/packages/param"
"github.com/anthropics/anthropic-sdk-go/packages/ssestream"

"github.com/docker/docker-agent/pkg/chat"
"github.com/docker/docker-agent/pkg/config/latest"
Expand Down Expand Up @@ -349,7 +350,7 @@ func (c *Client) CreateChatCompletionStream(
ad := c.newStreamAdapter(stream, trackUsage)

// Set up single retry for context length errors
ad.retryFn = func() *streamAdapter {
ad.retryFn = func() *ssestream.Stream[anthropic.MessageStreamEventUnion] {
used, err := countAnthropicTokens(ctx, client, anthropic.Model(c.ModelConfig.Model), converted, sys, allTools)
if err != nil {
slog.Warn("Failed to count tokens for retry, skipping", "error", err)
Expand All @@ -363,7 +364,7 @@ func (c *Client) CreateChatCompletionStream(
slog.Warn("Retrying with clamped max_tokens after context length error", "original max_tokens", maxTokens, "clamped max_tokens", newMaxTokens, "used tokens", used)
retryParams := params
retryParams.MaxTokens = newMaxTokens
return c.newStreamAdapter(client.Messages.NewStreaming(ctx, retryParams, betaHeader), trackUsage)
return client.Messages.NewStreaming(ctx, retryParams, betaHeader)
}

slog.Debug("Anthropic chat completion stream created successfully", "model", c.ModelConfig.Model)
Expand Down
46 changes: 46 additions & 0 deletions pkg/model/provider/anthropic/retry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package anthropic

import (
"io"

"github.com/anthropics/anthropic-sdk-go/packages/ssestream"
)

// retryableStream wraps an ssestream.Stream and adds a single-retry mechanism
// for context length errors. Both the standard and Beta stream adapters embed
// this to share the retry logic.
type retryableStream[T any] struct {
stream *ssestream.Stream[T]
// retryFn, when non-nil, is called once on a context-length error.
// It should return a new stream to use, or nil to skip retrying.
retryFn func() *ssestream.Stream[T]
retried bool
}

// next moves the stream forward. If the stream is exhausted it returns
// (false, io.EOF). If it encounters an error it attempts a single retry when
// the error is a context-length error and a retryFn is configured.
// On success it returns (true, nil).
func (r *retryableStream[T]) next() (bool, error) {
if r.stream.Next() {
return true, nil
}

err := r.stream.Err()
if err != nil && !r.retried && r.retryFn != nil && isContextLengthError(err) {
r.retried = true
if newStream := r.retryFn(); newStream != nil {
r.stream.Close()
r.stream = newStream
ok, err := r.next()
if !ok && err != nil {
r.stream.Close() // Clean up on retry failure
}
return ok, err
}
}
if err != nil {
return false, err
}
return false, io.EOF
}