aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Tim <xuchonglei@126.com> 2023-08-10 15:43:26 +0800
committerGravatar GitHub <noreply@github.com> 2023-08-10 09:43:26 +0200
commit54fdc7a73c6e8adb2b7913dea57811329958eb62 (patch)
tree482efe7b52ab625445727f482c6c8fc814e47e3b
parentfasthttpproxy support ipv6 (#1597) (diff)
downloadfasthttp-54fdc7a73c6e8adb2b7913dea57811329958eb62.tar.gz
fasthttp-54fdc7a73c6e8adb2b7913dea57811329958eb62.tar.bz2
fasthttp-54fdc7a73c6e8adb2b7913dea57811329958eb62.zip
Abstracts the RoundTripper interface and provides a default implement (#1602)
* Abstracts the RoundTripper interface and provides a default implementation for enhanced extensibility (#1601) * test: Add custom transport test case (#1601) * Make default RoundTripper implmention none public Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com> --------- Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>
-rw-r--r--client.go242
-rw-r--r--client_test.go114
2 files changed, 227 insertions, 129 deletions
diff --git a/client.go b/client.go
index fa399af..5223031 100644
--- a/client.go
+++ b/client.go
@@ -628,8 +628,10 @@ type DialFunc func(addr string) (net.Conn, error)
// Request argument passed to RetryIfFunc, if there are any request errors.
type RetryIfFunc func(request *Request) bool
-// TransportFunc wraps every request/response.
-type TransportFunc func(*Request, *Response) error
+// RoundTripper wraps every request/response.
+type RoundTripper interface {
+ RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error)
+}
// ConnPoolStrategyType define strategy of connection pool enqueue/dequeue
type ConnPoolStrategyType int
@@ -791,7 +793,7 @@ type HostClient struct {
RetryIf RetryIfFunc
// Transport defines a transport-like mechanism that wraps every request/response.
- Transport TransportFunc
+ Transport RoundTripper
// Connection pool strategy. Can be either LIFO or FIFO (default).
ConnPoolStrategy ConnPoolStrategyType
@@ -1343,119 +1345,15 @@ func (c *HostClient) doNonNilReqResp(req *Request, resp *Response) (bool, error)
req.Header.userAgent = append(req.Header.userAgent[:], userAgent...)
}
}
- if c.Transport != nil {
- err := c.Transport(req, resp)
- return err == nil, err
- }
-
- var deadline time.Time
- if req.timeout > 0 {
- deadline = time.Now().Add(req.timeout)
- }
-
- cc, err := c.acquireConn(req.timeout, req.ConnectionClose())
- if err != nil {
- return false, err
- }
- conn := cc.c
-
- resp.parseNetConn(conn)
-
- writeDeadline := deadline
- if c.WriteTimeout > 0 {
- tmpWriteDeadline := time.Now().Add(c.WriteTimeout)
- if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
- writeDeadline = tmpWriteDeadline
- }
- }
- if err = conn.SetWriteDeadline(writeDeadline); err != nil {
- c.closeConn(cc)
- return true, err
- }
-
- resetConnection := false
- if c.MaxConnDuration > 0 && time.Since(cc.createdTime) > c.MaxConnDuration && !req.ConnectionClose() {
- req.SetConnectionClose()
- resetConnection = true
- }
-
- bw := c.acquireWriter(conn)
- err = req.Write(bw)
-
- if resetConnection {
- req.Header.ResetConnectionClose()
- }
-
- if err == nil {
- err = bw.Flush()
- }
- c.releaseWriter(bw)
-
- // Return ErrTimeout on any timeout.
- if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
- err = ErrTimeout
- }
-
- isConnRST := isConnectionReset(err)
- if err != nil && !isConnRST {
- c.closeConn(cc)
- return true, err
- }
-
- readDeadline := deadline
- if c.ReadTimeout > 0 {
- tmpReadDeadline := time.Now().Add(c.ReadTimeout)
- if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
- readDeadline = tmpReadDeadline
- }
- }
-
- if err = conn.SetReadDeadline(readDeadline); err != nil {
- c.closeConn(cc)
- return true, err
- }
-
- if customSkipBody || req.Header.IsHead() {
- resp.SkipBody = true
- }
- if c.DisableHeaderNamesNormalizing {
- resp.Header.DisableNormalizing()
- }
-
- br := c.acquireReader(conn)
- err = resp.ReadLimitBody(br, c.MaxResponseBodySize)
- c.releaseReader(br)
- if err != nil {
- c.closeConn(cc)
- // Don't retry in case of ErrBodyTooLarge since we will just get the same again.
- retry := err != ErrBodyTooLarge
- return retry, err
- }
-
- closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
- if customStreamBody && resp.bodyStream != nil {
- rbs := resp.bodyStream
- resp.bodyStream = newCloseReader(rbs, func() error {
- if r, ok := rbs.(*requestStream); ok {
- releaseRequestStream(r)
- }
- if closeConn {
- c.closeConn(cc)
- } else {
- c.releaseConn(cc)
- }
- return nil
- })
- return false, nil
- }
+ return c.transport().RoundTrip(c, req, resp)
+}
- if closeConn {
- c.closeConn(cc)
- } else {
- c.releaseConn(cc)
+func (c *HostClient) transport() RoundTripper {
+ if c.Transport == nil {
+ return DefaultTransport
}
- return false, nil
+ return c.Transport
}
var (
@@ -2909,3 +2807,121 @@ func (c *pipelineConnClient) PendingRequests() int {
}
var errPipelineConnStopped = errors.New("pipeline connection has been stopped")
+
+var DefaultTransport RoundTripper = &transport{}
+
+type transport struct{}
+
+func (t *transport) RoundTrip(hc *HostClient, req *Request, resp *Response) (retry bool, err error) {
+ customSkipBody := resp.SkipBody
+ customStreamBody := resp.StreamBody
+
+ var deadline time.Time
+ if req.timeout > 0 {
+ deadline = time.Now().Add(req.timeout)
+ }
+
+ cc, err := hc.acquireConn(req.timeout, req.ConnectionClose())
+ if err != nil {
+ return false, err
+ }
+ conn := cc.c
+
+ resp.parseNetConn(conn)
+
+ writeDeadline := deadline
+ if hc.WriteTimeout > 0 {
+ tmpWriteDeadline := time.Now().Add(hc.WriteTimeout)
+ if writeDeadline.IsZero() || tmpWriteDeadline.Before(writeDeadline) {
+ writeDeadline = tmpWriteDeadline
+ }
+ }
+
+ if err = conn.SetWriteDeadline(writeDeadline); err != nil {
+ hc.closeConn(cc)
+ return true, err
+ }
+
+ resetConnection := false
+ if hc.MaxConnDuration > 0 && time.Since(cc.createdTime) > hc.MaxConnDuration && !req.ConnectionClose() {
+ req.SetConnectionClose()
+ resetConnection = true
+ }
+
+ bw := hc.acquireWriter(conn)
+ err = req.Write(bw)
+
+ if resetConnection {
+ req.Header.ResetConnectionClose()
+ }
+
+ if err == nil {
+ err = bw.Flush()
+ }
+ hc.releaseWriter(bw)
+
+ // Return ErrTimeout on any timeout.
+ if x, ok := err.(interface{ Timeout() bool }); ok && x.Timeout() {
+ err = ErrTimeout
+ }
+
+ isConnRST := isConnectionReset(err)
+ if err != nil && !isConnRST {
+ hc.closeConn(cc)
+ return true, err
+ }
+
+ readDeadline := deadline
+ if hc.ReadTimeout > 0 {
+ tmpReadDeadline := time.Now().Add(hc.ReadTimeout)
+ if readDeadline.IsZero() || tmpReadDeadline.Before(readDeadline) {
+ readDeadline = tmpReadDeadline
+ }
+ }
+
+ if err = conn.SetReadDeadline(readDeadline); err != nil {
+ hc.closeConn(cc)
+ return true, err
+ }
+
+ if customSkipBody || req.Header.IsHead() {
+ resp.SkipBody = true
+ }
+ if hc.DisableHeaderNamesNormalizing {
+ resp.Header.DisableNormalizing()
+ }
+
+ br := hc.acquireReader(conn)
+ err = resp.ReadLimitBody(br, hc.MaxResponseBodySize)
+ hc.releaseReader(br)
+ if err != nil {
+ hc.closeConn(cc)
+ // Don't retry in case of ErrBodyTooLarge since we will just get the same again.
+ needRetry := err != ErrBodyTooLarge
+ return needRetry, err
+ }
+
+ closeConn := resetConnection || req.ConnectionClose() || resp.ConnectionClose() || isConnRST
+ if customStreamBody && resp.bodyStream != nil {
+ rbs := resp.bodyStream
+ resp.bodyStream = newCloseReader(rbs, func() error {
+ if r, ok := rbs.(*requestStream); ok {
+ releaseRequestStream(r)
+ }
+ if closeConn {
+ hc.closeConn(cc)
+ } else {
+ hc.releaseConn(cc)
+ }
+ return nil
+ })
+ return false, nil
+ }
+
+ if closeConn {
+ hc.closeConn(cc)
+ } else {
+ hc.releaseConn(cc)
+ }
+ return false, nil
+}
diff --git a/client_test.go b/client_test.go
index 6609e55..8d07d7f 100644
--- a/client_test.go
+++ b/client_test.go
@@ -2111,6 +2111,22 @@ func TestClientRetryRequestWithCustomDecider(t *testing.T) {
}
}
+type TransportDemo struct {
+ br *bufio.Reader
+ bw *bufio.Writer
+}
+
+func (t TransportDemo) RoundTrip(hc *HostClient, req *Request, res *Response) (retry bool, err error) {
+ if err = req.Write(t.bw); err != nil {
+ return false, err
+ }
+ if err = t.bw.Flush(); err != nil {
+ return false, err
+ }
+ err = res.Read(t.br)
+ return err != nil, err
+}
+
func TestHostClientTransport(t *testing.T) {
t.Parallel()
@@ -2131,23 +2147,13 @@ func TestHostClientTransport(t *testing.T) {
c := &HostClient{
Addr: "foobar",
- Transport: func() TransportFunc {
+ Transport: func() RoundTripper {
c, _ := ln.Dial()
br := bufio.NewReader(c)
bw := bufio.NewWriter(c)
- return func(req *Request, res *Response) error {
- if err := req.Write(bw); err != nil {
- return err
- }
-
- if err := bw.Flush(); err != nil {
- return err
- }
-
- return res.Read(br)
- }
+ return TransportDemo{br: br, bw: bw}
}(),
}
@@ -3060,14 +3066,18 @@ func TestHostClientMaxConnWaitTimeoutWithEarlierDeadline(t *testing.T) {
}
}
+type TransportEmpty struct{}
+
+func (t TransportEmpty) RoundTrip(hc *HostClient, req *Request, res *Response) (retry bool, err error) {
+ return false, nil
+}
+
func TestHttpsRequestWithoutParsedURL(t *testing.T) {
t.Parallel()
client := HostClient{
- IsTLS: true,
- Transport: func(r1 *Request, r2 *Response) error {
- return nil
- },
+ IsTLS: true,
+ Transport: TransportEmpty{},
}
req := &Request{}
@@ -3182,3 +3192,75 @@ func Test_AddMissingPort(t *testing.T) {
})
}
}
+
+type TransportWrapper struct {
+ base RoundTripper
+ count *int
+ t *testing.T
+}
+
+func (tw *TransportWrapper) RoundTrip(hc *HostClient, req *Request, resp *Response) (bool, error) {
+ req.Header.Set("trace-id", "123")
+ tw.assertRequestLog(req.String())
+ retry, err := tw.transport().RoundTrip(hc, req, resp)
+ resp.Header.Set("trace-id", "124")
+ tw.assertResponseLog(resp.String())
+ *tw.count++
+ return retry, err
+}
+
+func (tw *TransportWrapper) transport() RoundTripper {
+ if tw.base == nil {
+ return DefaultTransport
+ }
+ return tw.base
+}
+
+func (tw *TransportWrapper) assertRequestLog(reqLog string) {
+ if !strings.Contains(reqLog, "Trace-Id: 123") {
+ tw.t.Errorf("request log should contains: %v", "Trace-Id: 123")
+ }
+}
+
+func (tw *TransportWrapper) assertResponseLog(respLog string) {
+ if !strings.Contains(respLog, "Trace-Id: 124") {
+ tw.t.Errorf("response log should contains: %v", "Trace-Id: 124")
+ }
+}
+
+func TestClientTransportEx(t *testing.T) {
+ sHTTP := startEchoServer(t, "tcp", "127.0.0.1:")
+ defer sHTTP.Stop()
+
+ sHTTPS := startEchoServerTLS(t, "tcp", "127.0.0.1:")
+ defer sHTTPS.Stop()
+
+ count := 0
+ c := &Client{
+ TLSConfig: &tls.Config{
+ InsecureSkipVerify: true,
+ },
+ ConfigureClient: func(hc *HostClient) error {
+ hc.Transport = &TransportWrapper{base: hc.Transport, count: &count, t: t}
+ return nil
+ },
+ }
+ // test transport
+ const loopCount = 4
+ const getCount = 20
+ const postCount = 10
+ for i := 0; i < loopCount; i++ {
+ addr := "http://" + sHTTP.Addr()
+ if i&1 != 0 {
+ addr = "https://" + sHTTPS.Addr()
+ }
+ // test get
+ testClientGet(t, c, addr, getCount)
+ // test post
+ testClientPost(t, c, addr, postCount)
+ }
+ roundTripCount := loopCount * (getCount + postCount)
+ if count != roundTripCount {
+ t.Errorf("round trip count should be: %v", roundTripCount)
+ }
+}