aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Erik Dubbelboer <erik@dubbelboer.com> 2021-02-16 21:53:53 +0100
committerGravatar GitHub <noreply@github.com> 2021-02-16 21:53:53 +0100
commit3cd0862fbb17abb02d43e538e41819e22ffd512a (patch)
treebf5be9dabdbc55a466228f676cdc8cb378e146a9
parentAdded Protocol() as a replacement of hardcoded strHTTP11 (#969) (diff)
downloadfasthttp-3cd0862fbb17abb02d43e538e41819e22ffd512a.tar.gz
fasthttp-3cd0862fbb17abb02d43e538e41819e22ffd512a.tar.bz2
fasthttp-3cd0862fbb17abb02d43e538e41819e22ffd512a.zip
Streaming fixes (#970)v1.21.0
- Allow DisablePreParseMultipartForm in combination with StreamRequestBody. - Support streaming into MultipartForm instead of reading the whole body first. - Support calling ctx.PostBody() when streaming is enabled.
-rw-r--r--http.go140
-rw-r--r--server_test.go183
-rw-r--r--streaming.go11
-rw-r--r--streaming_test.go90
4 files changed, 264 insertions, 160 deletions
diff --git a/http.go b/http.go
index 89bc2b5..5ea7847 100644
--- a/http.go
+++ b/http.go
@@ -3,6 +3,7 @@ package fasthttp
import (
"bufio"
"bytes"
+ "compress/gzip"
"encoding/base64"
"errors"
"fmt"
@@ -345,6 +346,15 @@ func (req *Request) bodyBytes() []byte {
if req.bodyRaw != nil {
return req.bodyRaw
}
+ if req.bodyStream != nil {
+ bodyBuf := req.bodyBuffer()
+ bodyBuf.Reset()
+ _, err := copyZeroAlloc(bodyBuf, req.bodyStream)
+ req.closeBodyStream() //nolint:errcheck
+ if err != nil {
+ bodyBuf.SetString(err.Error())
+ }
+ }
if req.body == nil {
return nil
}
@@ -630,14 +640,6 @@ func (req *Request) SwapBody(body []byte) []byte {
func (req *Request) Body() []byte {
if req.bodyRaw != nil {
return req.bodyRaw
- } else if req.bodyStream != nil {
- bodyBuf := req.bodyBuffer()
- bodyBuf.Reset()
- _, err := copyZeroAlloc(bodyBuf, req.bodyStream)
- req.closeBodyStream() //nolint:errcheck
- if err != nil {
- bodyBuf.SetString(err.Error())
- }
} else if req.onlyMultipartForm() {
body, err := marshalMultipartForm(req.multipartForm, req.multipartFormBoundary)
if err != nil {
@@ -814,24 +816,43 @@ func (req *Request) MultipartForm() (*multipart.Form, error) {
return nil, ErrNoMultipartForm
}
+ var err error
ce := req.Header.peek(strContentEncoding)
- body := req.bodyBytes()
- if bytes.Equal(ce, strGzip) {
- // Do not care about memory usage here.
- var err error
- if body, err = AppendGunzipBytes(nil, body); err != nil {
- return nil, fmt.Errorf("cannot gunzip request body: %s", err)
+
+ if req.bodyStream != nil {
+ bodyStream := req.bodyStream
+ if bytes.Equal(ce, strGzip) {
+ // Do not care about memory usage here.
+ if bodyStream, err = gzip.NewReader(bodyStream); err != nil {
+ return nil, fmt.Errorf("cannot gunzip request body: %s", err)
+ }
+ } else if len(ce) > 0 {
+ return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce)
}
- } else if len(ce) > 0 {
- return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce)
- }
- f, err := readMultipartForm(bytes.NewReader(body), req.multipartFormBoundary, len(body), len(body))
- if err != nil {
- return nil, err
+ mr := multipart.NewReader(bodyStream, req.multipartFormBoundary)
+ req.multipartForm, err = mr.ReadForm(8 * 1024)
+ if err != nil {
+ return nil, fmt.Errorf("cannot read multipart/form-data body: %s", err)
+ }
+ } else {
+ body := req.bodyBytes()
+ if bytes.Equal(ce, strGzip) {
+ // Do not care about memory usage here.
+ if body, err = AppendGunzipBytes(nil, body); err != nil {
+ return nil, fmt.Errorf("cannot gunzip request body: %s", err)
+ }
+ } else if len(ce) > 0 {
+ return nil, fmt.Errorf("unsupported Content-Encoding: %q", ce)
+ }
+
+ req.multipartForm, err = readMultipartForm(bytes.NewReader(body), req.multipartFormBoundary, len(body), len(body))
+ if err != nil {
+ return nil, err
+ }
}
- req.multipartForm = f
- return f, nil
+
+ return req.multipartForm, nil
}
func marshalMultipartForm(f *multipart.Form, boundary string) ([]byte, error) {
@@ -1022,6 +1043,9 @@ func (req *Request) readLimitBody(r *bufio.Reader, maxBodySize int, getOnly bool
}
func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly bool, preParseMultipartForm bool) error {
+ // Do not reset the request here - the caller must reset it before
+ // calling this method.
+
if getOnly && !req.Header.IsGet() {
return ErrGetOnly
}
@@ -1033,39 +1057,7 @@ func (req *Request) readBodyStream(r *bufio.Reader, maxBodySize int, getOnly boo
return nil
}
- var err error
- contentLength := req.Header.realContentLength()
- if contentLength > 0 {
- if preParseMultipartForm {
- // Pre-read multipart form data of known length.
- // This way we limit memory usage for large file uploads, since their contents
- // is streamed into temporary files if file size exceeds defaultMaxInMemoryFileSize.
- req.multipartFormBoundary = b2s(req.Header.MultipartFormBoundary())
- if len(req.multipartFormBoundary) > 0 && len(req.Header.peek(strContentEncoding)) == 0 {
- req.multipartForm, err = readMultipartForm(r, req.multipartFormBoundary, contentLength, defaultMaxInMemoryFileSize)
- if err != nil {
- req.Reset()
- }
- return err
- }
- }
- }
-
- if contentLength == -2 {
- // identity body has no sense for http requests, since
- // the end of body is determined by connection close.
- // So just ignore request body for requests without
- // 'Content-Length' and 'Transfer-Encoding' headers.
- req.Header.SetContentLength(0)
- return nil
- }
-
- bodyBuf := req.bodyBuffer()
- bodyBuf.Reset()
-
- req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)
-
- return nil
+ return req.ContinueReadBodyStream(r, maxBodySize, preParseMultipartForm)
}
// MayContinue returns true if the request contains
@@ -1170,21 +1162,15 @@ func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, pre
bodyBuf := req.bodyBuffer()
bodyBuf.Reset()
bodyBuf.B, err = readBodyWithStreaming(r, contentLength, maxBodySize, bodyBuf.B)
- bodyBufLen := maxBodySize
- if contentLength < maxBodySize {
- bodyBufLen = cap(bodyBuf.B)
- }
if err != nil {
if err == ErrBodyTooLarge {
req.Header.SetContentLength(contentLength)
req.body = bodyBuf
- req.bodyRaw = bodyBuf.B[:bodyBufLen]
req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)
return nil
}
if err == errChunkedStream {
req.body = bodyBuf
- req.bodyRaw = bodyBuf.B[:bodyBufLen]
req.bodyStream = acquireRequestStream(bodyBuf, r, -1)
return nil
}
@@ -1193,7 +1179,6 @@ func (req *Request) ContinueReadBodyStream(r *bufio.Reader, maxBodySize int, pre
}
req.body = bodyBuf
- req.bodyRaw = bodyBuf.B[:bodyBufLen]
req.bodyStream = acquireRequestStream(bodyBuf, r, contentLength)
req.Header.SetContentLength(len(bodyBuf.B))
return nil
@@ -1936,24 +1921,27 @@ func readBody(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) (
var errChunkedStream = errors.New("chunked stream")
func readBodyWithStreaming(r *bufio.Reader, contentLength int, maxBodySize int, dst []byte) (b []byte, err error) {
+ if contentLength == -1 {
+ // handled in requestStream.Read()
+ return b, errChunkedStream
+ }
+
dst = dst[:0]
- switch {
- case contentLength >= 0 && maxBodySize >= contentLength:
- readN := maxBodySize
- if contentLength > 8*1024 {
- readN = 8 * 1024
- }
+
+ readN := maxBodySize
+ if readN > contentLength {
+ readN = contentLength
+ }
+ if readN > 8*1024 {
+ readN = 8 * 1024
+ }
+
+ if contentLength >= 0 && maxBodySize >= contentLength {
b, err = appendBodyFixedSize(r, dst, readN)
- case contentLength == -1:
- // handled in requestStream.Read()
- err = errChunkedStream
- default:
- readN := maxBodySize
- if contentLength > 8*1024 {
- readN = 8 * 1024
- }
+ } else {
b, err = readBodyIdentity(r, readN, dst)
}
+
if err != nil {
return b, err
}
diff --git a/server_test.go b/server_test.go
index c7baa3e..a0583f6 100644
--- a/server_test.go
+++ b/server_test.go
@@ -1073,7 +1073,16 @@ func TestServerServeTLSEmbed(t *testing.T) {
func TestServerMultipartFormDataRequest(t *testing.T) {
t.Parallel()
- reqS := `POST /upload HTTP/1.1
+ for _, test := range []struct {
+ StreamRequestBody bool
+ DisablePreParseMultipartForm bool
+ }{
+ {false, false},
+ {false, true},
+ {true, false},
+ {true, true},
+ } {
+ reqS := `POST /upload HTTP/1.1
Host: qwerty.com
Content-Length: 521
Content-Type: multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg
@@ -1100,91 +1109,94 @@ Connection: close
`
- ln := fasthttputil.NewInmemoryListener()
-
- s := &Server{
- Handler: func(ctx *RequestCtx) {
- switch string(ctx.Path()) {
- case "/upload":
- f, err := ctx.MultipartForm()
- if err != nil {
- t.Errorf("unexpected error: %s", err)
- }
- if len(f.Value) != 1 {
- t.Errorf("unexpected values %d. Expecting %d", len(f.Value), 1)
- }
- if len(f.File) != 1 {
- t.Errorf("unexpected file values %d. Expecting %d", len(f.File), 1)
- }
- fv := ctx.FormValue("f1")
- if string(fv) != "value1" {
- t.Errorf("unexpected form value: %q. Expecting %q", fv, "value1")
+ ln := fasthttputil.NewInmemoryListener()
+
+ s := &Server{
+ StreamRequestBody: test.StreamRequestBody,
+ DisablePreParseMultipartForm: test.DisablePreParseMultipartForm,
+ Handler: func(ctx *RequestCtx) {
+ switch string(ctx.Path()) {
+ case "/upload":
+ f, err := ctx.MultipartForm()
+ if err != nil {
+ t.Errorf("unexpected error: %s", err)
+ }
+ if len(f.Value) != 1 {
+ t.Errorf("unexpected values %d. Expecting %d", len(f.Value), 1)
+ }
+ if len(f.File) != 1 {
+ t.Errorf("unexpected file values %d. Expecting %d", len(f.File), 1)
+ }
+ fv := ctx.FormValue("f1")
+ if string(fv) != "value1" {
+ t.Errorf("unexpected form value: %q. Expecting %q", fv, "value1")
+ }
+ ctx.Redirect("/", StatusSeeOther)
+ default:
+ ctx.WriteString("non-upload") //nolint:errcheck
}
- ctx.Redirect("/", StatusSeeOther)
- default:
- ctx.WriteString("non-upload") //nolint:errcheck
- }
- },
- }
-
- ch := make(chan struct{})
- go func() {
- if err := s.Serve(ln); err != nil {
- t.Errorf("unexpected error: %s", err)
+ },
}
- close(ch)
- }()
- conn, err := ln.Dial()
- if err != nil {
- t.Fatalf("unexpected error: %s", err)
- }
- if _, err = conn.Write([]byte(reqS)); err != nil {
- t.Fatalf("unexpected error: %s", err)
- }
+ ch := make(chan struct{})
+ go func() {
+ if err := s.Serve(ln); err != nil {
+ t.Errorf("unexpected error: %s", err)
+ }
+ close(ch)
+ }()
- var resp Response
- br := bufio.NewReader(conn)
- respCh := make(chan struct{})
- go func() {
- if err := resp.Read(br); err != nil {
- t.Errorf("error when reading response: %s", err)
- }
- if resp.StatusCode() != StatusSeeOther {
- t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusSeeOther)
+ conn, err := ln.Dial()
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err)
}
- loc := resp.Header.Peek(HeaderLocation)
- if string(loc) != "http://qwerty.com/" {
- t.Errorf("unexpected location %q. Expecting %q", loc, "http://qwerty.com/")
+ if _, err = conn.Write([]byte(reqS)); err != nil {
+ t.Fatalf("unexpected error: %s", err)
}
- if err := resp.Read(br); err != nil {
- t.Errorf("error when reading the second response: %s", err)
- }
- if resp.StatusCode() != StatusOK {
- t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
- }
- body := resp.Body()
- if string(body) != "non-upload" {
- t.Errorf("unexpected body %q. Expecting %q", body, "non-upload")
- }
- close(respCh)
- }()
+ var resp Response
+ br := bufio.NewReader(conn)
+ respCh := make(chan struct{})
+ go func() {
+ if err := resp.Read(br); err != nil {
+ t.Errorf("error when reading response: %s", err)
+ }
+ if resp.StatusCode() != StatusSeeOther {
+ t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusSeeOther)
+ }
+ loc := resp.Header.Peek(HeaderLocation)
+ if string(loc) != "http://qwerty.com/" {
+ t.Errorf("unexpected location %q. Expecting %q", loc, "http://qwerty.com/")
+ }
- select {
- case <-respCh:
- case <-time.After(time.Second):
- t.Fatal("timeout")
- }
+ if err := resp.Read(br); err != nil {
+ t.Errorf("error when reading the second response: %s", err)
+ }
+ if resp.StatusCode() != StatusOK {
+ t.Errorf("unexpected status code: %d. Expecting %d", resp.StatusCode(), StatusOK)
+ }
+ body := resp.Body()
+ if string(body) != "non-upload" {
+ t.Errorf("unexpected body %q. Expecting %q", body, "non-upload")
+ }
+ close(respCh)
+ }()
- if err := ln.Close(); err != nil {
- t.Fatalf("error when closing listener: %s", err)
- }
+ select {
+ case <-respCh:
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
- select {
- case <-ch:
- case <-time.After(time.Second):
- t.Fatal("timeout when waiting for the server to stop")
+ if err := ln.Close(); err != nil {
+ t.Fatalf("error when closing listener: %s", err)
+ }
+
+ select {
+ case <-ch:
+ case <-time.After(time.Second):
+ t.Fatal("timeout when waiting for the server to stop")
+ }
}
}
@@ -3413,8 +3425,8 @@ func TestMaxBodySizePerRequest(t *testing.T) {
func TestStreamRequestBody(t *testing.T) {
t.Parallel()
- part1 := strings.Repeat("1", 1<<10)
- part2 := strings.Repeat("2", 1<<20-1<<10)
+ part1 := strings.Repeat("1", 1<<15)
+ part2 := strings.Repeat("2", 1<<16)
contentLength := len(part1) + len(part2)
next := make(chan struct{})
@@ -3424,15 +3436,17 @@ func TestStreamRequestBody(t *testing.T) {
close(next)
checkReader(t, ctx.RequestBodyStream(), part2)
},
- DisableKeepalive: true,
StreamRequestBody: true,
}
pipe := fasthttputil.NewPipeConns()
cc, sc := pipe.Conn1(), pipe.Conn2()
//write headers and part1 body
- if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n%s", contentLength, part1))); err != nil {
- t.Error(err)
+ if _, err := cc.Write([]byte(fmt.Sprintf("POST /foo2 HTTP/1.1\r\nHost: aaa.com\r\nContent-Length: %d\r\nContent-Type: aa\r\n\r\n", contentLength))); err != nil {
+ t.Fatal(err)
+ }
+ if _, err := cc.Write([]byte(part1)); err != nil {
+ t.Fatal(err)
}
ch := make(chan error)
@@ -3447,12 +3461,15 @@ func TestStreamRequestBody(t *testing.T) {
}
if _, err := cc.Write([]byte(part2)); err != nil {
- t.Error(err)
+ t.Fatal(err)
+ }
+ if err := sc.Close(); err != nil {
+ t.Fatal(err)
}
select {
case err := <-ch:
- if err != nil {
+ if err == nil || err.Error() != "connection closed" { // fasthttputil.errConnectionClosed is private so do a string match.
t.Fatalf("Unexpected error from serveConn: %s", err)
}
case <-time.After(500 * time.Millisecond):
diff --git a/streaming.go b/streaming.go
index a6ad0a9..39000a2 100644
--- a/streaming.go
+++ b/streaming.go
@@ -45,7 +45,12 @@ func (rs *requestStream) Read(p []byte) (int, error) {
}
var n int
var err error
- if int(rs.prefetchedBytes.Size()) > rs.totalBytesRead {
+ prefetchedSize := int(rs.prefetchedBytes.Size())
+ if prefetchedSize > rs.totalBytesRead {
+ left := prefetchedSize - rs.totalBytesRead
+ if len(p) > left {
+ p = p[:left]
+ }
n, err := rs.prefetchedBytes.Read(p)
rs.totalBytesRead += n
if n == rs.contentLength {
@@ -53,6 +58,10 @@ func (rs *requestStream) Read(p []byte) (int, error) {
}
return n, err
} else {
+ left := rs.contentLength - rs.totalBytesRead
+ if len(p) > left {
+ p = p[:left]
+ }
n, err = rs.reader.Read(p)
rs.totalBytesRead += n
if err != nil {
diff --git a/streaming_test.go b/streaming_test.go
index e99033c..a943cb8 100644
--- a/streaming_test.go
+++ b/streaming_test.go
@@ -6,10 +6,100 @@ import (
"io/ioutil"
"sync"
"testing"
+ "time"
"github.com/valyala/fasthttp/fasthttputil"
)
+func TestStreamingPipeline(t *testing.T) {
+ t.Parallel()
+
+ reqS := `POST /one HTTP/1.1
+Host: example.com
+Content-Length: 10
+
+aaaaaaaaaa
+POST /two HTTP/1.1
+Host: example.com
+Content-Length: 10
+
+aaaaaaaaaa`
+
+ ln := fasthttputil.NewInmemoryListener()
+
+ s := &Server{
+ StreamRequestBody: true,
+ Handler: func(ctx *RequestCtx) {
+ body := ""
+ expected := "aaaaaaaaaa"
+ if string(ctx.Path()) == "/one" {
+ body = string(ctx.PostBody())
+ } else {
+ all, err := ioutil.ReadAll(ctx.RequestBodyStream())
+ if err != nil {
+ t.Error(err)
+ }
+ body = string(all)
+ }
+ if body != expected {
+ t.Errorf("expected %q got %q", expected, body)
+ }
+ },
+ }
+
+ ch := make(chan struct{})
+ go func() {
+ if err := s.Serve(ln); err != nil {
+ t.Errorf("unexpected error: %s", err)
+ }
+ close(ch)
+ }()
+
+ conn, err := ln.Dial()
+ if err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
+ if _, err = conn.Write([]byte(reqS)); err != nil {
+ t.Fatalf("unexpected error: %s", err)
+ }
+
+ var resp Response
+ br := bufio.NewReader(conn)
+ respCh := make(chan struct{})
+ go func() {
+ if err := resp.Read(br); err != nil {
+ t.Errorf("error when reading response: %s", err)
+ }
+ if resp.StatusCode() != StatusOK {
+ t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
+ }
+
+ if err := resp.Read(br); err != nil {
+ t.Errorf("error when reading response: %s", err)
+ }
+ if resp.StatusCode() != StatusOK {
+ t.Errorf("unexpected status code %d. Expecting %d", resp.StatusCode(), StatusOK)
+ }
+ close(respCh)
+ }()
+
+ select {
+ case <-respCh:
+ case <-time.After(time.Second):
+ t.Fatal("timeout")
+ }
+
+ if err := ln.Close(); err != nil {
+ t.Fatalf("error when closing listener: %s", err)
+ }
+
+ select {
+ case <-ch:
+ case <-time.After(time.Second):
+ t.Fatal("timeout when waiting for the server to stop")
+ }
+}
+
func TestRequestStream(t *testing.T) {
body := createFixedBody(3)
chunkedBody := createChunkedBody(body)