1
0

compression.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "compress/flate"
  7. "errors"
  8. "io"
  9. "strings"
  10. "sync"
  11. )
  12. const (
  13. minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6
  14. maxCompressionLevel = flate.BestCompression
  15. defaultCompressionLevel = 1
  16. )
  17. var (
  18. flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
  19. flateReaderPool = sync.Pool{New: func() interface{} {
  20. return flate.NewReader(nil)
  21. }}
  22. )
  23. func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
  24. const tail =
  25. // Add four bytes as specified in RFC
  26. "\x00\x00\xff\xff" +
  27. // Add final block to squelch unexpected EOF error from flate reader.
  28. "\x01\x00\x00\xff\xff"
  29. fr, _ := flateReaderPool.Get().(io.ReadCloser)
  30. fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
  31. return &flateReadWrapper{fr}
  32. }
  33. func isValidCompressionLevel(level int) bool {
  34. return minCompressionLevel <= level && level <= maxCompressionLevel
  35. }
  36. func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
  37. p := &flateWriterPools[level-minCompressionLevel]
  38. tw := &truncWriter{w: w}
  39. fw, _ := p.Get().(*flate.Writer)
  40. if fw == nil {
  41. fw, _ = flate.NewWriter(tw, level)
  42. } else {
  43. fw.Reset(tw)
  44. }
  45. return &flateWriteWrapper{fw: fw, tw: tw, p: p}
  46. }
  47. // truncWriter is an io.Writer that writes all but the last four bytes of the
  48. // stream to another io.Writer.
  49. type truncWriter struct {
  50. w io.WriteCloser
  51. n int
  52. p [4]byte
  53. }
  54. func (w *truncWriter) Write(p []byte) (int, error) {
  55. n := 0
  56. // fill buffer first for simplicity.
  57. if w.n < len(w.p) {
  58. n = copy(w.p[w.n:], p)
  59. p = p[n:]
  60. w.n += n
  61. if len(p) == 0 {
  62. return n, nil
  63. }
  64. }
  65. m := len(p)
  66. if m > len(w.p) {
  67. m = len(w.p)
  68. }
  69. if nn, err := w.w.Write(w.p[:m]); err != nil {
  70. return n + nn, err
  71. }
  72. copy(w.p[:], w.p[m:])
  73. copy(w.p[len(w.p)-m:], p[len(p)-m:])
  74. nn, err := w.w.Write(p[:len(p)-m])
  75. return n + nn, err
  76. }
  77. type flateWriteWrapper struct {
  78. fw *flate.Writer
  79. tw *truncWriter
  80. p *sync.Pool
  81. }
  82. func (w *flateWriteWrapper) Write(p []byte) (int, error) {
  83. if w.fw == nil {
  84. return 0, errWriteClosed
  85. }
  86. return w.fw.Write(p)
  87. }
  88. func (w *flateWriteWrapper) Close() error {
  89. if w.fw == nil {
  90. return errWriteClosed
  91. }
  92. err1 := w.fw.Flush()
  93. w.p.Put(w.fw)
  94. w.fw = nil
  95. if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
  96. return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
  97. }
  98. err2 := w.tw.w.Close()
  99. if err1 != nil {
  100. return err1
  101. }
  102. return err2
  103. }
  104. type flateReadWrapper struct {
  105. fr io.ReadCloser
  106. }
  107. func (r *flateReadWrapper) Read(p []byte) (int, error) {
  108. if r.fr == nil {
  109. return 0, io.ErrClosedPipe
  110. }
  111. n, err := r.fr.Read(p)
  112. if err == io.EOF {
  113. // Preemptively place the reader back in the pool. This helps with
  114. // scenarios where the application does not call NextReader() soon after
  115. // this final read.
  116. r.Close()
  117. }
  118. return n, err
  119. }
  120. func (r *flateReadWrapper) Close() error {
  121. if r.fr == nil {
  122. return io.ErrClosedPipe
  123. }
  124. err := r.fr.Close()
  125. flateReaderPool.Put(r.fr)
  126. r.fr = nil
  127. return err
  128. }