aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Aviv Carmi <avivcarmis@gmail.com> 2023-11-27 14:46:43 +0200
committerGravatar GitHub <noreply@github.com> 2023-11-27 13:46:43 +0100
commit8ca7a9c89c43a97658352651f5f66e16278164f6 (patch)
treec1ebf35c6e98ce83aab7564cfdf47a534e616bba
parentchore: Use 'any' instead of 'interface{}' (#1666) (diff)
downloadfasthttp-8ca7a9c89c43a97658352651f5f66e16278164f6.tar.gz
fasthttp-8ca7a9c89c43a97658352651f5f66e16278164f6.tar.bz2
fasthttp-8ca7a9c89c43a97658352651f5f66e16278164f6.zip
add support for custom dial function with timeouts (#1669)
* add support for custom dial function with timeouts * fix linting --------- Co-authored-by: Aviv Carmi <aviv@perimeterx.com>
-rw-r--r--client.go95
-rw-r--r--client_test.go92
-rw-r--r--tcpdialer.go8
3 files changed, 161 insertions, 34 deletions
diff --git a/client.go b/client.go
index 217dde4..0bfc272 100644
--- a/client.go
+++ b/client.go
@@ -185,7 +185,15 @@ type Client struct {
// Callback for establishing new connections to hosts.
//
- // Default Dial is used if not set.
+ // Default DialTimeout is used if not set.
+ DialTimeout DialFuncWithTimeout
+
+ // Callback for establishing new connections to hosts.
+ //
+ // Note that if Dial is set instead of DialTimeout, Dial will ignore Request timeout.
+ // If you want the tcp dial process to account for request timeouts, use DialTimeout instead.
+ //
+ // If not set, DialTimeout is used.
Dial DialFunc
// Attempt to connect to both ipv4 and ipv6 addresses if set to true.
@@ -505,6 +513,7 @@ func (c *Client) Do(req *Request, resp *Response) error {
Name: c.Name,
NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader,
Dial: c.Dial,
+ DialTimeout: c.DialTimeout,
DialDualStack: c.DialDualStack,
IsTLS: isTLS,
TLSConfig: c.TLSConfig,
@@ -624,6 +633,21 @@ const DefaultMaxIdemponentCallAttempts = 5
// - foobar.com:8080
type DialFunc func(addr string) (net.Conn, error)
+// DialFuncWithTimeout must establish connection to addr.
+// Unlike DialFunc, it also accepts a timeout.
+//
+// There is no need in establishing TLS (SSL) connection for https.
+// The client automatically converts connection to TLS
+// if HostClient.IsTLS is set.
+//
+// TCP address passed to DialFuncWithTimeout always contains host and port.
+// Example TCP addr values:
+//
+// - foobar.com:80
+// - foobar.com:443
+// - foobar.com:8080
+type DialFuncWithTimeout func(addr string, timeout time.Duration) (net.Conn, error)
+
// RetryIfFunc signature of retry if function
//
// Request argument passed to RetryIfFunc, if there are any request errors.
@@ -656,7 +680,7 @@ type HostClient struct {
noCopy noCopy
// Comma-separated list of upstream HTTP server host addresses,
- // which are passed to Dial in a round-robin manner.
+ // which are passed to Dial or DialTimeout in a round-robin manner.
//
// Each address may contain port if default dialer is used.
// For example,
@@ -673,16 +697,24 @@ type HostClient struct {
// User-Agent header to be excluded from the Request.
NoDefaultUserAgentHeader bool
- // Callback for establishing new connection to the host.
+ // Callback for establishing new connections to hosts.
//
- // Default Dial is used if not set.
+ // Default DialTimeout is used if not set.
+ DialTimeout DialFuncWithTimeout
+
+ // Callback for establishing new connections to hosts.
+ //
+ // Note that if Dial is set instead of DialTimeout, Dial will ignore Request timeout.
+ // If you want the tcp dial process to account for request timeouts, use DialTimeout instead.
+ //
+ // If not set, DialTimeout is used.
Dial DialFunc
// Attempt to connect to both ipv4 and ipv6 host addresses
// if set to true.
//
// This option is used only if default TCP dialer is used,
- // i.e. if Dial is blank.
+ // i.e. if Dial and DialTimeout are blank.
//
// By default client connects only to ipv4 addresses,
// since unfortunately ipv6 remains broken in many networks worldwide :)
@@ -1827,7 +1859,8 @@ func (c *HostClient) nextAddr() string {
}
func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err error) {
- // use dialTimeout to control the timeout of each dial. It does not work if dialTimeout is 0 or dial has been set.
+ // use dialTimeout to control the timeout of each dial. It does not work if dialTimeout is 0 or if
+ // c.DialTimeout has not been set and c.Dial has been set.
// attempt to dial all the available hosts before giving up.
c.addrsLock.Lock()
@@ -1839,16 +1872,6 @@ func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err
n = 1
}
- dial := c.Dial
- if dialTimeout != 0 && dial == nil {
- dial = func(addr string) (net.Conn, error) {
- if c.DialDualStack {
- return DialDualStackTimeout(addr, dialTimeout)
- }
- return DialTimeout(addr, dialTimeout)
- }
- }
-
timeout := c.ReadTimeout + c.WriteTimeout
if timeout <= 0 {
timeout = DefaultDialTimeout
@@ -1857,7 +1880,7 @@ func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err
for n > 0 {
addr := c.nextAddr()
tlsConfig := c.cachedTLSConfig(addr)
- conn, err = dialAddr(addr, dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
+ conn, err = dialAddr(addr, c.Dial, c.DialTimeout, c.DialDualStack, c.IsTLS, tlsConfig, dialTimeout, c.WriteTimeout)
if err == nil {
return conn, nil
}
@@ -1916,17 +1939,9 @@ func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, deadline time.T
return conn, nil
}
-func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
- deadline := time.Now().Add(timeout)
- if dial == nil {
- if dialDualStack {
- dial = DialDualStack
- } else {
- dial = Dial
- }
- addr = AddMissingPort(addr, isTLS)
- }
- conn, err := dial(addr)
+func dialAddr(addr string, dial DialFunc, dialWithTimeout DialFuncWithTimeout, dialDualStack, isTLS bool, tlsConfig *tls.Config, dialTimeout, writeTimeout time.Duration) (net.Conn, error) {
+ deadline := time.Now().Add(writeTimeout)
+ conn, err := callDialFunc(addr, dial, dialWithTimeout, dialDualStack, isTLS, dialTimeout)
if err != nil {
return nil, err
}
@@ -1939,7 +1954,7 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *
_, isTLSAlready := conn.(interface{ Handshake() error })
if isTLS && !isTLSAlready {
- if timeout == 0 {
+ if writeTimeout == 0 {
return tls.Client(conn, tlsConfig), nil
}
return tlsClientHandshake(conn, tlsConfig, deadline)
@@ -1947,6 +1962,26 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *
return conn, nil
}
+func callDialFunc(addr string, dial DialFunc, dialWithTimeout DialFuncWithTimeout, dialDualStack, isTLS bool, timeout time.Duration) (net.Conn, error) {
+ if dialWithTimeout != nil {
+ return dialWithTimeout(addr, timeout)
+ }
+ if dial != nil {
+ return dial(addr)
+ }
+ addr = AddMissingPort(addr, isTLS)
+ if timeout > 0 {
+ if dialDualStack {
+ return DialDualStackTimeout(addr, timeout)
+ }
+ return DialTimeout(addr, timeout)
+ }
+ if dialDualStack {
+ return DialDualStack(addr)
+ }
+ return Dial(addr)
+}
+
// AddMissingPort adds a port to a host if it is missing.
// A literal IPv6 address in hostport must be enclosed in square
// brackets, as in "[::1]:80", "[::1%lo0]:80".
@@ -2591,7 +2626,7 @@ func (c *pipelineConnClient) init() {
func (c *pipelineConnClient) worker() error {
tlsConfig := c.cachedTLSConfig()
- conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
+ conn, err := dialAddr(c.Addr, c.Dial, nil, c.DialDualStack, c.IsTLS, tlsConfig, 0, c.WriteTimeout)
if err != nil {
return err
}
diff --git a/client_test.go b/client_test.go
index e6a1358..1802383 100644
--- a/client_test.go
+++ b/client_test.go
@@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"crypto/tls"
+ "errors"
"fmt"
"io"
"net"
@@ -3392,3 +3393,94 @@ func Test_getRedirectURL(t *testing.T) {
})
}
}
+
+type clientDoTimeOuter interface {
+ DoTimeout(req *Request, resp *Response, timeout time.Duration) error
+}
+
+func TestDialTimeout(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ client clientDoTimeOuter
+ requestTimeout time.Duration
+ shouldFailFast bool
+ }{
+ {
+ name: "Client should fail after a millisecond due to request timeout",
+ client: &Client{
+ // should be ignored due to DialTimeout
+ Dial: func(addr string) (net.Conn, error) {
+ time.Sleep(time.Second)
+ return nil, errors.New("timeout")
+ },
+ // should be used
+ DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) {
+ time.Sleep(timeout)
+ return nil, errors.New("timeout")
+ },
+ },
+ requestTimeout: time.Millisecond,
+ shouldFailFast: true,
+ },
+ {
+ name: "Client should fail after a second due to no DialTimeout set",
+ client: &Client{
+ Dial: func(addr string) (net.Conn, error) {
+ time.Sleep(time.Second)
+ return nil, errors.New("timeout")
+ },
+ },
+ requestTimeout: time.Millisecond,
+ shouldFailFast: false,
+ },
+ {
+ name: "HostClient should fail after a millisecond due to request timeout",
+ client: &HostClient{
+ // should be ignored due to DialTimeout
+ Dial: func(addr string) (net.Conn, error) {
+ time.Sleep(time.Second)
+ return nil, errors.New("timeout")
+ },
+ // should be used
+ DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) {
+ time.Sleep(timeout)
+ return nil, errors.New("timeout")
+ },
+ },
+ requestTimeout: time.Millisecond,
+ shouldFailFast: true,
+ },
+ {
+ name: "HostClient should fail after a second due to no DialTimeout set",
+ client: &HostClient{
+ Dial: func(addr string) (net.Conn, error) {
+ time.Sleep(time.Second)
+ return nil, errors.New("timeout")
+ },
+ },
+ requestTimeout: time.Millisecond,
+ shouldFailFast: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ start := time.Now()
+ err := tt.client.DoTimeout(&Request{}, &Response{}, tt.requestTimeout)
+ if err == nil {
+ t.Fatal("expected error (timeout)")
+ }
+ if tt.shouldFailFast {
+ if time.Since(start) > time.Second {
+ t.Fatal("expected timeout after a millisecond")
+ }
+ } else {
+ if time.Since(start) < time.Second {
+ t.Fatal("expected timeout after a second")
+ }
+ }
+ })
+ }
+}
diff --git a/tcpdialer.go b/tcpdialer.go
index 46611e7..77bb569 100644
--- a/tcpdialer.go
+++ b/tcpdialer.go
@@ -48,7 +48,7 @@ func Dial(addr string) (net.Conn, error) {
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
-// to Client.Dial or HostClient.Dial.
+// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
@@ -102,7 +102,7 @@ func DialDualStack(addr string) (net.Conn, error) {
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
-// to Client.Dial or HostClient.Dial.
+// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
@@ -199,7 +199,7 @@ func (d *TCPDialer) Dial(addr string) (net.Conn, error) {
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
-// to Client.Dial or HostClient.Dial.
+// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.
@@ -253,7 +253,7 @@ func (d *TCPDialer) DialDualStack(addr string) (net.Conn, error) {
// are temporarily unreachable.
//
// This dialer is intended for custom code wrapping before passing
-// to Client.Dial or HostClient.Dial.
+// to Client.DialTimeout or HostClient.DialTimeout.
//
// For instance, per-host counters and/or limits may be implemented
// by such wrappers.