aboutsummaryrefslogtreecommitdiff
path: root/server_test.go
diff options
context:
space:
mode:
authorGravatar Sergio VS <savsgio.engineer@gmail.com> 2022-01-22 04:54:37 +0100
committerGravatar GitHub <noreply@github.com> 2022-01-22 04:54:37 +0100
commit436977654aa1d51bf45507353e5ff34a4f54ca1a (patch)
tree135f7424dd59f90cd5847b95f9683d6c6ff6a07c /server_test.go
parentfix(hijack): reset userValues after hijack handler execution (#1199) (diff)
downloadfasthttp-436977654aa1d51bf45507353e5ff34a4f54ca1a.tar.gz
fasthttp-436977654aa1d51bf45507353e5ff34a4f54ca1a.tar.bz2
fasthttp-436977654aa1d51bf45507353e5ff34a4f54ca1a.zip
fix(hijack): reuse RequestCtx (#1201)
* fix(hijack): reuse RequestCtx * fix(test/hijack): increase wait time * fix(test/hijack): wait for all connections to finish to check responses
Diffstat (limited to 'server_test.go')
-rw-r--r--server_test.go220
1 files changed, 169 insertions, 51 deletions
diff --git a/server_test.go b/server_test.go
index 47eb84d..cad3cf1 100644
--- a/server_test.go
+++ b/server_test.go
@@ -7,6 +7,7 @@ import (
"bytes"
"context"
"crypto/tls"
+ "errors"
"fmt"
"io"
"io/ioutil"
@@ -2088,6 +2089,63 @@ func TestServeConnKeepRequestAndResponseUntilResetUserValues(t *testing.T) {
}
}
+// TestServerErrorHandler tests unexpected cases the for loop will break
+// before request/response reset call. in such cases, call it before
+// release to fix #548.
+func TestServerErrorHandler(t *testing.T) {
+ t.Parallel()
+
+ var resultReqStr, resultRespStr string
+
+ s := &Server{
+ Handler: func(ctx *RequestCtx) {},
+ ErrorHandler: func(ctx *RequestCtx, err error) {
+ resultReqStr = ctx.Request.String()
+ resultRespStr = ctx.Response.String()
+ },
+ MaxRequestBodySize: 10,
+ }
+
+ reqStrTpl := "POST %s HTTP/1.1\r\nHost: example.com\r\nContent-Type: application/octet-stream\r\nContent-Length: %d\r\nConnection: keep-alive\r\n\r\n"
+ respRegex := regexp.MustCompile("HTTP/1.1 200 OK\r\nDate: (.*)\r\nContent-Length: 0\r\n\r\n")
+
+ rw := &readWriter{}
+
+ for i := 0; i < 100; i++ {
+ body := strings.Repeat("@", s.MaxRequestBodySize+1)
+ path := fmt.Sprintf("/%d", i)
+
+ reqStr := fmt.Sprintf(reqStrTpl, path, len(body))
+ expectedReqStr := fmt.Sprintf(reqStrTpl, path, 0)
+
+ rw.r.WriteString(reqStr)
+ rw.r.WriteString(body)
+
+ ch := make(chan struct{})
+ go func() {
+ err := s.ServeConn(rw)
+ if err != nil && !errors.Is(err, ErrBodyTooLarge) {
+ t.Errorf("unexpected error in ServeConn: %s", err)
+ }
+ close(ch)
+ }()
+
+ select {
+ case <-ch:
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
+
+ if resultReqStr != expectedReqStr {
+ t.Errorf("[iter: %d] Request == %s, want %s", i, resultReqStr, reqStr)
+ }
+
+ if !respRegex.MatchString(resultRespStr) {
+ t.Errorf("[iter: %d] Response == %s, want regex %s", i, resultRespStr, respRegex)
+ }
+ }
+}
+
func TestServeConnHijackResetUserValues(t *testing.T) {
t.Parallel()
@@ -2369,71 +2427,131 @@ func TestRequestCtxSendFile(t *testing.T) {
}
}
-func TestRequestCtxHijack(t *testing.T) {
- t.Parallel()
+func testRequestCtxHijack(t *testing.T, s *Server) {
+ t.Helper()
- hijackStartCh := make(chan struct{})
- hijackStopCh := make(chan struct{})
- s := &Server{
- Handler: func(ctx *RequestCtx) {
- if ctx.Hijacked() {
- t.Error("connection mustn't be hijacked")
- }
- ctx.Hijack(func(c net.Conn) {
- <-hijackStartCh
+ type hijackSignal struct {
+ id int
+ rw *readWriter
+ }
- b := make([]byte, 1)
- // ping-pong echo via hijacked conn
- for {
- n, err := c.Read(b)
- if n != 1 {
- if err == io.EOF {
- close(hijackStopCh)
- return
- }
- if err != nil {
- t.Errorf("unexpected error: %s", err)
- }
- t.Errorf("unexpected number of bytes read: %d. Expecting 1", n)
- }
- if _, err = c.Write(b); err != nil {
- t.Errorf("unexpected error when writing data: %s", err)
+ wg := sync.WaitGroup{}
+ totalConns := 100
+ hijackStartCh := make(chan *hijackSignal, totalConns)
+ hijackStopCh := make(chan *hijackSignal, totalConns)
+
+ s.Handler = func(ctx *RequestCtx) {
+ if ctx.Hijacked() {
+ t.Error("connection mustn't be hijacked")
+ }
+
+ ctx.Hijack(func(c net.Conn) {
+ signal := <-hijackStartCh
+ defer func() {
+ hijackStopCh <- signal
+ wg.Done()
+ }()
+
+ b := make([]byte, 1)
+ stop := false
+
+ // ping-pong echo via hijacked conn
+ for !stop {
+ n, err := c.Read(b)
+ if err != nil {
+ if errors.Is(err, io.EOF) {
+ stop = true
+
+ continue
}
+
+ t.Errorf("unexpected read error: %s", err)
+ } else if n != 1 {
+ t.Errorf("unexpected number of bytes read: %d. Expecting 1", n)
+ }
+
+ if _, err = c.Write(b); err != nil {
+ t.Errorf("unexpected error when writing data: %s", err)
}
- })
- if !ctx.Hijacked() {
- t.Error("connection must be hijacked")
}
- ctx.Success("foo/bar", []byte("hijack it!"))
- },
+ })
+
+ if !ctx.Hijacked() {
+ t.Error("connection must be hijacked")
+ }
+
+ ctx.Success("foo/bar", []byte("hijack it!"))
}
hijackedString := "foobar baz hijacked!!!"
- rw := &readWriter{}
- rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
- rw.r.WriteString(hijackedString)
- if err := s.ServeConn(rw); err != nil {
- t.Fatalf("Unexpected error from serveConn: %s", err)
- }
+ for i := 0; i < totalConns; i++ {
+ wg.Add(1)
- br := bufio.NewReader(&rw.w)
- verifyResponse(t, br, StatusOK, "foo/bar", "hijack it!")
+ go func(t *testing.T, id int) {
+ t.Helper()
- close(hijackStartCh)
- select {
- case <-hijackStopCh:
- case <-time.After(100 * time.Millisecond):
- t.Fatal("timeout")
- }
+ rw := new(readWriter)
+ rw.r.WriteString("GET /foo HTTP/1.1\r\nHost: google.com\r\n\r\n")
+ rw.r.WriteString(hijackedString)
- data, err := ioutil.ReadAll(br)
- if err != nil {
- t.Fatalf("Unexpected error when reading remaining data: %s", err)
+ if err := s.ServeConn(rw); err != nil {
+ t.Errorf("[iter: %d] Unexpected error from serveConn: %s", id, err)
+ }
+
+ hijackStartCh <- &hijackSignal{id, rw}
+ }(t, i)
}
- if string(data) != hijackedString {
- t.Fatalf("Unexpected data read after the first response %q. Expecting %q", data, hijackedString)
+
+ wg.Wait()
+
+ count := 0
+ for count != totalConns {
+ select {
+ case signal := <-hijackStopCh:
+ count++
+
+ id := signal.id
+ rw := signal.rw
+
+ br := bufio.NewReader(&rw.w)
+ verifyResponse(t, br, StatusOK, "foo/bar", "hijack it!")
+
+ data, err := ioutil.ReadAll(br)
+ if err != nil {
+ t.Errorf("[iter: %d] Unexpected error when reading remaining data: %s", id, err)
+
+ return
+ }
+ if string(data) != hijackedString {
+ t.Errorf(
+ "[iter: %d] Unexpected response %s. Expecting %s",
+ id, data, hijackedString,
+ )
+
+ return
+ }
+ case <-time.After(200 * time.Millisecond):
+ t.Errorf("timeout")
+ }
}
+
+ close(hijackStartCh)
+ close(hijackStopCh)
+}
+
+func TestRequestCtxHijack(t *testing.T) {
+ t.Parallel()
+
+ testRequestCtxHijack(t, &Server{})
+}
+
+func TestRequestCtxHijackReduceMemoryUsage(t *testing.T) {
+ t.Parallel()
+
+ testRequestCtxHijack(t, &Server{
+ ReduceMemoryUsage: true,
+ })
}
func TestRequestCtxHijackNoResponse(t *testing.T) {