diff options
author | Sergio VS <savsgio.engineer@gmail.com> | 2022-01-22 04:54:37 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-01-22 04:54:37 +0100 |
commit | 436977654aa1d51bf45507353e5ff34a4f54ca1a (patch) | |
tree | 135f7424dd59f90cd5847b95f9683d6c6ff6a07c /server_test.go | |
parent | fix(hijack): reset userValues after hijack handler execution (#1199) (diff) | |
download | fasthttp-436977654aa1d51bf45507353e5ff34a4f54ca1a.tar.gz fasthttp-436977654aa1d51bf45507353e5ff34a4f54ca1a.tar.bz2 fasthttp-436977654aa1d51bf45507353e5ff34a4f54ca1a.zip |
fix(hijack): reuse RequestCtx (#1201)
* fix(hijack): reuse RequestCtx
* fix(test/hijack): increase wait time
* fix(test/hijack): wait for all connections to finish to check responses
Diffstat (limited to 'server_test.go')
-rw-r--r-- | server_test.go | 220 |
1 files changed, 169 insertions, 51 deletions
diff --git a/server_test.go b/server_test.go index 47eb84d..cad3cf1 100644 --- a/server_test.go +++ b/server_test.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "crypto/tls" + "errors" "fmt" "io" "io/ioutil" @@ -2088,6 +2089,63 @@ func TestServeConnKeepRequestAndResponseUntilResetUserValues(t *testing.T) { } } +// TestServerErrorHandler tests unexpected cases the for loop will break +// before request/response reset call. in such cases, call it before +// release to fix #548. +func TestServerErrorHandler(t *testing.T) { + t.Parallel() + + var resultReqStr, resultRespStr string + + s := &Server{ + Handler: func(ctx *RequestCtx) {}, + ErrorHandler: func(ctx *RequestCtx, err error) { + resultReqStr = ctx.Request.String() + resultRespStr = ctx.Response.String() + }, + MaxRequestBodySize: 10, + } + + reqStrTpl := "POST %s HTTP/1.1\r\nHost: example.com\r\nContent-Type: application/octet-stream\r\nContent-Length: %d\r\nConnection: keep-alive\r\n\r\n" + respRegex := regexp.MustCompile("HTTP/1.1 200 OK\r\nDate: (.*)\r\nContent-Length: 0\r\n\r\n") + + rw := &readWriter{} + + for i := 0; i < 100; i++ { + body := strings.Repeat("@", s.MaxRequestBodySize+1) + path := fmt.Sprintf("/%d", i) + + reqStr := fmt.Sprintf(reqStrTpl, path, len(body)) + expectedReqStr := fmt.Sprintf(reqStrTpl, path, 0) + + rw.r.WriteString(reqStr) + rw.r.WriteString(body) + + ch := make(chan struct{}) + go func() { + err := s.ServeConn(rw) + if err != nil && !errors.Is(err, ErrBodyTooLarge) { + t.Errorf("unexpected error in ServeConn: %s", err) + } + close(ch) + }() + + select { + case <-ch: + case <-time.After(time.Second): + t.Fatal("timeout") + } + + if resultReqStr != expectedReqStr { + t.Errorf("[iter: %d] Request == %s, want %s", i, resultReqStr, reqStr) + } + + if !respRegex.MatchString(resultRespStr) { + t.Errorf("[iter: %d] Response == %s, want regex %s", i, resultRespStr, respRegex) + } + } +} + func TestServeConnHijackResetUserValues(t *testing.T) { t.Parallel() @@ -2369,71 +2427,131 @@ func TestRequestCtxSendFile(t *testing.T) { } } -func TestRequestCtxHijack(t *testing.T) { - t.Parallel() +func testRequestCtxHijack(t *testing.T, s *Server) { + t.Helper() - hijackStartCh := make(chan struct{}) - hijackStopCh := make(chan struct{}) - s := &Server{ - Handler: func(ctx *RequestCtx) { - if ctx.Hijacked() { - t.Error("connection mustn't be hijacked") - } - ctx.Hijack(func(c net.Conn) { - <-hijackStartCh + type hijackSignal struct { + id int + rw *readWriter + } - b := make([]byte, 1) - // ping-pong echo via hijacked conn - for { - n, err := c.Read(b) - if n != 1 { - if err == io.EOF { - close(hijackStopCh) - return - } - if err != nil { - t.Errorf("unexpected error: %s", err) - } - t.Errorf("unexpected number of bytes read: %d. Expecting 1", n) - } - if _, err = c.Write(b); err != nil { - t.Errorf("unexpected error when writing data: %s", err) + wg := sync.WaitGroup{} + totalConns := 100 + hijackStartCh := make(chan *hijackSignal, totalConns) + hijackStopCh := make(chan *hijackSignal, totalConns) + + s.Handler = func(ctx *RequestCtx) { + if ctx.Hijacked() { + t.Error("connection mustn't be hijacked") + } + + ctx.Hijack(func(c net.Conn) { + signal := <-hijackStartCh + defer func() { + hijackStopCh <- signal + wg.Done() + }() + + b := make([]byte, 1) + stop := false + + // ping-pong echo via hijacked conn + for !stop { + n, err := c.Read(b) + if err != nil { + if errors.Is(err, io.EOF) { + stop = true + + continue } + + t.Errorf("unexpected read error: %s", err) + } else if n != 1 { + t.Errorf("unexpected number of bytes read: %d. Expecting 1", n) + } + + if _, err = c.Write(b); err != nil { + t.Errorf("unexpected error when writing data: %s", err) } - }) - if !ctx.Hijacked() { - t.Error("connection must be hijacked") } - ctx.Success("foo/bar", []byte("hijack it!")) - }, + }) + + if !ctx.Hijacked() { + t.Error("connection must be hijacked") + } + + ctx.Success("foo/bar", []byte("hijack it!")) } hijackedString := "foobar baz hijacked!!!" - rw := &readWriter{} - rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") - rw.r.WriteString(hijackedString) - if err := s.ServeConn(rw); err != nil { - t.Fatalf("Unexpected error from serveConn: %s", err) - } + for i := 0; i < totalConns; i++ { + wg.Add(1) - br := bufio.NewReader(&rw.w) - verifyResponse(t, br, StatusOK, "foo/bar", "hijack it!") + go func(t *testing.T, id int) { + t.Helper() - close(hijackStartCh) - select { - case <-hijackStopCh: - case <-time.After(100 * time.Millisecond): - t.Fatal("timeout") - } + rw := new(readWriter) + rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n") + rw.r.WriteString(hijackedString) - data, err := ioutil.ReadAll(br) - if err != nil { - t.Fatalf("Unexpected error when reading remaining data: %s", err) + if err := s.ServeConn(rw); err != nil { + t.Errorf("[iter: %d] Unexpected error from serveConn: %s", id, err) + } + + hijackStartCh <- &hijackSignal{id, rw} + }(t, i) } - if string(data) != hijackedString { - t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, hijackedString) + + wg.Wait() + + count := 0 + for count != totalConns { + select { + case signal := <-hijackStopCh: + count++ + + id := signal.id + rw := signal.rw + + br := bufio.NewReader(&rw.w) + verifyResponse(t, br, StatusOK, "foo/bar", "hijack it!") + + data, err := ioutil.ReadAll(br) + if err != nil { + t.Errorf("[iter: %d] Unexpected error when reading remaining data: %s", id, err) + + return + } + if string(data) != hijackedString { + t.Errorf( + "[iter: %d] Unexpected response %s. Expecting %s", + id, data, hijackedString, + ) + + return + } + case <-time.After(200 * time.Millisecond): + t.Errorf("timeout") + } } + + close(hijackStartCh) + close(hijackStopCh) +} + +func TestRequestCtxHijack(t *testing.T) { + t.Parallel() + + testRequestCtxHijack(t, &Server{}) +} + +func TestRequestCtxHijackReduceMemoryUsage(t *testing.T) { + t.Parallel() + + testRequestCtxHijack(t, &Server{ + ReduceMemoryUsage: true, + }) } func TestRequestCtxHijackNoResponse(t *testing.T) { |