diff options
Diffstat (limited to 'fasthttpadaptor/adaptor.go')
-rw-r--r-- | fasthttpadaptor/adaptor.go | 50 |
1 files changed, 48 insertions, 2 deletions
diff --git a/fasthttpadaptor/adaptor.go b/fasthttpadaptor/adaptor.go index dcd43e4..5e856fb 100644 --- a/fasthttpadaptor/adaptor.go +++ b/fasthttpadaptor/adaptor.go @@ -3,8 +3,11 @@ package fasthttpadaptor import ( + "bufio" "io" + "net" "net/http" + "sync" "github.com/valyala/fasthttp" ) @@ -53,8 +56,10 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler { ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError) return } - - w := netHTTPResponseWriter{w: ctx.Response.BodyWriter()} + w := netHTTPResponseWriter{ + w: ctx.Response.BodyWriter(), + ctx: ctx, + } h.ServeHTTP(&w, r.WithContext(ctx)) ctx.SetStatusCode(w.StatusCode()) @@ -86,6 +91,7 @@ type netHTTPResponseWriter struct { statusCode int h http.Header w io.Writer + ctx *fasthttp.RequestCtx } func (w *netHTTPResponseWriter) StatusCode() int { @@ -111,3 +117,43 @@ func (w *netHTTPResponseWriter) Write(p []byte) (int, error) { } func (w *netHTTPResponseWriter) Flush() {} + +type wrappedConn struct { + net.Conn + + wg sync.WaitGroup + once sync.Once +} + +func (c *wrappedConn) Close() (err error) { + c.once.Do(func() { + err = c.Conn.Close() + c.wg.Done() + }) + return +} + +func (w *netHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + // Hijack assumes control of the connection, so we need to prevent fasthttp from closing it or + // doing anything else with it. + w.ctx.HijackSetNoResponse(true) + + conn := &wrappedConn{Conn: w.ctx.Conn()} + conn.wg.Add(1) + w.ctx.Hijack(func(net.Conn) { + conn.wg.Wait() + }) + + bufW := bufio.NewWriter(conn) + + // Write any unflushed body to the hijacked connection buffer. + unflushedBody := w.ctx.Response.Body() + if len(unflushedBody) > 0 { + if _, err := bufW.Write(unflushedBody); err != nil { + conn.Close() + return nil, nil, err + } + } + + return conn, &bufio.ReadWriter{Reader: bufio.NewReader(conn), Writer: bufW}, nil +} |