123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package websocket
- import (
- "compress/flate"
- "errors"
- "io"
- "strings"
- "sync"
- )
- const (
- minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6
- maxCompressionLevel = flate.BestCompression
- defaultCompressionLevel = 1
- )
- var (
- flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
- flateReaderPool = sync.Pool{New: func() interface{} {
- return flate.NewReader(nil)
- }}
- )
- func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
- const tail =
- // Add four bytes as specified in RFC
- "\x00\x00\xff\xff" +
- // Add final block to squelch unexpected EOF error from flate reader.
- "\x01\x00\x00\xff\xff"
- fr, _ := flateReaderPool.Get().(io.ReadCloser)
- fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
- return &flateReadWrapper{fr}
- }
- func isValidCompressionLevel(level int) bool {
- return minCompressionLevel <= level && level <= maxCompressionLevel
- }
- func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
- p := &flateWriterPools[level-minCompressionLevel]
- tw := &truncWriter{w: w}
- fw, _ := p.Get().(*flate.Writer)
- if fw == nil {
- fw, _ = flate.NewWriter(tw, level)
- } else {
- fw.Reset(tw)
- }
- return &flateWriteWrapper{fw: fw, tw: tw, p: p}
- }
- // truncWriter is an io.Writer that writes all but the last four bytes of the
- // stream to another io.Writer.
- type truncWriter struct {
- w io.WriteCloser
- n int
- p [4]byte
- }
- func (w *truncWriter) Write(p []byte) (int, error) {
- n := 0
- // fill buffer first for simplicity.
- if w.n < len(w.p) {
- n = copy(w.p[w.n:], p)
- p = p[n:]
- w.n += n
- if len(p) == 0 {
- return n, nil
- }
- }
- m := len(p)
- if m > len(w.p) {
- m = len(w.p)
- }
- if nn, err := w.w.Write(w.p[:m]); err != nil {
- return n + nn, err
- }
- copy(w.p[:], w.p[m:])
- copy(w.p[len(w.p)-m:], p[len(p)-m:])
- nn, err := w.w.Write(p[:len(p)-m])
- return n + nn, err
- }
- type flateWriteWrapper struct {
- fw *flate.Writer
- tw *truncWriter
- p *sync.Pool
- }
- func (w *flateWriteWrapper) Write(p []byte) (int, error) {
- if w.fw == nil {
- return 0, errWriteClosed
- }
- return w.fw.Write(p)
- }
- func (w *flateWriteWrapper) Close() error {
- if w.fw == nil {
- return errWriteClosed
- }
- err1 := w.fw.Flush()
- w.p.Put(w.fw)
- w.fw = nil
- if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
- return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
- }
- err2 := w.tw.w.Close()
- if err1 != nil {
- return err1
- }
- return err2
- }
- type flateReadWrapper struct {
- fr io.ReadCloser
- }
- func (r *flateReadWrapper) Read(p []byte) (int, error) {
- if r.fr == nil {
- return 0, io.ErrClosedPipe
- }
- n, err := r.fr.Read(p)
- if err == io.EOF {
- // Preemptively place the reader back in the pool. This helps with
- // scenarios where the application does not call NextReader() soon after
- // this final read.
- r.Close()
- }
- return n, err
- }
- func (r *flateReadWrapper) Close() error {
- if r.fr == nil {
- return io.ErrClosedPipe
- }
- err := r.fr.Close()
- flateReaderPool.Put(r.fr)
- r.fr = nil
- return err
- }
|