From 57b9352ad1cc93a0aaaa72b2130e03ace8a5b118 Mon Sep 17 00:00:00 2001 From: Max Denushev Date: Mon, 22 Apr 2024 11:45:33 +0500 Subject: fix: propagate body stream error to close function (#1743) (#1757) * fix: propagate body stream error to close function (#1743) * fix: http test * fix: close body stream with error in encoding functions * fix: lint --------- Co-authored-by: Max Denushev --- client.go | 4 ++-- http.go | 76 ++++++++++++++++++++++++++++++++++++++---------------------- http_test.go | 2 +- 3 files changed, 51 insertions(+), 31 deletions(-) diff --git a/client.go b/client.go index b5493e9..1f12d4a 100644 --- a/client.go +++ b/client.go @@ -2975,12 +2975,12 @@ func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (ret closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST if customStreamBody && resp.bodyStream != nil { rbs := resp.bodyStream - resp.bodyStream = newCloseReader(rbs, func() error { + resp.bodyStream = newCloseReaderWithError(rbs, func(wErr error) error { hc.releaseReader(br) if r, ok := rbs.(*requestStream); ok { releaseRequestStream(r) } - if closeConn || resp.ConnectionClose() { + if closeConn || resp.ConnectionClose() || wErr != nil { hc.closeConn(cc) } else { hc.releaseConn(cc) diff --git a/http.go b/http.go index 30a500d..75d8b6c 100644 --- a/http.go +++ b/http.go @@ -321,26 +321,31 @@ func (resp *Response) BodyStream() io.Reader { } func (resp *Response) CloseBodyStream() error { - return resp.closeBodyStream() + return resp.closeBodyStream(nil) +} + +type ReadCloserWithError interface { + io.Reader + CloseWithError(err error) error } type closeReader struct { io.Reader - closeFunc func() error + closeFunc func(err error) error } -func newCloseReader(r io.Reader, closeFunc func() error) io.ReadCloser { +func newCloseReaderWithError(r io.Reader, closeFunc func(err error) error) ReadCloserWithError { if r == nil { panic(`BUG: reader is nil`) } return &closeReader{Reader: r, closeFunc: closeFunc} } -func (c *closeReader) Close() error { +func (c *closeReader) CloseWithError(err error) error { if c.closeFunc == nil { return nil } - return c.closeFunc() + return c.closeFunc(err) } // BodyWriter returns writer for populating request body. @@ -394,7 +399,7 @@ func (resp *Response) Body() []byte { bodyBuf := resp.bodyBuffer() bodyBuf.Reset() _, err := copyZeroAlloc(bodyBuf, resp.bodyStream) - resp.closeBodyStream() //nolint:errcheck + resp.closeBodyStream(err) //nolint:errcheck if err != nil { bodyBuf.SetString(err.Error()) } @@ -618,7 +623,7 @@ func (req *Request) BodyWriteTo(w io.Writer) error { func (resp *Response) BodyWriteTo(w io.Writer) error { if resp.bodyStream != nil { _, err := copyZeroAlloc(w, resp.bodyStream) - resp.closeBodyStream() //nolint:errcheck + resp.closeBodyStream(err) //nolint:errcheck return err } _, err := w.Write(resp.bodyBytes()) @@ -629,13 +634,13 @@ func (resp *Response) BodyWriteTo(w io.Writer) error { // // It is safe re-using p after the function returns. func (resp *Response) AppendBody(p []byte) { - resp.closeBodyStream() //nolint:errcheck + resp.closeBodyStream(nil) //nolint:errcheck resp.bodyBuffer().Write(p) //nolint:errcheck } // AppendBodyString appends s to response body. func (resp *Response) AppendBodyString(s string) { - resp.closeBodyStream() //nolint:errcheck + resp.closeBodyStream(nil) //nolint:errcheck resp.bodyBuffer().WriteString(s) //nolint:errcheck } @@ -643,7 +648,7 @@ func (resp *Response) AppendBodyString(s string) { // // It is safe re-using body argument after the function returns. func (resp *Response) SetBody(body []byte) { - resp.closeBodyStream() //nolint:errcheck + resp.closeBodyStream(nil) //nolint:errcheck bodyBuf := resp.bodyBuffer() bodyBuf.Reset() bodyBuf.Write(body) //nolint:errcheck @@ -651,7 +656,7 @@ func (resp *Response) SetBody(body []byte) { // SetBodyString sets response body. func (resp *Response) SetBodyString(body string) { - resp.closeBodyStream() //nolint:errcheck + resp.closeBodyStream(nil) //nolint:errcheck bodyBuf := resp.bodyBuffer() bodyBuf.Reset() bodyBuf.WriteString(body) //nolint:errcheck @@ -660,7 +665,7 @@ func (resp *Response) SetBodyString(body string) { // ResetBody resets response body. func (resp *Response) ResetBody() { resp.bodyRaw = nil - resp.closeBodyStream() //nolint:errcheck + resp.closeBodyStream(nil) //nolint:errcheck if resp.body != nil { if resp.keepBodyBuffer { resp.body.Reset() @@ -700,7 +705,7 @@ func (resp *Response) ReleaseBody(size int) { return } if cap(resp.body.B) > size { - resp.closeBodyStream() //nolint:errcheck + resp.closeBodyStream(nil) //nolint:errcheck resp.body = nil } } @@ -734,7 +739,7 @@ func (resp *Response) SwapBody(body []byte) []byte { if resp.bodyStream != nil { bb.Reset() _, err := copyZeroAlloc(bb, resp.bodyStream) - resp.closeBodyStream() //nolint:errcheck + resp.closeBodyStream(err) //nolint:errcheck if err != nil { bb.Reset() bb.SetString(err.Error()) @@ -1725,10 +1730,13 @@ func (resp *Response) brotliBody(level int) { wf: zw, bw: sw, } - copyZeroAlloc(fw, bs) //nolint:errcheck + _, wErr := copyZeroAlloc(fw, bs) releaseStacklessBrotliWriter(zw, level) - if bsc, ok := bs.(io.Closer); ok { - bsc.Close() + switch v := bs.(type) { + case io.Closer: + v.Close() + case ReadCloserWithError: + v.CloseWithError(wErr) //nolint:errcheck } }) } else { @@ -1780,10 +1788,13 @@ func (resp *Response) gzipBody(level int) { wf: zw, bw: sw, } - copyZeroAlloc(fw, bs) //nolint:errcheck + _, wErr := copyZeroAlloc(fw, bs) releaseStacklessGzipWriter(zw, level) - if bsc, ok := bs.(io.Closer); ok { - bsc.Close() + switch v := bs.(type) { + case io.Closer: + v.Close() + case ReadCloserWithError: + v.CloseWithError(wErr) //nolint:errcheck } }) } else { @@ -1835,10 +1846,13 @@ func (resp *Response) deflateBody(level int) { wf: zw, bw: sw, } - copyZeroAlloc(fw, bs) //nolint:errcheck + _, wErr := copyZeroAlloc(fw, bs) releaseStacklessDeflateWriter(zw, level) - if bsc, ok := bs.(io.Closer); ok { - bsc.Close() + switch v := bs.(type) { + case io.Closer: + v.Close() + case ReadCloserWithError: + v.CloseWithError(wErr) //nolint:errcheck } }) } else { @@ -1887,10 +1901,13 @@ func (resp *Response) zstdBody(level int) { wf: zw, bw: sw, } - copyZeroAlloc(fw, bs) //nolint:errcheck + _, wErr := copyZeroAlloc(fw, bs) releaseStacklessZstdWriter(zw, level) - if bsc, ok := bs.(io.Closer); ok { - bsc.Close() + switch v := bs.(type) { + case io.Closer: + v.Close() + case ReadCloserWithError: + v.CloseWithError(wErr) //nolint:errcheck } }) } else { @@ -2053,7 +2070,7 @@ func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) (err error } } } - errc := resp.closeBodyStream() + errc := resp.closeBodyStream(err) if err == nil { err = errc } @@ -2075,7 +2092,7 @@ func (req *Request) closeBodyStream() error { return err } -func (resp *Response) closeBodyStream() error { +func (resp *Response) closeBodyStream(wErr error) error { if resp.bodyStream == nil { return nil } @@ -2083,6 +2100,9 @@ func (resp *Response) closeBodyStream() error { if bsc, ok := resp.bodyStream.(io.Closer); ok { err = bsc.Close() } + if bsc, ok := resp.bodyStream.(ReadCloserWithError); ok { + err = bsc.CloseWithError(wErr) + } if bsr, ok := resp.bodyStream.(*requestStream); ok { releaseRequestStream(bsr) } diff --git a/http_test.go b/http_test.go index b83a487..a9f440a 100644 --- a/http_test.go +++ b/http_test.go @@ -2943,7 +2943,7 @@ func TestResponseBodyStream(t *testing.T) { t.Fatalf("parse response find err: %v", err) } defer func() { - if err := response.closeBodyStream(); err != nil { + if err := response.closeBodyStream(nil); err != nil { t.Fatalf("close body stream err: %v", err) } }() -- cgit v1.2.3