diff options
-rw-r--r-- | .github/workflows/lint.yml | 2 | ||||
-rw-r--r-- | client.go | 5 | ||||
-rw-r--r-- | fs.go | 6 | ||||
-rw-r--r-- | fuzz_test.go | 10 | ||||
-rw-r--r-- | peripconn.go | 46 | ||||
-rw-r--r-- | peripconn_test.go | 2 | ||||
-rw-r--r-- | tcpdialer.go | 47 |
7 files changed, 96 insertions, 22 deletions
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index c250bf7..301af69 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -23,7 +23,7 @@ jobs: go-version: 1.20.x - run: go version - name: Run golangci-lint - uses: golangci/golangci-lint-action@v4 + uses: golangci/golangci-lint-action@v6 with: version: v1.56.2 args: --verbose @@ -2,6 +2,7 @@ package fasthttp import ( "bufio" + "bytes" "crypto/tls" "errors" "fmt" @@ -477,6 +478,10 @@ func (c *Client) Do(req *Request, resp *Response) error { host := uri.Host() + if bytes.ContainsRune(host, ',') { + return fmt.Errorf("invalid host %q. Use HostClient for multiple hosts", host) + } + isTLS := false if uri.isHTTPS() { isTLS = true @@ -1233,6 +1233,12 @@ func (h *fsHandler) openIndexFile(ctx *RequestCtx, dirPath string, mustCompress if err == nil { return ff, nil } + if mustCompress && err == errNoCreatePermission { + ctx.Logger().Printf("insufficient permissions for saving compressed file for %q. Serving uncompressed file. "+ + "Allow write access to the directory with this file in order to improve fasthttp performance", indexFilePath) + mustCompress = false + return h.openFSFile(indexFilePath, mustCompress, fileEncoding) + } if !errors.Is(err, fs.ErrNotExist) { return nil, fmt.Errorf("cannot open file %q: %w", indexFilePath, err) } diff --git a/fuzz_test.go b/fuzz_test.go index 532c052..ba1737a 100644 --- a/fuzz_test.go +++ b/fuzz_test.go @@ -47,11 +47,10 @@ func FuzzResponseReadLimitBody(f *testing.F) { return } - var res Response + res := AcquireResponse() + defer ReleaseResponse(res) _ = res.ReadLimitBody(bufio.NewReader(bytes.NewReader(body)), max) - w := bytes.Buffer{} - _, _ = res.WriteTo(&w) }) } @@ -63,11 +62,10 @@ func FuzzRequestReadLimitBody(f *testing.F) { return } - var req Request + req := AcquireRequest() + defer ReleaseRequest(req) _ = req.ReadLimitBody(bufio.NewReader(bytes.NewReader(body)), max) - w := bytes.Buffer{} - _, _ = req.WriteTo(&w) }) } diff --git a/peripconn.go b/peripconn.go index 123c55e..46bddbf 100644 --- a/peripconn.go +++ b/peripconn.go @@ -1,14 +1,16 @@ package fasthttp import ( + "crypto/tls" "net" "sync" ) type perIPConnCounter struct { - pool sync.Pool - lock sync.Mutex - m map[uint32]int + perIPConnPool sync.Pool + perIPTLSConnPool sync.Pool + lock sync.Mutex + m map[uint32]int } func (cc *perIPConnCounter) Register(ip uint32) int { @@ -43,8 +45,30 @@ type perIPConn struct { perIPConnCounter *perIPConnCounter } -func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) *perIPConn { - v := counter.pool.Get() +type perIPTLSConn struct { + *tls.Conn + + ip uint32 + perIPConnCounter *perIPConnCounter +} + +func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) net.Conn { + if tlcConn, ok := conn.(*tls.Conn); ok { + v := counter.perIPTLSConnPool.Get() + if v == nil { + return &perIPTLSConn{ + perIPConnCounter: counter, + Conn: tlcConn, + ip: ip, + } + } + c := v.(*perIPConn) + c.Conn = conn + c.ip = ip + return c + } + + v := counter.perIPConnPool.Get() if v == nil { return &perIPConn{ perIPConnCounter: counter, @@ -58,15 +82,19 @@ func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) *perI return c } -func releasePerIPConn(c *perIPConn) { +func (c *perIPConn) Close() error { + err := c.Conn.Close() + c.perIPConnCounter.Unregister(c.ip) c.Conn = nil - c.perIPConnCounter.pool.Put(c) + c.perIPConnCounter.perIPConnPool.Put(c) + return err } -func (c *perIPConn) Close() error { +func (c *perIPTLSConn) Close() error { err := c.Conn.Close() c.perIPConnCounter.Unregister(c.ip) - releasePerIPConn(c) + c.Conn = nil + c.perIPConnCounter.perIPTLSConnPool.Put(c) return err } diff --git a/peripconn_test.go b/peripconn_test.go index 5571654..6bfccf1 100644 --- a/peripconn_test.go +++ b/peripconn_test.go @@ -4,6 +4,8 @@ import ( "testing" ) +var _ connTLSer = &perIPTLSConn{} + func TestIPxUint32(t *testing.T) { t.Parallel() diff --git a/tcpdialer.go b/tcpdialer.go index e8430cb..e5f06bd 100644 --- a/tcpdialer.go +++ b/tcpdialer.go @@ -3,6 +3,7 @@ package fasthttp import ( "context" "errors" + "fmt" "net" "strconv" "sync" @@ -302,7 +303,7 @@ func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (ne if err == nil { return conn, nil } - if err == ErrDialTimeout { + if errors.Is(err, ErrDialTimeout) { return nil, err } idx++ @@ -316,7 +317,7 @@ func (d *TCPDialer) tryDial( ) (net.Conn, error) { timeout := time.Until(deadline) if timeout <= 0 { - return nil, ErrDialTimeout + return nil, wrapDialWithUpstream(ErrDialTimeout, addr) } if concurrencyCh != nil { @@ -332,7 +333,7 @@ func (d *TCPDialer) tryDial( } ReleaseTimer(tc) if isTimeout { - return nil, ErrDialTimeout + return nil, wrapDialWithUpstream(ErrDialTimeout, addr) } } defer func() { <-concurrencyCh }() @@ -346,15 +347,49 @@ func (d *TCPDialer) tryDial( ctx, cancelCtx := context.WithDeadline(context.Background(), deadline) defer cancelCtx() conn, err := dialer.DialContext(ctx, network, addr) - if err != nil && ctx.Err() == context.DeadlineExceeded { - return nil, ErrDialTimeout + if err != nil { + if ctx.Err() == context.DeadlineExceeded { + return nil, wrapDialWithUpstream(ErrDialTimeout, addr) + } + return nil, wrapDialWithUpstream(err, addr) } - return conn, err + return conn, nil } // ErrDialTimeout is returned when TCP dialing is timed out. var ErrDialTimeout = errors.New("dialing to the given TCP address timed out") +// ErrDialWithUpstream wraps dial error with upstream info. +// +// Should use errors.As to get upstream information from error: +// +// hc := fasthttp.HostClient{Addr: "foo.com,bar.com"} +// err := hc.Do(req, res) +// +// var dialErr *fasthttp.ErrDialWithUpstream +// if errors.As(err, &dialErr) { +// upstream = dialErr.Upstream // 34.206.39.153:80 +// } +type ErrDialWithUpstream struct { + Upstream string + wrapErr error +} + +func (e *ErrDialWithUpstream) Error() string { + return fmt.Sprintf("error when dialing %s: %s", e.Upstream, e.wrapErr.Error()) +} + +func (e *ErrDialWithUpstream) Unwrap() error { + return e.wrapErr +} + +func wrapDialWithUpstream(err error, upstream string) error { + return &ErrDialWithUpstream{ + Upstream: upstream, + wrapErr: err, + } +} + // DefaultDialTimeout is timeout used by Dial and DialDualStack // for establishing TCP connections. const DefaultDialTimeout = 3 * time.Second |