aboutsummaryrefslogtreecommitdiff
path: root/client_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'client_test.go')
-rw-r--r--client_test.go92
1 files changed, 92 insertions, 0 deletions
diff --git a/client_test.go b/client_test.go
index e6a1358..1802383 100644
--- a/client_test.go
+++ b/client_test.go
@@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"crypto/tls"
+ "errors"
"fmt"
"io"
"net"
@@ -3392,3 +3393,94 @@ func Test_getRedirectURL(t *testing.T) {
})
}
}
+
+type clientDoTimeOuter interface {
+ DoTimeout(req *Request, resp *Response, timeout time.Duration) error
+}
+
+func TestDialTimeout(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ client clientDoTimeOuter
+ requestTimeout time.Duration
+ shouldFailFast bool
+ }{
+ {
+ name: "Client should fail after a millisecond due to request timeout",
+ client: &Client{
+ // should be ignored due to DialTimeout
+ Dial: func(addr string) (net.Conn, error) {
+ time.Sleep(time.Second)
+ return nil, errors.New("timeout")
+ },
+ // should be used
+ DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) {
+ time.Sleep(timeout)
+ return nil, errors.New("timeout")
+ },
+ },
+ requestTimeout: time.Millisecond,
+ shouldFailFast: true,
+ },
+ {
+ name: "Client should fail after a second due to no DialTimeout set",
+ client: &Client{
+ Dial: func(addr string) (net.Conn, error) {
+ time.Sleep(time.Second)
+ return nil, errors.New("timeout")
+ },
+ },
+ requestTimeout: time.Millisecond,
+ shouldFailFast: false,
+ },
+ {
+ name: "HostClient should fail after a millisecond due to request timeout",
+ client: &HostClient{
+ // should be ignored due to DialTimeout
+ Dial: func(addr string) (net.Conn, error) {
+ time.Sleep(time.Second)
+ return nil, errors.New("timeout")
+ },
+ // should be used
+ DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) {
+ time.Sleep(timeout)
+ return nil, errors.New("timeout")
+ },
+ },
+ requestTimeout: time.Millisecond,
+ shouldFailFast: true,
+ },
+ {
+ name: "HostClient should fail after a second due to no DialTimeout set",
+ client: &HostClient{
+ Dial: func(addr string) (net.Conn, error) {
+ time.Sleep(time.Second)
+ return nil, errors.New("timeout")
+ },
+ },
+ requestTimeout: time.Millisecond,
+ shouldFailFast: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ start := time.Now()
+ err := tt.client.DoTimeout(&Request{}, &Response{}, tt.requestTimeout)
+ if err == nil {
+ t.Fatal("expected error (timeout)")
+ }
+ if tt.shouldFailFast {
+ if time.Since(start) > time.Second {
+ t.Fatal("expected timeout after a millisecond")
+ }
+ } else {
+ if time.Since(start) < time.Second {
+ t.Fatal("expected timeout after a second")
+ }
+ }
+ })
+ }
+}