aboutsummaryrefslogtreecommitdiff
path: root/client.go
diff options
context:
space:
mode:
authorGravatar Tim <xuchonglei@126.com> 2023-08-10 15:43:26 +0800
committerGravatar GitHub <noreply@github.com> 2023-08-10 09:43:26 +0200
commit54fdc7a73c6e8adb2b7913dea57811329958eb62 (patch)
tree482efe7b52ab625445727f482c6c8fc814e47e3b /client.go
parentfasthttpproxy support ipv6 (#1597) (diff)
downloadfasthttp-54fdc7a73c6e8adb2b7913dea57811329958eb62.tar.gz
fasthttp-54fdc7a73c6e8adb2b7913dea57811329958eb62.tar.bz2
fasthttp-54fdc7a73c6e8adb2b7913dea57811329958eb62.zip
Abstracts the RoundTripper interface and provides a default implement (#1602)
* Abstracts the RoundTripper interface and provides a default implementation for enhanced extensibility (#1601) * test: Add custom transport test case (#1601) * Make default RoundTripper implmention none public Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com> --------- Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>
Diffstat (limited to 'client.go')
-rw-r--r--client.go242
1 files changed, 129 insertions, 113 deletions
diff --git a/client.go b/client.go
index fa399af..5223031 100644
--- a/client.go
+++ b/client.go
@@ -628,8 +628,10 @@ type DialFunc func(addr string) (net.Conn, error)
// Request argument passed to RetryIfFunc, if there are any request errors.
type RetryIfFunc func(request *Request) bool
-// TransportFunc wraps every request/response.
-type TransportFunc func(*Request, *Response) error
+// RoundTripper wraps every request/response.
+type RoundTripper interface {
+ RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error)
+}
// ConnPoolStrategyType define strategy of connection pool enqueue/dequeue
type ConnPoolStrategyType int
@@ -791,7 +793,7 @@ type HostClient struct {
RetryIf RetryIfFunc
// Transport defines a transport-like mechanism that wraps every request/response.
- Transport TransportFunc
+ Transport RoundTripper
// Connection pool strategy. Can be either LIFO or FIFO (default).
ConnPoolStrategy ConnPoolStrategyType
@@ -1343,119 +1345,15 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
req.Header.userAgent = append(req.Header.userAgent[:], userAgent...)
}
}
- if c.Transport != nil {
- err := c.Transport(req, resp)
- return err == nil, err
- }
-
- var deadline time.Time
- if req.timeout > 0 {
- deadline = time.Now().Add(req.timeout)
- }
-
- cc, err := c.acquireConn(req.timeout, req.ConnectionClose())
- if err != nil {
- return false, err
- }
- conn := cc.c
-
- resp.parseNetConn(conn)
-
- writeDeadline := deadline
- if c.WriteTimeout > 0 {
- tmpWriteDeadline := time.Now().Add(c.WriteTimeout)
- if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
- writeDeadline = tmpWriteDeadline
- }
- }
- if err = conn.SetWriteDeadline(writeDeadline); err != nil {
- c.closeConn(cc)
- return true, err
- }
-
- resetConnection := false
- if c.MaxConnDuration > 0 && time.Since(cc.createdTime) > c.MaxConnDuration && !req.ConnectionClose() {
- req.SetConnectionClose()
- resetConnection = true
- }
-
- bw := c.acquireWriter(conn)
- err = req.Write(bw)
-
- if resetConnection {
- req.Header.ResetConnectionClose()
- }
-
- if err == nil {
- err = bw.Flush()
- }
- c.releaseWriter(bw)
-
- // Return ErrTimeout on any timeout.
- if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
- err = ErrTimeout
- }
-
- isConnRST := isConnectionReset(err)
- if err != nil && !isConnRST {
- c.closeConn(cc)
- return true, err
- }
-
- readDeadline := deadline
- if c.ReadTimeout > 0 {
- tmpReadDeadline := time.Now().Add(c.ReadTimeout)
- if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
- readDeadline = tmpReadDeadline
- }
- }
-
- if err = conn.SetReadDeadline(readDeadline); err != nil {
- c.closeConn(cc)
- return true, err
- }
-
- if customSkipBody || req.Header.IsHead() {
- resp.SkipBody = true
- }
- if c.DisableHeaderNamesNormalizing {
- resp.Header.DisableNormalizing()
- }
-
- br := c.acquireReader(conn)
- err = resp.ReadLimitBody(br, c.MaxResponseBodySize)
- c.releaseReader(br)
- if err != nil {
- c.closeConn(cc)
- // Don't retry in case of ErrBodyTooLarge since we will just get the same again.
- retry := err != ErrBodyTooLarge
- return retry, err
- }
-
- closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
- if customStreamBody && resp.bodyStream != nil {
- rbs := resp.bodyStream
- resp.bodyStream = newCloseReader(rbs, func() error {
- if r, ok := rbs.(*requestStream); ok {
- releaseRequestStream(r)
- }
- if closeConn {
- c.closeConn(cc)
- } else {
- c.releaseConn(cc)
- }
- return nil
- })
- return false, nil
- }
+ return c.transport().RoundTrip(c, req, resp)
+}
- if closeConn {
- c.closeConn(cc)
- } else {
- c.releaseConn(cc)
+func (c *HostClient) transport() RoundTripper {
+ if c.Transport == nil {
+ return DefaultTransport
}
- return false, nil
+ return c.Transport
}
var (
@@ -2909,3 +2807,121 @@ func (c *pipelineConnClient) PendingRequests() int {
}
var errPipelineConnStopped = errors.New("pipeline connection has been stopped")
+
+var DefaultTransport RoundTripper = &transport{}
+
+type transport struct{}
+
+func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error) {
+ customSkipBody := resp.SkipBody
+ customStreamBody := resp.StreamBody
+
+ var deadline time.Time
+ if req.timeout > 0 {
+ deadline = time.Now().Add(req.timeout)
+ }
+
+ cc, err := hc.acquireConn(req.timeout, req.ConnectionClose())
+ if err != nil {
+ return false, err
+ }
+ conn := cc.c
+
+ resp.parseNetConn(conn)
+
+ writeDeadline := deadline
+ if hc.WriteTimeout > 0 {
+ tmpWriteDeadline := time.Now().Add(hc.WriteTimeout)
+ if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
+ writeDeadline = tmpWriteDeadline
+ }
+ }
+
+ if err = conn.SetWriteDeadline(writeDeadline); err != nil {
+ hc.closeConn(cc)
+ return true, err
+ }
+
+ resetConnection := false
+ if hc.MaxConnDuration > 0 && time.Since(cc.createdTime) > hc.MaxConnDuration && !req.ConnectionClose() {
+ req.SetConnectionClose()
+ resetConnection = true
+ }
+
+ bw := hc.acquireWriter(conn)
+ err = req.Write(bw)
+
+ if resetConnection {
+ req.Header.ResetConnectionClose()
+ }
+
+ if err == nil {
+ err = bw.Flush()
+ }
+ hc.releaseWriter(bw)
+
+ // Return ErrTimeout on any timeout.
+ if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
+ err = ErrTimeout
+ }
+
+ isConnRST := isConnectionReset(err)
+ if err != nil && !isConnRST {
+ hc.closeConn(cc)
+ return true, err
+ }
+
+ readDeadline := deadline
+ if hc.ReadTimeout > 0 {
+ tmpReadDeadline := time.Now().Add(hc.ReadTimeout)
+ if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
+ readDeadline = tmpReadDeadline
+ }
+ }
+
+ if err = conn.SetReadDeadline(readDeadline); err != nil {
+ hc.closeConn(cc)
+ return true, err
+ }
+
+ if customSkipBody || req.Header.IsHead() {
+ resp.SkipBody = true
+ }
+ if hc.DisableHeaderNamesNormalizing {
+ resp.Header.DisableNormalizing()
+ }
+
+ br := hc.acquireReader(conn)
+ err = resp.ReadLimitBody(br, hc.MaxResponseBodySize)
+ hc.releaseReader(br)
+ if err != nil {
+ hc.closeConn(cc)
+ // Don't retry in case of ErrBodyTooLarge since we will just get the same again.
+ needRetry := err != ErrBodyTooLarge
+ return needRetry, err
+ }
+
+ closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
+ if customStreamBody && resp.bodyStream != nil {
+ rbs := resp.bodyStream
+ resp.bodyStream = newCloseReader(rbs, func() error {
+ if r, ok := rbs.(*requestStream); ok {
+ releaseRequestStream(r)
+ }
+ if closeConn {
+ hc.closeConn(cc)
+ } else {
+ hc.releaseConn(cc)
+ }
+ return nil
+ })
+ return false, nil
+ }
+
+ if closeConn {
+ hc.closeConn(cc)
+ } else {
+ hc.releaseConn(cc)
+ }
+ return false, nil
+}