From 7ea3b6330e929f5531600324c1da9fb7c61d1a8e Mon Sep 17 00:00:00 2001 From: Zhengkai Wang Date: Thu, 11 May 2023 16:07:58 +0800 Subject: add concurrency for client's HostClient map (#1550) * add the functions to get host clients * add concurrency for client's HostClient map * delete test code * add lock in once block --------- Co-authored-by: wangzhengkai.wzk --- client.go | 102 ++++++++++++++++++++++++++++++++------------------------------ 1 file changed, 53 insertions(+), 49 deletions(-) (limited to 'client.go') diff --git a/client.go b/client.go index a5c18bc..8800571 100644 --- a/client.go +++ b/client.go @@ -303,7 +303,8 @@ type Client struct { // ConfigureClient configures the fasthttp.HostClient. ConfigureClient func(hc *HostClient) error - mLock sync.Mutex + mLock sync.RWMutex + mOnce sync.Once m map[string]*HostClient ms map[string]*HostClient readerPool sync.Pool @@ -485,68 +486,71 @@ func (c *Client) Do(req *Request, resp *Response) error { return fmt.Errorf("unsupported protocol %q. http and https are supported", uri.Scheme()) } + c.mOnce.Do(func() { + c.mLock.Lock() + c.m = make(map[string]*HostClient) + c.ms = make(map[string]*HostClient) + c.mLock.Unlock() + }) + startCleaner := false - c.mLock.Lock() + c.mLock.RLock() m := c.m if isTLS { m = c.ms } - if m == nil { - m = make(map[string]*HostClient) - if isTLS { - c.ms = m - } else { - c.m = m - } - } hc := m[string(host)] + c.mLock.RUnlock() if hc == nil { - hc = &HostClient{ - Addr: AddMissingPort(string(host), isTLS), - Name: c.Name, - NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader, - Dial: c.Dial, - DialDualStack: c.DialDualStack, - IsTLS: isTLS, - TLSConfig: c.TLSConfig, - MaxConns: c.MaxConnsPerHost, - MaxIdleConnDuration: c.MaxIdleConnDuration, - MaxConnDuration: c.MaxConnDuration, - MaxIdemponentCallAttempts: c.MaxIdemponentCallAttempts, - ReadBufferSize: c.ReadBufferSize, - WriteBufferSize: c.WriteBufferSize, - ReadTimeout: c.ReadTimeout, - WriteTimeout: c.WriteTimeout, - MaxResponseBodySize: c.MaxResponseBodySize, - DisableHeaderNamesNormalizing: c.DisableHeaderNamesNormalizing, - DisablePathNormalizing: c.DisablePathNormalizing, - MaxConnWaitTimeout: c.MaxConnWaitTimeout, - RetryIf: c.RetryIf, - ConnPoolStrategy: c.ConnPoolStrategy, - StreamResponseBody: c.StreamResponseBody, - clientReaderPool: &c.readerPool, - clientWriterPool: &c.writerPool, - } - - if c.ConfigureClient != nil { - if err := c.ConfigureClient(hc); err != nil { - c.mLock.Unlock() - return err + c.mLock.Lock() + hc = m[string(host)] + if hc == nil { + hc = &HostClient{ + Addr: AddMissingPort(string(host), isTLS), + Name: c.Name, + NoDefaultUserAgentHeader: c.NoDefaultUserAgentHeader, + Dial: c.Dial, + DialDualStack: c.DialDualStack, + IsTLS: isTLS, + TLSConfig: c.TLSConfig, + MaxConns: c.MaxConnsPerHost, + MaxIdleConnDuration: c.MaxIdleConnDuration, + MaxConnDuration: c.MaxConnDuration, + MaxIdemponentCallAttempts: c.MaxIdemponentCallAttempts, + ReadBufferSize: c.ReadBufferSize, + WriteBufferSize: c.WriteBufferSize, + ReadTimeout: c.ReadTimeout, + WriteTimeout: c.WriteTimeout, + MaxResponseBodySize: c.MaxResponseBodySize, + DisableHeaderNamesNormalizing: c.DisableHeaderNamesNormalizing, + DisablePathNormalizing: c.DisablePathNormalizing, + MaxConnWaitTimeout: c.MaxConnWaitTimeout, + RetryIf: c.RetryIf, + ConnPoolStrategy: c.ConnPoolStrategy, + StreamResponseBody: c.StreamResponseBody, + clientReaderPool: &c.readerPool, + clientWriterPool: &c.writerPool, } - } - m[string(host)] = hc - if len(m) == 1 { - startCleaner = true + if c.ConfigureClient != nil { + if err := c.ConfigureClient(hc); err != nil { + c.mLock.Unlock() + return err + } + } + + m[string(host)] = hc + if len(m) == 1 { + startCleaner = true + } } + c.mLock.Unlock() } atomic.AddInt32(&hc.pendingClientRequests, 1) defer atomic.AddInt32(&hc.pendingClientRequests, -1) - c.mLock.Unlock() - if startCleaner { go c.mCleaner(m) } @@ -559,14 +563,14 @@ func (c *Client) Do(req *Request, resp *Response) error { // "keep-alive" state. It does not interrupt any connections currently // in use. func (c *Client) CloseIdleConnections() { - c.mLock.Lock() + c.mLock.RLock() for _, v := range c.m { v.CloseIdleConnections() } for _, v := range c.ms { v.CloseIdleConnections() } - c.mLock.Unlock() + c.mLock.RUnlock() } func (c *Client) mCleaner(m map[string]*HostClient) { -- cgit v1.2.3