@@ -12,6 +12,7 @@ import (
1212 "net/http/httptest"
1313 "strings"
1414 "sync"
15+ "sync/atomic"
1516 "testing"
1617 "time"
1718
@@ -29,13 +30,15 @@ type streamableRequestKey struct {
2930
3031type header map [string ]string
3132
33+ // TODO: replace body and status fields with responseFunc; add helpers to reduce duplication.
3234type streamableResponse struct {
33- header header // response headers
34- status int // or http.StatusOK
35- body string // or ""
36- optional bool // if set, request need not be sent
37- wantProtocolVersion string // if "", unchecked
38- done chan struct {} // if set, receive from this channel before terminating the request
35+ header header // response headers
36+ status int // or http.StatusOK; ignored if responseFunc is set
37+ body string // or ""; ignored if responseFunc is set
38+ responseFunc func (r * jsonrpc.Request ) (string , int ) // if set, overrides body and status
39+ optional bool // if set, request need not be sent
40+ wantProtocolVersion string // if "", unchecked
41+ done chan struct {} // if set, receive from this channel before terminating the request
3942}
4043
4144type fakeResponses map [streamableRequestKey ]* streamableResponse
@@ -44,17 +47,17 @@ type fakeStreamableServer struct {
4447 t * testing.T
4548 responses fakeResponses
4649
47- callMu sync.Mutex
48- calls map [streamableRequestKey ]int
50+ calledMu sync.Mutex
51+ called map [streamableRequestKey ]bool
4952}
5053
5154func (s * fakeStreamableServer ) missingRequests () []streamableRequestKey {
52- s .callMu .Lock ()
53- defer s .callMu .Unlock ()
55+ s .calledMu .Lock ()
56+ defer s .calledMu .Unlock ()
5457
5558 var unused []streamableRequestKey
5659 for k , resp := range s .responses {
57- if s . calls [k ] == 0 && ! resp .optional {
60+ if ! s . called [k ] && ! resp .optional {
5861 unused = append (unused , k )
5962 }
6063 }
@@ -67,6 +70,7 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques
6770 sessionID : req .Header .Get (sessionIDHeader ),
6871 lastEventID : req .Header .Get ("Last-Event-ID" ), // TODO: extract this to a constant, like sessionIDHeader
6972 }
73+ var jsonrpcReq * jsonrpc.Request
7074 if req .Method == http .MethodPost {
7175 body , err := io .ReadAll (req .Body )
7276 if err != nil {
@@ -82,36 +86,44 @@ func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Reques
8286 }
8387 if r , ok := msg .(* jsonrpc.Request ); ok {
8488 key .jsonrpcMethod = r .Method
89+ jsonrpcReq = r
8590 }
8691 }
8792
88- s .callMu .Lock ()
89- if s .calls == nil {
90- s .calls = make (map [streamableRequestKey ]int )
93+ s .calledMu .Lock ()
94+ if s .called == nil {
95+ s .called = make (map [streamableRequestKey ]bool )
9196 }
92- s .calls [key ]++
93- s .callMu .Unlock ()
97+ s .called [key ] = true
98+ s .calledMu .Unlock ()
9499
95100 resp , ok := s .responses [key ]
96101 if ! ok {
97102 s .t .Errorf ("missing response for %v" , key )
98103 http .Error (w , "no response" , http .StatusInternalServerError )
99104 return
100105 }
101- for k , v := range resp . header {
102- w . Header (). Set ( k , v )
103- }
106+
107+ // Determine body and status, potentially using responseFunc for dynamic responses.
108+ body := resp . body
104109 status := resp .status
110+ if resp .responseFunc != nil {
111+ body , status = resp .responseFunc (jsonrpcReq )
112+ }
105113 if status == 0 {
106114 status = http .StatusOK
107115 }
116+
117+ for k , v := range resp .header {
118+ w .Header ().Set (k , v )
119+ }
108120 w .WriteHeader (status )
109121 w .(http.Flusher ).Flush () // flush response headers
110122
111123 if v := req .Header .Get (protocolVersionHeader ); v != resp .wantProtocolVersion && resp .wantProtocolVersion != "" {
112124 s .t .Errorf ("%v: bad protocol version header: got %q, want %q" , key , v , resp .wantProtocolVersion )
113125 }
114- w .Write ([]byte (resp . body ))
126+ w .Write ([]byte (body ))
115127 w .(http.Flusher ).Flush () // flush response
116128
117129 if resp .done != nil {
@@ -555,3 +567,129 @@ data: { "jsonrpc": "2.0", "method": "notifications/message", "params": { "level"
555567 })
556568 }
557569}
570+
571+ // TestStreamableClientTransientErrors verifies that transient errors (timeouts,
572+ // 5xx HTTP status codes) do not permanently break the client connection.
573+ // This tests the fix for issue #683.
574+ func TestStreamableClientTransientErrors (t * testing.T ) {
575+ ctx := context .Background ()
576+
577+ tests := []struct {
578+ transientStatus int // HTTP status to return for the transient call
579+ wantCallError bool // whether the transient call should error
580+ wantSessionBroken bool // whether the session should be broken after
581+ wantErrorContains string // substring expected in error message
582+ }{
583+ {
584+ transientStatus : http .StatusServiceUnavailable ,
585+ wantCallError : true ,
586+ wantSessionBroken : false ,
587+ wantErrorContains : "Service Unavailable" ,
588+ },
589+ {
590+ transientStatus : http .StatusBadGateway ,
591+ wantCallError : true ,
592+ wantSessionBroken : false ,
593+ wantErrorContains : "Bad Gateway" ,
594+ },
595+ {
596+ transientStatus : http .StatusGatewayTimeout ,
597+ wantCallError : true ,
598+ wantSessionBroken : false ,
599+ wantErrorContains : "Gateway Timeout" ,
600+ },
601+ {
602+ transientStatus : http .StatusTooManyRequests ,
603+ wantCallError : true ,
604+ wantSessionBroken : false ,
605+ wantErrorContains : "Too Many Requests" ,
606+ },
607+ {
608+ transientStatus : http .StatusUnauthorized ,
609+ wantCallError : true ,
610+ wantSessionBroken : true ,
611+ wantErrorContains : "Unauthorized" ,
612+ },
613+ {
614+ transientStatus : http .StatusNotFound ,
615+ wantCallError : true ,
616+ wantSessionBroken : true ,
617+ wantErrorContains : "not found" , // NotFound has special handling
618+ },
619+ }
620+
621+ for _ , test := range tests {
622+ t .Run (http .StatusText (test .transientStatus ), func (t * testing.T ) {
623+ var returnedError atomic.Bool
624+ fake := & fakeStreamableServer {
625+ t : t ,
626+ responses : fakeResponses {
627+ {"POST" , "" , methodInitialize , "" }: {
628+ header : header {
629+ "Content-Type" : "application/json" ,
630+ sessionIDHeader : "123" ,
631+ },
632+ body : jsonBody (t , initResp ),
633+ },
634+ {"POST" , "123" , notificationInitialized , "" }: {
635+ status : http .StatusAccepted ,
636+ wantProtocolVersion : latestProtocolVersion ,
637+ },
638+ {"GET" , "123" , "" , "" }: {
639+ status : http .StatusMethodNotAllowed ,
640+ },
641+ {"POST" , "123" , methodListTools , "" }: {
642+ header : header {
643+ "Content-Type" : "application/json" ,
644+ sessionIDHeader : "123" ,
645+ },
646+ responseFunc : func (r * jsonrpc.Request ) (string , int ) {
647+ // First call returns transient error, subsequent calls succeed.
648+ if ! returnedError .Swap (true ) && test .transientStatus != 0 {
649+ return "" , test .transientStatus
650+ }
651+ return jsonBody (t , resp (r .ID .Raw ().(int64 ), & ListToolsResult {Tools : []* Tool {}}, nil )), 0
652+ },
653+ optional : true ,
654+ },
655+ {"DELETE" , "123" , "" , "" }: {optional : true },
656+ },
657+ }
658+
659+ httpServer := httptest .NewServer (fake )
660+ defer httpServer .Close ()
661+
662+ transport := & StreamableClientTransport {Endpoint : httpServer .URL }
663+ client := NewClient (testImpl , nil )
664+ session , err := client .Connect (ctx , transport , nil )
665+ if err != nil {
666+ t .Fatalf ("Connect failed: %v" , err )
667+ }
668+ defer session .Close ()
669+
670+ // First call: should trigger transient error.
671+ _ , err = session .ListTools (ctx , nil )
672+ if test .wantCallError {
673+ if err == nil {
674+ t .Error ("ListTools succeeded unexpectedly, want error" )
675+ } else if test .wantErrorContains != "" && ! strings .Contains (err .Error (), test .wantErrorContains ) {
676+ t .Errorf ("ListTools error = %q, want containing %q" , err .Error (), test .wantErrorContains )
677+ }
678+ } else if err != nil {
679+ t .Errorf ("ListTools failed unexpectedly: %v" , err )
680+ }
681+
682+ // Second call: verifies whether the session is still usable.
683+ _ , err = session .ListTools (ctx , nil )
684+ if test .wantSessionBroken {
685+ if err == nil {
686+ t .Error ("second ListTools succeeded unexpectedly, want session broken" )
687+ }
688+ } else {
689+ if err != nil {
690+ t .Errorf ("second ListTools failed unexpectedly: %v (session should survive transient errors)" , err )
691+ }
692+ }
693+ })
694+ }
695+ }
0 commit comments