package fasthttp import ( "bytes" "fmt" "io" "sync" "github.com/klauspost/compress/zstd" "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp/stackless" ) const ( CompressZstdSpeedNotSet = iota CompressZstdBestSpeed CompressZstdDefault CompressZstdSpeedBetter CompressZstdBestCompression ) var ( zstdDecoderPool sync.Pool zstdEncoderPool sync.Pool realZstdWriterPoolMap = newCompressWriterPoolMap() stacklessZstdWriterPoolMap = newCompressWriterPoolMap() ) func acquireZstdReader(r io.Reader) (*zstd.Decoder, error) { v := zstdDecoderPool.Get() if v == nil { return zstd.NewReader(r) } zr := v.(*zstd.Decoder) if err := zr.Reset(r); err != nil { return nil, err } return zr, nil } func releaseZstdReader(zr *zstd.Decoder) { zstdDecoderPool.Put(zr) } func acquireZstdWriter(w io.Writer, level int) (*zstd.Encoder, error) { v := zstdEncoderPool.Get() if v == nil { return zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.EncoderLevel(level))) } zw := v.(*zstd.Encoder) zw.Reset(w) return zw, nil } func releaseZstdWriter(zw *zstd.Encoder) { //nolint:unused zw.Close() zstdEncoderPool.Put(zw) } func acquireStacklessZstdWriter(w io.Writer, compressLevel int) stackless.Writer { nLevel := normalizeZstdCompressLevel(compressLevel) p := stacklessZstdWriterPoolMap[nLevel] v := p.Get() if v == nil { return stackless.NewWriter(w, func(w io.Writer) stackless.Writer { return acquireRealZstdWriter(w, compressLevel) }) } sw := v.(stackless.Writer) sw.Reset(w) return sw } func releaseStacklessZstdWriter(zf stackless.Writer, zstdDefault int) { zf.Close() nLevel := normalizeZstdCompressLevel(zstdDefault) p := stacklessZstdWriterPoolMap[nLevel] p.Put(zf) } func acquireRealZstdWriter(w io.Writer, level int) *zstd.Encoder { nLevel := normalizeZstdCompressLevel(level) p := realZstdWriterPoolMap[nLevel] v := p.Get() if v == nil { zw, err := acquireZstdWriter(w, level) if err != nil { panic(err) } return zw } zw := v.(*zstd.Encoder) zw.Reset(w) return zw } func releaseRealZstdWrter(zw *zstd.Encoder, level int) { zw.Close() nLevel := normalizeZstdCompressLevel(level) p := realZstdWriterPoolMap[nLevel] p.Put(zw) } func AppendZstdBytesLevel(dst, src []byte, level int) []byte { w := &byteSliceWriter{dst} WriteZstdLevel(w, src, level) //nolint:errcheck return w.b } func WriteZstdLevel(w io.Writer, p []byte, level int) (int, error) { level = normalizeZstdCompressLevel(level) switch w.(type) { case *byteSliceWriter, *bytes.Buffer, *bytebufferpool.ByteBuffer: ctx := &compressCtx{ w: w, p: p, level: level, } stacklessWriteZstd(ctx) return len(p), nil default: zw := acquireStacklessZstdWriter(w, level) n, err := zw.Write(p) releaseStacklessZstdWriter(zw, level) return n, err } } var ( stacklessWriteZstdOnce sync.Once stacklessWriteZstdFunc func(ctx any) bool ) func stacklessWriteZstd(ctx any) { stacklessWriteZstdOnce.Do(func() { stacklessWriteZstdFunc = stackless.NewFunc(nonblockingWriteZstd) }) stacklessWriteZstdFunc(ctx) } func nonblockingWriteZstd(ctxv any) { ctx := ctxv.(*compressCtx) zw := acquireRealZstdWriter(ctx.w, ctx.level) zw.Write(ctx.p) //nolint:errcheck releaseRealZstdWrter(zw, ctx.level) } // AppendZstdBytes appends zstd src to dst and returns the resulting dst. func AppendZstdBytes(dst, src []byte) []byte { return AppendZstdBytesLevel(dst, src, CompressZstdDefault) } // WriteUnzstd writes unzstd p to w and returns the number of uncompressed // bytes written to w. func WriteUnzstd(w io.Writer, p []byte) (int, error) { r := &byteSliceReader{p} zr, err := acquireZstdReader(r) if err != nil { return 0, err } n, err := copyZeroAlloc(w, zr) releaseZstdReader(zr) nn := int(n) if int64(nn) != n { return 0, fmt.Errorf("too much data unzstd: %d", n) } return nn, err } // AppendUnzstdBytes appends unzstd src to dst and returns the resulting dst. func AppendUnzstdBytes(dst, src []byte) ([]byte, error) { w := &byteSliceWriter{dst} _, err := WriteUnzstd(w, src) return w.b, err } // normalizes compression level into [0..7], so it could be used as an index // in *PoolMap. func normalizeZstdCompressLevel(level int) int { if level < CompressZstdSpeedNotSet || level > CompressZstdBestCompression { level = CompressZstdDefault } return level }