123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496 |
- // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package websocket
- import (
- "bufio"
- "bytes"
- "errors"
- "fmt"
- "io"
- "io/ioutil"
- "net"
- "reflect"
- "testing"
- "testing/iotest"
- "time"
- )
- var _ net.Error = errWriteTimeout
- type fakeNetConn struct {
- io.Reader
- io.Writer
- }
- func (c fakeNetConn) Close() error { return nil }
- func (c fakeNetConn) LocalAddr() net.Addr { return localAddr }
- func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr }
- func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
- func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
- func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
- type fakeAddr int
- var (
- localAddr = fakeAddr(1)
- remoteAddr = fakeAddr(2)
- )
- func (a fakeAddr) Network() string {
- return "net"
- }
- func (a fakeAddr) String() string {
- return "str"
- }
- func TestFraming(t *testing.T) {
- frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
- var readChunkers = []struct {
- name string
- f func(io.Reader) io.Reader
- }{
- {"half", iotest.HalfReader},
- {"one", iotest.OneByteReader},
- {"asis", func(r io.Reader) io.Reader { return r }},
- }
- writeBuf := make([]byte, 65537)
- for i := range writeBuf {
- writeBuf[i] = byte(i)
- }
- var writers = []struct {
- name string
- f func(w io.Writer, n int) (int, error)
- }{
- {"iocopy", func(w io.Writer, n int) (int, error) {
- nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
- return int(nn), err
- }},
- {"write", func(w io.Writer, n int) (int, error) {
- return w.Write(writeBuf[:n])
- }},
- {"string", func(w io.Writer, n int) (int, error) {
- return io.WriteString(w, string(writeBuf[:n]))
- }},
- }
- for _, compress := range []bool{false, true} {
- for _, isServer := range []bool{true, false} {
- for _, chunker := range readChunkers {
- var connBuf bytes.Buffer
- wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
- rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
- if compress {
- wc.newCompressionWriter = compressNoContextTakeover
- rc.newDecompressionReader = decompressNoContextTakeover
- }
- for _, n := range frameSizes {
- for _, writer := range writers {
- name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
- w, err := wc.NextWriter(TextMessage)
- if err != nil {
- t.Errorf("%s: wc.NextWriter() returned %v", name, err)
- continue
- }
- nn, err := writer.f(w, n)
- if err != nil || nn != n {
- t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
- continue
- }
- err = w.Close()
- if err != nil {
- t.Errorf("%s: w.Close() returned %v", name, err)
- continue
- }
- opCode, r, err := rc.NextReader()
- if err != nil || opCode != TextMessage {
- t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
- continue
- }
- rbuf, err := ioutil.ReadAll(r)
- if err != nil {
- t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
- continue
- }
- if len(rbuf) != n {
- t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
- continue
- }
- for i, b := range rbuf {
- if byte(i) != b {
- t.Errorf("%s: bad byte at offset %d", name, i)
- break
- }
- }
- }
- }
- }
- }
- }
- }
- func TestControl(t *testing.T) {
- const message = "this is a ping/pong messsage"
- for _, isServer := range []bool{true, false} {
- for _, isWriteControl := range []bool{true, false} {
- name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
- var connBuf bytes.Buffer
- wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
- rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024)
- if isWriteControl {
- wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
- } else {
- w, err := wc.NextWriter(PongMessage)
- if err != nil {
- t.Errorf("%s: wc.NextWriter() returned %v", name, err)
- continue
- }
- if _, err := w.Write([]byte(message)); err != nil {
- t.Errorf("%s: w.Write() returned %v", name, err)
- continue
- }
- if err := w.Close(); err != nil {
- t.Errorf("%s: w.Close() returned %v", name, err)
- continue
- }
- var actualMessage string
- rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
- rc.NextReader()
- if actualMessage != message {
- t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
- continue
- }
- }
- }
- }
- }
- func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
- const bufSize = 512
- expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
- var b1, b2 bytes.Buffer
- wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
- rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
- w, _ := wc.NextWriter(BinaryMessage)
- w.Write(make([]byte, bufSize+bufSize/2))
- wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
- w.Close()
- op, r, err := rc.NextReader()
- if op != BinaryMessage || err != nil {
- t.Fatalf("NextReader() returned %d, %v", op, err)
- }
- _, err = io.Copy(ioutil.Discard, r)
- if !reflect.DeepEqual(err, expectedErr) {
- t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
- }
- _, _, err = rc.NextReader()
- if !reflect.DeepEqual(err, expectedErr) {
- t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
- }
- }
- func TestEOFWithinFrame(t *testing.T) {
- const bufSize = 64
- for n := 0; ; n++ {
- var b bytes.Buffer
- wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
- rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
- w, _ := wc.NextWriter(BinaryMessage)
- w.Write(make([]byte, bufSize))
- w.Close()
- if n >= b.Len() {
- break
- }
- b.Truncate(n)
- op, r, err := rc.NextReader()
- if err == errUnexpectedEOF {
- continue
- }
- if op != BinaryMessage || err != nil {
- t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
- }
- _, err = io.Copy(ioutil.Discard, r)
- if err != errUnexpectedEOF {
- t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
- }
- _, _, err = rc.NextReader()
- if err != errUnexpectedEOF {
- t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
- }
- }
- }
- func TestEOFBeforeFinalFrame(t *testing.T) {
- const bufSize = 512
- var b1, b2 bytes.Buffer
- wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
- rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
- w, _ := wc.NextWriter(BinaryMessage)
- w.Write(make([]byte, bufSize+bufSize/2))
- op, r, err := rc.NextReader()
- if op != BinaryMessage || err != nil {
- t.Fatalf("NextReader() returned %d, %v", op, err)
- }
- _, err = io.Copy(ioutil.Discard, r)
- if err != errUnexpectedEOF {
- t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
- }
- _, _, err = rc.NextReader()
- if err != errUnexpectedEOF {
- t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
- }
- }
- func TestWriteAfterMessageWriterClose(t *testing.T) {
- wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024)
- w, _ := wc.NextWriter(BinaryMessage)
- io.WriteString(w, "hello")
- if err := w.Close(); err != nil {
- t.Fatalf("unxpected error closing message writer, %v", err)
- }
- if _, err := io.WriteString(w, "world"); err == nil {
- t.Fatalf("no error writing after close")
- }
- w, _ = wc.NextWriter(BinaryMessage)
- io.WriteString(w, "hello")
- // close w by getting next writer
- _, err := wc.NextWriter(BinaryMessage)
- if err != nil {
- t.Fatalf("unexpected error getting next writer, %v", err)
- }
- if _, err := io.WriteString(w, "world"); err == nil {
- t.Fatalf("no error writing after close")
- }
- }
- func TestReadLimit(t *testing.T) {
- const readLimit = 512
- message := make([]byte, readLimit+1)
- var b1, b2 bytes.Buffer
- wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2)
- rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
- rc.SetReadLimit(readLimit)
- // Send message at the limit with interleaved pong.
- w, _ := wc.NextWriter(BinaryMessage)
- w.Write(message[:readLimit-1])
- wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
- w.Write(message[:1])
- w.Close()
- // Send message larger than the limit.
- wc.WriteMessage(BinaryMessage, message[:readLimit+1])
- op, _, err := rc.NextReader()
- if op != BinaryMessage || err != nil {
- t.Fatalf("1: NextReader() returned %d, %v", op, err)
- }
- op, r, err := rc.NextReader()
- if op != BinaryMessage || err != nil {
- t.Fatalf("2: NextReader() returned %d, %v", op, err)
- }
- _, err = io.Copy(ioutil.Discard, r)
- if err != ErrReadLimit {
- t.Fatalf("io.Copy() returned %v", err)
- }
- }
- func TestAddrs(t *testing.T) {
- c := newConn(&fakeNetConn{}, true, 1024, 1024)
- if c.LocalAddr() != localAddr {
- t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
- }
- if c.RemoteAddr() != remoteAddr {
- t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
- }
- }
- func TestUnderlyingConn(t *testing.T) {
- var b1, b2 bytes.Buffer
- fc := fakeNetConn{Reader: &b1, Writer: &b2}
- c := newConn(fc, true, 1024, 1024)
- ul := c.UnderlyingConn()
- if ul != fc {
- t.Fatalf("Underlying conn is not what it should be.")
- }
- }
- func TestBufioReadBytes(t *testing.T) {
- // Test calling bufio.ReadBytes for value longer than read buffer size.
- m := make([]byte, 512)
- m[len(m)-1] = '\n'
- var b1, b2 bytes.Buffer
- wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64)
- rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64)
- w, _ := wc.NextWriter(BinaryMessage)
- w.Write(m)
- w.Close()
- op, r, err := rc.NextReader()
- if op != BinaryMessage || err != nil {
- t.Fatalf("NextReader() returned %d, %v", op, err)
- }
- br := bufio.NewReader(r)
- p, err := br.ReadBytes('\n')
- if err != nil {
- t.Fatalf("ReadBytes() returned %v", err)
- }
- if len(p) != len(m) {
- t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m))
- }
- }
- var closeErrorTests = []struct {
- err error
- codes []int
- ok bool
- }{
- {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
- {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
- {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
- {errors.New("hello"), []int{CloseNormalClosure}, false},
- }
- func TestCloseError(t *testing.T) {
- for _, tt := range closeErrorTests {
- ok := IsCloseError(tt.err, tt.codes...)
- if ok != tt.ok {
- t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
- }
- }
- }
- var unexpectedCloseErrorTests = []struct {
- err error
- codes []int
- ok bool
- }{
- {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
- {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
- {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
- {errors.New("hello"), []int{CloseNormalClosure}, false},
- }
- func TestUnexpectedCloseErrors(t *testing.T) {
- for _, tt := range unexpectedCloseErrorTests {
- ok := IsUnexpectedCloseError(tt.err, tt.codes...)
- if ok != tt.ok {
- t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
- }
- }
- }
- type blockingWriter struct {
- c1, c2 chan struct{}
- }
- func (w blockingWriter) Write(p []byte) (int, error) {
- // Allow main to continue
- close(w.c1)
- // Wait for panic in main
- <-w.c2
- return len(p), nil
- }
- func TestConcurrentWritePanic(t *testing.T) {
- w := blockingWriter{make(chan struct{}), make(chan struct{})}
- c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
- go func() {
- c.WriteMessage(TextMessage, []byte{})
- }()
- // wait for goroutine to block in write.
- <-w.c1
- defer func() {
- close(w.c2)
- if v := recover(); v != nil {
- return
- }
- }()
- c.WriteMessage(TextMessage, []byte{})
- t.Fatal("should not get here")
- }
- type failingReader struct{}
- func (r failingReader) Read(p []byte) (int, error) {
- return 0, io.EOF
- }
- func TestFailedConnectionReadPanic(t *testing.T) {
- c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024)
- defer func() {
- if v := recover(); v != nil {
- return
- }
- }()
- for i := 0; i < 20000; i++ {
- c.ReadMessage()
- }
- t.Fatal("should not get here")
- }
- func TestBufioReuse(t *testing.T) {
- brw := bufio.NewReadWriter(bufio.NewReader(nil), bufio.NewWriter(nil))
- c := newConnBRW(nil, false, 0, 0, brw)
- if c.br != brw.Reader {
- t.Error("connection did not reuse bufio.Reader")
- }
- var wh writeHook
- brw.Writer.Reset(&wh)
- brw.WriteByte(0)
- brw.Flush()
- if &c.writeBuf[0] != &wh.p[0] {
- t.Error("connection did not reuse bufio.Writer")
- }
- brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 0), bufio.NewWriterSize(nil, 0))
- c = newConnBRW(nil, false, 0, 0, brw)
- if c.br == brw.Reader {
- t.Error("connection used bufio.Reader with small size")
- }
- brw.Writer.Reset(&wh)
- brw.WriteByte(0)
- brw.Flush()
- if &c.writeBuf[0] != &wh.p[0] {
- t.Error("connection used bufio.Writer with small size")
- }
- }
|