1
0

stream.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. package smux
  2. import (
  3. "bytes"
  4. "io"
  5. "net"
  6. "sync"
  7. "sync/atomic"
  8. "time"
  9. "github.com/pkg/errors"
  10. )
  11. // Stream implements net.Conn
  12. type Stream struct {
  13. id uint32
  14. rstflag int32
  15. sess *Session
  16. buffer bytes.Buffer
  17. bufferLock sync.Mutex
  18. frameSize int
  19. chReadEvent chan struct{} // notify a read event
  20. die chan struct{} // flag the stream has closed
  21. dieLock sync.Mutex
  22. readDeadline atomic.Value
  23. writeDeadline atomic.Value
  24. }
  25. // newStream initiates a Stream struct
  26. func newStream(id uint32, frameSize int, sess *Session) *Stream {
  27. s := new(Stream)
  28. s.id = id
  29. s.chReadEvent = make(chan struct{}, 1)
  30. s.frameSize = frameSize
  31. s.sess = sess
  32. s.die = make(chan struct{})
  33. return s
  34. }
  35. // ID returns the unique stream ID.
  36. func (s *Stream) ID() uint32 {
  37. return s.id
  38. }
  39. // Read implements net.Conn
  40. func (s *Stream) Read(b []byte) (n int, err error) {
  41. var deadline <-chan time.Time
  42. if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
  43. timer := time.NewTimer(d.Sub(time.Now()))
  44. defer timer.Stop()
  45. deadline = timer.C
  46. }
  47. READ:
  48. select {
  49. case <-s.die:
  50. return 0, errors.New(errBrokenPipe)
  51. case <-deadline:
  52. return n, errTimeout
  53. default:
  54. }
  55. s.bufferLock.Lock()
  56. n, err = s.buffer.Read(b)
  57. s.bufferLock.Unlock()
  58. if n > 0 {
  59. s.sess.returnTokens(n)
  60. return n, nil
  61. } else if atomic.LoadInt32(&s.rstflag) == 1 {
  62. _ = s.Close()
  63. return 0, io.EOF
  64. }
  65. select {
  66. case <-s.chReadEvent:
  67. goto READ
  68. case <-deadline:
  69. return n, errTimeout
  70. case <-s.die:
  71. return 0, errors.New(errBrokenPipe)
  72. }
  73. }
  74. // Write implements net.Conn
  75. func (s *Stream) Write(b []byte) (n int, err error) {
  76. var deadline <-chan time.Time
  77. if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
  78. timer := time.NewTimer(d.Sub(time.Now()))
  79. defer timer.Stop()
  80. deadline = timer.C
  81. }
  82. select {
  83. case <-s.die:
  84. return 0, errors.New(errBrokenPipe)
  85. default:
  86. }
  87. frames := s.split(b, cmdPSH, s.id)
  88. sent := 0
  89. for k := range frames {
  90. req := writeRequest{
  91. frame: frames[k],
  92. result: make(chan writeResult, 1),
  93. }
  94. select {
  95. case s.sess.writes <- req:
  96. case <-s.die:
  97. return sent, errors.New(errBrokenPipe)
  98. case <-deadline:
  99. return sent, errTimeout
  100. }
  101. select {
  102. case result := <-req.result:
  103. sent += result.n
  104. if result.err != nil {
  105. return sent, result.err
  106. }
  107. case <-s.die:
  108. return sent, errors.New(errBrokenPipe)
  109. case <-deadline:
  110. return sent, errTimeout
  111. }
  112. }
  113. return sent, nil
  114. }
  115. // Close implements net.Conn
  116. func (s *Stream) Close() error {
  117. s.dieLock.Lock()
  118. select {
  119. case <-s.die:
  120. s.dieLock.Unlock()
  121. return errors.New(errBrokenPipe)
  122. default:
  123. close(s.die)
  124. s.dieLock.Unlock()
  125. s.sess.streamClosed(s.id)
  126. _, err := s.sess.writeFrame(newFrame(cmdFIN, s.id))
  127. return err
  128. }
  129. }
  130. // SetReadDeadline sets the read deadline as defined by
  131. // net.Conn.SetReadDeadline.
  132. // A zero time value disables the deadline.
  133. func (s *Stream) SetReadDeadline(t time.Time) error {
  134. s.readDeadline.Store(t)
  135. return nil
  136. }
  137. // SetWriteDeadline sets the write deadline as defined by
  138. // net.Conn.SetWriteDeadline.
  139. // A zero time value disables the deadline.
  140. func (s *Stream) SetWriteDeadline(t time.Time) error {
  141. s.writeDeadline.Store(t)
  142. return nil
  143. }
  144. // SetDeadline sets both read and write deadlines as defined by
  145. // net.Conn.SetDeadline.
  146. // A zero time value disables the deadlines.
  147. func (s *Stream) SetDeadline(t time.Time) error {
  148. if err := s.SetReadDeadline(t); err != nil {
  149. return err
  150. }
  151. if err := s.SetWriteDeadline(t); err != nil {
  152. return err
  153. }
  154. return nil
  155. }
  156. // session closes the stream
  157. func (s *Stream) sessionClose() {
  158. s.dieLock.Lock()
  159. defer s.dieLock.Unlock()
  160. select {
  161. case <-s.die:
  162. default:
  163. close(s.die)
  164. }
  165. }
  166. // LocalAddr satisfies net.Conn interface
  167. func (s *Stream) LocalAddr() net.Addr {
  168. if ts, ok := s.sess.conn.(interface {
  169. LocalAddr() net.Addr
  170. }); ok {
  171. return ts.LocalAddr()
  172. }
  173. return nil
  174. }
  175. // RemoteAddr satisfies net.Conn interface
  176. func (s *Stream) RemoteAddr() net.Addr {
  177. if ts, ok := s.sess.conn.(interface {
  178. RemoteAddr() net.Addr
  179. }); ok {
  180. return ts.RemoteAddr()
  181. }
  182. return nil
  183. }
  184. // pushBytes a slice into buffer
  185. func (s *Stream) pushBytes(p []byte) {
  186. s.bufferLock.Lock()
  187. s.buffer.Write(p)
  188. s.bufferLock.Unlock()
  189. }
  190. // recycleTokens transform remaining bytes to tokens(will truncate buffer)
  191. func (s *Stream) recycleTokens() (n int) {
  192. s.bufferLock.Lock()
  193. n = s.buffer.Len()
  194. s.buffer.Reset()
  195. s.bufferLock.Unlock()
  196. return
  197. }
  198. // split large byte buffer into smaller frames, reference only
  199. func (s *Stream) split(bts []byte, cmd byte, sid uint32) []Frame {
  200. frames := make([]Frame, 0, len(bts)/s.frameSize+1)
  201. for len(bts) > s.frameSize {
  202. frame := newFrame(cmd, sid)
  203. frame.data = bts[:s.frameSize]
  204. bts = bts[s.frameSize:]
  205. frames = append(frames, frame)
  206. }
  207. if len(bts) > 0 {
  208. frame := newFrame(cmd, sid)
  209. frame.data = bts
  210. frames = append(frames, frame)
  211. }
  212. return frames
  213. }
  214. // notify read event
  215. func (s *Stream) notifyReadEvent() {
  216. select {
  217. case s.chReadEvent <- struct{}{}:
  218. default:
  219. }
  220. }
  221. // mark this stream has been reset
  222. func (s *Stream) markRST() {
  223. atomic.StoreInt32(&s.rstflag, 1)
  224. }
  225. var errTimeout error = &timeoutError{}
  226. type timeoutError struct{}
  227. func (e *timeoutError) Error() string { return "i/o timeout" }
  228. func (e *timeoutError) Timeout() bool { return true }
  229. func (e *timeoutError) Temporary() bool { return true }