diff options
Diffstat (limited to 'zstd.go')
-rw-r--r-- | zstd.go | 186 |
1 files changed, 186 insertions, 0 deletions
@@ -0,0 +1,186 @@ +package fasthttp + +import ( + "bytes" + "fmt" + "io" + "sync" + + "github.com/klauspost/compress/zstd" + "github.com/valyala/bytebufferpool" + "github.com/valyala/fasthttp/stackless" +) + +const ( + CompressZstdSpeedNotSet = iota + CompressZstdBestSpeed + CompressZstdDefault + CompressZstdSpeedBetter + CompressZstdBestCompression +) + +var ( + zstdDecoderPool sync.Pool + zstdEncoderPool sync.Pool + realZstdWriterPoolMap = newCompressWriterPoolMap() + stacklessZstdWriterPoolMap = newCompressWriterPoolMap() +) + +func acquireZstdReader(r io.Reader) (*zstd.Decoder, error) { + v := zstdDecoderPool.Get() + if v == nil { + return zstd.NewReader(r) + } + zr := v.(*zstd.Decoder) + if err := zr.Reset(r); err != nil { + return nil, err + } + return zr, nil +} + +func releaseZstdReader(zr *zstd.Decoder) { + zstdDecoderPool.Put(zr) +} + +func acquireZstdWriter(w io.Writer, level int) (*zstd.Encoder, error) { + v := zstdEncoderPool.Get() + if v == nil { + return zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.EncoderLevel(level))) + } + zw := v.(*zstd.Encoder) + zw.Reset(w) + return zw, nil +} + +func releaseZstdWriter(zw *zstd.Encoder) { //nolint:unused + zw.Close() + zstdEncoderPool.Put(zw) +} + +func acquireStacklessZstdWriter(w io.Writer, compressLevel int) stackless.Writer { + nLevel := normalizeZstdCompressLevel(compressLevel) + p := stacklessZstdWriterPoolMap[nLevel] + v := p.Get() + if v == nil { + return stackless.NewWriter(w, func(w io.Writer) stackless.Writer { + return acquireRealZstdWriter(w, compressLevel) + }) + } + sw := v.(stackless.Writer) + sw.Reset(w) + return sw +} + +func releaseStacklessZstdWriter(zf stackless.Writer, zstdDefault int) { + zf.Close() + nLevel := normalizeZstdCompressLevel(zstdDefault) + p := stacklessZstdWriterPoolMap[nLevel] + p.Put(zf) +} + +func acquireRealZstdWriter(w io.Writer, level int) *zstd.Encoder { + nLevel := normalizeZstdCompressLevel(level) + p := realZstdWriterPoolMap[nLevel] + v := p.Get() + if v == nil { + zw, err := acquireZstdWriter(w, level) + if err != nil { + panic(err) + } + return zw + } + zw := v.(*zstd.Encoder) + zw.Reset(w) + return zw +} + +func releaseRealZstdWrter(zw *zstd.Encoder, level int) { + zw.Close() + nLevel := normalizeZstdCompressLevel(level) + p := realZstdWriterPoolMap[nLevel] + p.Put(zw) +} + +func AppendZstdBytesLevel(dst, src []byte, level int) []byte { + w := &byteSliceWriter{dst} + WriteZstdLevel(w, src, level) //nolint:errcheck + return w.b +} + +func WriteZstdLevel(w io.Writer, p []byte, level int) (int, error) { + level = normalizeZstdCompressLevel(level) + switch w.(type) { + case *byteSliceWriter, + *bytes.Buffer, + *bytebufferpool.ByteBuffer: + ctx := &compressCtx{ + w: w, + p: p, + level: level, + } + stacklessWriteZstd(ctx) + return len(p), nil + default: + zw := acquireStacklessZstdWriter(w, level) + n, err := zw.Write(p) + releaseStacklessZstdWriter(zw, level) + return n, err + } +} + +var ( + stacklessWriteZstdOnce sync.Once + stacklessWriteZstdFunc func(ctx any) bool +) + +func stacklessWriteZstd(ctx any) { + stacklessWriteZstdOnce.Do(func() { + stacklessWriteZstdFunc = stackless.NewFunc(nonblockingWriteZstd) + }) + stacklessWriteZstdFunc(ctx) +} + +func nonblockingWriteZstd(ctxv any) { + ctx := ctxv.(*compressCtx) + zw := acquireRealZstdWriter(ctx.w, ctx.level) + zw.Write(ctx.p) //nolint:errcheck + releaseRealZstdWrter(zw, ctx.level) +} + +// AppendZstdBytes appends zstd src to dst and returns the resulting dst. +func AppendZstdBytes(dst, src []byte) []byte { + return AppendZstdBytesLevel(dst, src, CompressZstdDefault) +} + +// WriteUnzstd writes unzstd p to w and returns the number of uncompressed +// bytes written to w. +func WriteUnzstd(w io.Writer, p []byte) (int, error) { + r := &byteSliceReader{p} + zr, err := acquireZstdReader(r) + if err != nil { + return 0, err + } + n, err := copyZeroAlloc(w, zr) + releaseZstdReader(zr) + nn := int(n) + if int64(nn) != n { + return 0, fmt.Errorf("too much data unzstd: %d", n) + } + return nn, err +} + +// AppendUnzstdBytes appends unzstd src to dst and returns the resulting dst. +func AppendUnzstdBytes(dst, src []byte) ([]byte, error) { + w := &byteSliceWriter{dst} + _, err := WriteUnzstd(w, src) + return w.b, err +} + +// normalizes compression level into [0..7], so it could be used as an index +// in *PoolMap. +func normalizeZstdCompressLevel(level int) int { + if level < CompressZstdSpeedNotSet || level > CompressZstdBestCompression { + level = CompressZstdDefault + } + return level +} |