Skip to content

Commit 875d1d3

Browse files
authored
mcp: don't break the streamable client connection for transient errors (#723)
When POST requests in the streamableClientConn return a transient error, return this error to the caller rather than permanently breaking the connection. This is achieved by using the special sentinel ErrRejected error to the jsonrpc2 layer. In doing so, the change revealed a pre-existing bug: ErrRejected had the same code as ErrConnectionClosing, and jsonrpc2.WireError implements errors.Is, so the two sentinel values could be conflated. This is fixed by using a new internal code. There's more to do for #683: we should also retry transient errors in handleSSE. For #683
1 parent e009bac commit 875d1d3

File tree

4 files changed

+191
-31
lines changed

4 files changed

+191
-31
lines changed

internal/jsonrpc2/conn.go

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -792,13 +792,9 @@ func (c *Connection) write(ctx context.Context, msg Message) error {
792792
err = c.writer.Write(ctx, msg)
793793
}
794794

795-
// For rejected requests, we don't set the writeErr (which would break the
796-
// connection). They can just be returned to the caller.
797-
if errors.Is(err, ErrRejected) {
798-
return err
799-
}
800-
801-
if err != nil && ctx.Err() == nil {
795+
// For cancelled or rejected requests, we don't set the writeErr (which would
796+
// break the connection). They can just be returned to the caller.
797+
if err != nil && ctx.Err() == nil && !errors.Is(err, ErrRejected) {
802798
// The call to Write failed, and since ctx.Err() is nil we can't attribute
803799
// the failure (even indirectly) to Context cancellation. The writer appears
804800
// to be broken, and future writes are likely to also fail.

internal/jsonrpc2/wire.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ var (
4747
// Such failures do not indicate that the connection is broken, but rather
4848
// should be returned to the caller to indicate that the specific request is
4949
// invalid in the current context.
50-
ErrRejected = NewError(-32004, "rejected by transport")
50+
ErrRejected = NewError(-32005, "rejected by transport")
5151
)
5252

5353
const wireVersion = "2.0"

mcp/streamable.go

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1657,11 +1657,18 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
16571657

16581658
resp, err := c.client.Do(req)
16591659
if err != nil {
1660-
return fmt.Errorf("%s: %v", requestSummary, err)
1660+
// Any error from client.Do means the request didn't reach the server.
1661+
// Wrap with ErrRejected so the jsonrpc2 connection doesn't set writeErr
1662+
// and permanently break the connection.
1663+
return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, err)
16611664
}
16621665

16631666
if err := c.checkResponse(requestSummary, resp); err != nil {
1664-
c.fail(err)
1667+
// Only fail the connection for non-transient errors.
1668+
// Transient errors (wrapped with ErrRejected) should not break the connection.
1669+
if !errors.Is(err, jsonrpc2.ErrRejected) {
1670+
c.fail(err)
1671+
}
16651672
return err
16661673
}
16671674

@@ -1826,8 +1833,13 @@ func (c *streamableClientConn) checkResponse(requestSummary string, resp *http.R
18261833
// session is already gone.
18271834
return fmt.Errorf("%s: failed to connect (session ID: %v): %w", requestSummary, c.sessionID, errSessionMissing)
18281835
}
1836+
// Transient server errors (502, 503, 504, 429) should not break the connection.
1837+
// Wrap them with ErrRejected so the jsonrpc2 layer doesn't set writeErr.
1838+
if isTransientHTTPStatus(resp.StatusCode) {
1839+
return fmt.Errorf("%w: %s: %v", jsonrpc2.ErrRejected, requestSummary, http.StatusText(resp.StatusCode))
1840+
}
18291841
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
1830-
return fmt.Errorf("%s: failed to connect: %v", requestSummary, http.StatusText(resp.StatusCode))
1842+
return fmt.Errorf("%s: %v", requestSummary, http.StatusText(resp.StatusCode))
18311843
}
18321844
return nil
18331845
}
@@ -2012,3 +2024,17 @@ func calculateReconnectDelay(attempt int) time.Duration {
20122024

20132025
return backoffDuration + jitter
20142026
}
2027+
2028+
// isTransientHTTPStatus reports whether the HTTP status code indicates a
2029+
// transient server error that should not permanently break the connection.
2030+
func isTransientHTTPStatus(statusCode int) bool {
2031+
switch statusCode {
2032+
case http.StatusInternalServerError, // 500
2033+
http.StatusBadGateway, // 502
2034+
http.StatusServiceUnavailable, // 503
2035+
http.StatusGatewayTimeout, // 504
2036+
http.StatusTooManyRequests: // 429
2037+
return true
2038+
}
2039+
return false
2040+
}

mcp/streamable_client_test.go

