Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ func (c *Context) Response() http.ResponseWriter {
return c.response
}

// SetResponse sets `*http.ResponseWriter`. Some middleware require that given ResponseWriter implements following
// method `Unwrap() http.ResponseWriter` which eventually should return echo.Response instance.
// SetResponse sets `*http.ResponseWriter`. Some context methods and/or middleware require that given ResponseWriter implements following
// method `Unwrap() http.ResponseWriter` which eventually should return *echo.Response instance.
func (c *Context) SetResponse(r http.ResponseWriter) {
c.response = r
}
Expand Down Expand Up @@ -415,6 +415,15 @@ func (c *Context) Render(code int, name string, data any) (err error) {
if c.echo.Renderer == nil {
return ErrRendererNotRegistered
}
// as Renderer.Render can fail, and in that case we need to delay sending status code to the client until
// (global) error handler decides the correct status code for the error to be sent to the client, so we need to write
// the rendered template to the buffer first.
//
// html.Template.ExecuteTemplate() documentations writes:
// > If an error occurs executing the template or writing its output,
// > execution stops, but partial results may already have been written to
// > the output writer.

buf := new(bytes.Buffer)
if err = c.echo.Renderer.Render(c, buf, name, data); err != nil {
return
Expand Down Expand Up @@ -454,7 +463,18 @@ func (c *Context) jsonPBlob(code int, callback string, i any) (err error) {

func (c *Context) json(code int, i any, indent string) error {
c.writeContentType(MIMEApplicationJSON)
c.response.WriteHeader(code)

// as JSONSerializer.Serialize can fail, and in that case we need to delay sending status code to the client until
// (global) error handler decides correct status code for the error to be sent to the client.
// For that we need to use writer that can store the proposed status code until the first Write is called.
if r, err := UnwrapResponse(c.response); err == nil {
r.Status = code
} else {
resp := c.Response()
c.SetResponse(&delayedStatusWriter{ResponseWriter: resp, status: code})
defer c.SetResponse(resp)
}

return c.echo.JSONSerializer.Serialize(c, i, indent)
}

Expand Down
43 changes: 39 additions & 4 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"io"
"io/fs"
"log/slog"
"math"
"mime/multipart"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -138,6 +139,24 @@ func TestContextRenderTemplate(t *testing.T) {
}
}

func TestContextRenderTemplateError(t *testing.T) {
// we test that when template rendering fails, no response is sent to the client yet, so the global error handler can decide what to do
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)

tmpl := &Template{
templates: template.Must(template.New("hello").Parse("Hello, {{.}}!")),
}
c.Echo().Renderer = tmpl
err := c.Render(http.StatusOK, "not_existing", "Jon Snow")

assert.EqualError(t, err, `template: no template "not_existing" associated with template "hello"`)
assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
assert.Empty(t, rec.Body.String()) // body must not be sent to the client
}

func TestContextRenderErrorsOnNoRenderer(t *testing.T) {
e := New()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
Expand Down Expand Up @@ -173,10 +192,9 @@ func TestContextStream(t *testing.T) {
}

func TestContextHTML(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, rec)
c := NewContext(req, rec)

err := c.HTML(http.StatusOK, "Hi, Jon Snow")
if assert.NoError(t, err) {
Expand All @@ -187,10 +205,9 @@ func TestContextHTML(t *testing.T) {
}

func TestContextHTMLBlob(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/", nil)
c := e.NewContext(req, rec)
c := NewContext(req, rec)

err := c.HTMLBlob(http.StatusOK, []byte("Hi, Jon Snow"))
if assert.NoError(t, err) {
Expand Down Expand Up @@ -222,6 +239,24 @@ func TestContextJSONErrorsOut(t *testing.T) {

err := c.JSON(http.StatusOK, make(chan bool))
assert.EqualError(t, err, "json: unsupported type: chan bool")

assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
assert.Empty(t, rec.Body.String()) // body must not be sent to the client
}

func TestContextJSONWithNotEchoResponse(t *testing.T) {
e := New()
rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(userJSON))
c := e.NewContext(req, rec)

c.SetResponse(rec)

err := c.JSON(http.StatusCreated, map[string]float64{"foo": math.NaN()})
assert.EqualError(t, err, "json: unsupported value: NaN")

assert.Equal(t, http.StatusOK, rec.Code) // status code must not be sent to the client
assert.Empty(t, rec.Body.String()) // body must not be sent to the client
}

func TestContextJSONPretty(t *testing.T) {
Expand Down
42 changes: 41 additions & 1 deletion response.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,47 @@ func UnwrapResponse(rw http.ResponseWriter) (*Response, error) {
rw = t.Unwrap()
continue
default:
return nil, errors.New("ResponseWriter does not implement 'Unwrap() http.ResponseWriter' interface")
return nil, errors.New("ResponseWriter does not implement 'Unwrap() http.ResponseWriter' interface or unwrap to *echo.Response")
}
}
}

// delayedStatusWriter is a wrapper around http.ResponseWriter that delays writing the status code until first Write is called.
// This allows (global) error handler to decide correct status code to be sent to the client.
type delayedStatusWriter struct {
http.ResponseWriter
commited bool
status int
}

func (w *delayedStatusWriter) WriteHeader(statusCode int) {
// in case something else writes status code explicitly before us we need mark response commited
w.commited = true
w.ResponseWriter.WriteHeader(statusCode)
}

func (w *delayedStatusWriter) Write(data []byte) (int, error) {
if !w.commited {
w.commited = true
if w.status == 0 {
w.status = http.StatusOK
}
w.ResponseWriter.WriteHeader(w.status)
}
return w.ResponseWriter.Write(data)
}

func (w *delayedStatusWriter) Flush() {
err := http.NewResponseController(w.ResponseWriter).Flush()
if err != nil && errors.Is(err, http.ErrNotSupported) {
panic(errors.New("response writer flushing is not supported"))
}
}

func (w *delayedStatusWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return http.NewResponseController(w.ResponseWriter).Hijack()
}

func (w *delayedStatusWriter) Unwrap() http.ResponseWriter {
return w.ResponseWriter
}
16 changes: 16 additions & 0 deletions response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,19 @@ func TestResponse_FlushPanics(t *testing.T) {
res.Flush()
})
}

func TestResponse_UnwrapResponse(t *testing.T) {
orgRes := NewResponse(httptest.NewRecorder(), nil)
res, err := UnwrapResponse(orgRes)

assert.NotNil(t, res)
assert.NoError(t, err)
}

func TestResponse_UnwrapResponse_error(t *testing.T) {
rw := new(testResponseWriter)
res, err := UnwrapResponse(rw)

assert.Nil(t, res)
assert.EqualError(t, err, "ResponseWriter does not implement 'Unwrap() http.ResponseWriter' interface or unwrap to *echo.Response")
}
Loading