aboutsummaryrefslogtreecommitdiff
path: root/tcpdialer.go
diff options
context:
space:
mode:
authorGravatar Erik Dubbelboer <erik@dubbelboer.com> 2019-01-03 16:31:58 +0100
committerGravatar Erik Dubbelboer <erik@dubbelboer.com> 2019-02-04 09:11:02 +0000
commit45548243d7c46ace0ca0ad69969bf3fec2c480dc (patch)
tree54e41f1c1e46ef423967fa8e9f36e40464186f70 /tcpdialer.go
parentchange timer to public api #525 (#527) (diff)
downloadfasthttp-45548243d7c46ace0ca0ad69969bf3fec2c480dc.tar.gz
fasthttp-45548243d7c46ace0ca0ad69969bf3fec2c480dc.tar.bz2
fasthttp-45548243d7c46ace0ca0ad69969bf3fec2c480dc.zip
Add TCPDialer
Diffstat (limited to 'tcpdialer.go')
-rw-r--r--tcpdialer.go247
1 files changed, 163 insertions, 84 deletions
diff --git a/tcpdialer.go b/tcpdialer.go
index 906dfdc..6a5cd3a 100644
--- a/tcpdialer.go
+++ b/tcpdialer.go
@@ -33,7 +33,7 @@ import (
// * foo.bar:80
// * aaa.com:8080
func Dial(addr string) (net.Conn, error) {
- return getDialer(DefaultDialTimeout, false)(addr)
+ return defaultDialer.Dial(addr)
}
// DialTimeout dials the given TCP addr using tcp4 using the given timeout.
@@ -58,7 +58,7 @@ func Dial(addr string) (net.Conn, error) {
// * foo.bar:80
// * aaa.com:8080
func DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
- return getDialer(timeout, false)(addr)
+ return defaultDialer.DialTimeout(addr, timeout)
}
// DialDualStack dials the given TCP addr using both tcp4 and tcp6.
@@ -86,7 +86,7 @@ func DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
// * foo.bar:80
// * aaa.com:8080
func DialDualStack(addr string) (net.Conn, error) {
- return getDialer(DefaultDialTimeout, true)(addr)
+ return defaultDialer.DialDualStack(addr)
}
// DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6
@@ -112,45 +112,22 @@ func DialDualStack(addr string) (net.Conn, error) {
// * foo.bar:80
// * aaa.com:8080
func DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) {
- return getDialer(timeout, true)(addr)
-}
-
-func getDialer(timeout time.Duration, dualStack bool) DialFunc {
- if timeout <= 0 {
- timeout = DefaultDialTimeout
- }
- timeoutRounded := int(timeout.Seconds()*10 + 9)
-
- m := dialMap
- if dualStack {
- m = dialDualStackMap
- }
-
- dialMapLock.Lock()
- d := m[timeoutRounded]
- if d == nil {
- dialer := dialerStd
- if dualStack {
- dialer = dialerDualStack
- }
- d = dialer.NewDial(timeout)
- m[timeoutRounded] = d
- }
- dialMapLock.Unlock()
- return d
+ return defaultDialer.DialDualStackTimeout(addr, timeout)
}
var (
- dialerStd = &tcpDialer{}
- dialerDualStack = &tcpDialer{DualStack: true}
-
- dialMap = make(map[int]DialFunc)
- dialDualStackMap = make(map[int]DialFunc)
- dialMapLock sync.Mutex
+ defaultDialer = &TCPDialer{Concurrency: 1000}
)
-type tcpDialer struct {
- DualStack bool
+// TCPDialer contains options to control a group of Dial calls.
+type TCPDialer struct {
+ // Concurrency controls the maximum number of concurrent Dails
+ // that can be performed using this object.
+ // Setting this to 0 means unlimited.
+ //
+ // WARNING: This can only be changed before the first Dial.
+ // Changes made after the first Dial will not affect anything.
+ Concurrency int
tcpAddrsLock sync.Mutex
tcpAddrsMap map[string]*tcpAddrEntry
@@ -160,41 +137,145 @@ type tcpDialer struct {
once sync.Once
}
-const maxDialConcurrency = 1000
+// Dial dials the given TCP addr using tcp4.
+//
+// This function has the following additional features comparing to net.Dial:
+//
+// * It reduces load on DNS resolver by caching resolved TCP addressed
+// for DefaultDNSCacheDuration.
+// * It dials all the resolved TCP addresses in round-robin manner until
+// connection is established. This may be useful if certain addresses
+// are temporarily unreachable.
+// * It returns ErrDialTimeout if connection cannot be established during
+// DefaultDialTimeout seconds. Use DialTimeout for customizing dial timeout.
+//
+// This dialer is intended for custom code wrapping before passing
+// to Client.Dial or HostClient.Dial.
+//
+// For instance, per-host counters and/or limits may be implemented
+// by such wrappers.
+//
+// The addr passed to the function must contain port. Example addr values:
+//
+// * foobar.baz:443
+// * foo.bar:80
+// * aaa.com:8080
+func (d *TCPDialer) Dial(addr string) (net.Conn, error) {
+ return d.dial(addr, false, DefaultDialTimeout)
+}
+
+// DialTimeout dials the given TCP addr using tcp4 using the given timeout.
+//
+// This function has the following additional features comparing to net.Dial:
+//
+// * It reduces load on DNS resolver by caching resolved TCP addressed
+// for DefaultDNSCacheDuration.
+// * It dials all the resolved TCP addresses in round-robin manner until
+// connection is established. This may be useful if certain addresses
+// are temporarily unreachable.
+//
+// This dialer is intended for custom code wrapping before passing
+// to Client.Dial or HostClient.Dial.
+//
+// For instance, per-host counters and/or limits may be implemented
+// by such wrappers.
+//
+// The addr passed to the function must contain port. Example addr values:
+//
+// * foobar.baz:443
+// * foo.bar:80
+// * aaa.com:8080
+func (d *TCPDialer) DialTimeout(addr string, timeout time.Duration) (net.Conn, error) {
+ return d.dial(addr, false, timeout)
+}
+
+// DialDualStack dials the given TCP addr using both tcp4 and tcp6.
+//
+// This function has the following additional features comparing to net.Dial:
+//
+// * It reduces load on DNS resolver by caching resolved TCP addressed
+// for DefaultDNSCacheDuration.
+// * It dials all the resolved TCP addresses in round-robin manner until
+// connection is established. This may be useful if certain addresses
+// are temporarily unreachable.
+// * It returns ErrDialTimeout if connection cannot be established during
+// DefaultDialTimeout seconds. Use DialDualStackTimeout for custom dial
+// timeout.
+//
+// This dialer is intended for custom code wrapping before passing
+// to Client.Dial or HostClient.Dial.
+//
+// For instance, per-host counters and/or limits may be implemented
+// by such wrappers.
+//
+// The addr passed to the function must contain port. Example addr values:
+//
+// * foobar.baz:443
+// * foo.bar:80
+// * aaa.com:8080
+func (d *TCPDialer) DialDualStack(addr string) (net.Conn, error) {
+ return d.dial(addr, true, DefaultDialTimeout)
+}
+
+// DialDualStackTimeout dials the given TCP addr using both tcp4 and tcp6
+// using the given timeout.
+//
+// This function has the following additional features comparing to net.Dial:
+//
+// * It reduces load on DNS resolver by caching resolved TCP addressed
+// for DefaultDNSCacheDuration.
+// * It dials all the resolved TCP addresses in round-robin manner until
+// connection is established. This may be useful if certain addresses
+// are temporarily unreachable.
+//
+// This dialer is intended for custom code wrapping before passing
+// to Client.Dial or HostClient.Dial.
+//
+// For instance, per-host counters and/or limits may be implemented
+// by such wrappers.
+//
+// The addr passed to the function must contain port. Example addr values:
+//
+// * foobar.baz:443
+// * foo.bar:80
+// * aaa.com:8080
+func (d *TCPDialer) DialDualStackTimeout(addr string, timeout time.Duration) (net.Conn, error) {
+ return d.dial(addr, true, timeout)
+}
-func (d *tcpDialer) NewDial(timeout time.Duration) DialFunc {
+func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (net.Conn, error) {
d.once.Do(func() {
- d.concurrencyCh = make(chan struct{}, maxDialConcurrency)
+ if d.Concurrency > 0 {
+ d.concurrencyCh = make(chan struct{}, d.Concurrency)
+ }
d.tcpAddrsMap = make(map[string]*tcpAddrEntry)
go d.tcpAddrsClean()
})
- return func(addr string) (net.Conn, error) {
- addrs, idx, err := d.getTCPAddrs(addr)
- if err != nil {
- return nil, err
- }
- network := "tcp4"
- if d.DualStack {
- network = "tcp"
- }
+ addrs, idx, err := d.getTCPAddrs(addr, dualStack)
+ if err != nil {
+ return nil, err
+ }
+ network := "tcp4"
+ if dualStack {
+ network = "tcp"
+ }
- var conn net.Conn
- n := uint32(len(addrs))
- deadline := time.Now().Add(timeout)
- for n > 0 {
- conn, err = tryDial(network, &addrs[idx%n], deadline, d.concurrencyCh)
- if err == nil {
- return conn, nil
- }
- if err == ErrDialTimeout {
- return nil, err
- }
- idx++
- n--
+ var conn net.Conn
+ n := uint32(len(addrs))
+ deadline := time.Now().Add(timeout)
+ for n > 0 {
+ conn, err = tryDial(network, &addrs[idx%n], deadline, d.concurrencyCh)
+ if err == nil {
+ return conn, nil
}
- return nil, err
+ if err == ErrDialTimeout {
+ return nil, err
+ }
+ idx++
+ n--
}
+ return nil, err
}
func tryDial(network string, addr *net.TCPAddr, deadline time.Time, concurrencyCh chan struct{}) (net.Conn, error) {
@@ -203,28 +284,24 @@ func tryDial(network string, addr *net.TCPAddr, deadline time.Time, concurrencyC
return nil, ErrDialTimeout
}
- select {
- case concurrencyCh <- struct{}{}:
- default:
- tc := AcquireTimer(timeout)
- isTimeout := false
+ if concurrencyCh != nil {
select {
case concurrencyCh <- struct{}{}:
- case <-tc.C:
- isTimeout = true
- }
- ReleaseTimer(tc)
- if isTimeout {
- return nil, ErrDialTimeout
+ default:
+ tc := AcquireTimer(timeout)
+ isTimeout := false
+ select {
+ case concurrencyCh <- struct{}{}:
+ case <-tc.C:
+ isTimeout = true
+ }
+ ReleaseTimer(tc)
+ if isTimeout {
+ return nil, ErrDialTimeout
+ }
}
}
- timeout = -time.Since(deadline)
- if timeout <= 0 {
- <-concurrencyCh
- return nil, ErrDialTimeout
- }
-
chv := dialResultChanPool.Get()
if chv == nil {
chv = make(chan dialResult, 1)
@@ -234,7 +311,9 @@ func tryDial(network string, addr *net.TCPAddr, deadline time.Time, concurrencyC
var dr dialResult
dr.conn, dr.err = net.DialTCP(network, nil, addr)
ch <- dr
- <-concurrencyCh
+ if concurrencyCh != nil {
+ <-concurrencyCh
+ }
}()
var (
@@ -282,7 +361,7 @@ type tcpAddrEntry struct {
// by Dial* functions.
const DefaultDNSCacheDuration = time.Minute
-func (d *tcpDialer) tcpAddrsClean() {
+func (d *TCPDialer) tcpAddrsClean() {
expireDuration := 2 * DefaultDNSCacheDuration
for {
time.Sleep(time.Second)
@@ -298,7 +377,7 @@ func (d *tcpDialer) tcpAddrsClean() {
}
}
-func (d *tcpDialer) getTCPAddrs(addr string) ([]net.TCPAddr, uint32, error) {
+func (d *TCPDialer) getTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, uint32, error) {
d.tcpAddrsLock.Lock()
e := d.tcpAddrsMap[addr]
if e != nil && !e.pending && time.Since(e.resolveTime) > DefaultDNSCacheDuration {
@@ -308,7 +387,7 @@ func (d *tcpDialer) getTCPAddrs(addr string) ([]net.TCPAddr, uint32, error) {
d.tcpAddrsLock.Unlock()
if e == nil {
- addrs, err := resolveTCPAddrs(addr, d.DualStack)
+ addrs, err := resolveTCPAddrs(addr, dualStack)
if err != nil {
d.tcpAddrsLock.Lock()
e = d.tcpAddrsMap[addr]