From 8ca7a9c89c43a97658352651f5f66e16278164f6 Mon Sep 17 00:00:00 2001 From: Aviv Carmi Date: Mon, 27 Nov 2023 14:46:43 +0200 Subject: add support for custom dial function with timeouts (#1669) * add support for custom dial function with timeouts * fix linting --------- Co-authored-by: Aviv Carmi --- client_test.go | 92 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) (limited to 'client_test.go') diff --git a/client_test.go b/client_test.go index e6a1358..1802383 100644 --- a/client_test.go +++ b/client_test.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "crypto/tls" + "errors" "fmt" "io" "net" @@ -3392,3 +3393,94 @@ func Test_getRedirectURL(t *testing.T) { }) } } + +type clientDoTimeOuter interface { + DoTimeout(req *Request, resp *Response, timeout time.Duration) error +} + +func TestDialTimeout(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + client clientDoTimeOuter + requestTimeout time.Duration + shouldFailFast bool + }{ + { + name: "Client should fail after a millisecond due to request timeout", + client: &Client{ + // should be ignored due to DialTimeout + Dial: func(addr string) (net.Conn, error) { + time.Sleep(time.Second) + return nil, errors.New("timeout") + }, + // should be used + DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) { + time.Sleep(timeout) + return nil, errors.New("timeout") + }, + }, + requestTimeout: time.Millisecond, + shouldFailFast: true, + }, + { + name: "Client should fail after a second due to no DialTimeout set", + client: &Client{ + Dial: func(addr string) (net.Conn, error) { + time.Sleep(time.Second) + return nil, errors.New("timeout") + }, + }, + requestTimeout: time.Millisecond, + shouldFailFast: false, + }, + { + name: "HostClient should fail after a millisecond due to request timeout", + client: &HostClient{ + // should be ignored due to DialTimeout + Dial: func(addr string) (net.Conn, error) { + time.Sleep(time.Second) + return nil, errors.New("timeout") + }, + // should be used + DialTimeout: func(addr string, timeout time.Duration) (net.Conn, error) { + time.Sleep(timeout) + return nil, errors.New("timeout") + }, + }, + requestTimeout: time.Millisecond, + shouldFailFast: true, + }, + { + name: "HostClient should fail after a second due to no DialTimeout set", + client: &HostClient{ + Dial: func(addr string) (net.Conn, error) { + time.Sleep(time.Second) + return nil, errors.New("timeout") + }, + }, + requestTimeout: time.Millisecond, + shouldFailFast: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + start := time.Now() + err := tt.client.DoTimeout(&Request{}, &Response{}, tt.requestTimeout) + if err == nil { + t.Fatal("expected error (timeout)") + } + if tt.shouldFailFast { + if time.Since(start) > time.Second { + t.Fatal("expected timeout after a millisecond") + } + } else { + if time.Since(start) < time.Second { + t.Fatal("expected timeout after a second") + } + } + }) + } +} -- cgit v1.2.3