123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261 |
- package smux
- import (
- "bytes"
- "io"
- "net"
- "sync"
- "sync/atomic"
- "time"
- "github.com/pkg/errors"
- )
- // Stream implements net.Conn
- type Stream struct {
- id uint32
- rstflag int32
- sess *Session
- buffer bytes.Buffer
- bufferLock sync.Mutex
- frameSize int
- chReadEvent chan struct{} // notify a read event
- die chan struct{} // flag the stream has closed
- dieLock sync.Mutex
- readDeadline atomic.Value
- writeDeadline atomic.Value
- }
- // newStream initiates a Stream struct
- func newStream(id uint32, frameSize int, sess *Session) *Stream {
- s := new(Stream)
- s.id = id
- s.chReadEvent = make(chan struct{}, 1)
- s.frameSize = frameSize
- s.sess = sess
- s.die = make(chan struct{})
- return s
- }
- // ID returns the unique stream ID.
- func (s *Stream) ID() uint32 {
- return s.id
- }
- // Read implements net.Conn
- func (s *Stream) Read(b []byte) (n int, err error) {
- var deadline <-chan time.Time
- if d, ok := s.readDeadline.Load().(time.Time); ok && !d.IsZero() {
- timer := time.NewTimer(d.Sub(time.Now()))
- defer timer.Stop()
- deadline = timer.C
- }
- READ:
- select {
- case <-s.die:
- return 0, errors.New(errBrokenPipe)
- case <-deadline:
- return n, errTimeout
- default:
- }
- s.bufferLock.Lock()
- n, err = s.buffer.Read(b)
- s.bufferLock.Unlock()
- if n > 0 {
- s.sess.returnTokens(n)
- return n, nil
- } else if atomic.LoadInt32(&s.rstflag) == 1 {
- _ = s.Close()
- return 0, io.EOF
- }
- select {
- case <-s.chReadEvent:
- goto READ
- case <-deadline:
- return n, errTimeout
- case <-s.die:
- return 0, errors.New(errBrokenPipe)
- }
- }
- // Write implements net.Conn
- func (s *Stream) Write(b []byte) (n int, err error) {
- var deadline <-chan time.Time
- if d, ok := s.writeDeadline.Load().(time.Time); ok && !d.IsZero() {
- timer := time.NewTimer(d.Sub(time.Now()))
- defer timer.Stop()
- deadline = timer.C
- }
- select {
- case <-s.die:
- return 0, errors.New(errBrokenPipe)
- default:
- }
- frames := s.split(b, cmdPSH, s.id)
- sent := 0
- for k := range frames {
- req := writeRequest{
- frame: frames[k],
- result: make(chan writeResult, 1),
- }
- select {
- case s.sess.writes <- req:
- case <-s.die:
- return sent, errors.New(errBrokenPipe)
- case <-deadline:
- return sent, errTimeout
- }
- select {
- case result := <-req.result:
- sent += result.n
- if result.err != nil {
- return sent, result.err
- }
- case <-s.die:
- return sent, errors.New(errBrokenPipe)
- case <-deadline:
- return sent, errTimeout
- }
- }
- return sent, nil
- }
- // Close implements net.Conn
- func (s *Stream) Close() error {
- s.dieLock.Lock()
- select {
- case <-s.die:
- s.dieLock.Unlock()
- return errors.New(errBrokenPipe)
- default:
- close(s.die)
- s.dieLock.Unlock()
- s.sess.streamClosed(s.id)
- _, err := s.sess.writeFrame(newFrame(cmdFIN, s.id))
- return err
- }
- }
- // SetReadDeadline sets the read deadline as defined by
- // net.Conn.SetReadDeadline.
- // A zero time value disables the deadline.
- func (s *Stream) SetReadDeadline(t time.Time) error {
- s.readDeadline.Store(t)
- return nil
- }
- // SetWriteDeadline sets the write deadline as defined by
- // net.Conn.SetWriteDeadline.
- // A zero time value disables the deadline.
- func (s *Stream) SetWriteDeadline(t time.Time) error {
- s.writeDeadline.Store(t)
- return nil
- }
- // SetDeadline sets both read and write deadlines as defined by
- // net.Conn.SetDeadline.
- // A zero time value disables the deadlines.
- func (s *Stream) SetDeadline(t time.Time) error {
- if err := s.SetReadDeadline(t); err != nil {
- return err
- }
- if err := s.SetWriteDeadline(t); err != nil {
- return err
- }
- return nil
- }
- // session closes the stream
- func (s *Stream) sessionClose() {
- s.dieLock.Lock()
- defer s.dieLock.Unlock()
- select {
- case <-s.die:
- default:
- close(s.die)
- }
- }
- // LocalAddr satisfies net.Conn interface
- func (s *Stream) LocalAddr() net.Addr {
- if ts, ok := s.sess.conn.(interface {
- LocalAddr() net.Addr
- }); ok {
- return ts.LocalAddr()
- }
- return nil
- }
- // RemoteAddr satisfies net.Conn interface
- func (s *Stream) RemoteAddr() net.Addr {
- if ts, ok := s.sess.conn.(interface {
- RemoteAddr() net.Addr
- }); ok {
- return ts.RemoteAddr()
- }
- return nil
- }
- // pushBytes a slice into buffer
- func (s *Stream) pushBytes(p []byte) {
- s.bufferLock.Lock()
- s.buffer.Write(p)
- s.bufferLock.Unlock()
- }
- // recycleTokens transform remaining bytes to tokens(will truncate buffer)
- func (s *Stream) recycleTokens() (n int) {
- s.bufferLock.Lock()
- n = s.buffer.Len()
- s.buffer.Reset()
- s.bufferLock.Unlock()
- return
- }
- // split large byte buffer into smaller frames, reference only
- func (s *Stream) split(bts []byte, cmd byte, sid uint32) []Frame {
- frames := make([]Frame, 0, len(bts)/s.frameSize+1)
- for len(bts) > s.frameSize {
- frame := newFrame(cmd, sid)
- frame.data = bts[:s.frameSize]
- bts = bts[s.frameSize:]
- frames = append(frames, frame)
- }
- if len(bts) > 0 {
- frame := newFrame(cmd, sid)
- frame.data = bts
- frames = append(frames, frame)
- }
- return frames
- }
- // notify read event
- func (s *Stream) notifyReadEvent() {
- select {
- case s.chReadEvent <- struct{}{}:
- default:
- }
- }
- // mark this stream has been reset
- func (s *Stream) markRST() {
- atomic.StoreInt32(&s.rstflag, 1)
- }
- var errTimeout error = &timeoutError{}
- type timeoutError struct{}
- func (e *timeoutError) Error() string { return "i/o timeout" }
- func (e *timeoutError) Timeout() bool { return true }
- func (e *timeoutError) Temporary() bool { return true }
|