aboutsummaryrefslogtreecommitdiff
path: root/zstd.go
diff options
context:
space:
mode:
authorGravatar Co1a <aaron9shire@gmail.com> 2024-02-21 14:21:52 +0800
committerGravatar GitHub <noreply@github.com> 2024-02-21 07:21:52 +0100
commit5f81476d7cf339624b5fec57a06ee96d3e27d9c2 (patch)
tree62b06111db30c5895dc775f53bf5577f1092c06a /zstd.go
parentLimit memory for fuzz testing (diff)
downloadfasthttp-5f81476d7cf339624b5fec57a06ee96d3e27d9c2.tar.gz
fasthttp-5f81476d7cf339624b5fec57a06ee96d3e27d9c2.tar.bz2
fasthttp-5f81476d7cf339624b5fec57a06ee96d3e27d9c2.zip
feat:support zstd compress and uncompressed (#1701)
* feat:support zstd compress and uncompressed * fix:real & stackless write using different pool to avoid get stackless.writer * fix:zstd normalize compress level * Change empty string checks to be more idiomatic (#1684) * chore:lint fix and rebase with master * chore:remove 1.18 test & upgrade compress version * fix:error default compress level * Fix lint --------- Co-authored-by: Erik Dubbelboer <erik@dubbelboer.com>
Diffstat (limited to 'zstd.go')
-rw-r--r--zstd.go186
1 files changed, 186 insertions, 0 deletions
diff --git a/zstd.go b/zstd.go
new file mode 100644
index 0000000..226a126
--- /dev/null
+++ b/zstd.go
@@ -0,0 +1,186 @@
+package fasthttp
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "sync"
+
+ "github.com/klauspost/compress/zstd"
+ "github.com/valyala/bytebufferpool"
+ "github.com/valyala/fasthttp/stackless"
+)
+
+const (
+ CompressZstdSpeedNotSet = iota
+ CompressZstdBestSpeed
+ CompressZstdDefault
+ CompressZstdSpeedBetter
+ CompressZstdBestCompression
+)
+
+var (
+ zstdDecoderPool sync.Pool
+ zstdEncoderPool sync.Pool
+ realZstdWriterPoolMap = newCompressWriterPoolMap()
+ stacklessZstdWriterPoolMap = newCompressWriterPoolMap()
+)
+
+func acquireZstdReader(r io.Reader) (*zstd.Decoder, error) {
+ v := zstdDecoderPool.Get()
+ if v == nil {
+ return zstd.NewReader(r)
+ }
+ zr := v.(*zstd.Decoder)
+ if err := zr.Reset(r); err != nil {
+ return nil, err
+ }
+ return zr, nil
+}
+
+func releaseZstdReader(zr *zstd.Decoder) {
+ zstdDecoderPool.Put(zr)
+}
+
+func acquireZstdWriter(w io.Writer, level int) (*zstd.Encoder, error) {
+ v := zstdEncoderPool.Get()
+ if v == nil {
+ return zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.EncoderLevel(level)))
+ }
+ zw := v.(*zstd.Encoder)
+ zw.Reset(w)
+ return zw, nil
+}
+
+func releaseZstdWriter(zw *zstd.Encoder) { //nolint:unused
+ zw.Close()
+ zstdEncoderPool.Put(zw)
+}
+
+func acquireStacklessZstdWriter(w io.Writer, compressLevel int) stackless.Writer {
+ nLevel := normalizeZstdCompressLevel(compressLevel)
+ p := stacklessZstdWriterPoolMap[nLevel]
+ v := p.Get()
+ if v == nil {
+ return stackless.NewWriter(w, func(w io.Writer) stackless.Writer {
+ return acquireRealZstdWriter(w, compressLevel)
+ })
+ }
+ sw := v.(stackless.Writer)
+ sw.Reset(w)
+ return sw
+}
+
+func releaseStacklessZstdWriter(zf stackless.Writer, zstdDefault int) {
+ zf.Close()
+ nLevel := normalizeZstdCompressLevel(zstdDefault)
+ p := stacklessZstdWriterPoolMap[nLevel]
+ p.Put(zf)
+}
+
+func acquireRealZstdWriter(w io.Writer, level int) *zstd.Encoder {
+ nLevel := normalizeZstdCompressLevel(level)
+ p := realZstdWriterPoolMap[nLevel]
+ v := p.Get()
+ if v == nil {
+ zw, err := acquireZstdWriter(w, level)
+ if err != nil {
+ panic(err)
+ }
+ return zw
+ }
+ zw := v.(*zstd.Encoder)
+ zw.Reset(w)
+ return zw
+}
+
+func releaseRealZstdWrter(zw *zstd.Encoder, level int) {
+ zw.Close()
+ nLevel := normalizeZstdCompressLevel(level)
+ p := realZstdWriterPoolMap[nLevel]
+ p.Put(zw)
+}
+
+func AppendZstdBytesLevel(dst, src []byte, level int) []byte {
+ w := &byteSliceWriter{dst}
+ WriteZstdLevel(w, src, level) //nolint:errcheck
+ return w.b
+}
+
+func WriteZstdLevel(w io.Writer, p []byte, level int) (int, error) {
+ level = normalizeZstdCompressLevel(level)
+ switch w.(type) {
+ case *byteSliceWriter,
+ *bytes.Buffer,
+ *bytebufferpool.ByteBuffer:
+ ctx := &compressCtx{
+ w: w,
+ p: p,
+ level: level,
+ }
+ stacklessWriteZstd(ctx)
+ return len(p), nil
+ default:
+ zw := acquireStacklessZstdWriter(w, level)
+ n, err := zw.Write(p)
+ releaseStacklessZstdWriter(zw, level)
+ return n, err
+ }
+}
+
+var (
+ stacklessWriteZstdOnce sync.Once
+ stacklessWriteZstdFunc func(ctx any) bool
+)
+
+func stacklessWriteZstd(ctx any) {
+ stacklessWriteZstdOnce.Do(func() {
+ stacklessWriteZstdFunc = stackless.NewFunc(nonblockingWriteZstd)
+ })
+ stacklessWriteZstdFunc(ctx)
+}
+
+func nonblockingWriteZstd(ctxv any) {
+ ctx := ctxv.(*compressCtx)
+ zw := acquireRealZstdWriter(ctx.w, ctx.level)
+ zw.Write(ctx.p) //nolint:errcheck
+ releaseRealZstdWrter(zw, ctx.level)
+}
+
+// AppendZstdBytes appends zstd src to dst and returns the resulting dst.
+func AppendZstdBytes(dst, src []byte) []byte {
+ return AppendZstdBytesLevel(dst, src, CompressZstdDefault)
+}
+
+// WriteUnzstd writes unzstd p to w and returns the number of uncompressed
+// bytes written to w.
+func WriteUnzstd(w io.Writer, p []byte) (int, error) {
+ r := &byteSliceReader{p}
+ zr, err := acquireZstdReader(r)
+ if err != nil {
+ return 0, err
+ }
+ n, err := copyZeroAlloc(w, zr)
+ releaseZstdReader(zr)
+ nn := int(n)
+ if int64(nn) != n {
+ return 0, fmt.Errorf("too much data unzstd: %d", n)
+ }
+ return nn, err
+}
+
+// AppendUnzstdBytes appends unzstd src to dst and returns the resulting dst.
+func AppendUnzstdBytes(dst, src []byte) ([]byte, error) {
+ w := &byteSliceWriter{dst}
+ _, err := WriteUnzstd(w, src)
+ return w.b, err
+}
+
+// normalizes compression level into [0..7], so it could be used as an index
+// in *PoolMap.
+func normalizeZstdCompressLevel(level int) int {
+ if level < CompressZstdSpeedNotSet || level > CompressZstdBestCompression {
+ level = CompressZstdDefault
+ }
+ return level
+}