aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Max Denushev <mdenushev@ya.ru> 2024-04-22 11:45:33 +0500
committerGravatar GitHub <noreply@github.com> 2024-04-22 08:45:33 +0200
commit57b9352ad1cc93a0aaaa72b2130e03ace8a5b118 (patch)
tree2d70338e58a3243d05709b1fcd660cccc1ddcb94
parentrefactor: do not return error as it is always nil (#1759) (diff)
downloadfasthttp-57b9352ad1cc93a0aaaa72b2130e03ace8a5b118.tar.gz
fasthttp-57b9352ad1cc93a0aaaa72b2130e03ace8a5b118.tar.bz2
fasthttp-57b9352ad1cc93a0aaaa72b2130e03ace8a5b118.zip
fix: propagate body stream error to close function (#1743) (#1757)
* fix: propagate body stream error to close function (#1743) * fix: http test * fix: close body stream with error in encoding functions * fix: lint --------- Co-authored-by: Max Denushev <denushev@tochka.com>
-rw-r--r--client.go4
-rw-r--r--http.go76
-rw-r--r--http_test.go2
3 files changed, 51 insertions, 31 deletions
diff --git a/client.go b/client.go
index b5493e9..1f12d4a 100644
--- a/client.go
+++ b/client.go
@@ -2975,12 +2975,12 @@ func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (ret
closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
if customStreamBody && resp.bodyStream != nil {
rbs := resp.bodyStream
- resp.bodyStream = newCloseReader(rbs, func() error {
+ resp.bodyStream = newCloseReaderWithError(rbs, func(wErr error) error {
hc.releaseReader(br)
if r, ok := rbs.(*requestStream); ok {
releaseRequestStream(r)
}
- if closeConn || resp.ConnectionClose() {
+ if closeConn || resp.ConnectionClose() || wErr != nil {
hc.closeConn(cc)
} else {
hc.releaseConn(cc)
diff --git a/http.go b/http.go
index 30a500d..75d8b6c 100644
--- a/http.go
+++ b/http.go
@@ -321,26 +321,31 @@ func (resp *Response) BodyStream() io.Reader {
}
func (resp *Response) CloseBodyStream() error {
- return resp.closeBodyStream()
+ return resp.closeBodyStream(nil)
+}
+
+type ReadCloserWithError interface {
+ io.Reader
+ CloseWithError(err error) error
}
type closeReader struct {
io.Reader
- closeFunc func() error
+ closeFunc func(err error) error
}
-func newCloseReader(r io.Reader, closeFunc func() error) io.ReadCloser {
+func newCloseReaderWithError(r io.Reader, closeFunc func(err error) error) ReadCloserWithError {
if r == nil {
panic(`BUG: reader is nil`)
}
return &closeReader{Reader: r, closeFunc: closeFunc}
}
-func (c *closeReader) Close() error {
+func (c *closeReader) CloseWithError(err error) error {
if c.closeFunc == nil {
return nil
}
- return c.closeFunc()
+ return c.closeFunc(err)
}
// BodyWriter returns writer for populating request body.
@@ -394,7 +399,7 @@ func (resp *Response) Body() []byte {
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
_, err := copyZeroAlloc(bodyBuf, resp.bodyStream)
- resp.closeBodyStream() //nolint:errcheck
+ resp.closeBodyStream(err) //nolint:errcheck
if err != nil {
bodyBuf.SetString(err.Error())
}
@@ -618,7 +623,7 @@ func (req *Request) BodyWriteTo(w io.Writer) error {
func (resp *Response) BodyWriteTo(w io.Writer) error {
if resp.bodyStream != nil {
_, err := copyZeroAlloc(w, resp.bodyStream)
- resp.closeBodyStream() //nolint:errcheck
+ resp.closeBodyStream(err) //nolint:errcheck
return err
}
_, err := w.Write(resp.bodyBytes())
@@ -629,13 +634,13 @@ func (resp *Response) BodyWriteTo(w io.Writer) error {
//
// It is safe re-using p after the function returns.
func (resp *Response) AppendBody(p []byte) {
- resp.closeBodyStream() //nolint:errcheck
+ resp.closeBodyStream(nil) //nolint:errcheck
resp.bodyBuffer().Write(p) //nolint:errcheck
}
// AppendBodyString appends s to response body.
func (resp *Response) AppendBodyString(s string) {
- resp.closeBodyStream() //nolint:errcheck
+ resp.closeBodyStream(nil) //nolint:errcheck
resp.bodyBuffer().WriteString(s) //nolint:errcheck
}
@@ -643,7 +648,7 @@ func (resp *Response) AppendBodyString(s string) {
//
// It is safe re-using body argument after the function returns.
func (resp *Response) SetBody(body []byte) {
- resp.closeBodyStream() //nolint:errcheck
+ resp.closeBodyStream(nil) //nolint:errcheck
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
bodyBuf.Write(body) //nolint:errcheck
@@ -651,7 +656,7 @@ func (resp *Response) SetBody(body []byte) {
// SetBodyString sets response body.
func (resp *Response) SetBodyString(body string) {
- resp.closeBodyStream() //nolint:errcheck
+ resp.closeBodyStream(nil) //nolint:errcheck
bodyBuf := resp.bodyBuffer()
bodyBuf.Reset()
bodyBuf.WriteString(body) //nolint:errcheck
@@ -660,7 +665,7 @@ func (resp *Response) SetBodyString(body string) {
// ResetBody resets response body.
func (resp *Response) ResetBody() {
resp.bodyRaw = nil
- resp.closeBodyStream() //nolint:errcheck
+ resp.closeBodyStream(nil) //nolint:errcheck
if resp.body != nil {
if resp.keepBodyBuffer {
resp.body.Reset()
@@ -700,7 +705,7 @@ func (resp *Response) ReleaseBody(size int) {
return
}
if cap(resp.body.B) > size {
- resp.closeBodyStream() //nolint:errcheck
+ resp.closeBodyStream(nil) //nolint:errcheck
resp.body = nil
}
}
@@ -734,7 +739,7 @@ func (resp *Response) SwapBody(body []byte) []byte {
if resp.bodyStream != nil {
bb.Reset()
_, err := copyZeroAlloc(bb, resp.bodyStream)
- resp.closeBodyStream() //nolint:errcheck
+ resp.closeBodyStream(err) //nolint:errcheck
if err != nil {
bb.Reset()
bb.SetString(err.Error())
@@ -1725,10 +1730,13 @@ func (resp *Response) brotliBody(level int) {
wf: zw,
bw: sw,
}
- copyZeroAlloc(fw, bs) //nolint:errcheck
+ _, wErr := copyZeroAlloc(fw, bs)
releaseStacklessBrotliWriter(zw, level)
- if bsc, ok := bs.(io.Closer); ok {
- bsc.Close()
+ switch v := bs.(type) {
+ case io.Closer:
+ v.Close()
+ case ReadCloserWithError:
+ v.CloseWithError(wErr) //nolint:errcheck
}
})
} else {
@@ -1780,10 +1788,13 @@ func (resp *Response) gzipBody(level int) {
wf: zw,
bw: sw,
}
- copyZeroAlloc(fw, bs) //nolint:errcheck
+ _, wErr := copyZeroAlloc(fw, bs)
releaseStacklessGzipWriter(zw, level)
- if bsc, ok := bs.(io.Closer); ok {
- bsc.Close()
+ switch v := bs.(type) {
+ case io.Closer:
+ v.Close()
+ case ReadCloserWithError:
+ v.CloseWithError(wErr) //nolint:errcheck
}
})
} else {
@@ -1835,10 +1846,13 @@ func (resp *Response) deflateBody(level int) {
wf: zw,
bw: sw,
}
- copyZeroAlloc(fw, bs) //nolint:errcheck
+ _, wErr := copyZeroAlloc(fw, bs)
releaseStacklessDeflateWriter(zw, level)
- if bsc, ok := bs.(io.Closer); ok {
- bsc.Close()
+ switch v := bs.(type) {
+ case io.Closer:
+ v.Close()
+ case ReadCloserWithError:
+ v.CloseWithError(wErr) //nolint:errcheck
}
})
} else {
@@ -1887,10 +1901,13 @@ func (resp *Response) zstdBody(level int) {
wf: zw,
bw: sw,
}
- copyZeroAlloc(fw, bs) //nolint:errcheck
+ _, wErr := copyZeroAlloc(fw, bs)
releaseStacklessZstdWriter(zw, level)
- if bsc, ok := bs.(io.Closer); ok {
- bsc.Close()
+ switch v := bs.(type) {
+ case io.Closer:
+ v.Close()
+ case ReadCloserWithError:
+ v.CloseWithError(wErr) //nolint:errcheck
}
})
} else {
@@ -2053,7 +2070,7 @@ func (resp *Response) writeBodyStream(w *bufio.Writer, sendBody bool) (err error
}
}
}
- errc := resp.closeBodyStream()
+ errc := resp.closeBodyStream(err)
if err == nil {
err = errc
}
@@ -2075,7 +2092,7 @@ func (req *Request) closeBodyStream() error {
return err
}
-func (resp *Response) closeBodyStream() error {
+func (resp *Response) closeBodyStream(wErr error) error {
if resp.bodyStream == nil {
return nil
}
@@ -2083,6 +2100,9 @@ func (resp *Response) closeBodyStream() error {
if bsc, ok := resp.bodyStream.(io.Closer); ok {
err = bsc.Close()
}
+ if bsc, ok := resp.bodyStream.(ReadCloserWithError); ok {
+ err = bsc.CloseWithError(wErr)
+ }
if bsr, ok := resp.bodyStream.(*requestStream); ok {
releaseRequestStream(bsr)
}
diff --git a/http_test.go b/http_test.go
index b83a487..a9f440a 100644
--- a/http_test.go
+++ b/http_test.go
@@ -2943,7 +2943,7 @@ func TestResponseBodyStream(t *testing.T) {
t.Fatalf("parse response find err: %v", err)
}
defer func() {
- if err := response.closeBodyStream(); err != nil {
+ if err := response.closeBodyStream(nil); err != nil {
t.Fatalf("close body stream err: %v", err)
}
}()