aboutsummaryrefslogtreecommitdiff
path: root/fasthttpadaptor/adaptor.go
diff options
context:
space:
mode:
Diffstat (limited to 'fasthttpadaptor/adaptor.go')
-rw-r--r--fasthttpadaptor/adaptor.go50
1 files changed, 48 insertions, 2 deletions
diff --git a/fasthttpadaptor/adaptor.go b/fasthttpadaptor/adaptor.go
index dcd43e4..5e856fb 100644
--- a/fasthttpadaptor/adaptor.go
+++ b/fasthttpadaptor/adaptor.go
@@ -3,8 +3,11 @@
package fasthttpadaptor
import (
+ "bufio"
"io"
+ "net"
"net/http"
+ "sync"
"github.com/valyala/fasthttp"
)
@@ -53,8 +56,10 @@ func NewFastHTTPHandler(h http.Handler) fasthttp.RequestHandler {
ctx.Error("Internal Server Error", fasthttp.StatusInternalServerError)
return
}
-
- w := netHTTPResponseWriter{w: ctx.Response.BodyWriter()}
+ w := netHTTPResponseWriter{
+ w: ctx.Response.BodyWriter(),
+ ctx: ctx,
+ }
h.ServeHTTP(&w, r.WithContext(ctx))
ctx.SetStatusCode(w.StatusCode())
@@ -86,6 +91,7 @@ type netHTTPResponseWriter struct {
statusCode int
h http.Header
w io.Writer
+ ctx *fasthttp.RequestCtx
}
func (w *netHTTPResponseWriter) StatusCode() int {
@@ -111,3 +117,43 @@ func (w *netHTTPResponseWriter) Write(p []byte) (int, error) {
}
func (w *netHTTPResponseWriter) Flush() {}
+
+type wrappedConn struct {
+ net.Conn
+
+ wg sync.WaitGroup
+ once sync.Once
+}
+
+func (c *wrappedConn) Close() (err error) {
+ c.once.Do(func() {
+ err = c.Conn.Close()
+ c.wg.Done()
+ })
+ return
+}
+
+func (w *netHTTPResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+ // Hijack assumes control of the connection, so we need to prevent fasthttp from closing it or
+ // doing anything else with it.
+ w.ctx.HijackSetNoResponse(true)
+
+ conn := &wrappedConn{Conn: w.ctx.Conn()}
+ conn.wg.Add(1)
+ w.ctx.Hijack(func(net.Conn) {
+ conn.wg.Wait()
+ })
+
+ bufW := bufio.NewWriter(conn)
+
+ // Write any unflushed body to the hijacked connection buffer.
+ unflushedBody := w.ctx.Response.Body()
+ if len(unflushedBody) > 0 {
+ if _, err := bufW.Write(unflushedBody); err != nil {
+ conn.Close()
+ return nil, nil, err
+ }
+ }
+
+ return conn, &bufio.ReadWriter{Reader: bufio.NewReader(conn), Writer: bufW}, nil
+}