aboutsummaryrefslogtreecommitdiff
path: root/client_test.go
diff options
context:
space:
mode:
authorGravatar Aviv Carmi <avivcarmis@gmail.com> 2023-11-27 14:46:43 +0200
committerGravatar GitHub <noreply@github.com> 2023-11-27 13:46:43 +0100
commit8ca7a9c89c43a97658352651f5f66e16278164f6 (patch)
treec1ebf35c6e98ce83aab7564cfdf47a534e616bba /client_test.go
parentchore: Use 'any' instead of 'interface{}' (#1666) (diff)
downloadfasthttp-8ca7a9c89c43a97658352651f5f66e16278164f6.tar.gz
fasthttp-8ca7a9c89c43a97658352651f5f66e16278164f6.tar.bz2
fasthttp-8ca7a9c89c43a97658352651f5f66e16278164f6.zip
add support for custom dial function with timeouts (#1669)
* add support for custom dial function with timeouts * fix linting --------- Co-authored-by: Aviv Carmi <aviv@perimeterx.com>
Diffstat (limited to 'client_test.go')
-rw-r--r--client_test.go92
1 files changed, 92 insertions, 0 deletions
diff --git a/client_test.go b/client_test.go
index e6a1358..1802383 100644
--- a/client_test.go
+++ b/client_test.go
@@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"crypto/tls"
+ "errors"
"fmt"
"io"
"net"
@@ -3392,3 +3393,94 @@ func Test_getRedirectURL(t *testing.T) {
})
}
}
+
+type clientDoTimeOuter interface {
+ DoTimeout(req *Request, resp *Response, timeout time.Duration) error
+}
+
+func TestDialTimeout(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ client clientDoTimeOuter
+ requestTimeout time.Duration
+ shouldFailFast bool
+ }{
+ {
+ name: "Client should fail after a millisecond due to request timeout",
+ client: &Client{
+ // should be ignored due to DialTimeout
+ Dial: func(addr string) (net.Conn, error) {
+ time.Sleep(time.Second)
+ return nil, errors.New("timeout")
+ },
+ // should be used
+ DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) {
+ time.Sleep(timeout)
+ return nil, errors.New("timeout")
+ },
+ },
+ requestTimeout: time.Millisecond,
+ shouldFailFast: true,
+ },
+ {
+ name: "Client should fail after a second due to no DialTimeout set",
+ client: &Client{
+ Dial: func(addr string) (net.Conn, error) {
+ time.Sleep(time.Second)
+ return nil, errors.New("timeout")
+ },
+ },
+ requestTimeout: time.Millisecond,
+ shouldFailFast: false,
+ },
+ {
+ name: "HostClient should fail after a millisecond due to request timeout",
+ client: &HostClient{
+ // should be ignored due to DialTimeout
+ Dial: func(addr string) (net.Conn, error) {
+ time.Sleep(time.Second)
+ return nil, errors.New("timeout")
+ },
+ // should be used
+ DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) {
+ time.Sleep(timeout)
+ return nil, errors.New("timeout")
+ },
+ },
+ requestTimeout: time.Millisecond,
+ shouldFailFast: true,
+ },
+ {
+ name: "HostClient should fail after a second due to no DialTimeout set",
+ client: &HostClient{
+ Dial: func(addr string) (net.Conn, error) {
+ time.Sleep(time.Second)
+ return nil, errors.New("timeout")
+ },
+ },
+ requestTimeout: time.Millisecond,
+ shouldFailFast: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ start := time.Now()
+ err := tt.client.DoTimeout(&Request{}, &Response{}, tt.requestTimeout)
+ if err == nil {
+ t.Fatal("expected error (timeout)")
+ }
+ if tt.shouldFailFast {
+ if time.Since(start) > time.Second {
+ t.Fatal("expected timeout after a millisecond")
+ }
+ } else {
+ if time.Since(start) < time.Second {
+ t.Fatal("expected timeout after a second")
+ }
+ }
+ })
+ }
+}