From bce576699a322ab33b618773a4456a25e602682d Mon Sep 17 00:00:00 2001 From: Erik Dubbelboer Date: Sun, 11 Feb 2024 15:08:56 +0800 Subject: Prevent request smuggling (#1719) * Prevent request smuggling Prevent request smuggling when fasthttp is behind a reverse proxy that might interprets headers differently by being stricter. Should also prevent request smuggling when fasthttp is used as the reverse proxy. * Make header value comparison case-insensitive --- header.go | 19 ++++++++++++++++++- header_test.go | 13 +++++++++---- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/header.go b/header.go index bdee768..c20af2c 100644 --- a/header.go +++ b/header.go @@ -3029,6 +3029,8 @@ func (h *ResponseHeader) parseHeaders(buf []byte) (int, error) { func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { h.contentLength = -2 + contentLengthSeen := false + var s headerScanner s.b = buf s.disableNormalizing = h.disableNormalizing @@ -3064,6 +3066,11 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { continue } if caseInsensitiveCompare(s.key, strContentLength) { + if contentLengthSeen { + return 0, fmt.Errorf("duplicate Content-Length header") + } + contentLengthSeen = true + if h.contentLength != -1 { var nerr error if h.contentLength, nerr = parseContentLength(s.value); nerr != nil { @@ -3088,7 +3095,17 @@ func (h *RequestHeader) parseHeaders(buf []byte) (int, error) { } case 't': if caseInsensitiveCompare(s.key, strTransferEncoding) { - if !bytes.Equal(s.value, strIdentity) { + isIdentity := caseInsensitiveCompare(s.value, strIdentity) + isChunked := caseInsensitiveCompare(s.value, strChunked) + + if !isIdentity && !isChunked { + if h.secureErrorLogMessage { + return 0, fmt.Errorf("unsupported Transfer-Encoding") + } + return 0, fmt.Errorf("unsupported Transfer-Encoding: %q", s.value) + } + + if isChunked { h.contentLength = -1 h.h = setArgBytes(h.h, strTransferEncoding, strChunked, argsHasValue) } diff --git a/header_test.go b/header_test.go index c0f98dc..d6da8e2 100644 --- a/header_test.go +++ b/header_test.go @@ -2618,10 +2618,6 @@ func TestRequestHeaderReadSuccess(t *testing.T) { testRequestHeaderReadSuccess(t, h, "POST /a HTTP/1.1\r\nHost: aa\r\nContent-Type: ab\r\nContent-Length: 123\r\nContent-Type: xx\r\n\r\n", 123, "/a", "aa", "", "xx", nil) - // post with duplicate content-length - testRequestHeaderReadSuccess(t, h, "POST /xx HTTP/1.1\r\nHost: aa\r\nContent-Type: s\r\nContent-Length: 13\r\nContent-Length: 1\r\n\r\n", - 1, "/xx", "aa", "", "s", nil) - // non-post with content-type testRequestHeaderReadSuccess(t, h, "GET /aaa HTTP/1.1\r\nHost: bbb.com\r\nContent-Type: aaab\r\n\r\n", -2, "/aaa", "bbb.com", "", "aaab", nil) @@ -2756,6 +2752,9 @@ func TestRequestHeaderReadError(t *testing.T) { // forbidden trailer testRequestHeaderReadError(t, h, "POST /a HTTP/1.1\r\nContent-Length: -1\r\nTrailer: Foo, Content-Length\r\n\r\n") + + // post with duplicate content-length + testRequestHeaderReadError(t, h, "POST /xx HTTP/1.1\r\nHost: aa\r\nContent-Type: s\r\nContent-Length: 13\r\nContent-Length: 1\r\n\r\n") } func TestRequestHeaderReadSecuredError(t *testing.T) { @@ -2805,6 +2804,8 @@ func testResponseHeaderReadSecuredError(t *testing.T, h *ResponseHeader, headers } func testRequestHeaderReadError(t *testing.T, h *RequestHeader, headers string) { + t.Helper() + r := bytes.NewBufferString(headers) br := bufio.NewReader(r) err := h.Read(br) @@ -2835,6 +2836,8 @@ func testRequestHeaderReadSecuredError(t *testing.T, h *RequestHeader, headers s func testResponseHeaderReadSuccess(t *testing.T, h *ResponseHeader, headers string, expectedStatusCode, expectedContentLength int, expectedContentType string, ) { + t.Helper() + r := bytes.NewBufferString(headers) br := bufio.NewReader(r) err := h.Read(br) @@ -2847,6 +2850,8 @@ func testResponseHeaderReadSuccess(t *testing.T, h *ResponseHeader, headers stri func testRequestHeaderReadSuccess(t *testing.T, h *RequestHeader, headers string, expectedContentLength int, expectedRequestURI, expectedHost, expectedReferer, expectedContentType string, expectedTrailer map[string]string, ) { + t.Helper() + r := bytes.NewBufferString(headers) br := bufio.NewReader(r) err := h.Read(br) -- cgit v1.2.3