From 8ca7a9c89c43a97658352651f5f66e16278164f6 Mon Sep 17 00:00:00 2001 From: Aviv Carmi Date: Mon, 27 Nov 2023 14:46:43 +0200 Subject: add support for custom dial function with timeouts (#1669) * add support for custom dial function with timeouts * fix linting --------- Co-authored-by: Aviv Carmi --- client.go | 95 +++++++++++++++++++++++++++++++++++++++------------------- client_test.go | 92 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ tcpdialer.go | 8 ++--- 3 files changed, 161 insertions(+), 34 deletions(-) diff --git a/client.go b/client.go index 217dde4..0bfc272 100644 --- a/client.go +++ b/client.go @@ -185,7 +185,15 @@ type Client struct { // Callback for establishing new connections to hosts. // - // Default Dial is used if not set. + // Default DialTimeout is used if not set. + DialTimeout DialFuncWithTimeout + + // Callback for establishing new connections to hosts. + // + // Note that if Dial is set instead of DialTimeout, Dial will ignore Request timeout. + // If you want the tcp dial process to account for request timeouts, use DialTimeout instead. + // + // If not set, DialTimeout is used. Dial DialFunc // Attempt to connect to both ipv4 and ipv6 addresses if set to true. @@ -505,6 +513,7 @@ func (c *Client) Do(req *Request, resp *Response) error { Name: c.Name, NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader, Dial: c.Dial, + DialTimeout: c.DialTimeout, DialDualStack: c.DialDualStack, IsTLS: isTLS, TLSConfig: c.TLSConfig, @@ -624,6 +633,21 @@ const DefaultMaxIdemponentCallAttempts = 5 // - foobar.com:8080 type DialFunc func(addr string) (net.Conn, error) +// DialFuncWithTimeout must establish connection to addr. +// Unlike DialFunc, it also accepts a timeout. +// +// There is no need in establishing TLS (SSL) connection for https. +// The client automatically converts connection to TLS +// if HostClient.IsTLS is set. +// +// TCP address passed to DialFuncWithTimeout always contains host and port. +// Example TCP addr values: +// +// - foobar.com:80 +// - foobar.com:443 +// - foobar.com:8080 +type DialFuncWithTimeout func(addr string, timeout time.Duration) (net.Conn, error) + // RetryIfFunc signature of retry if function // // Request argument passed to RetryIfFunc, if there are any request errors. @@ -656,7 +680,7 @@ type HostClient struct { noCopy noCopy // Comma-separated list of upstream HTTP server host addresses, - // which are passed to Dial in a round-robin manner. + // which are passed to Dial or DialTimeout in a round-robin manner. // // Each address may contain port if default dialer is used. // For example, @@ -673,16 +697,24 @@ type HostClient struct { // User-Agent header to be excluded from the Request. NoDefaultUserAgentHeader bool - // Callback for establishing new connection to the host. + // Callback for establishing new connections to hosts. // - // Default Dial is used if not set. + // Default DialTimeout is used if not set. + DialTimeout DialFuncWithTimeout + + // Callback for establishing new connections to hosts. + // + // Note that if Dial is set instead of DialTimeout, Dial will ignore Request timeout. + // If you want the tcp dial process to account for request timeouts, use DialTimeout instead. + // + // If not set, DialTimeout is used. Dial DialFunc // Attempt to connect to both ipv4 and ipv6 host addresses // if set to true. // // This option is used only if default TCP dialer is used, - // i.e. if Dial is blank. + // i.e. if Dial and DialTimeout are blank. // // By default client connects only to ipv4 addresses, // since unfortunately ipv6 remains broken in many networks worldwide :) @@ -1827,7 +1859,8 @@ func (c *HostClient) nextAddr() string { } func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err error) { - // use dialTimeout to control the timeout of each dial. It does not work if dialTimeout is 0 or dial has been set. + // use dialTimeout to control the timeout of each dial. It does not work if dialTimeout is 0 or if + // c.DialTimeout has not been set and c.Dial has been set. // attempt to dial all the available hosts before giving up. c.addrsLock.Lock() @@ -1839,16 +1872,6 @@ func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err n = 1 } - dial := c.Dial - if dialTimeout != 0 && dial == nil { - dial = func(addr string) (net.Conn, error) { - if c.DialDualStack { - return DialDualStackTimeout(addr, dialTimeout) - } - return DialTimeout(addr, dialTimeout) - } - } - timeout := c.ReadTimeout + c.WriteTimeout if timeout <= 0 { timeout = DefaultDialTimeout @@ -1857,7 +1880,7 @@ func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err for n > 0 { addr := c.nextAddr() tlsConfig := c.cachedTLSConfig(addr) - conn, err = dialAddr(addr, dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout) + conn, err = dialAddr(addr, c.Dial, c.DialTimeout, c.DialDualStack, c.IsTLS, tlsConfig, dialTimeout, c.WriteTimeout) if err == nil { return conn, nil } @@ -1916,17 +1939,9 @@ func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, deadline time.T return conn, nil } -func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) { - deadline := time.Now().Add(timeout) - if dial == nil { - if dialDualStack { - dial = DialDualStack - } else { - dial = Dial - } - addr = AddMissingPort(addr, isTLS) - } - conn, err := dial(addr) +func dialAddr(addr string, dial DialFunc, dialWithTimeout DialFuncWithTimeout, dialDualStack, isTLS bool, tlsConfig *tls.Config, dialTimeout, writeTimeout time.Duration) (net.Conn, error) { + deadline := time.Now().Add(writeTimeout) + conn, err := callDialFunc(addr, dial, dialWithTimeout, dialDualStack, isTLS, dialTimeout) if err != nil { return nil, err } @@ -1939,7 +1954,7 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig * _, isTLSAlready := conn.(interface{ Handshake() error }) if isTLS && !isTLSAlready { - if timeout == 0 { + if writeTimeout == 0 { return tls.Client(conn, tlsConfig), nil } return tlsClientHandshake(conn, tlsConfig, deadline) @@ -1947,6 +1962,26 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig * return conn, nil } +func callDialFunc(addr string, dial DialFunc, dialWithTimeout DialFuncWithTimeout, dialDualStack, isTLS bool, timeout time.Duration) (net.Conn, error) { + if dialWithTimeout != nil { + return dialWithTimeout(addr, timeout) + } + if dial != nil { + return dial(addr) + } + addr = AddMissingPort(addr, isTLS) + if timeout > 0 { + if dialDualStack { + return DialDualStackTimeout(addr, timeout) + } + return DialTimeout(addr, timeout) + } + if dialDualStack { + return DialDualStack(addr) + } + return Dial(addr) +} + // AddMissingPort adds a port to a host if it is missing. // A literal IPv6 address in hostport must be enclosed in square // brackets, as in "[::1]:80", "[::1%lo0]:80". @@ -2591,7 +2626,7 @@ func (c *pipelineConnClient) init() { func (c *pipelineConnClient) worker() error { tlsConfig := c.cachedTLSConfig() - conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout) + conn, err := dialAddr(c.Addr, c.Dial, nil, c.DialDualStack, c.IsTLS, tlsConfig, 0, c.WriteTimeout) if err != nil { return err } diff --git a/client_test.go b/client_test.go index e6a1358..1802383 100644 --- a/client_test.go +++ b/client_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "crypto/tls" + "errors" "fmt" "io" "net" @@ -3392,3 +3393,94 @@ func Test_getRedirectURL(t *testing.T) { }) } } + +type clientDoTimeOuter interface { + DoTimeout(req *Request, resp *Response, timeout time.Duration) error +} + +func TestDialTimeout(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + client clientDoTimeOuter + requestTimeout time.Duration + shouldFailFast bool + }{ + { + name: "Client should fail after a millisecond due to request timeout", + client: &Client{ + // should be ignored due to DialTimeout + Dial: func(addr string) (net.Conn, error) { + time.Sleep(time.Second) + return nil, errors.New("timeout") + }, + // should be used + DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) { + time.Sleep(timeout) + return nil, errors.New("timeout") + }, + }, + requestTimeout: time.Millisecond, + shouldFailFast: true, + }, + { + name: "Client should fail after a second due to no DialTimeout set", + client: &Client{ + Dial: func(addr string) (net.Conn, error) { + time.Sleep(time.Second) + return nil, errors.New("timeout") + }, + }, + requestTimeout: time.Millisecond, + shouldFailFast: false, + }, + { + name: "HostClient should fail after a millisecond due to request timeout", + client: &HostClient{ + // should be ignored due to DialTimeout + Dial: func(addr string) (net.Conn, error) { + time.Sleep(time.Second) + return nil, errors.New("timeout") + }, + // should be used + DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) { + time.Sleep(timeout) + return nil, errors.New("timeout") + }, + }, + requestTimeout: time.Millisecond, + shouldFailFast: true, + }, + { + name: "HostClient should fail after a second due to no DialTimeout set", + client: &HostClient{ + Dial: func(addr string) (net.Conn, error) { + time.Sleep(time.Second) + return nil, errors.New("timeout") + }, + }, + requestTimeout: time.Millisecond, + shouldFailFast: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + start := time.Now() + err := tt.client.DoTimeout(&Request{}, &Response{}, tt.requestTimeout) + if err == nil { + t.Fatal("expected error (timeout)") + } + if tt.shouldFailFast { + if time.Since(start) > time.Second { + t.Fatal("expected timeout after a millisecond") + } + } else { + if time.Since(start) < time.Second { + t.Fatal("expected timeout after a second") + } + } + }) + } +} diff --git a/tcpdialer.go b/tcpdialer.go index 46611e7..77bb569 100644 --- a/tcpdialer.go +++ b/tcpdialer.go @@ -48,7 +48,7 @@ func Dial(addr string) (net.Conn, error) { // are temporarily unreachable. // // This dialer is intended for custom code wrapping before passing -// to Client.Dial or HostClient.Dial. +// to Client.DialTimeout or HostClient.DialTimeout. // // For instance, per-host counters and/or limits may be implemented // by such wrappers. @@ -102,7 +102,7 @@ func DialDualStack(addr string) (net.Conn, error) { // are temporarily unreachable. // // This dialer is intended for custom code wrapping before passing -// to Client.Dial or HostClient.Dial. +// to Client.DialTimeout or HostClient.DialTimeout. // // For instance, per-host counters and/or limits may be implemented // by such wrappers. @@ -199,7 +199,7 @@ func (d *TCPDialer) Dial(addr string) (net.Conn, error) { // are temporarily unreachable. // // This dialer is intended for custom code wrapping before passing -// to Client.Dial or HostClient.Dial. +// to Client.DialTimeout or HostClient.DialTimeout. // // For instance, per-host counters and/or limits may be implemented // by such wrappers. @@ -253,7 +253,7 @@ func (d *TCPDialer) DialDualStack(addr string) (net.Conn, error) { // are temporarily unreachable. // // This dialer is intended for custom code wrapping before passing -// to Client.Dial or HostClient.Dial. +// to Client.DialTimeout or HostClient.DialTimeout. // // For instance, per-host counters and/or limits may be implemented // by such wrappers. -- cgit v1.2.3