conn_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package websocket
  5. import (
  6. "bufio"
  7. "bytes"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "io/ioutil"
  12. "net"
  13. "reflect"
  14. "testing"
  15. "testing/iotest"
  16. "time"
  17. )
  18. var _ net.Error = errWriteTimeout
  19. type fakeNetConn struct {
  20. io.Reader
  21. io.Writer
  22. }
  23. func (c fakeNetConn) Close() error { return nil }
  24. func (c fakeNetConn) LocalAddr() net.Addr { return localAddr }
  25. func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr }
  26. func (c fakeNetConn) SetDeadline(t time.Time) error { return nil }
  27. func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil }
  28. func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil }
  29. type fakeAddr int
  30. var (
  31. localAddr = fakeAddr(1)
  32. remoteAddr = fakeAddr(2)
  33. )
  34. func (a fakeAddr) Network() string {
  35. return "net"
  36. }
  37. func (a fakeAddr) String() string {
  38. return "str"
  39. }
  40. func TestFraming(t *testing.T) {
  41. frameSizes := []int{0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, 65536, 65537}
  42. var readChunkers = []struct {
  43. name string
  44. f func(io.Reader) io.Reader
  45. }{
  46. {"half", iotest.HalfReader},
  47. {"one", iotest.OneByteReader},
  48. {"asis", func(r io.Reader) io.Reader { return r }},
  49. }
  50. writeBuf := make([]byte, 65537)
  51. for i := range writeBuf {
  52. writeBuf[i] = byte(i)
  53. }
  54. var writers = []struct {
  55. name string
  56. f func(w io.Writer, n int) (int, error)
  57. }{
  58. {"iocopy", func(w io.Writer, n int) (int, error) {
  59. nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n]))
  60. return int(nn), err
  61. }},
  62. {"write", func(w io.Writer, n int) (int, error) {
  63. return w.Write(writeBuf[:n])
  64. }},
  65. {"string", func(w io.Writer, n int) (int, error) {
  66. return io.WriteString(w, string(writeBuf[:n]))
  67. }},
  68. }
  69. for _, compress := range []bool{false, true} {
  70. for _, isServer := range []bool{true, false} {
  71. for _, chunker := range readChunkers {
  72. var connBuf bytes.Buffer
  73. wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
  74. rc := newConn(fakeNetConn{Reader: chunker.f(&connBuf), Writer: nil}, !isServer, 1024, 1024)
  75. if compress {
  76. wc.newCompressionWriter = compressNoContextTakeover
  77. rc.newDecompressionReader = decompressNoContextTakeover
  78. }
  79. for _, n := range frameSizes {
  80. for _, writer := range writers {
  81. name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name)
  82. w, err := wc.NextWriter(TextMessage)
  83. if err != nil {
  84. t.Errorf("%s: wc.NextWriter() returned %v", name, err)
  85. continue
  86. }
  87. nn, err := writer.f(w, n)
  88. if err != nil || nn != n {
  89. t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err)
  90. continue
  91. }
  92. err = w.Close()
  93. if err != nil {
  94. t.Errorf("%s: w.Close() returned %v", name, err)
  95. continue
  96. }
  97. opCode, r, err := rc.NextReader()
  98. if err != nil || opCode != TextMessage {
  99. t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err)
  100. continue
  101. }
  102. rbuf, err := ioutil.ReadAll(r)
  103. if err != nil {
  104. t.Errorf("%s: ReadFull() returned rbuf, %v", name, err)
  105. continue
  106. }
  107. if len(rbuf) != n {
  108. t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n)
  109. continue
  110. }
  111. for i, b := range rbuf {
  112. if byte(i) != b {
  113. t.Errorf("%s: bad byte at offset %d", name, i)
  114. break
  115. }
  116. }
  117. }
  118. }
  119. }
  120. }
  121. }
  122. }
  123. func TestControl(t *testing.T) {
  124. const message = "this is a ping/pong messsage"
  125. for _, isServer := range []bool{true, false} {
  126. for _, isWriteControl := range []bool{true, false} {
  127. name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl)
  128. var connBuf bytes.Buffer
  129. wc := newConn(fakeNetConn{Reader: nil, Writer: &connBuf}, isServer, 1024, 1024)
  130. rc := newConn(fakeNetConn{Reader: &connBuf, Writer: nil}, !isServer, 1024, 1024)
  131. if isWriteControl {
  132. wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second))
  133. } else {
  134. w, err := wc.NextWriter(PongMessage)
  135. if err != nil {
  136. t.Errorf("%s: wc.NextWriter() returned %v", name, err)
  137. continue
  138. }
  139. if _, err := w.Write([]byte(message)); err != nil {
  140. t.Errorf("%s: w.Write() returned %v", name, err)
  141. continue
  142. }
  143. if err := w.Close(); err != nil {
  144. t.Errorf("%s: w.Close() returned %v", name, err)
  145. continue
  146. }
  147. var actualMessage string
  148. rc.SetPongHandler(func(s string) error { actualMessage = s; return nil })
  149. rc.NextReader()
  150. if actualMessage != message {
  151. t.Errorf("%s: pong=%q, want %q", name, actualMessage, message)
  152. continue
  153. }
  154. }
  155. }
  156. }
  157. }
  158. func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) {
  159. const bufSize = 512
  160. expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"}
  161. var b1, b2 bytes.Buffer
  162. wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
  163. rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
  164. w, _ := wc.NextWriter(BinaryMessage)
  165. w.Write(make([]byte, bufSize+bufSize/2))
  166. wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second))
  167. w.Close()
  168. op, r, err := rc.NextReader()
  169. if op != BinaryMessage || err != nil {
  170. t.Fatalf("NextReader() returned %d, %v", op, err)
  171. }
  172. _, err = io.Copy(ioutil.Discard, r)
  173. if !reflect.DeepEqual(err, expectedErr) {
  174. t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr)
  175. }
  176. _, _, err = rc.NextReader()
  177. if !reflect.DeepEqual(err, expectedErr) {
  178. t.Fatalf("NextReader() returned %v, want %v", err, expectedErr)
  179. }
  180. }
  181. func TestEOFWithinFrame(t *testing.T) {
  182. const bufSize = 64
  183. for n := 0; ; n++ {
  184. var b bytes.Buffer
  185. wc := newConn(fakeNetConn{Reader: nil, Writer: &b}, false, 1024, 1024)
  186. rc := newConn(fakeNetConn{Reader: &b, Writer: nil}, true, 1024, 1024)
  187. w, _ := wc.NextWriter(BinaryMessage)
  188. w.Write(make([]byte, bufSize))
  189. w.Close()
  190. if n >= b.Len() {
  191. break
  192. }
  193. b.Truncate(n)
  194. op, r, err := rc.NextReader()
  195. if err == errUnexpectedEOF {
  196. continue
  197. }
  198. if op != BinaryMessage || err != nil {
  199. t.Fatalf("%d: NextReader() returned %d, %v", n, op, err)
  200. }
  201. _, err = io.Copy(ioutil.Discard, r)
  202. if err != errUnexpectedEOF {
  203. t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF)
  204. }
  205. _, _, err = rc.NextReader()
  206. if err != errUnexpectedEOF {
  207. t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF)
  208. }
  209. }
  210. }
  211. func TestEOFBeforeFinalFrame(t *testing.T) {
  212. const bufSize = 512
  213. var b1, b2 bytes.Buffer
  214. wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize)
  215. rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
  216. w, _ := wc.NextWriter(BinaryMessage)
  217. w.Write(make([]byte, bufSize+bufSize/2))
  218. op, r, err := rc.NextReader()
  219. if op != BinaryMessage || err != nil {
  220. t.Fatalf("NextReader() returned %d, %v", op, err)
  221. }
  222. _, err = io.Copy(ioutil.Discard, r)
  223. if err != errUnexpectedEOF {
  224. t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF)
  225. }
  226. _, _, err = rc.NextReader()
  227. if err != errUnexpectedEOF {
  228. t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF)
  229. }
  230. }
  231. func TestWriteAfterMessageWriterClose(t *testing.T) {
  232. wc := newConn(fakeNetConn{Reader: nil, Writer: &bytes.Buffer{}}, false, 1024, 1024)
  233. w, _ := wc.NextWriter(BinaryMessage)
  234. io.WriteString(w, "hello")
  235. if err := w.Close(); err != nil {
  236. t.Fatalf("unxpected error closing message writer, %v", err)
  237. }
  238. if _, err := io.WriteString(w, "world"); err == nil {
  239. t.Fatalf("no error writing after close")
  240. }
  241. w, _ = wc.NextWriter(BinaryMessage)
  242. io.WriteString(w, "hello")
  243. // close w by getting next writer
  244. _, err := wc.NextWriter(BinaryMessage)
  245. if err != nil {
  246. t.Fatalf("unexpected error getting next writer, %v", err)
  247. }
  248. if _, err := io.WriteString(w, "world"); err == nil {
  249. t.Fatalf("no error writing after close")
  250. }
  251. }
  252. func TestReadLimit(t *testing.T) {
  253. const readLimit = 512
  254. message := make([]byte, readLimit+1)
  255. var b1, b2 bytes.Buffer
  256. wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, readLimit-2)
  257. rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024)
  258. rc.SetReadLimit(readLimit)
  259. // Send message at the limit with interleaved pong.
  260. w, _ := wc.NextWriter(BinaryMessage)
  261. w.Write(message[:readLimit-1])
  262. wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second))
  263. w.Write(message[:1])
  264. w.Close()
  265. // Send message larger than the limit.
  266. wc.WriteMessage(BinaryMessage, message[:readLimit+1])
  267. op, _, err := rc.NextReader()
  268. if op != BinaryMessage || err != nil {
  269. t.Fatalf("1: NextReader() returned %d, %v", op, err)
  270. }
  271. op, r, err := rc.NextReader()
  272. if op != BinaryMessage || err != nil {
  273. t.Fatalf("2: NextReader() returned %d, %v", op, err)
  274. }
  275. _, err = io.Copy(ioutil.Discard, r)
  276. if err != ErrReadLimit {
  277. t.Fatalf("io.Copy() returned %v", err)
  278. }
  279. }
  280. func TestAddrs(t *testing.T) {
  281. c := newConn(&fakeNetConn{}, true, 1024, 1024)
  282. if c.LocalAddr() != localAddr {
  283. t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr)
  284. }
  285. if c.RemoteAddr() != remoteAddr {
  286. t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr)
  287. }
  288. }
  289. func TestUnderlyingConn(t *testing.T) {
  290. var b1, b2 bytes.Buffer
  291. fc := fakeNetConn{Reader: &b1, Writer: &b2}
  292. c := newConn(fc, true, 1024, 1024)
  293. ul := c.UnderlyingConn()
  294. if ul != fc {
  295. t.Fatalf("Underlying conn is not what it should be.")
  296. }
  297. }
  298. func TestBufioReadBytes(t *testing.T) {
  299. // Test calling bufio.ReadBytes for value longer than read buffer size.
  300. m := make([]byte, 512)
  301. m[len(m)-1] = '\n'
  302. var b1, b2 bytes.Buffer
  303. wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64)
  304. rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64)
  305. w, _ := wc.NextWriter(BinaryMessage)
  306. w.Write(m)
  307. w.Close()
  308. op, r, err := rc.NextReader()
  309. if op != BinaryMessage || err != nil {
  310. t.Fatalf("NextReader() returned %d, %v", op, err)
  311. }
  312. br := bufio.NewReader(r)
  313. p, err := br.ReadBytes('\n')
  314. if err != nil {
  315. t.Fatalf("ReadBytes() returned %v", err)
  316. }
  317. if len(p) != len(m) {
  318. t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m))
  319. }
  320. }
  321. var closeErrorTests = []struct {
  322. err error
  323. codes []int
  324. ok bool
  325. }{
  326. {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true},
  327. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false},
  328. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true},
  329. {errors.New("hello"), []int{CloseNormalClosure}, false},
  330. }
  331. func TestCloseError(t *testing.T) {
  332. for _, tt := range closeErrorTests {
  333. ok := IsCloseError(tt.err, tt.codes...)
  334. if ok != tt.ok {
  335. t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
  336. }
  337. }
  338. }
  339. var unexpectedCloseErrorTests = []struct {
  340. err error
  341. codes []int
  342. ok bool
  343. }{
  344. {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false},
  345. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true},
  346. {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false},
  347. {errors.New("hello"), []int{CloseNormalClosure}, false},
  348. }
  349. func TestUnexpectedCloseErrors(t *testing.T) {
  350. for _, tt := range unexpectedCloseErrorTests {
  351. ok := IsUnexpectedCloseError(tt.err, tt.codes...)
  352. if ok != tt.ok {
  353. t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok)
  354. }
  355. }
  356. }
  357. type blockingWriter struct {
  358. c1, c2 chan struct{}
  359. }
  360. func (w blockingWriter) Write(p []byte) (int, error) {
  361. // Allow main to continue
  362. close(w.c1)
  363. // Wait for panic in main
  364. <-w.c2
  365. return len(p), nil
  366. }
  367. func TestConcurrentWritePanic(t *testing.T) {
  368. w := blockingWriter{make(chan struct{}), make(chan struct{})}
  369. c := newConn(fakeNetConn{Reader: nil, Writer: w}, false, 1024, 1024)
  370. go func() {
  371. c.WriteMessage(TextMessage, []byte{})
  372. }()
  373. // wait for goroutine to block in write.
  374. <-w.c1
  375. defer func() {
  376. close(w.c2)
  377. if v := recover(); v != nil {
  378. return
  379. }
  380. }()
  381. c.WriteMessage(TextMessage, []byte{})
  382. t.Fatal("should not get here")
  383. }
  384. type failingReader struct{}
  385. func (r failingReader) Read(p []byte) (int, error) {
  386. return 0, io.EOF
  387. }
  388. func TestFailedConnectionReadPanic(t *testing.T) {
  389. c := newConn(fakeNetConn{Reader: failingReader{}, Writer: nil}, false, 1024, 1024)
  390. defer func() {
  391. if v := recover(); v != nil {
  392. return
  393. }
  394. }()
  395. for i := 0; i < 20000; i++ {
  396. c.ReadMessage()
  397. }
  398. t.Fatal("should not get here")
  399. }
  400. func TestBufioReuse(t *testing.T) {
  401. brw := bufio.NewReadWriter(bufio.NewReader(nil), bufio.NewWriter(nil))
  402. c := newConnBRW(nil, false, 0, 0, brw)
  403. if c.br != brw.Reader {
  404. t.Error("connection did not reuse bufio.Reader")
  405. }
  406. var wh writeHook
  407. brw.Writer.Reset(&wh)
  408. brw.WriteByte(0)
  409. brw.Flush()
  410. if &c.writeBuf[0] != &wh.p[0] {
  411. t.Error("connection did not reuse bufio.Writer")
  412. }
  413. brw = bufio.NewReadWriter(bufio.NewReaderSize(nil, 0), bufio.NewWriterSize(nil, 0))
  414. c = newConnBRW(nil, false, 0, 0, brw)
  415. if c.br == brw.Reader {
  416. t.Error("connection used bufio.Reader with small size")
  417. }
  418. brw.Writer.Reset(&wh)
  419. brw.WriteByte(0)
  420. brw.Flush()
  421. if &c.writeBuf[0] != &wh.p[0] {
  422. t.Error("connection used bufio.Writer with small size")
  423. }
  424. }