Lines changed: 158 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"net/http/httptest"
1313
"strings"
1414
"sync"
15+
"sync/atomic"
1516
"testing"
1617
"time"
1718

@@ -29,13 +30,15 @@ type streamableRequestKey struct {
2930

3031
type header map[string]string
3132

33+
// TODO: replace body and status fields with responseFunc; add helpers to reduce duplication.
3234
type streamableResponse struct {
33-
header header // response headers
34-
status int // or http.StatusOK
35-
body string // or ""
36-
optional bool // if set, request need not be sent
37-
wantProtocolVersion string // if "", unchecked
38-
done chan struct{} // if set, receive from this channel before terminating the request
35+
header header // response headers
36+
status int // or http.StatusOK; ignored if responseFunc is set
37+
body string // or ""; ignored if responseFunc is set
38+
responseFunc func(r *jsonrpc.Request) (string, int) // if set, overrides body and status
39+
optional bool // if set, request need not be sent
40+
wantProtocolVersion string // if "", unchecked
41+
done chan struct{} // if set, receive from this channel before terminating the request
3942
}
4043

4144
type fakeResponses map[streamableRequestKey]*streamableResponse
@@ -44,17 +47,17 @@ type fakeStreamableServer struct {
4447
t *testing.T
4548
responses fakeResponses
4649

47-
callMu sync.Mutex
48-
calls map[streamableRequestKey]int
50+
calledMu sync.Mutex
51+
called map[streamableRequestKey]bool
4952
}
5053

5154
func (s *fakeStreamableServer) missingRequests() []streamableRequestKey {
52-
s.callMu.Lock()
53-
defer s.callMu.Unlock()
55+
s.calledMu.Lock()
56+
defer s.calledMu.Unlock()
5457

5558
var unused []streamableRequestKey
5659
for k, resp := range s.responses {
57-
if s.calls[k] == 0 && !resp.optional {
60+
if !s.called[k] && !resp.optional {
5861
unused = append(unused, k)
5962
}
6063
}
@@ -67,6 +70,7 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques
6770
sessionID: req.Header.Get(sessionIDHeader),
6871
lastEventID: req.Header.Get("Last-Event-ID"), // TODO: extract this to a constant, like sessionIDHeader
6972
}
73+
var jsonrpcReq *jsonrpc.Request
7074
if req.Method == http.MethodPost {
7175
body, err := io.ReadAll(req.Body)
7276
if err != nil {
@@ -82,36 +86,44 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques
8286
}
8387
if r, ok := msg.(*jsonrpc.Request); ok {
8488
key.jsonrpcMethod = r.Method
89+
jsonrpcReq = r
8590
}
8691
}
8792

88-
s.callMu.Lock()
89-
if s.calls == nil {
90-
s.calls = make(map[streamableRequestKey]int)
93+
s.calledMu.Lock()
94+
if s.called == nil {
95+
s.called = make(map[streamableRequestKey]bool)
9196
}
92-
s.calls[key]++
93-
s.callMu.Unlock()
97+
s.called[key] = true
98+
s.calledMu.Unlock()
9499

95100
resp, ok := s.responses[key]
96101
if !ok {
97102
s.t.Errorf("missing response for %v", key)
98103
http.Error(w, "no response", http.StatusInternalServerError)
99104
return
100105
}
101-
for k, v := range resp.header {
102-
w.Header().Set(k, v)
103-
}
106+
107+
// Determine body and status, potentially using responseFunc for dynamic responses.
108+
body := resp.body
104109
status := resp.status
110+
if resp.responseFunc != nil {
111+
body, status = resp.responseFunc(jsonrpcReq)
112+
}
105113
if status == 0 {
106114
status = http.StatusOK
107115
}
116+
117+
for k, v := range resp.header {
118+
w.Header().Set(k, v)
119+
}
108120
w.WriteHeader(status)
109121
w.(http.Flusher).Flush() // flush response headers
110122

111123
if v := req.Header.Get(protocolVersionHeader); v != resp.wantProtocolVersion && resp.wantProtocolVersion != "" {
112124
s.t.Errorf("%v: bad protocol version header: got %q, want %q", key, v, resp.wantProtocolVersion)
113125
}
114-
w.Write([]byte(resp.body))
126+
w.Write([]byte(body))
115127
w.(http.Flusher).Flush() // flush response
116128

117129
if resp.done != nil {
@@ -555,3 +567,129 @@ data: { "jsonrpc": "2.0", "method": "notifications/message", "params": { "level"
555567
})
556568
}
557569
}
570+
571+
// TestStreamableClientTransientErrors verifies that transient errors (timeouts,
572+
// 5xx HTTP status codes) do not permanently break the client connection.
573+
// This tests the fix for issue #683.
574+
func TestStreamableClientTransientErrors(t *testing.T) {
575+
ctx := context.Background()
576+
577+
tests := []struct {
578+
transientStatus int // HTTP status to return for the transient call
579+
wantCallError bool // whether the transient call should error
580+
wantSessionBroken bool // whether the session should be broken after
581+
wantErrorContains string // substring expected in error message
582+
}{
583+
{
584+
transientStatus: http.StatusServiceUnavailable,
585+
wantCallError: true,
586+
wantSessionBroken: false,
587+
wantErrorContains: "Service Unavailable",
588+
},
589+
{
590+
transientStatus: http.StatusBadGateway,
591+
wantCallError: true,
592+
wantSessionBroken: false,
593+
wantErrorContains: "Bad Gateway",
594+
},
595+
{
596+
transientStatus: http.StatusGatewayTimeout,
597+
wantCallError: true,
598+
wantSessionBroken: false,
599+
wantErrorContains: "Gateway Timeout",
600+
},
601+
{
602+
transientStatus: http.StatusTooManyRequests,
603+
wantCallError: true,
604+
wantSessionBroken: false,
605+
wantErrorContains: "Too Many Requests",
606+
},
607+
{
608+
transientStatus: http.StatusUnauthorized,
609+
wantCallError: true,
610+
wantSessionBroken: true,
611+
wantErrorContains: "Unauthorized",
612+
},
613+
{
614+
transientStatus: http.StatusNotFound,
615+
wantCallError: true,
616+
wantSessionBroken: true,
617+
wantErrorContains: "not found", // NotFound has special handling
618+
},
619+
}
620+
621+
for _, test := range tests {
622+
t.Run(http.StatusText(test.transientStatus), func(t *testing.T) {
623+
var returnedError atomic.Bool
624+
fake := &fakeStreamableServer{
625+
t: t,
626+
responses: fakeResponses{
627+
{"POST", "", methodInitialize, ""}: {
628+
header: header{
629+
"Content-Type": "application/json",
630+
sessionIDHeader: "123",
631+
},
632+
body: jsonBody(t, initResp),
633+
},
634+
{"POST", "123", notificationInitialized, ""}: {
635+
status: http.StatusAccepted,
636+
wantProtocolVersion: latestProtocolVersion,
637+
},
638+
{"GET", "123", "", ""}: {
639+
status: http.StatusMethodNotAllowed,
640+
},
641+
{"POST", "123", methodListTools, ""}: {
642+
header: header{
643+
"Content-Type": "application/json",
644+
sessionIDHeader: "123",
645+
},
646+
responseFunc: func(r *jsonrpc.Request) (string, int) {
647+
// First call returns transient error, subsequent calls succeed.
648+
if !returnedError.Swap(true) && test.transientStatus != 0 {
649+
return "", test.transientStatus
650+
}
651+
return jsonBody(t, resp(r.ID.Raw().(int64), &ListToolsResult{Tools: []*Tool{}}, nil)), 0
652+
},
653+
optional: true,
654+
},
655+
{"DELETE", "123", "", ""}: {optional: true},
656+
},
657+
}
658+
659+
httpServer := httptest.NewServer(fake)
660+
defer httpServer.Close()
661+
662+
transport := &StreamableClientTransport{Endpoint: httpServer.URL}
663+
client := NewClient(testImpl, nil)
664+
session, err := client.Connect(ctx, transport, nil)
665+
if err != nil {
666+
t.Fatalf("Connect failed: %v", err)
667+
}
668+
defer session.Close()
669+
670+
// First call: should trigger transient error.
671+
_, err = session.ListTools(ctx, nil)
672+
if test.wantCallError {
673+
if err == nil {
674+
t.Error("ListTools succeeded unexpectedly, want error")
675+
} else if test.wantErrorContains != "" && !strings.Contains(err.Error(), test.wantErrorContains) {
676+
t.Errorf("ListTools error = %q, want containing %q", err.Error(), test.wantErrorContains)
677+
}
678+
} else if err != nil {
679+
t.Errorf("ListTools failed unexpectedly: %v", err)
680+
}
681+
682+
// Second call: verifies whether the session is still usable.
683+
_, err = session.ListTools(ctx, nil)
684+
if test.wantSessionBroken {
685+
if err == nil {
686+
t.Error("second ListTools succeeded unexpectedly, want session broken")
687+
}
688+
} else {
689+
if err != nil {
690+
t.Errorf("second ListTools failed unexpectedly: %v (session should survive transient errors)", err)
691+
}
692+
}
693+
})
694+
}
695+
}

0 commit comments

Comments
 (0)