From 5f81476d7cf339624b5fec57a06ee96d3e27d9c2 Mon Sep 17 00:00:00 2001 From: Co1a Date: Wed, 21 Feb 2024 14:21:52 +0800 Subject: feat:support zstd compress and uncompressed (#1701) * feat:support zstd compress and uncompressed * fix:real & stackless write using different pool to avoid get stackless.writer * fix:zstd normalize compress level * Change empty string checks to be more idiomatic (#1684) * chore:lint fix and rebase with master * chore:remove 1.18 test & upgrade compress version * fix:error default compress level * Fix lint --------- Co-authored-by: Erik Dubbelboer --- .github/workflows/lint.yml | 9 +++ .github/workflows/test.yml | 2 +- fs.go | 74 ++++++++++++++---- http.go | 70 +++++++++++++++++ request_body.zst | Bin 0 -> 31 bytes server.go | 9 ++- strings.go | 1 + zstd.go | 186 +++++++++++++++++++++++++++++++++++++++++++++ zstd_test.go | 102 +++++++++++++++++++++++++ 9 files changed, 435 insertions(+), 18 deletions(-) create mode 100644 request_body.zst create mode 100644 zstd.go create mode 100644 zstd_test.go diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 48972fd..c250bf7 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -4,6 +4,15 @@ on: branches: - master pull_request: + +permissions: + # Required: allow read access to the content for analysis. + contents: read + # Optional: allow read access to pull request. Use with `only-new-issues` option. + pull-requests: read + # Optional: Allow write access to checks to allow the action to annotate code in the PR. + checks: write + jobs: lint: runs-on: ubuntu-latest diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5b0430d..5cfc4e5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -9,7 +9,7 @@ jobs: strategy: fail-fast: false matrix: - go-version: [1.18.x, 1.19.x, 1.20.x, 1.21.x, 1.22.x] + go-version: [1.19.x, 1.20.x, 1.21.x, 1.22.x] os: [ubuntu-latest, macos-latest, windows-latest] runs-on: ${{ matrix.os }} steps: diff --git a/fs.go b/fs.go index 57260d1..4f2bbbf 100644 --- a/fs.go +++ b/fs.go @@ -18,6 +18,7 @@ import ( "github.com/andybalholm/brotli" "github.com/klauspost/compress/gzip" + "github.com/klauspost/compress/zstd" "github.com/valyala/bytebufferpool" ) @@ -370,6 +371,7 @@ const FSCompressedFileSuffix = ".fasthttp.gz" var FSCompressedFileSuffixes = map[string]string{ "gzip": ".fasthttp.gz", "br": ".fasthttp.br", + "zstd": ".fasthttp.zst", } // FSHandlerCacheDuration is the default expiration duration for inactive @@ -460,7 +462,9 @@ func (fs *FS) initRequestHandler() { compressedFileSuffixes := fs.CompressedFileSuffixes if compressedFileSuffixes["br"] == "" || compressedFileSuffixes["gzip"] == "" || - compressedFileSuffixes["br"] == compressedFileSuffixes["gzip"] { + compressedFileSuffixes["zstd"] == "" || compressedFileSuffixes["br"] == compressedFileSuffixes["gzip"] || + compressedFileSuffixes["br"] == compressedFileSuffixes["zstd"] || + compressedFileSuffixes["gzip"] == compressedFileSuffixes["zstd"] { // Copy global map compressedFileSuffixes = make(map[string]string, len(FSCompressedFileSuffixes)) for k, v := range FSCompressedFileSuffixes { @@ -471,6 +475,7 @@ func (fs *FS) initRequestHandler() { if fs.CompressedFileSuffix != "" { compressedFileSuffixes["gzip"] = fs.CompressedFileSuffix compressedFileSuffixes["br"] = FSCompressedFileSuffixes["br"] + compressedFileSuffixes["zstd"] = FSCompressedFileSuffixes["zstd"] } h := &fsHandler{ @@ -794,6 +799,7 @@ const ( defaultCacheKind CacheKind = iota brotliCacheKind gzipCacheKind + zstdCacheKind ) func newCacheManager(fs *FS) cacheManager { @@ -1032,14 +1038,19 @@ func (h *fsHandler) handleRequest(ctx *RequestCtx) { fileEncoding := "" byteRange := ctx.Request.Header.peek(strRange) if len(byteRange) == 0 && h.compress { - if h.compressBrotli && ctx.Request.Header.HasAcceptEncodingBytes(strBr) { + switch { + case h.compressBrotli && ctx.Request.Header.HasAcceptEncodingBytes(strBr): mustCompress = true fileCacheKind = brotliCacheKind fileEncoding = "br" - } else if ctx.Request.Header.HasAcceptEncodingBytes(strGzip) { + case ctx.Request.Header.HasAcceptEncodingBytes(strGzip): mustCompress = true fileCacheKind = gzipCacheKind fileEncoding = "gzip" + case ctx.Request.Header.HasAcceptEncodingBytes(strZstd): + mustCompress = true + fileCacheKind = zstdCacheKind + fileEncoding = "zstd" } } @@ -1097,10 +1108,13 @@ func (h *fsHandler) handleRequest(ctx *RequestCtx) { hdr := &ctx.Response.Header if ff.compressed { - if fileEncoding == "br" { + switch fileEncoding { + case "br": hdr.SetContentEncodingBytes(strBr) - } else if fileEncoding == "gzip" { + case "gzip": hdr.SetContentEncodingBytes(strGzip) + case "zstd": + hdr.SetContentEncodingBytes(strZstd) } } @@ -1304,10 +1318,13 @@ nestedContinue: if mustCompress { var zbuf bytebufferpool.ByteBuffer - if fileEncoding == "br" { + switch fileEncoding { + case "br": zbuf.B = AppendBrotliBytesLevel(zbuf.B, w.B, CompressDefaultCompression) - } else if fileEncoding == "gzip" { + case "gzip": zbuf.B = AppendGzipBytesLevel(zbuf.B, w.B, CompressDefaultCompression) + case "zstd": + zbuf.B = AppendZstdBytesLevel(zbuf.B, w.B, CompressZstdDefault) } w = &zbuf } @@ -1406,20 +1423,28 @@ func (h *fsHandler) compressFileNolock( } return nil, errNoCreatePermission } - if fileEncoding == "br" { + switch fileEncoding { + case "br": zw := acquireStacklessBrotliWriter(zf, CompressDefaultCompression) _, err = copyZeroAlloc(zw, f) if err1 := zw.Flush(); err == nil { err = err1 } releaseStacklessBrotliWriter(zw, CompressDefaultCompression) - } else if fileEncoding == "gzip" { + case "gzip": zw := acquireStacklessGzipWriter(zf, CompressDefaultCompression) _, err = copyZeroAlloc(zw, f) if err1 := zw.Flush(); err == nil { err = err1 } releaseStacklessGzipWriter(zw, CompressDefaultCompression) + case "zstd": + zw := acquireStacklessZstdWriter(zf, CompressZstdDefault) + _, err = copyZeroAlloc(zw, f) + if err1 := zw.Flush(); err == nil { + err = err1 + } + releaseStacklessZstdWriter(zw, CompressZstdDefault) } _ = zf.Close() _ = f.Close() @@ -1443,20 +1468,28 @@ func (h *fsHandler) newCompressedFSFileCache(f fs.File, fileInfo fs.FileInfo, fi err error ) - if fileEncoding == "br" { + switch fileEncoding { + case "br": zw := acquireStacklessBrotliWriter(w, CompressDefaultCompression) _, err = copyZeroAlloc(zw, f) if err1 := zw.Flush(); err == nil { err = err1 } releaseStacklessBrotliWriter(zw, CompressDefaultCompression) - } else if fileEncoding == "gzip" { + case "gzip": zw := acquireStacklessGzipWriter(w, CompressDefaultCompression) _, err = copyZeroAlloc(zw, f) if err1 := zw.Flush(); err == nil { err = err1 } releaseStacklessGzipWriter(zw, CompressDefaultCompression) + case "zstd": + zw := acquireStacklessZstdWriter(w, CompressZstdDefault) + _, err = copyZeroAlloc(zw, f) + if err1 := zw.Flush(); err == nil { + err = err1 + } + releaseStacklessZstdWriter(zw, CompressZstdDefault) } defer func() { _ = f.Close() }() @@ -1600,21 +1633,28 @@ func (h *fsHandler) newFSFile(f fs.File, fileInfo fs.FileInfo, compressed bool, func readFileHeader(f io.Reader, compressed bool, fileEncoding string) ([]byte, error) { r := f var ( - br *brotli.Reader - zr *gzip.Reader + br *brotli.Reader + zr *gzip.Reader + zsr *zstd.Decoder ) if compressed { var err error - if fileEncoding == "br" { + switch fileEncoding { + case "br": if br, err = acquireBrotliReader(f); err != nil { return nil, err } r = br - } else if fileEncoding == "gzip" { + case "gzip": if zr, err = acquireGzipReader(f); err != nil { return nil, err } r = zr + case "zstd": + if zsr, err = acquireZstdReader(f); err != nil { + return nil, err + } + r = zsr } } @@ -1639,6 +1679,10 @@ func readFileHeader(f io.Reader, compressed bool, fileEncoding string) ([]byte, releaseGzipReader(zr) } + if zsr != nil { + releaseZstdReader(zsr) + } + return data, err } diff --git a/http.go b/http.go index 74d66cb..e078809 100644 --- a/http.go +++ b/http.go @@ -528,6 +528,23 @@ func (ctx *RequestCtx) RequestBodyStream() io.Reader { return ctx.Request.bodyStream } +func (req *Request) BodyUnzstd() ([]byte, error) { + return unzstdData(req.Body()) +} + +func (resp *Response) BodyUnzstd() ([]byte, error) { + return unzstdData(resp.Body()) +} + +func unzstdData(p []byte) ([]byte, error) { + var bb bytebufferpool.ByteBuffer + _, err := WriteUnzstd(&bb, p) + if err != nil { + return nil, err + } + return bb.B, nil +} + func inflateData(p []byte) ([]byte, error) { var bb bytebufferpool.ByteBuffer _, err := WriteInflate(&bb, p) @@ -554,6 +571,8 @@ func (req *Request) BodyUncompressed() ([]byte, error) { return req.BodyGunzip() case "br": return req.BodyUnbrotli() + case "zstd": + return req.BodyUnzstd() default: return nil, ErrContentEncodingUnsupported } @@ -574,6 +593,8 @@ func (resp *Response) BodyUncompressed() ([]byte, error) { return resp.BodyGunzip() case "br": return resp.BodyUnbrotli() + case "zstd": + return resp.BodyUnzstd() default: return nil, ErrContentEncodingUnsupported } @@ -1849,6 +1870,55 @@ func (resp *Response) deflateBody(level int) error { return nil } +func (resp *Response) zstdBody(level int) error { + if len(resp.Header.ContentEncoding()) > 0 { + return nil + } + + if !resp.Header.isCompressibleContentType() { + return nil + } + + if resp.bodyStream != nil { + // Reset Content-Length to -1, since it is impossible + // to determine body size beforehand of streamed compression. + // For + resp.Header.SetContentLength(-1) + + // Do not care about memory allocations here, since flate is slow + // and allocates a lot of memory by itself. + bs := resp.bodyStream + resp.bodyStream = NewStreamReader(func(sw *bufio.Writer) { + zw := acquireStacklessZstdWriter(sw, level) + fw := &flushWriter{ + wf: zw, + bw: sw, + } + copyZeroAlloc(fw, bs) //nolint:errcheck + releaseStacklessZstdWriter(zw, level) + if bsc, ok := bs.(io.Closer); ok { + bsc.Close() + } + }) + } else { + bodyBytes := resp.bodyBytes() + if len(bodyBytes) < minCompressLen { + return nil + } + w := responseBodyPool.Get() + w.B = AppendZstdBytesLevel(w.B, bodyBytes, level) + + if resp.body != nil { + responseBodyPool.Put(resp.body) + } + resp.body = w + resp.bodyRaw = nil + } + resp.Header.SetContentEncodingBytes(strZstd) + resp.Header.addVaryBytes(strAcceptEncoding) + return nil +} + // Bodies with sizes smaller than minCompressLen aren't compressed at all. const minCompressLen = 200 diff --git a/request_body.zst b/request_body.zst new file mode 100644 index 0000000..ea95e73 Binary files /dev/null and b/request_body.zst differ diff --git a/server.go b/server.go index e3593cd..426351b 100644 --- a/server.go +++ b/server.go @@ -523,10 +523,13 @@ func CompressHandler(h RequestHandler) RequestHandler { func CompressHandlerLevel(h RequestHandler, level int) RequestHandler { return func(ctx *RequestCtx) { h(ctx) - if ctx.Request.Header.HasAcceptEncodingBytes(strGzip) { + switch { + case ctx.Request.Header.HasAcceptEncodingBytes(strGzip): ctx.Response.gzipBody(level) //nolint:errcheck - } else if ctx.Request.Header.HasAcceptEncodingBytes(strDeflate) { + case ctx.Request.Header.HasAcceptEncodingBytes(strDeflate): ctx.Response.deflateBody(level) //nolint:errcheck + case ctx.Request.Header.HasAcceptEncodingBytes(strZstd): + ctx.Response.zstdBody(level) //nolint:errcheck } } } @@ -559,6 +562,8 @@ func CompressHandlerBrotliLevel(h RequestHandler, brotliLevel, otherLevel int) R ctx.Response.gzipBody(otherLevel) //nolint:errcheck case ctx.Request.Header.HasAcceptEncodingBytes(strDeflate): ctx.Response.deflateBody(otherLevel) //nolint:errcheck + case ctx.Request.Header.HasAcceptEncodingBytes(strZstd): + ctx.Response.zstdBody(otherLevel) //nolint:errcheck } } } diff --git a/strings.go b/strings.go index 3374678..a9e4072 100644 --- a/strings.go +++ b/strings.go @@ -72,6 +72,7 @@ var ( strClose = []byte("close") strGzip = []byte("gzip") strBr = []byte("br") + strZstd = []byte("zstd") strDeflate = []byte("deflate") strKeepAlive = []byte("keep-alive") strUpgrade = []byte("Upgrade") diff --git a/zstd.go b/zstd.go new file mode 100644 index 0000000..226a126 --- /dev/null +++ b/zstd.go @@ -0,0 +1,186 @@ +package fasthttp + +import ( + "bytes" + "fmt" + "io" + "sync" + + "github.com/klauspost/compress/zstd" + "github.com/valyala/bytebufferpool" + "github.com/valyala/fasthttp/stackless" +) + +const ( + CompressZstdSpeedNotSet = iota + CompressZstdBestSpeed + CompressZstdDefault + CompressZstdSpeedBetter + CompressZstdBestCompression +) + +var ( + zstdDecoderPool sync.Pool + zstdEncoderPool sync.Pool + realZstdWriterPoolMap = newCompressWriterPoolMap() + stacklessZstdWriterPoolMap = newCompressWriterPoolMap() +) + +func acquireZstdReader(r io.Reader) (*zstd.Decoder, error) { + v := zstdDecoderPool.Get() + if v == nil { + return zstd.NewReader(r) + } + zr := v.(*zstd.Decoder) + if err := zr.Reset(r); err != nil { + return nil, err + } + return zr, nil +} + +func releaseZstdReader(zr *zstd.Decoder) { + zstdDecoderPool.Put(zr) +} + +func acquireZstdWriter(w io.Writer, level int) (*zstd.Encoder, error) { + v := zstdEncoderPool.Get() + if v == nil { + return zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.EncoderLevel(level))) + } + zw := v.(*zstd.Encoder) + zw.Reset(w) + return zw, nil +} + +func releaseZstdWriter(zw *zstd.Encoder) { //nolint:unused + zw.Close() + zstdEncoderPool.Put(zw) +} + +func acquireStacklessZstdWriter(w io.Writer, compressLevel int) stackless.Writer { + nLevel := normalizeZstdCompressLevel(compressLevel) + p := stacklessZstdWriterPoolMap[nLevel] + v := p.Get() + if v == nil { + return stackless.NewWriter(w, func(w io.Writer) stackless.Writer { + return acquireRealZstdWriter(w, compressLevel) + }) + } + sw := v.(stackless.Writer) + sw.Reset(w) + return sw +} + +func releaseStacklessZstdWriter(zf stackless.Writer, zstdDefault int) { + zf.Close() + nLevel := normalizeZstdCompressLevel(zstdDefault) + p := stacklessZstdWriterPoolMap[nLevel] + p.Put(zf) +} + +func acquireRealZstdWriter(w io.Writer, level int) *zstd.Encoder { + nLevel := normalizeZstdCompressLevel(level) + p := realZstdWriterPoolMap[nLevel] + v := p.Get() + if v == nil { + zw, err := acquireZstdWriter(w, level) + if err != nil { + panic(err) + } + return zw + } + zw := v.(*zstd.Encoder) + zw.Reset(w) + return zw +} + +func releaseRealZstdWrter(zw *zstd.Encoder, level int) { + zw.Close() + nLevel := normalizeZstdCompressLevel(level) + p := realZstdWriterPoolMap[nLevel] + p.Put(zw) +} + +func AppendZstdBytesLevel(dst, src []byte, level int) []byte { + w := &byteSliceWriter{dst} + WriteZstdLevel(w, src, level) //nolint:errcheck + return w.b +} + +func WriteZstdLevel(w io.Writer, p []byte, level int) (int, error) { + level = normalizeZstdCompressLevel(level) + switch w.(type) { + case *byteSliceWriter, + *bytes.Buffer, + *bytebufferpool.ByteBuffer: + ctx := &compressCtx{ + w: w, + p: p, + level: level, + } + stacklessWriteZstd(ctx) + return len(p), nil + default: + zw := acquireStacklessZstdWriter(w, level) + n, err := zw.Write(p) + releaseStacklessZstdWriter(zw, level) + return n, err + } +} + +var ( + stacklessWriteZstdOnce sync.Once + stacklessWriteZstdFunc func(ctx any) bool +) + +func stacklessWriteZstd(ctx any) { + stacklessWriteZstdOnce.Do(func() { + stacklessWriteZstdFunc = stackless.NewFunc(nonblockingWriteZstd) + }) + stacklessWriteZstdFunc(ctx) +} + +func nonblockingWriteZstd(ctxv any) { + ctx := ctxv.(*compressCtx) + zw := acquireRealZstdWriter(ctx.w, ctx.level) + zw.Write(ctx.p) //nolint:errcheck + releaseRealZstdWrter(zw, ctx.level) +} + +// AppendZstdBytes appends zstd src to dst and returns the resulting dst. +func AppendZstdBytes(dst, src []byte) []byte { + return AppendZstdBytesLevel(dst, src, CompressZstdDefault) +} + +// WriteUnzstd writes unzstd p to w and returns the number of uncompressed +// bytes written to w. +func WriteUnzstd(w io.Writer, p []byte) (int, error) { + r := &byteSliceReader{p} + zr, err := acquireZstdReader(r) + if err != nil { + return 0, err + } + n, err := copyZeroAlloc(w, zr) + releaseZstdReader(zr) + nn := int(n) + if int64(nn) != n { + return 0, fmt.Errorf("too much data unzstd: %d", n) + } + return nn, err +} + +// AppendUnzstdBytes appends unzstd src to dst and returns the resulting dst. +func AppendUnzstdBytes(dst, src []byte) ([]byte, error) { + w := &byteSliceWriter{dst} + _, err := WriteUnzstd(w, src) + return w.b, err +} + +// normalizes compression level into [0..7], so it could be used as an index +// in *PoolMap. +func normalizeZstdCompressLevel(level int) int { + if level < CompressZstdSpeedNotSet || level > CompressZstdBestCompression { + level = CompressZstdDefault + } + return level +} diff --git a/zstd_test.go b/zstd_test.go new file mode 100644 index 0000000..dc0c45f --- /dev/null +++ b/zstd_test.go @@ -0,0 +1,102 @@ +package fasthttp + +import ( + "bytes" + "fmt" + "io" + "testing" +) + +func TestZstdBytesSerial(t *testing.T) { + t.Parallel() + + if err := testZstdBytes(); err != nil { + t.Fatal(err) + } +} + +func TestZstdBytesConcurrent(t *testing.T) { + t.Parallel() + + if err := testConcurrent(10, testZstdBytes); err != nil { + t.Fatal(err) + } +} + +func testZstdBytes() error { + for _, s := range compressTestcases { + if err := testZstdBytesSingleCase(s); err != nil { + return err + } + } + return nil +} + +func testZstdBytesSingleCase(s string) error { + prefix := []byte("foobar") + ZstdpedS := AppendZstdBytes(prefix, []byte(s)) + if !bytes.Equal(ZstdpedS[:len(prefix)], prefix) { + return fmt.Errorf("unexpected prefix when compressing %q: %q. Expecting %q", s, ZstdpedS[:len(prefix)], prefix) + } + + unZstdedS, err := AppendUnzstdBytes(prefix, ZstdpedS[len(prefix):]) + if err != nil { + return fmt.Errorf("unexpected error when uncompressing %q: %w", s, err) + } + if !bytes.Equal(unZstdedS[:len(prefix)], prefix) { + return fmt.Errorf("unexpected prefix when uncompressing %q: %q. Expecting %q", s, unZstdedS[:len(prefix)], prefix) + } + unZstdedS = unZstdedS[len(prefix):] + if string(unZstdedS) != s { + return fmt.Errorf("unexpected uncompressed string %q. Expecting %q", unZstdedS, s) + } + return nil +} + +func TestZstdCompressSerial(t *testing.T) { + t.Parallel() + + if err := testZstdCompress(); err != nil { + t.Fatal(err) + } +} + +func TestZstdCompressConcurrent(t *testing.T) { + t.Parallel() + + if err := testConcurrent(10, testZstdCompress); err != nil { + t.Fatal(err) + } +} + +func testZstdCompress() error { + for _, s := range compressTestcases { + if err := testZstdCompressSingleCase(s); err != nil { + return err + } + } + return nil +} + +func testZstdCompressSingleCase(s string) error { + var buf bytes.Buffer + zw := acquireStacklessZstdWriter(&buf, CompressZstdDefault) + if _, err := zw.Write([]byte(s)); err != nil { + return fmt.Errorf("unexpected error: %w. s=%q", err, s) + } + releaseStacklessZstdWriter(zw, CompressZstdDefault) + + zr, err := acquireZstdReader(&buf) + if err != nil { + return fmt.Errorf("unexpected error: %w. s=%q", err, s) + } + body, err := io.ReadAll(zr) + if err != nil { + return fmt.Errorf("unexpected error: %w. s=%q", err, s) + } + if string(body) != s { + return fmt.Errorf("unexpected string after decompression: %q. Expecting %q", body, s) + } + releaseZstdReader(zr) + return nil +} -- cgit v1.2.3