diff options
author | Co1a <aaron9shire@gmail.com> | 2024-02-21 14:21:52 +0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-02-21 07:21:52 +0100 |
commit | 5f81476d7cf339624b5fec57a06ee96d3e27d9c2 (patch) | |
tree | 62b06111db30c5895dc775f53bf5577f1092c06a /zstd.go | |
parent | Limit memory for fuzz testing (diff) | |
download | fasthttp-5f81476d7cf339624b5fec57a06ee96d3e27d9c2.tar.gz fasthttp-5f81476d7cf339624b5fec57a06ee96d3e27d9c2.tar.bz2 fasthttp-5f81476d7cf339624b5fec57a06ee96d3e27d9c2.zip |
feat:support zstd compress and uncompressed (#1701)
* feat:support zstd compress and uncompressed
* fix:real & stackless write using different pool to avoid get stackless.writer
* fix:zstd normalize compress level
* Change empty string checks to be more idiomatic (#1684)
* chore:lint fix and rebase with master
* chore:remove 1.18 test & upgrade compress version
* fix:error default compress level
* Fix lint
---------
Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>
Diffstat (limited to 'zstd.go')
-rw-r--r-- | zstd.go | 186 |
1 files changed, 186 insertions, 0 deletions
@@ -0,0 +1,186 @@ +package fasthttp + +import ( + "bytes" + "fmt" + "io" + "sync" + + "github.com/klauspost/compress/zstd" + "github.com/valyala/bytebufferpool" + "github.com/valyala/fasthttp/stackless" +) + +const ( + CompressZstdSpeedNotSet = iota + CompressZstdBestSpeed + CompressZstdDefault + CompressZstdSpeedBetter + CompressZstdBestCompression +) + +var ( + zstdDecoderPool sync.Pool + zstdEncoderPool sync.Pool + realZstdWriterPoolMap = newCompressWriterPoolMap() + stacklessZstdWriterPoolMap = newCompressWriterPoolMap() +) + +func acquireZstdReader(r io.Reader) (*zstd.Decoder, error) { + v := zstdDecoderPool.Get() + if v == nil { + return zstd.NewReader(r) + } + zr := v.(*zstd.Decoder) + if err := zr.Reset(r); err != nil { + return nil, err + } + return zr, nil +} + +func releaseZstdReader(zr *zstd.Decoder) { + zstdDecoderPool.Put(zr) +} + +func acquireZstdWriter(w io.Writer, level int) (*zstd.Encoder, error) { + v := zstdEncoderPool.Get() + if v == nil { + return zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.EncoderLevel(level))) + } + zw := v.(*zstd.Encoder) + zw.Reset(w) + return zw, nil +} + +func releaseZstdWriter(zw *zstd.Encoder) { //nolint:unused + zw.Close() + zstdEncoderPool.Put(zw) +} + +func acquireStacklessZstdWriter(w io.Writer, compressLevel int) stackless.Writer { + nLevel := normalizeZstdCompressLevel(compressLevel) + p := stacklessZstdWriterPoolMap[nLevel] + v := p.Get() + if v == nil { + return stackless.NewWriter(w, func(w io.Writer) stackless.Writer { + return acquireRealZstdWriter(w, compressLevel) + }) + } + sw := v.(stackless.Writer) + sw.Reset(w) + return sw +} + +func releaseStacklessZstdWriter(zf stackless.Writer, zstdDefault int) { + zf.Close() + nLevel := normalizeZstdCompressLevel(zstdDefault) + p := stacklessZstdWriterPoolMap[nLevel] + p.Put(zf) +} + +func acquireRealZstdWriter(w io.Writer, level int) *zstd.Encoder { + nLevel := normalizeZstdCompressLevel(level) + p := realZstdWriterPoolMap[nLevel] + v := p.Get() + if v == nil { + zw, err := acquireZstdWriter(w, level) + if err != nil { + panic(err) + } + return zw + } + zw := v.(*zstd.Encoder) + zw.Reset(w) + return zw +} + +func releaseRealZstdWrter(zw *zstd.Encoder, level int) { + zw.Close() + nLevel := normalizeZstdCompressLevel(level) + p := realZstdWriterPoolMap[nLevel] + p.Put(zw) +} + +func AppendZstdBytesLevel(dst, src []byte, level int) []byte { + w := &byteSliceWriter{dst} + WriteZstdLevel(w, src, level) //nolint:errcheck + return w.b +} + +func WriteZstdLevel(w io.Writer, p []byte, level int) (int, error) { + level = normalizeZstdCompressLevel(level) + switch w.(type) { + case *byteSliceWriter, + *bytes.Buffer, + *bytebufferpool.ByteBuffer: + ctx := &compressCtx{ + w: w, + p: p, + level: level, + } + stacklessWriteZstd(ctx) + return len(p), nil + default: + zw := acquireStacklessZstdWriter(w, level) + n, err := zw.Write(p) + releaseStacklessZstdWriter(zw, level) + return n, err + } +} + +var ( + stacklessWriteZstdOnce sync.Once + stacklessWriteZstdFunc func(ctx any) bool +) + +func stacklessWriteZstd(ctx any) { + stacklessWriteZstdOnce.Do(func() { + stacklessWriteZstdFunc = stackless.NewFunc(nonblockingWriteZstd) + }) + stacklessWriteZstdFunc(ctx) +} + +func nonblockingWriteZstd(ctxv any) { + ctx := ctxv.(*compressCtx) + zw := acquireRealZstdWriter(ctx.w, ctx.level) + zw.Write(ctx.p) //nolint:errcheck + releaseRealZstdWrter(zw, ctx.level) +} + +// AppendZstdBytes appends zstd src to dst and returns the resulting dst. +func AppendZstdBytes(dst, src []byte) []byte { + return AppendZstdBytesLevel(dst, src, CompressZstdDefault) +} + +// WriteUnzstd writes unzstd p to w and returns the number of uncompressed +// bytes written to w. +func WriteUnzstd(w io.Writer, p []byte) (int, error) { + r := &byteSliceReader{p} + zr, err := acquireZstdReader(r) + if err != nil { + return 0, err + } + n, err := copyZeroAlloc(w, zr) + releaseZstdReader(zr) + nn := int(n) + if int64(nn) != n { + return 0, fmt.Errorf("too much data unzstd: %d", n) + } + return nn, err +} + +// AppendUnzstdBytes appends unzstd src to dst and returns the resulting dst. +func AppendUnzstdBytes(dst, src []byte) ([]byte, error) { + w := &byteSliceWriter{dst} + _, err := WriteUnzstd(w, src) + return w.b, err +} + +// normalizes compression level into [0..7], so it could be used as an index +// in *PoolMap. +func normalizeZstdCompressLevel(level int) int { + if level < CompressZstdSpeedNotSet || level > CompressZstdBestCompression { + level = CompressZstdDefault + } + return level +} |