aboutsummaryrefslogtreecommitdiff
path: root/client.go
diff options
context:
space:
mode:
Diffstat (limited to 'client.go')
-rw-r--r--client.go95
1 files changed, 65 insertions, 30 deletions
diff --git a/client.go b/client.go
index 217dde4..0bfc272 100644
--- a/client.go
+++ b/client.go
@@ -185,7 +185,15 @@ type Client struct {
// Callback for establishing new connections to hosts.
//
- // Default Dial is used if not set.
+ // Default DialTimeout is used if not set.
+ DialTimeout DialFuncWithTimeout
+
+ // Callback for establishing new connections to hosts.
+ //
+ // Note that if Dial is set instead of DialTimeout, Dial will ignore Request timeout.
+ // If you want the tcp dial process to account for request timeouts, use DialTimeout instead.
+ //
+ // If not set, DialTimeout is used.
Dial DialFunc
// Attempt to connect to both ipv4 and ipv6 addresses if set to true.
@@ -505,6 +513,7 @@ func (c *Client) Do(req *Request, resp *Response) error {
Name: c.Name,
NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader,
Dial: c.Dial,
+ DialTimeout: c.DialTimeout,
DialDualStack: c.DialDualStack,
IsTLS: isTLS,
TLSConfig: c.TLSConfig,
@@ -624,6 +633,21 @@ const DefaultMaxIdemponentCallAttempts = 5
// - foobar.com:8080
type DialFunc func(addr string) (net.Conn, error)
+// DialFuncWithTimeout must establish connection to addr.
+// Unlike DialFunc, it also accepts a timeout.
+//
+// There is no need in establishing TLS (SSL) connection for https.
+// The client automatically converts connection to TLS
+// if HostClient.IsTLS is set.
+//
+// TCP address passed to DialFuncWithTimeout always contains host and port.
+// Example TCP addr values:
+//
+// - foobar.com:80
+// - foobar.com:443
+// - foobar.com:8080
+type DialFuncWithTimeout func(addr string, timeout time.Duration) (net.Conn, error)
+
// RetryIfFunc signature of retry if function
//
// Request argument passed to RetryIfFunc, if there are any request errors.
@@ -656,7 +680,7 @@ type HostClient struct {
noCopy noCopy
// Comma-separated list of upstream HTTP server host addresses,
- // which are passed to Dial in a round-robin manner.
+ // which are passed to Dial or DialTimeout in a round-robin manner.
//
// Each address may contain port if default dialer is used.
// For example,
@@ -673,16 +697,24 @@ type HostClient struct {
// User-Agent header to be excluded from the Request.
NoDefaultUserAgentHeader bool
- // Callback for establishing new connection to the host.
+ // Callback for establishing new connections to hosts.
//
- // Default Dial is used if not set.
+ // Default DialTimeout is used if not set.
+ DialTimeout DialFuncWithTimeout
+
+ // Callback for establishing new connections to hosts.
+ //
+ // Note that if Dial is set instead of DialTimeout, Dial will ignore Request timeout.
+ // If you want the tcp dial process to account for request timeouts, use DialTimeout instead.
+ //
+ // If not set, DialTimeout is used.
Dial DialFunc
// Attempt to connect to both ipv4 and ipv6 host addresses
// if set to true.
//
// This option is used only if default TCP dialer is used,
- // i.e. if Dial is blank.
+ // i.e. if Dial and DialTimeout are blank.
//
// By default client connects only to ipv4 addresses,
// since unfortunately ipv6 remains broken in many networks worldwide :)
@@ -1827,7 +1859,8 @@ func (c *HostClient) nextAddr() string {
}
func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err error) {
- // use dialTimeout to control the timeout of each dial. It does not work if dialTimeout is 0 or dial has been set.
+ // use dialTimeout to control the timeout of each dial. It does not work if dialTimeout is 0 or if
+ // c.DialTimeout has not been set and c.Dial has been set.
// attempt to dial all the available hosts before giving up.
c.addrsLock.Lock()
@@ -1839,16 +1872,6 @@ func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err
n = 1
}
- dial := c.Dial
- if dialTimeout != 0 && dial == nil {
- dial = func(addr string) (net.Conn, error) {
- if c.DialDualStack {
- return DialDualStackTimeout(addr, dialTimeout)
- }
- return DialTimeout(addr, dialTimeout)
- }
- }
-
timeout := c.ReadTimeout + c.WriteTimeout
if timeout <= 0 {
timeout = DefaultDialTimeout
@@ -1857,7 +1880,7 @@ func (c *HostClient) dialHostHard(dialTimeout time.Duration) (conn net.Conn, err
for n > 0 {
addr := c.nextAddr()
tlsConfig := c.cachedTLSConfig(addr)
- conn, err = dialAddr(addr, dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
+ conn, err = dialAddr(addr, c.Dial, c.DialTimeout, c.DialDualStack, c.IsTLS, tlsConfig, dialTimeout, c.WriteTimeout)
if err == nil {
return conn, nil
}
@@ -1916,17 +1939,9 @@ func tlsClientHandshake(rawConn net.Conn, tlsConfig *tls.Config, deadline time.T
return conn, nil
}
-func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *tls.Config, timeout time.Duration) (net.Conn, error) {
- deadline := time.Now().Add(timeout)
- if dial == nil {
- if dialDualStack {
- dial = DialDualStack
- } else {
- dial = Dial
- }
- addr = AddMissingPort(addr, isTLS)
- }
- conn, err := dial(addr)
+func dialAddr(addr string, dial DialFunc, dialWithTimeout DialFuncWithTimeout, dialDualStack, isTLS bool, tlsConfig *tls.Config, dialTimeout, writeTimeout time.Duration) (net.Conn, error) {
+ deadline := time.Now().Add(writeTimeout)
+ conn, err := callDialFunc(addr, dial, dialWithTimeout, dialDualStack, isTLS, dialTimeout)
if err != nil {
return nil, err
}
@@ -1939,7 +1954,7 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *
_, isTLSAlready := conn.(interface{ Handshake() error })
if isTLS && !isTLSAlready {
- if timeout == 0 {
+ if writeTimeout == 0 {
return tls.Client(conn, tlsConfig), nil
}
return tlsClientHandshake(conn, tlsConfig, deadline)
@@ -1947,6 +1962,26 @@ func dialAddr(addr string, dial DialFunc, dialDualStack, isTLS bool, tlsConfig *
return conn, nil
}
+func callDialFunc(addr string, dial DialFunc, dialWithTimeout DialFuncWithTimeout, dialDualStack, isTLS bool, timeout time.Duration) (net.Conn, error) {
+ if dialWithTimeout != nil {
+ return dialWithTimeout(addr, timeout)
+ }
+ if dial != nil {
+ return dial(addr)
+ }
+ addr = AddMissingPort(addr, isTLS)
+ if timeout > 0 {
+ if dialDualStack {
+ return DialDualStackTimeout(addr, timeout)
+ }
+ return DialTimeout(addr, timeout)
+ }
+ if dialDualStack {
+ return DialDualStack(addr)
+ }
+ return Dial(addr)
+}
+
// AddMissingPort adds a port to a host if it is missing.
// A literal IPv6 address in hostport must be enclosed in square
// brackets, as in "[::1]:80", "[::1%lo0]:80".
@@ -2591,7 +2626,7 @@ func (c *pipelineConnClient) init() {
func (c *pipelineConnClient) worker() error {
tlsConfig := c.cachedTLSConfig()
- conn, err := dialAddr(c.Addr, c.Dial, c.DialDualStack, c.IsTLS, tlsConfig, c.WriteTimeout)
+ conn, err := dialAddr(c.Addr, c.Dial, nil, c.DialDualStack, c.IsTLS, tlsConfig, 0, c.WriteTimeout)
if err != nil {
return err
}