aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.github/workflows/lint.yml2
-rw-r--r--client.go5
-rw-r--r--fs.go6
-rw-r--r--fuzz_test.go10
-rw-r--r--peripconn.go46
-rw-r--r--peripconn_test.go2
-rw-r--r--tcpdialer.go47
7 files changed, 96 insertions, 22 deletions
diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml
index c250bf7..301af69 100644
--- a/.github/workflows/lint.yml
+++ b/.github/workflows/lint.yml
@@ -23,7 +23,7 @@ jobs:
go-version: 1.20.x
- run: go version
- name: Run golangci-lint
- uses: golangci/golangci-lint-action@v4
+ uses: golangci/golangci-lint-action@v6
with:
version: v1.56.2
args: --verbose
diff --git a/client.go b/client.go
index 1f12d4a..5cae78d 100644
--- a/client.go
+++ b/client.go
@@ -2,6 +2,7 @@ package fasthttp
import (
"bufio"
+ "bytes"
"crypto/tls"
"errors"
"fmt"
@@ -477,6 +478,10 @@ func (c *Client) Do(req *Request, resp *Response) error {
host := uri.Host()
+ if bytes.ContainsRune(host, ',') {
+ return fmt.Errorf("invalid host %q. Use HostClient for multiple hosts", host)
+ }
+
isTLS := false
if uri.isHTTPS() {
isTLS = true
diff --git a/fs.go b/fs.go
index 59638ad..9e15a0e 100644
--- a/fs.go
+++ b/fs.go
@@ -1233,6 +1233,12 @@ func (h *fsHandler) openIndexFile(ctx *RequestCtx, dirPath string, mustCompress
if err == nil {
return ff, nil
}
+ if mustCompress && err == errNoCreatePermission {
+ ctx.Logger().Printf("insufficient permissions for saving compressed file for %q. Serving uncompressed file. "+
+ "Allow write access to the directory with this file in order to improve fasthttp performance", indexFilePath)
+ mustCompress = false
+ return h.openFSFile(indexFilePath, mustCompress, fileEncoding)
+ }
if !errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("cannot open file %q: %w", indexFilePath, err)
}
diff --git a/fuzz_test.go b/fuzz_test.go
index 532c052..ba1737a 100644
--- a/fuzz_test.go
+++ b/fuzz_test.go
@@ -47,11 +47,10 @@ func FuzzResponseReadLimitBody(f *testing.F) {
return
}
- var res Response
+ res := AcquireResponse()
+ defer ReleaseResponse(res)
_ = res.ReadLimitBody(bufio.NewReader(bytes.NewReader(body)), max)
- w := bytes.Buffer{}
- _, _ = res.WriteTo(&w)
})
}
@@ -63,11 +62,10 @@ func FuzzRequestReadLimitBody(f *testing.F) {
return
}
- var req Request
+ req := AcquireRequest()
+ defer ReleaseRequest(req)
_ = req.ReadLimitBody(bufio.NewReader(bytes.NewReader(body)), max)
- w := bytes.Buffer{}
- _, _ = req.WriteTo(&w)
})
}
diff --git a/peripconn.go b/peripconn.go
index 123c55e..46bddbf 100644
--- a/peripconn.go
+++ b/peripconn.go
@@ -1,14 +1,16 @@
package fasthttp
import (
+ "crypto/tls"
"net"
"sync"
)
type perIPConnCounter struct {
- pool sync.Pool
- lock sync.Mutex
- m map[uint32]int
+ perIPConnPool sync.Pool
+ perIPTLSConnPool sync.Pool
+ lock sync.Mutex
+ m map[uint32]int
}
func (cc *perIPConnCounter) Register(ip uint32) int {
@@ -43,8 +45,30 @@ type perIPConn struct {
perIPConnCounter *perIPConnCounter
}
-func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) *perIPConn {
- v := counter.pool.Get()
+type perIPTLSConn struct {
+ *tls.Conn
+
+ ip uint32
+ perIPConnCounter *perIPConnCounter
+}
+
+func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) net.Conn {
+ if tlcConn, ok := conn.(*tls.Conn); ok {
+ v := counter.perIPTLSConnPool.Get()
+ if v == nil {
+ return &perIPTLSConn{
+ perIPConnCounter: counter,
+ Conn: tlcConn,
+ ip: ip,
+ }
+ }
+ c := v.(*perIPConn)
+ c.Conn = conn
+ c.ip = ip
+ return c
+ }
+
+ v := counter.perIPConnPool.Get()
if v == nil {
return &perIPConn{
perIPConnCounter: counter,
@@ -58,15 +82,19 @@ func acquirePerIPConn(conn net.Conn, ip uint32, counter *perIPConnCounter) *perI
return c
}
-func releasePerIPConn(c *perIPConn) {
+func (c *perIPConn) Close() error {
+ err := c.Conn.Close()
+ c.perIPConnCounter.Unregister(c.ip)
c.Conn = nil
- c.perIPConnCounter.pool.Put(c)
+ c.perIPConnCounter.perIPConnPool.Put(c)
+ return err
}
-func (c *perIPConn) Close() error {
+func (c *perIPTLSConn) Close() error {
err := c.Conn.Close()
c.perIPConnCounter.Unregister(c.ip)
- releasePerIPConn(c)
+ c.Conn = nil
+ c.perIPConnCounter.perIPTLSConnPool.Put(c)
return err
}
diff --git a/peripconn_test.go b/peripconn_test.go
index 5571654..6bfccf1 100644
--- a/peripconn_test.go
+++ b/peripconn_test.go
@@ -4,6 +4,8 @@ import (
"testing"
)
+var _ connTLSer = &perIPTLSConn{}
+
func TestIPxUint32(t *testing.T) {
t.Parallel()
diff --git a/tcpdialer.go b/tcpdialer.go
index e8430cb..e5f06bd 100644
--- a/tcpdialer.go
+++ b/tcpdialer.go
@@ -3,6 +3,7 @@ package fasthttp
import (
"context"
"errors"
+ "fmt"
"net"
"strconv"
"sync"
@@ -302,7 +303,7 @@ func (d *TCPDialer) dial(addr string, dualStack bool, timeout time.Duration) (ne
if err == nil {
return conn, nil
}
- if err == ErrDialTimeout {
+ if errors.Is(err, ErrDialTimeout) {
return nil, err
}
idx++
@@ -316,7 +317,7 @@ func (d *TCPDialer) tryDial(
) (net.Conn, error) {
timeout := time.Until(deadline)
if timeout <= 0 {
- return nil, ErrDialTimeout
+ return nil, wrapDialWithUpstream(ErrDialTimeout, addr)
}
if concurrencyCh != nil {
@@ -332,7 +333,7 @@ func (d *TCPDialer) tryDial(
}
ReleaseTimer(tc)
if isTimeout {
- return nil, ErrDialTimeout
+ return nil, wrapDialWithUpstream(ErrDialTimeout, addr)
}
}
defer func() { <-concurrencyCh }()
@@ -346,15 +347,49 @@ func (d *TCPDialer) tryDial(
ctx, cancelCtx := context.WithDeadline(context.Background(), deadline)
defer cancelCtx()
conn, err := dialer.DialContext(ctx, network, addr)
- if err != nil && ctx.Err() == context.DeadlineExceeded {
- return nil, ErrDialTimeout
+ if err != nil {
+ if ctx.Err() == context.DeadlineExceeded {
+ return nil, wrapDialWithUpstream(ErrDialTimeout, addr)
+ }
+ return nil, wrapDialWithUpstream(err, addr)
}
- return conn, err
+ return conn, nil
}
// ErrDialTimeout is returned when TCP dialing is timed out.
var ErrDialTimeout = errors.New("dialing to the given TCP address timed out")
+// ErrDialWithUpstream wraps dial error with upstream info.
+//
+// Should use errors.As to get upstream information from error:
+//
+// hc := fasthttp.HostClient{Addr: "foo.com,bar.com"}
+// err := hc.Do(req, res)
+//
+// var dialErr *fasthttp.ErrDialWithUpstream
+// if errors.As(err, &dialErr) {
+// upstream = dialErr.Upstream // 34.206.39.153:80
+// }
+type ErrDialWithUpstream struct {
+ Upstream string
+ wrapErr error
+}
+
+func (e *ErrDialWithUpstream) Error() string {
+ return fmt.Sprintf("error when dialing %s: %s", e.Upstream, e.wrapErr.Error())
+}
+
+func (e *ErrDialWithUpstream) Unwrap() error {
+ return e.wrapErr
+}
+
+func wrapDialWithUpstream(err error, upstream string) error {
+ return &ErrDialWithUpstream{
+ Upstream: upstream,
+ wrapErr: err,
+ }
+}
+
// DefaultDialTimeout is timeout used by Dial and DialDualStack
// for establishing TCP connections.
const DefaultDialTimeout = 3 * time.Second