From 45548243d7c46ace0ca0ad69969bf3fec2c480dc Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Thu, 3 Jan 2019 16:31:58 +0100 Subject: Add TCPDialer --- tcpdialer.go | 247 +++++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 163 insertions(+), 84 deletions(-) (limited to 'tcpdialer.go') diff --git a/tcpdialer.go b/tcpdialer.go index 906dfdc..6a5cd3a 100644 --- a/tcpdialer.go +++ b/tcpdialer.go @@ -33,7 +33,7 @@ import ( // * foo.bar:80 // * aaa.com:8080 func Dial(addr string) (net.Conn, error) { - return getDialer(DefaultDialTimeout, false)(addr) + return defaultDialer.Dial(addr) } // DialTimeout dials the given TCP addr using tcp4 using the given timeout. @@ -58,7 +58,7 @@ func Dial(addr string) (net.Conn, error) { // * foo.bar:80 // * aaa.com:8080 func DialTimeout(addr string, timeout time.Duration) (net.Conn, error) { - return getDialer(timeout, false)(addr) + return defaultDialer.DialTimeout(addr, timeout) } // DialDualStack dials the given TCP addr using both tcp4 and tcp6. @@ -86,7 +86,7 @@ func DialTimeout(addr string, timeout time.Duration) (net.Conn, error) { // * foo.bar:80 // * aaa.com:8080 func DialDualStack(addr string) (net.Conn, error) { - return getDialer(DefaultDialTimeout, true)(addr) + return defaultDialer.DialDualStack(addr) } // DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6 @@ -112,45 +112,22 @@ func DialDualStack(addr string) (net.Conn, error) { // * foo.bar:80 // * aaa.com:8080 func DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) { - return getDialer(timeout, true)(addr) -} - -func getDialer(timeout time.Duration, dualStack bool) DialFunc { - if timeout <= 0 { - timeout = DefaultDialTimeout - } - timeoutRounded := int(timeout.Seconds()*10 + 9) - - m := dialMap - if dualStack { - m = dialDualStackMap - } - - dialMapLock.Lock() - d := m[timeoutRounded] - if d == nil { - dialer := dialerStd - if dualStack { - dialer = dialerDualStack - } - d = dialer.NewDial(timeout) - m[timeoutRounded] = d - } - dialMapLock.Unlock() - return d + return defaultDialer.DialDualStackTimeout(addr, timeout) } var ( - dialerStd = &tcpDialer{} - dialerDualStack = &tcpDialer{DualStack: true} - - dialMap = make(map[int]DialFunc) - dialDualStackMap = make(map[int]DialFunc) - dialMapLock sync.Mutex + defaultDialer = &TCPDialer{Concurrency: 1000} ) -type tcpDialer struct { - DualStack bool +// TCPDialer contains options to control a group of Dial calls. +type TCPDialer struct { + // Concurrency controls the maximum number of concurrent Dails + // that can be performed using this object. + // Setting this to 0 means unlimited. + // + // WARNING: This can only be changed before the first Dial. + // Changes made after the first Dial will not affect anything. + Concurrency int tcpAddrsLock sync.Mutex tcpAddrsMap map[string]*tcpAddrEntry @@ -160,41 +137,145 @@ type tcpDialer struct { once sync.Once } -const maxDialConcurrency = 1000 +// Dial dials the given TCP addr using tcp4. +// +// This function has the following additional features comparing to net.Dial: +// +// * It reduces load on DNS resolver by caching resolved TCP addressed +// for DefaultDNSCacheDuration. +// * It dials all the resolved TCP addresses in round-robin manner until +// connection is established. This may be useful if certain addresses +// are temporarily unreachable. +// * It returns ErrDialTimeout if connection cannot be established during +// DefaultDialTimeout seconds. Use DialTimeout for customizing dial timeout. +// +// This dialer is intended for custom code wrapping before passing +// to Client.Dial or HostClient.Dial. +// +// For instance, per-host counters and/or limits may be implemented +// by such wrappers. +// +// The addr passed to the function must contain port. Example addr values: +// +// * foobar.baz:443 +// * foo.bar:80 +// * aaa.com:8080 +func (d *TCPDialer) Dial(addr string) (net.Conn, error) { + return d.dial(addr, false, DefaultDialTimeout) +} + +// DialTimeout dials the given TCP addr using tcp4 using the given timeout. +// +// This function has the following additional features comparing to net.Dial: +// +// * It reduces load on DNS resolver by caching resolved TCP addressed +// for DefaultDNSCacheDuration. +// * It dials all the resolved TCP addresses in round-robin manner until +// connection is established. This may be useful if certain addresses +// are temporarily unreachable. +// +// This dialer is intended for custom code wrapping before passing +// to Client.Dial or HostClient.Dial. +// +// For instance, per-host counters and/or limits may be implemented +// by such wrappers. +// +// The addr passed to the function must contain port. Example addr values: +// +// * foobar.baz:443 +// * foo.bar:80 +// * aaa.com:8080 +func (d *TCPDialer) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) { + return d.dial(addr, false, timeout) +} + +// DialDualStack dials the given TCP addr using both tcp4 and tcp6. +// +// This function has the following additional features comparing to net.Dial: +// +// * It reduces load on DNS resolver by caching resolved TCP addressed +// for DefaultDNSCacheDuration. +// * It dials all the resolved TCP addresses in round-robin manner until +// connection is established. This may be useful if certain addresses +// are temporarily unreachable. +// * It returns ErrDialTimeout if connection cannot be established during +// DefaultDialTimeout seconds. Use DialDualStackTimeout for custom dial +// timeout. +// +// This dialer is intended for custom code wrapping before passing +// to Client.Dial or HostClient.Dial. +// +// For instance, per-host counters and/or limits may be implemented +// by such wrappers. +// +// The addr passed to the function must contain port. Example addr values: +// +// * foobar.baz:443 +// * foo.bar:80 +// * aaa.com:8080 +func (d *TCPDialer) DialDualStack(addr string) (net.Conn, error) { + return d.dial(addr, true, DefaultDialTimeout) +} + +// DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6 +// using the given timeout. +// +// This function has the following additional features comparing to net.Dial: +// +// * It reduces load on DNS resolver by caching resolved TCP addressed +// for DefaultDNSCacheDuration. +// * It dials all the resolved TCP addresses in round-robin manner until +// connection is established. This may be useful if certain addresses +// are temporarily unreachable. +// +// This dialer is intended for custom code wrapping before passing +// to Client.Dial or HostClient.Dial. +// +// For instance, per-host counters and/or limits may be implemented +// by such wrappers. +// +// The addr passed to the function must contain port. Example addr values: +// +// * foobar.baz:443 +// * foo.bar:80 +// * aaa.com:8080 +func (d *TCPDialer) DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) { + return d.dial(addr, true, timeout) +} -func (d *tcpDialer) NewDial(timeout time.Duration) DialFunc { +func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (net.Conn, error) { d.once.Do(func() { - d.concurrencyCh = make(chan struct{}, maxDialConcurrency) + if d.Concurrency > 0 { + d.concurrencyCh = make(chan struct{}, d.Concurrency) + } d.tcpAddrsMap = make(map[string]*tcpAddrEntry) go d.tcpAddrsClean() }) - return func(addr string) (net.Conn, error) { - addrs, idx, err := d.getTCPAddrs(addr) - if err != nil { - return nil, err - } - network := "tcp4" - if d.DualStack { - network = "tcp" - } + addrs, idx, err := d.getTCPAddrs(addr, dualStack) + if err != nil { + return nil, err + } + network := "tcp4" + if dualStack { + network = "tcp" + } - var conn net.Conn - n := uint32(len(addrs)) - deadline := time.Now().Add(timeout) - for n > 0 { - conn, err = tryDial(network, &addrs[idx%n], deadline, d.concurrencyCh) - if err == nil { - return conn, nil - } - if err == ErrDialTimeout { - return nil, err - } - idx++ - n-- + var conn net.Conn + n := uint32(len(addrs)) + deadline := time.Now().Add(timeout) + for n > 0 { + conn, err = tryDial(network, &addrs[idx%n], deadline, d.concurrencyCh) + if err == nil { + return conn, nil } - return nil, err + if err == ErrDialTimeout { + return nil, err + } + idx++ + n-- } + return nil, err } func tryDial(network string, addr *net.TCPAddr, deadline time.Time, concurrencyCh chan struct{}) (net.Conn, error) { @@ -203,28 +284,24 @@ func tryDial(network string, addr *net.TCPAddr, deadline time.Time, concurrencyC return nil, ErrDialTimeout } - select { - case concurrencyCh <- struct{}{}: - default: - tc := AcquireTimer(timeout) - isTimeout := false + if concurrencyCh != nil { select { case concurrencyCh <- struct{}{}: - case <-tc.C: - isTimeout = true - } - ReleaseTimer(tc) - if isTimeout { - return nil, ErrDialTimeout + default: + tc := AcquireTimer(timeout) + isTimeout := false + select { + case concurrencyCh <- struct{}{}: + case <-tc.C: + isTimeout = true + } + ReleaseTimer(tc) + if isTimeout { + return nil, ErrDialTimeout + } } } - timeout = -time.Since(deadline) - if timeout <= 0 { - <-concurrencyCh - return nil, ErrDialTimeout - } - chv := dialResultChanPool.Get() if chv == nil { chv = make(chan dialResult, 1) @@ -234,7 +311,9 @@ func tryDial(network string, addr *net.TCPAddr, deadline time.Time, concurrencyC var dr dialResult dr.conn, dr.err = net.DialTCP(network, nil, addr) ch <- dr - <-concurrencyCh + if concurrencyCh != nil { + <-concurrencyCh + } }() var ( @@ -282,7 +361,7 @@ type tcpAddrEntry struct { // by Dial* functions. const DefaultDNSCacheDuration = time.Minute -func (d *tcpDialer) tcpAddrsClean() { +func (d *TCPDialer) tcpAddrsClean() { expireDuration := 2 * DefaultDNSCacheDuration for { time.Sleep(time.Second) @@ -298,7 +377,7 @@ func (d *tcpDialer) tcpAddrsClean() { } } -func (d *tcpDialer) getTCPAddrs(addr string) ([]net.TCPAddr, uint32, error) { +func (d *TCPDialer) getTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, uint32, error) { d.tcpAddrsLock.Lock() e := d.tcpAddrsMap[addr] if e != nil && !e.pending && time.Since(e.resolveTime) > DefaultDNSCacheDuration { @@ -308,7 +387,7 @@ func (d *tcpDialer) getTCPAddrs(addr string) ([]net.TCPAddr, uint32, error) { d.tcpAddrsLock.Unlock() if e == nil { - addrs, err := resolveTCPAddrs(addr, d.DualStack) + addrs, err := resolveTCPAddrs(addr, dualStack) if err != nil { d.tcpAddrsLock.Lock() e = d.tcpAddrsMap[addr] -- cgit v1.2.3