aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar gilwo <gilwo@users.noreply.github.com> 2024-02-17 08:51:38 +0200
committerGravatar GitHub <noreply@github.com> 2024-02-17 14:51:38 +0800
commitaefd08067438735cb683bf773948e9f50ca1142e (patch)
tree65cfccf84e278837cb4c2998cee50566eaeda969
parentchore(deps): bump securego/gosec from 2.18.2 to 2.19.0 (#1720) (diff)
downloadfasthttp-aefd08067438735cb683bf773948e9f50ca1142e.tar.gz
fasthttp-aefd08067438735cb683bf773948e9f50ca1142e.tar.bz2
fasthttp-aefd08067438735cb683bf773948e9f50ca1142e.zip
adaptor ResponseWriter - adding Hijack method and pass proper fields (#1525)
* adding hijack method and pass proper fields * adding hijack method and pass proper fields - adding tests * improve hijack handling, use proper test for hijacking * extend hijackhandler propogation to NewFastHTTPHandlerFunc * align hijacking of fasthttp adaptor net request with fasthttp request, safe conn handling for proper release of resources and custom hijack handler for more controlled by hijacking implementation * Implement actual behaviour of net/http Hijacker --------- Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>
-rw-r--r--fasthttpadaptor/adaptor.go50
-rw-r--r--fasthttpadaptor/adaptor_test.go73
2 files changed, 121 insertions, 2 deletions
diff --git a/fasthttpadaptor/adaptor.go b/fasthttpadaptor/adaptor.go
index dcd43e4..5e856fb 100644
--- a/fasthttpadaptor/adaptor.go
+++ b/fasthttpadaptor/adaptor.go
@@ -3,8 +3,11 @@
package fasthttpadaptor
import (
+ "bufio"
"io"
+ "net"
"net/http"
+ "sync"
"github.com/valyala/fasthttp"
)
@@ -53,8 +56,10 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler {
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
return
}
-
- w := netHTTPResponseWriter{w: ctx.Response.BodyWriter()}
+ w := netHTTPResponseWriter{
+ w: ctx.Response.BodyWriter(),
+ ctx: ctx,
+ }
h.ServeHTTP(&w, r.WithContext(ctx))
ctx.SetStatusCode(w.StatusCode())
@@ -86,6 +91,7 @@ type netHTTPResponseWriter struct {
statusCode int
h http.Header
w io.Writer
+ ctx *fasthttp.RequestCtx
}
func (w *netHTTPResponseWriter) StatusCode() int {
@@ -111,3 +117,43 @@ func (w *netHTTPResponseWriter) Write(p []byte) (int, error) {
}
func (w *netHTTPResponseWriter) Flush() {}
+
+type wrappedConn struct {
+ net.Conn
+
+ wg sync.WaitGroup
+ once sync.Once
+}
+
+func (c *wrappedConn) Close() (err error) {
+ c.once.Do(func() {
+ err = c.Conn.Close()
+ c.wg.Done()
+ })
+ return
+}
+
+func (w *netHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ // Hijack assumes control of the connection, so we need to prevent fasthttp from closing it or
+ // doing anything else with it.
+ w.ctx.HijackSetNoResponse(true)
+
+ conn := &wrappedConn{Conn: w.ctx.Conn()}
+ conn.wg.Add(1)
+ w.ctx.Hijack(func(net.Conn) {
+ conn.wg.Wait()
+ })
+
+ bufW := bufio.NewWriter(conn)
+
+ // Write any unflushed body to the hijacked connection buffer.
+ unflushedBody := w.ctx.Response.Body()
+ if len(unflushedBody) > 0 {
+ if _, err := bufW.Write(unflushedBody); err != nil {
+ conn.Close()
+ return nil, nil, err
+ }
+ }
+
+ return conn, &bufio.ReadWriter{Reader: bufio.NewReader(conn), Writer: bufW}, nil
+}
diff --git a/fasthttpadaptor/adaptor_test.go b/fasthttpadaptor/adaptor_test.go
index 9f03858..b58950a 100644
--- a/fasthttpadaptor/adaptor_test.go
+++ b/fasthttpadaptor/adaptor_test.go
@@ -7,8 +7,10 @@ import (
"net/url"
"reflect"
"testing"
+ "time"
"github.com/valyala/fasthttp"
+ "github.com/valyala/fasthttp/fasthttputil"
)
func TestNewFastHTTPHandler(t *testing.T) {
@@ -143,3 +145,74 @@ func setContextValueMiddleware(next fasthttp.RequestHandler, key string, value a
next(ctx)
}
}
+
+func TestHijack(t *testing.T) {
+ t.Parallel()
+
+ nethttpH := func(w http.ResponseWriter, r *http.Request) {
+ if f, ok := w.(http.Hijacker); !ok {
+ t.Errorf("expected http.ResponseWriter to implement http.Hijacker")
+ } else {
+ if _, err := w.Write([]byte("foo")); err != nil {
+ t.Error(err)
+ }
+
+ if c, rw, err := f.Hijack(); err != nil {
+ t.Error(err)
+ } else {
+ if _, err := rw.Write([]byte("bar")); err != nil {
+ t.Error(err)
+ }
+
+ if err := rw.Flush(); err != nil {
+ t.Error(err)
+ }
+
+ if err := c.Close(); err != nil {
+ t.Error(err)
+ }
+ }
+ }
+ }
+
+ s := &fasthttp.Server{
+ Handler: NewFastHTTPHandler(http.HandlerFunc(nethttpH)),
+ }
+
+ ln := fasthttputil.NewInmemoryListener()
+
+ go func() {
+ if err := s.Serve(ln); err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ }()
+
+ clientCh := make(chan struct{})
+ go func() {
+ c, err := ln.Dial()
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+
+ if _, err = c.Write([]byte("GET / HTTP/1.1\r\nHost: aa\r\n\r\n")); err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+
+ buf, err := io.ReadAll(c)
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+
+ if string(buf) != "foobar" {
+ t.Errorf("unexpected response: %q. Expecting %q", buf, "foobar")
+ }
+
+ close(clientCh)
+ }()
+
+ select {
+ case <-clientCh:
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
+}