diff options
author | Aliaksandr Valialkin <valyala@gmail.com> | 2015-12-07 16:12:28 +0200 |
---|---|---|
committer | Aliaksandr Valialkin <valyala@gmail.com> | 2015-12-07 16:12:28 +0200 |
commit | 9c70042061feba5a5f3f5b137a227cfd709d671e (patch) | |
tree | dbb0a301ede04e848ba3797b8fdffb68c40910b9 /tcpdialer.go | |
parent | gofmt (diff) | |
download | fasthttp-9c70042061feba5a5f3f5b137a227cfd709d671e.tar.gz fasthttp-9c70042061feba5a5f3f5b137a227cfd709d671e.tar.bz2 fasthttp-9c70042061feba5a5f3f5b137a227cfd709d671e.zip |
Exported default TCP dialers used by clients, so custom wrappers may be implemented around these dialers
Diffstat (limited to 'tcpdialer.go')
-rw-r--r-- | tcpdialer.go | 194 |
1 files changed, 194 insertions, 0 deletions
diff --git a/tcpdialer.go b/tcpdialer.go new file mode 100644 index 0000000..57cdfe2 --- /dev/null +++ b/tcpdialer.go @@ -0,0 +1,194 @@ +package fasthttp + +import ( + "fmt" + "net" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +// TCP dialers used by client. +// +// These dialers are intended for custom code wrapping before passing +// to Client.Dial or HostClient.Dial. +// +// For instance, per-host counters and/or limits may be implemented +// by such wrappers. +// +// The addr paseed to dial func may contain port. Example addr values: +// +// * google.com +// * foobar.baz:443 +// * foo.bar:80 +// * aaa.com:8080 +// +// Default port is appended to the addr if port is missing: +// +// * ':80' if Dial is used +// * ':443' if DialTLS is used +var ( + // Dial dials the given addr using tcp4. + // '80' port is used if port is missing in the addr passed to the func. + Dial = DialFunc((&tcpDialer{}).NewDial()) + + // DialTLS dials the given addr using tcp4. + // '443' port is used if port is missing in the addr passed to the func. + DialTLS = DialFunc((&tcpDialer{IsTLS: true}).NewDial()) + + // DialDualStack dials the given addr using both tcp4 and tcp6. + // '80' port is used if port is missing in the addr passed to the func. + DialDualStack = DialFunc((&tcpDialer{DualStack: true}).NewDial()) + + // DialTLSDualStack dials the given addr using both tcp4 and tcp6. + // '443' port is used if port is missing in the addr passed to the func. + DialTLSDualStack = DialFunc((&tcpDialer{IsTLS: true, DualStack: true}).NewDial()) +) + +// tcpDialer implements default TCP dialer for the Client and HostClient. +// +// tcpDialer instance copying is forbiddent. Create new instance instead. +type tcpDialer struct { + // Appends ':80' to the addr with missing port in Dial if set to false. + // Appends ':443' to the addr with missing port in Dial if set to true. + IsTLS bool + + // Set to true if you want simultaneously dialing tcp4 and tcp6. + DualStack bool + + tcpAddrsLock sync.Mutex + tcpAddrsMap map[string]*tcpAddrEntry +} + +func (d *tcpDialer) NewDial() DialFunc { + if d.tcpAddrsMap != nil { + panic("BUG: NewDial() already called") + } + + d.tcpAddrsMap = make(map[string]*tcpAddrEntry) + go d.tcpAddrsClean() + + return func(addr string) (net.Conn, error) { + tcpAddr, err := d.getTCPAddr(addr) + if err != nil { + return nil, err + } + network := "tcp4" + if d.DualStack { + network = "tcp" + } + return net.DialTCP(network, nil, tcpAddr) + } +} + +type tcpAddrEntry struct { + addrs []net.TCPAddr + addrsIdx uint32 + + resolveTime time.Time + pending bool +} + +var tcpAddrsCacheDuration = time.Minute + +func (d *tcpDialer) tcpAddrsClean() { + expireDuration := 2 * tcpAddrsCacheDuration + for { + time.Sleep(time.Second) + t := time.Now() + + d.tcpAddrsLock.Lock() + for k, e := range d.tcpAddrsMap { + if t.Sub(e.resolveTime) > expireDuration { + delete(d.tcpAddrsMap, k) + } + } + d.tcpAddrsLock.Unlock() + } +} + +func (d *tcpDialer) getTCPAddr(addr string) (*net.TCPAddr, error) { + addr = addMissingPort(addr, d.IsTLS) + + d.tcpAddrsLock.Lock() + e := d.tcpAddrsMap[addr] + if e != nil && !e.pending && time.Since(e.resolveTime) > tcpAddrsCacheDuration { + e.pending = true + e = nil + } + d.tcpAddrsLock.Unlock() + + if e == nil { + tcpAddrs, err := resolveTCPAddrs(addr, d.DualStack) + if err != nil { + d.tcpAddrsLock.Lock() + e = d.tcpAddrsMap[addr] + if e != nil && e.pending { + e.pending = false + } + d.tcpAddrsLock.Unlock() + return nil, err + } + + e = &tcpAddrEntry{ + addrs: tcpAddrs, + resolveTime: time.Now(), + } + + d.tcpAddrsLock.Lock() + d.tcpAddrsMap[addr] = e + d.tcpAddrsLock.Unlock() + } + + tcpAddr := &e.addrs[0] + n := len(e.addrs) + if n > 1 { + n := atomic.AddUint32(&e.addrsIdx, 1) + tcpAddr = &e.addrs[n%uint32(n)] + } + return tcpAddr, nil +} + +func addMissingPort(addr string, isTLS bool) string { + n := strings.Index(addr, ":") + if n >= 0 { + return addr + } + port := 80 + if isTLS { + port = 443 + } + return fmt.Sprintf("%s:%d", addr, port) +} + +func resolveTCPAddrs(addr string, dualStack bool) ([]net.TCPAddr, error) { + host, portS, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + port, err := strconv.Atoi(portS) + if err != nil { + return nil, err + } + + ips, err := net.LookupIP(host) + if err != nil { + return nil, err + } + + n := len(ips) + addrs := make([]net.TCPAddr, 0, n) + for i := 0; i < n; i++ { + ip := ips[i] + if !dualStack && ip.To4() == nil { + continue + } + addrs = append(addrs, net.TCPAddr{ + IP: ip, + Port: port, + }) + } + return addrs, nil +} |