diff options
author | Aviv Carmi <avivcarmis@gmail.com> | 2023-11-27 14:46:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-11-27 13:46:43 +0100 |
commit | 8ca7a9c89c43a97658352651f5f66e16278164f6 (patch) | |
tree | c1ebf35c6e98ce83aab7564cfdf47a534e616bba /client_test.go | |
parent | chore: Use 'any' instead of 'interface{}' (#1666) (diff) | |
download | fasthttp-8ca7a9c89c43a97658352651f5f66e16278164f6.tar.gz fasthttp-8ca7a9c89c43a97658352651f5f66e16278164f6.tar.bz2 fasthttp-8ca7a9c89c43a97658352651f5f66e16278164f6.zip |
add support for custom dial function with timeouts (#1669)
* add support for custom dial function with timeouts
* fix linting
---------
Co-authored-by: Aviv Carmi <aviv@perimeterx.com>
Diffstat (limited to 'client_test.go')
-rw-r--r-- | client_test.go | 92 |
1 files changed, 92 insertions, 0 deletions
diff --git a/client_test.go b/client_test.go index e6a1358..1802383 100644 --- a/client_test.go +++ b/client_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "crypto/tls" + "errors" "fmt" "io" "net" @@ -3392,3 +3393,94 @@ func Test_getRedirectURL(t *testing.T) { }) } } + +type clientDoTimeOuter interface { + DoTimeout(req *Request, resp *Response, timeout time.Duration) error +} + +func TestDialTimeout(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + client clientDoTimeOuter + requestTimeout time.Duration + shouldFailFast bool + }{ + { + name: "Client should fail after a millisecond due to request timeout", + client: &Client{ + // should be ignored due to DialTimeout + Dial: func(addr string) (net.Conn, error) { + time.Sleep(time.Second) + return nil, errors.New("timeout") + }, + // should be used + DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) { + time.Sleep(timeout) + return nil, errors.New("timeout") + }, + }, + requestTimeout: time.Millisecond, + shouldFailFast: true, + }, + { + name: "Client should fail after a second due to no DialTimeout set", + client: &Client{ + Dial: func(addr string) (net.Conn, error) { + time.Sleep(time.Second) + return nil, errors.New("timeout") + }, + }, + requestTimeout: time.Millisecond, + shouldFailFast: false, + }, + { + name: "HostClient should fail after a millisecond due to request timeout", + client: &HostClient{ + // should be ignored due to DialTimeout + Dial: func(addr string) (net.Conn, error) { + time.Sleep(time.Second) + return nil, errors.New("timeout") + }, + // should be used + DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) { + time.Sleep(timeout) + return nil, errors.New("timeout") + }, + }, + requestTimeout: time.Millisecond, + shouldFailFast: true, + }, + { + name: "HostClient should fail after a second due to no DialTimeout set", + client: &HostClient{ + Dial: func(addr string) (net.Conn, error) { + time.Sleep(time.Second) + return nil, errors.New("timeout") + }, + }, + requestTimeout: time.Millisecond, + shouldFailFast: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + start := time.Now() + err := tt.client.DoTimeout(&Request{}, &Response{}, tt.requestTimeout) + if err == nil { + t.Fatal("expected error (timeout)") + } + if tt.shouldFailFast { + if time.Since(start) > time.Second { + t.Fatal("expected timeout after a millisecond") + } + } else { + if time.Since(start) < time.Second { + t.Fatal("expected timeout after a second") + } + } + }) + } +} |