session.go 7.6 KB


  1. package smux
  2. import (
  3. "encoding/binary"
  4. "io"
  5. "sync"
  6. "sync/atomic"
  7. "time"
  8. "github.com/pkg/errors"
  9. )
  10. const (
  11. defaultAcceptBacklog = 1024
  12. )
  13. const (
  14. errBrokenPipe = "broken pipe"
  15. errInvalidProtocol = "invalid protocol version"
  16. errGoAway = "stream id overflows, should start a new connection"
  17. )
  18. type writeRequest struct {
  19. frame Frame
  20. result chan writeResult
  21. }
  22. type writeResult struct {
  23. n int
  24. err error
  25. }
  26. // Session defines a multiplexed connection for streams
  27. type Session struct {
  28. conn io.ReadWriteCloser
  29. config *Config
  30. nextStreamID uint32 // next stream identifier
  31. nextStreamIDLock sync.Mutex
  32. bucket int32 // token bucket
  33. bucketNotify chan struct{} // used for waiting for tokens
  34. streams map[uint32]*Stream // all streams in this session
  35. streamLock sync.Mutex // locks streams
  36. die chan struct{} // flag session has died
  37. dieLock sync.Mutex
  38. chAccepts chan *Stream
  39. dataReady int32 // flag data has arrived
  40. goAway int32 // flag id exhausted
  41. deadline atomic.Value
  42. writes chan writeRequest
  43. }
  44. func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
  45. s := new(Session)
  46. s.die = make(chan struct{})
  47. s.conn = conn
  48. s.config = config
  49. s.streams = make(map[uint32]*Stream)
  50. s.chAccepts = make(chan *Stream, defaultAcceptBacklog)
  51. s.bucket = int32(config.MaxReceiveBuffer)
  52. s.bucketNotify = make(chan struct{}, 1)
  53. s.writes = make(chan writeRequest)
  54. if client {
  55. s.nextStreamID = 1
  56. } else {
  57. s.nextStreamID = 0
  58. }
  59. go s.recvLoop()
  60. go s.sendLoop()
  61. go s.keepalive()
  62. return s
  63. }
  64. // OpenStream is used to create a new stream
  65. func (s *Session) OpenStream() (*Stream, error) {
  66. if s.IsClosed() {
  67. return nil, errors.New(errBrokenPipe)
  68. }
  69. // generate stream id
  70. s.nextStreamIDLock.Lock()
  71. if s.goAway > 0 {
  72. s.nextStreamIDLock.Unlock()
  73. return nil, errors.New(errGoAway)
  74. }
  75. s.nextStreamID += 2
  76. sid := s.nextStreamID
  77. if sid == sid%2 { // stream-id overflows
  78. s.goAway = 1
  79. s.nextStreamIDLock.Unlock()
  80. return nil, errors.New(errGoAway)
  81. }
  82. s.nextStreamIDLock.Unlock()
  83. stream := newStream(sid, s.config.MaxFrameSize, s)
  84. if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil {
  85. return nil, errors.Wrap(err, "writeFrame")
  86. }
  87. s.streamLock.Lock()
  88. s.streams[sid] = stream
  89. s.streamLock.Unlock()
  90. return stream, nil
  91. }
  92. // AcceptStream is used to block until the next available stream
  93. // is ready to be accepted.
  94. func (s *Session) AcceptStream() (*Stream, error) {
  95. var deadline <-chan time.Time
  96. if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() {
  97. timer := time.NewTimer(d.Sub(time.Now()))
  98. defer timer.Stop()
  99. deadline = timer.C
  100. }
  101. select {
  102. case stream := <-s.chAccepts:
  103. return stream, nil
  104. case <-deadline:
  105. return nil, errTimeout
  106. case <-s.die:
  107. return nil, errors.New(errBrokenPipe)
  108. }
  109. }
  110. // Close is used to close the session and all streams.
  111. func (s *Session) Close() (err error) {
  112. s.dieLock.Lock()
  113. select {
  114. case <-s.die:
  115. s.dieLock.Unlock()
  116. return errors.New(errBrokenPipe)
  117. default:
  118. close(s.die)
  119. s.dieLock.Unlock()
  120. s.streamLock.Lock()
  121. for k := range s.streams {
  122. s.streams[k].sessionClose()
  123. }
  124. s.streamLock.Unlock()
  125. s.notifyBucket()
  126. return s.conn.Close()
  127. }
  128. }
  129. // notifyBucket notifies recvLoop that bucket is available
  130. func (s *Session) notifyBucket() {
  131. select {
  132. case s.bucketNotify <- struct{}{}:
  133. default:
  134. }
  135. }
  136. // IsClosed does a safe check to see if we have shutdown
  137. func (s *Session) IsClosed() bool {
  138. select {
  139. case <-s.die:
  140. return true
  141. default:
  142. return false
  143. }
  144. }
  145. // NumStreams returns the number of currently open streams
  146. func (s *Session) NumStreams() int {
  147. if s.IsClosed() {
  148. return 0
  149. }
  150. s.streamLock.Lock()
  151. defer s.streamLock.Unlock()
  152. return len(s.streams)
  153. }
  154. // SetDeadline sets a deadline used by Accept* calls.
  155. // A zero time value disables the deadline.
  156. func (s *Session) SetDeadline(t time.Time) error {
  157. s.deadline.Store(t)
  158. return nil
  159. }
  160. // notify the session that a stream has closed
  161. func (s *Session) streamClosed(sid uint32) {
  162. s.streamLock.Lock()
  163. if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket
  164. if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
  165. s.notifyBucket()
  166. }
  167. }
  168. delete(s.streams, sid)
  169. s.streamLock.Unlock()
  170. }
  171. // returnTokens is called by stream to return token after read
  172. func (s *Session) returnTokens(n int) {
  173. if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
  174. s.notifyBucket()
  175. }
  176. }
  177. // session read a frame from underlying connection
  178. // it's data is pointed to the input buffer
  179. func (s *Session) readFrame(buffer []byte) (f Frame, err error) {
  180. if _, err := io.ReadFull(s.conn, buffer[:headerSize]); err != nil {
  181. return f, errors.Wrap(err, "readFrame")
  182. }
  183. dec := rawHeader(buffer)
  184. if dec.Version() != version {
  185. return f, errors.New(errInvalidProtocol)
  186. }
  187. f.ver = dec.Version()
  188. f.cmd = dec.Cmd()
  189. f.sid = dec.StreamID()
  190. if length := dec.Length(); length > 0 {
  191. if _, err := io.ReadFull(s.conn, buffer[headerSize:headerSize+length]); err != nil {
  192. return f, errors.Wrap(err, "readFrame")
  193. }
  194. f.data = buffer[headerSize : headerSize+length]
  195. }
  196. return f, nil
  197. }
  198. // recvLoop keeps on reading from underlying connection if tokens are available
  199. func (s *Session) recvLoop() {
  200. buffer := make([]byte, (1<<16)+headerSize)
  201. for {
  202. for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() {
  203. <-s.bucketNotify
  204. }
  205. if f, err := s.readFrame(buffer); err == nil {
  206. atomic.StoreInt32(&s.dataReady, 1)
  207. switch f.cmd {
  208. case cmdNOP:
  209. case cmdSYN:
  210. s.streamLock.Lock()
  211. if _, ok := s.streams[f.sid]; !ok {
  212. stream := newStream(f.sid, s.config.MaxFrameSize, s)
  213. s.streams[f.sid] = stream
  214. select {
  215. case s.chAccepts <- stream:
  216. case <-s.die:
  217. }
  218. }
  219. s.streamLock.Unlock()
  220. case cmdFIN:
  221. s.streamLock.Lock()
  222. if stream, ok := s.streams[f.sid]; ok {
  223. stream.markRST()
  224. stream.notifyReadEvent()
  225. }
  226. s.streamLock.Unlock()
  227. case cmdPSH:
  228. s.streamLock.Lock()
  229. if stream, ok := s.streams[f.sid]; ok {
  230. atomic.AddInt32(&s.bucket, -int32(len(f.data)))
  231. stream.pushBytes(f.data)
  232. stream.notifyReadEvent()
  233. }
  234. s.streamLock.Unlock()
  235. default:
  236. s.Close()
  237. return
  238. }
  239. } else {
  240. s.Close()
  241. return
  242. }
  243. }
  244. }
  245. func (s *Session) keepalive() {
  246. tickerPing := time.NewTicker(s.config.KeepAliveInterval)
  247. tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout)
  248. defer tickerPing.Stop()
  249. defer tickerTimeout.Stop()
  250. for {
  251. select {
  252. case <-tickerPing.C:
  253. s.writeFrame(newFrame(cmdNOP, 0))
  254. s.notifyBucket() // force a signal to the recvLoop
  255. case <-tickerTimeout.C:
  256. if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) {
  257. s.Close()
  258. return
  259. }
  260. case <-s.die:
  261. return
  262. }
  263. }
  264. }
  265. func (s *Session) sendLoop() {
  266. buf := make([]byte, (1<<16)+headerSize)
  267. for {
  268. select {
  269. case <-s.die:
  270. return
  271. case request, ok := <-s.writes:
  272. if !ok {
  273. continue
  274. }
  275. buf[0] = request.frame.ver
  276. buf[1] = request.frame.cmd
  277. binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data)))
  278. binary.LittleEndian.PutUint32(buf[4:], request.frame.sid)
  279. copy(buf[headerSize:], request.frame.data)
  280. n, err := s.conn.Write(buf[:headerSize+len(request.frame.data)])
  281. n -= headerSize
  282. if n < 0 {
  283. n = 0
  284. }
  285. result := writeResult{
  286. n: n,
  287. err: err,
  288. }
  289. request.result <- result
  290. close(request.result)
  291. }
  292. }
  293. }
  294. // writeFrame writes the frame to the underlying connection
  295. // and returns the number of bytes written if successful
  296. func (s *Session) writeFrame(f Frame) (n int, err error) {
  297. req := writeRequest{
  298. frame: f,
  299. result: make(chan writeResult, 1),
  300. }
  301. select {
  302. case <-s.die:
  303. return 0, errors.New(errBrokenPipe)
  304. case s.writes <- req:
  305. }
  306. result := <-req.result
  307. return result.n, result.err
  308. }