123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353 |
- package smux
- import (
- "encoding/binary"
- "io"
- "sync"
- "sync/atomic"
- "time"
- "github.com/pkg/errors"
- )
- const (
- defaultAcceptBacklog = 1024
- )
- const (
- errBrokenPipe = "broken pipe"
- errInvalidProtocol = "invalid protocol version"
- errGoAway = "stream id overflows, should start a new connection"
- )
- type writeRequest struct {
- frame Frame
- result chan writeResult
- }
- type writeResult struct {
- n int
- err error
- }
- // Session defines a multiplexed connection for streams
- type Session struct {
- conn io.ReadWriteCloser
- config *Config
- nextStreamID uint32 // next stream identifier
- nextStreamIDLock sync.Mutex
- bucket int32 // token bucket
- bucketNotify chan struct{} // used for waiting for tokens
- streams map[uint32]*Stream // all streams in this session
- streamLock sync.Mutex // locks streams
- die chan struct{} // flag session has died
- dieLock sync.Mutex
- chAccepts chan *Stream
- dataReady int32 // flag data has arrived
- goAway int32 // flag id exhausted
- deadline atomic.Value
- writes chan writeRequest
- }
- func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
- s := new(Session)
- s.die = make(chan struct{})
- s.conn = conn
- s.config = config
- s.streams = make(map[uint32]*Stream)
- s.chAccepts = make(chan *Stream, defaultAcceptBacklog)
- s.bucket = int32(config.MaxReceiveBuffer)
- s.bucketNotify = make(chan struct{}, 1)
- s.writes = make(chan writeRequest)
- if client {
- s.nextStreamID = 1
- } else {
- s.nextStreamID = 0
- }
- go s.recvLoop()
- go s.sendLoop()
- go s.keepalive()
- return s
- }
- // OpenStream is used to create a new stream
- func (s *Session) OpenStream() (*Stream, error) {
- if s.IsClosed() {
- return nil, errors.New(errBrokenPipe)
- }
- // generate stream id
- s.nextStreamIDLock.Lock()
- if s.goAway > 0 {
- s.nextStreamIDLock.Unlock()
- return nil, errors.New(errGoAway)
- }
- s.nextStreamID += 2
- sid := s.nextStreamID
- if sid == sid%2 { // stream-id overflows
- s.goAway = 1
- s.nextStreamIDLock.Unlock()
- return nil, errors.New(errGoAway)
- }
- s.nextStreamIDLock.Unlock()
- stream := newStream(sid, s.config.MaxFrameSize, s)
- if _, err := s.writeFrame(newFrame(cmdSYN, sid)); err != nil {
- return nil, errors.Wrap(err, "writeFrame")
- }
- s.streamLock.Lock()
- s.streams[sid] = stream
- s.streamLock.Unlock()
- return stream, nil
- }
- // AcceptStream is used to block until the next available stream
- // is ready to be accepted.
- func (s *Session) AcceptStream() (*Stream, error) {
- var deadline <-chan time.Time
- if d, ok := s.deadline.Load().(time.Time); ok && !d.IsZero() {
- timer := time.NewTimer(d.Sub(time.Now()))
- defer timer.Stop()
- deadline = timer.C
- }
- select {
- case stream := <-s.chAccepts:
- return stream, nil
- case <-deadline:
- return nil, errTimeout
- case <-s.die:
- return nil, errors.New(errBrokenPipe)
- }
- }
- // Close is used to close the session and all streams.
- func (s *Session) Close() (err error) {
- s.dieLock.Lock()
- select {
- case <-s.die:
- s.dieLock.Unlock()
- return errors.New(errBrokenPipe)
- default:
- close(s.die)
- s.dieLock.Unlock()
- s.streamLock.Lock()
- for k := range s.streams {
- s.streams[k].sessionClose()
- }
- s.streamLock.Unlock()
- s.notifyBucket()
- return s.conn.Close()
- }
- }
- // notifyBucket notifies recvLoop that bucket is available
- func (s *Session) notifyBucket() {
- select {
- case s.bucketNotify <- struct{}{}:
- default:
- }
- }
- // IsClosed does a safe check to see if we have shutdown
- func (s *Session) IsClosed() bool {
- select {
- case <-s.die:
- return true
- default:
- return false
- }
- }
- // NumStreams returns the number of currently open streams
- func (s *Session) NumStreams() int {
- if s.IsClosed() {
- return 0
- }
- s.streamLock.Lock()
- defer s.streamLock.Unlock()
- return len(s.streams)
- }
- // SetDeadline sets a deadline used by Accept* calls.
- // A zero time value disables the deadline.
- func (s *Session) SetDeadline(t time.Time) error {
- s.deadline.Store(t)
- return nil
- }
- // notify the session that a stream has closed
- func (s *Session) streamClosed(sid uint32) {
- s.streamLock.Lock()
- if n := s.streams[sid].recycleTokens(); n > 0 { // return remaining tokens to the bucket
- if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
- s.notifyBucket()
- }
- }
- delete(s.streams, sid)
- s.streamLock.Unlock()
- }
- // returnTokens is called by stream to return token after read
- func (s *Session) returnTokens(n int) {
- if atomic.AddInt32(&s.bucket, int32(n)) > 0 {
- s.notifyBucket()
- }
- }
- // session read a frame from underlying connection
- // it's data is pointed to the input buffer
- func (s *Session) readFrame(buffer []byte) (f Frame, err error) {
- if _, err := io.ReadFull(s.conn, buffer[:headerSize]); err != nil {
- return f, errors.Wrap(err, "readFrame")
- }
- dec := rawHeader(buffer)
- if dec.Version() != version {
- return f, errors.New(errInvalidProtocol)
- }
- f.ver = dec.Version()
- f.cmd = dec.Cmd()
- f.sid = dec.StreamID()
- if length := dec.Length(); length > 0 {
- if _, err := io.ReadFull(s.conn, buffer[headerSize:headerSize+length]); err != nil {
- return f, errors.Wrap(err, "readFrame")
- }
- f.data = buffer[headerSize : headerSize+length]
- }
- return f, nil
- }
- // recvLoop keeps on reading from underlying connection if tokens are available
- func (s *Session) recvLoop() {
- buffer := make([]byte, (1<<16)+headerSize)
- for {
- for atomic.LoadInt32(&s.bucket) <= 0 && !s.IsClosed() {
- <-s.bucketNotify
- }
- if f, err := s.readFrame(buffer); err == nil {
- atomic.StoreInt32(&s.dataReady, 1)
- switch f.cmd {
- case cmdNOP:
- case cmdSYN:
- s.streamLock.Lock()
- if _, ok := s.streams[f.sid]; !ok {
- stream := newStream(f.sid, s.config.MaxFrameSize, s)
- s.streams[f.sid] = stream
- select {
- case s.chAccepts <- stream:
- case <-s.die:
- }
- }
- s.streamLock.Unlock()
- case cmdFIN:
- s.streamLock.Lock()
- if stream, ok := s.streams[f.sid]; ok {
- stream.markRST()
- stream.notifyReadEvent()
- }
- s.streamLock.Unlock()
- case cmdPSH:
- s.streamLock.Lock()
- if stream, ok := s.streams[f.sid]; ok {
- atomic.AddInt32(&s.bucket, -int32(len(f.data)))
- stream.pushBytes(f.data)
- stream.notifyReadEvent()
- }
- s.streamLock.Unlock()
- default:
- s.Close()
- return
- }
- } else {
- s.Close()
- return
- }
- }
- }
- func (s *Session) keepalive() {
- tickerPing := time.NewTicker(s.config.KeepAliveInterval)
- tickerTimeout := time.NewTicker(s.config.KeepAliveTimeout)
- defer tickerPing.Stop()
- defer tickerTimeout.Stop()
- for {
- select {
- case <-tickerPing.C:
- s.writeFrame(newFrame(cmdNOP, 0))
- s.notifyBucket() // force a signal to the recvLoop
- case <-tickerTimeout.C:
- if !atomic.CompareAndSwapInt32(&s.dataReady, 1, 0) {
- s.Close()
- return
- }
- case <-s.die:
- return
- }
- }
- }
- func (s *Session) sendLoop() {
- buf := make([]byte, (1<<16)+headerSize)
- for {
- select {
- case <-s.die:
- return
- case request, ok := <-s.writes:
- if !ok {
- continue
- }
- buf[0] = request.frame.ver
- buf[1] = request.frame.cmd
- binary.LittleEndian.PutUint16(buf[2:], uint16(len(request.frame.data)))
- binary.LittleEndian.PutUint32(buf[4:], request.frame.sid)
- copy(buf[headerSize:], request.frame.data)
- n, err := s.conn.Write(buf[:headerSize+len(request.frame.data)])
- n -= headerSize
- if n < 0 {
- n = 0
- }
- result := writeResult{
- n: n,
- err: err,
- }
- request.result <- result
- close(request.result)
- }
- }
- }
- // writeFrame writes the frame to the underlying connection
- // and returns the number of bytes written if successful
- func (s *Session) writeFrame(f Frame) (n int, err error) {
- req := writeRequest{
- frame: f,
- result: make(chan writeResult, 1),
- }
- select {
- case <-s.die:
- return 0, errors.New(errBrokenPipe)
- case s.writes <- req:
- }
- result := <-req.result
- return result.n, result.err
- }
|