session_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
  1. package smux
  2. import (
  3. crand "crypto/rand"
  4. "encoding/binary"
  5. "fmt"
  6. "io"
  7. "log"
  8. "math/rand"
  9. "net"
  10. "net/http"
  11. _ "net/http/pprof"
  12. "strings"
  13. "sync"
  14. "testing"
  15. "time"
  16. )
  17. func init() {
  18. go func() {
  19. log.Println(http.ListenAndServe("localhost:6060", nil))
  20. }()
  21. log.SetFlags(log.LstdFlags | log.Lshortfile)
  22. ln, err := net.Listen("tcp", "127.0.0.1:19999")
  23. if err != nil {
  24. // handle error
  25. panic(err)
  26. }
  27. go func() {
  28. for {
  29. conn, err := ln.Accept()
  30. if err != nil {
  31. // handle error
  32. }
  33. go handleConnection(conn)
  34. }
  35. }()
  36. }
  37. func handleConnection(conn net.Conn) {
  38. session, _ := Server(conn, nil)
  39. for {
  40. if stream, err := session.AcceptStream(); err == nil {
  41. go func(s io.ReadWriteCloser) {
  42. buf := make([]byte, 65536)
  43. for {
  44. n, err := s.Read(buf)
  45. if err != nil {
  46. return
  47. }
  48. s.Write(buf[:n])
  49. }
  50. }(stream)
  51. } else {
  52. return
  53. }
  54. }
  55. }
  56. func TestEcho(t *testing.T) {
  57. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  58. if err != nil {
  59. t.Fatal(err)
  60. }
  61. session, _ := Client(cli, nil)
  62. stream, _ := session.OpenStream()
  63. const N = 100
  64. buf := make([]byte, 10)
  65. var sent string
  66. var received string
  67. for i := 0; i < N; i++ {
  68. msg := fmt.Sprintf("hello%v", i)
  69. stream.Write([]byte(msg))
  70. sent += msg
  71. if n, err := stream.Read(buf); err != nil {
  72. t.Fatal(err)
  73. } else {
  74. received += string(buf[:n])
  75. }
  76. }
  77. if sent != received {
  78. t.Fatal("data mimatch")
  79. }
  80. session.Close()
  81. }
  82. func TestSpeed(t *testing.T) {
  83. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  84. if err != nil {
  85. t.Fatal(err)
  86. }
  87. session, _ := Client(cli, nil)
  88. stream, _ := session.OpenStream()
  89. t.Log(stream.LocalAddr(), stream.RemoteAddr())
  90. start := time.Now()
  91. var wg sync.WaitGroup
  92. wg.Add(1)
  93. go func() {
  94. buf := make([]byte, 1024*1024)
  95. nrecv := 0
  96. for {
  97. n, err := stream.Read(buf)
  98. if err != nil {
  99. t.Fatal(err)
  100. break
  101. } else {
  102. nrecv += n
  103. if nrecv == 4096*4096 {
  104. break
  105. }
  106. }
  107. }
  108. stream.Close()
  109. t.Log("time for 16MB rtt", time.Since(start))
  110. wg.Done()
  111. }()
  112. msg := make([]byte, 8192)
  113. for i := 0; i < 2048; i++ {
  114. stream.Write(msg)
  115. }
  116. wg.Wait()
  117. session.Close()
  118. }
  119. func TestParallel(t *testing.T) {
  120. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  121. if err != nil {
  122. t.Fatal(err)
  123. }
  124. session, _ := Client(cli, nil)
  125. par := 1000
  126. messages := 100
  127. var wg sync.WaitGroup
  128. wg.Add(par)
  129. for i := 0; i < par; i++ {
  130. stream, _ := session.OpenStream()
  131. go func(s *Stream) {
  132. buf := make([]byte, 20)
  133. for j := 0; j < messages; j++ {
  134. msg := fmt.Sprintf("hello%v", j)
  135. s.Write([]byte(msg))
  136. if _, err := s.Read(buf); err != nil {
  137. break
  138. }
  139. }
  140. s.Close()
  141. wg.Done()
  142. }(stream)
  143. }
  144. t.Log("created", session.NumStreams(), "streams")
  145. wg.Wait()
  146. session.Close()
  147. }
  148. func TestCloseThenOpen(t *testing.T) {
  149. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  150. if err != nil {
  151. t.Fatal(err)
  152. }
  153. session, _ := Client(cli, nil)
  154. session.Close()
  155. if _, err := session.OpenStream(); err == nil {
  156. t.Fatal("opened after close")
  157. }
  158. }
  159. func TestStreamDoubleClose(t *testing.T) {
  160. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  161. if err != nil {
  162. t.Fatal(err)
  163. }
  164. session, _ := Client(cli, nil)
  165. stream, _ := session.OpenStream()
  166. stream.Close()
  167. if err := stream.Close(); err == nil {
  168. t.Log("double close doesn't return error")
  169. }
  170. session.Close()
  171. }
  172. func TestConcurrentClose(t *testing.T) {
  173. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  174. if err != nil {
  175. t.Fatal(err)
  176. }
  177. session, _ := Client(cli, nil)
  178. numStreams := 100
  179. streams := make([]*Stream, 0, numStreams)
  180. var wg sync.WaitGroup
  181. wg.Add(numStreams)
  182. for i := 0; i < 100; i++ {
  183. stream, _ := session.OpenStream()
  184. streams = append(streams, stream)
  185. }
  186. for _, s := range streams {
  187. stream := s
  188. go func() {
  189. stream.Close()
  190. wg.Done()
  191. }()
  192. }
  193. session.Close()
  194. wg.Wait()
  195. }
  196. func TestTinyReadBuffer(t *testing.T) {
  197. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  198. if err != nil {
  199. t.Fatal(err)
  200. }
  201. session, _ := Client(cli, nil)
  202. stream, _ := session.OpenStream()
  203. const N = 100
  204. tinybuf := make([]byte, 6)
  205. var sent string
  206. var received string
  207. for i := 0; i < N; i++ {
  208. msg := fmt.Sprintf("hello%v", i)
  209. sent += msg
  210. nsent, err := stream.Write([]byte(msg))
  211. if err != nil {
  212. t.Fatal("cannot write")
  213. }
  214. nrecv := 0
  215. for nrecv < nsent {
  216. if n, err := stream.Read(tinybuf); err == nil {
  217. nrecv += n
  218. received += string(tinybuf[:n])
  219. } else {
  220. t.Fatal("cannot read with tiny buffer")
  221. }
  222. }
  223. }
  224. if sent != received {
  225. t.Fatal("data mimatch")
  226. }
  227. session.Close()
  228. }
  229. func TestIsClose(t *testing.T) {
  230. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  231. if err != nil {
  232. t.Fatal(err)
  233. }
  234. session, _ := Client(cli, nil)
  235. session.Close()
  236. if session.IsClosed() != true {
  237. t.Fatal("still open after close")
  238. }
  239. }
  240. func TestKeepAliveTimeout(t *testing.T) {
  241. ln, err := net.Listen("tcp", "127.0.0.1:29999")
  242. if err != nil {
  243. // handle error
  244. panic(err)
  245. }
  246. go func() {
  247. ln.Accept()
  248. }()
  249. cli, err := net.Dial("tcp", "127.0.0.1:29999")
  250. if err != nil {
  251. t.Fatal(err)
  252. }
  253. config := DefaultConfig()
  254. config.KeepAliveInterval = time.Second
  255. config.KeepAliveTimeout = 2 * time.Second
  256. session, _ := Client(cli, config)
  257. <-time.After(3 * time.Second)
  258. if session.IsClosed() != true {
  259. t.Fatal("keepalive-timeout failed")
  260. }
  261. }
  262. func TestServerEcho(t *testing.T) {
  263. ln, err := net.Listen("tcp", "127.0.0.1:39999")
  264. if err != nil {
  265. // handle error
  266. panic(err)
  267. }
  268. go func() {
  269. if conn, err := ln.Accept(); err == nil {
  270. session, _ := Server(conn, nil)
  271. if stream, err := session.OpenStream(); err == nil {
  272. const N = 100
  273. buf := make([]byte, 10)
  274. for i := 0; i < N; i++ {
  275. msg := fmt.Sprintf("hello%v", i)
  276. stream.Write([]byte(msg))
  277. if n, err := stream.Read(buf); err != nil {
  278. t.Fatal(err)
  279. } else if string(buf[:n]) != msg {
  280. t.Fatal(err)
  281. }
  282. }
  283. stream.Close()
  284. } else {
  285. t.Fatal(err)
  286. }
  287. } else {
  288. t.Fatal(err)
  289. }
  290. }()
  291. cli, err := net.Dial("tcp", "127.0.0.1:39999")
  292. if err != nil {
  293. t.Fatal(err)
  294. }
  295. if session, err := Client(cli, nil); err == nil {
  296. if stream, err := session.AcceptStream(); err == nil {
  297. buf := make([]byte, 65536)
  298. for {
  299. n, err := stream.Read(buf)
  300. if err != nil {
  301. break
  302. }
  303. stream.Write(buf[:n])
  304. }
  305. } else {
  306. t.Fatal(err)
  307. }
  308. } else {
  309. t.Fatal(err)
  310. }
  311. }
  312. func TestSendWithoutRecv(t *testing.T) {
  313. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  314. if err != nil {
  315. t.Fatal(err)
  316. }
  317. session, _ := Client(cli, nil)
  318. stream, _ := session.OpenStream()
  319. const N = 100
  320. for i := 0; i < N; i++ {
  321. msg := fmt.Sprintf("hello%v", i)
  322. stream.Write([]byte(msg))
  323. }
  324. buf := make([]byte, 1)
  325. if _, err := stream.Read(buf); err != nil {
  326. t.Fatal(err)
  327. }
  328. stream.Close()
  329. }
  330. func TestWriteAfterClose(t *testing.T) {
  331. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  332. if err != nil {
  333. t.Fatal(err)
  334. }
  335. session, _ := Client(cli, nil)
  336. stream, _ := session.OpenStream()
  337. stream.Close()
  338. if _, err := stream.Write([]byte("write after close")); err == nil {
  339. t.Fatal("write after close failed")
  340. }
  341. }
  342. func TestReadStreamAfterSessionClose(t *testing.T) {
  343. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  344. if err != nil {
  345. t.Fatal(err)
  346. }
  347. session, _ := Client(cli, nil)
  348. stream, _ := session.OpenStream()
  349. session.Close()
  350. buf := make([]byte, 10)
  351. if _, err := stream.Read(buf); err != nil {
  352. t.Log(err)
  353. } else {
  354. t.Fatal("read stream after session close succeeded")
  355. }
  356. }
  357. func TestWriteStreamAfterConnectionClose(t *testing.T) {
  358. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  359. if err != nil {
  360. t.Fatal(err)
  361. }
  362. session, _ := Client(cli, nil)
  363. stream, _ := session.OpenStream()
  364. session.conn.Close()
  365. if _, err := stream.Write([]byte("write after connection close")); err == nil {
  366. t.Fatal("write after connection close failed")
  367. }
  368. }
  369. func TestNumStreamAfterClose(t *testing.T) {
  370. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  371. if err != nil {
  372. t.Fatal(err)
  373. }
  374. session, _ := Client(cli, nil)
  375. if _, err := session.OpenStream(); err == nil {
  376. if session.NumStreams() != 1 {
  377. t.Fatal("wrong number of streams after opened")
  378. }
  379. session.Close()
  380. if session.NumStreams() != 0 {
  381. t.Fatal("wrong number of streams after session closed")
  382. }
  383. } else {
  384. t.Fatal(err)
  385. }
  386. cli.Close()
  387. }
  388. func TestRandomFrame(t *testing.T) {
  389. // pure random
  390. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  391. if err != nil {
  392. t.Fatal(err)
  393. }
  394. session, _ := Client(cli, nil)
  395. for i := 0; i < 100; i++ {
  396. rnd := make([]byte, rand.Uint32()%1024)
  397. io.ReadFull(crand.Reader, rnd)
  398. session.conn.Write(rnd)
  399. }
  400. cli.Close()
  401. // double syn
  402. cli, err = net.Dial("tcp", "127.0.0.1:19999")
  403. if err != nil {
  404. t.Fatal(err)
  405. }
  406. session, _ = Client(cli, nil)
  407. for i := 0; i < 100; i++ {
  408. f := newFrame(cmdSYN, 1000)
  409. session.writeFrame(f)
  410. }
  411. cli.Close()
  412. // random cmds
  413. cli, err = net.Dial("tcp", "127.0.0.1:19999")
  414. if err != nil {
  415. t.Fatal(err)
  416. }
  417. allcmds := []byte{cmdSYN, cmdFIN, cmdPSH, cmdNOP}
  418. session, _ = Client(cli, nil)
  419. for i := 0; i < 100; i++ {
  420. f := newFrame(allcmds[rand.Int()%len(allcmds)], rand.Uint32())
  421. session.writeFrame(f)
  422. }
  423. cli.Close()
  424. // random cmds & sids
  425. cli, err = net.Dial("tcp", "127.0.0.1:19999")
  426. if err != nil {
  427. t.Fatal(err)
  428. }
  429. session, _ = Client(cli, nil)
  430. for i := 0; i < 100; i++ {
  431. f := newFrame(byte(rand.Uint32()), rand.Uint32())
  432. session.writeFrame(f)
  433. }
  434. cli.Close()
  435. // random version
  436. cli, err = net.Dial("tcp", "127.0.0.1:19999")
  437. if err != nil {
  438. t.Fatal(err)
  439. }
  440. session, _ = Client(cli, nil)
  441. for i := 0; i < 100; i++ {
  442. f := newFrame(byte(rand.Uint32()), rand.Uint32())
  443. f.ver = byte(rand.Uint32())
  444. session.writeFrame(f)
  445. }
  446. cli.Close()
  447. // incorrect size
  448. cli, err = net.Dial("tcp", "127.0.0.1:19999")
  449. if err != nil {
  450. t.Fatal(err)
  451. }
  452. session, _ = Client(cli, nil)
  453. f := newFrame(byte(rand.Uint32()), rand.Uint32())
  454. rnd := make([]byte, rand.Uint32()%1024)
  455. io.ReadFull(crand.Reader, rnd)
  456. f.data = rnd
  457. buf := make([]byte, headerSize+len(f.data))
  458. buf[0] = f.ver
  459. buf[1] = f.cmd
  460. binary.LittleEndian.PutUint16(buf[2:], uint16(len(rnd)+1)) /// incorrect size
  461. binary.LittleEndian.PutUint32(buf[4:], f.sid)
  462. copy(buf[headerSize:], f.data)
  463. session.conn.Write(buf)
  464. t.Log(rawHeader(buf))
  465. cli.Close()
  466. }
  467. func TestReadDeadline(t *testing.T) {
  468. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  469. if err != nil {
  470. t.Fatal(err)
  471. }
  472. session, _ := Client(cli, nil)
  473. stream, _ := session.OpenStream()
  474. const N = 100
  475. buf := make([]byte, 10)
  476. var readErr error
  477. for i := 0; i < N; i++ {
  478. msg := fmt.Sprintf("hello%v", i)
  479. stream.Write([]byte(msg))
  480. stream.SetReadDeadline(time.Now().Add(-1 * time.Minute))
  481. if _, readErr = stream.Read(buf); readErr != nil {
  482. break
  483. }
  484. }
  485. if readErr != nil {
  486. if !strings.Contains(readErr.Error(), "i/o timeout") {
  487. t.Fatalf("Wrong error: %v", readErr)
  488. }
  489. } else {
  490. t.Fatal("No error when reading with past deadline")
  491. }
  492. session.Close()
  493. }
  494. func TestWriteDeadline(t *testing.T) {
  495. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  496. if err != nil {
  497. t.Fatal(err)
  498. }
  499. session, _ := Client(cli, nil)
  500. stream, _ := session.OpenStream()
  501. buf := make([]byte, 10)
  502. var writeErr error
  503. for {
  504. stream.SetWriteDeadline(time.Now().Add(-1 * time.Minute))
  505. if _, writeErr = stream.Write(buf); writeErr != nil {
  506. if !strings.Contains(writeErr.Error(), "i/o timeout") {
  507. t.Fatalf("Wrong error: %v", writeErr)
  508. }
  509. break
  510. }
  511. }
  512. session.Close()
  513. }
  514. func BenchmarkAcceptClose(b *testing.B) {
  515. cli, err := net.Dial("tcp", "127.0.0.1:19999")
  516. if err != nil {
  517. b.Fatal(err)
  518. }
  519. session, _ := Client(cli, nil)
  520. for i := 0; i < b.N; i++ {
  521. if stream, err := session.OpenStream(); err == nil {
  522. stream.Close()
  523. } else {
  524. b.Fatal(err)
  525. }
  526. }
  527. }
  528. func BenchmarkConnSmux(b *testing.B) {
  529. cs, ss, err := getSmuxStreamPair()
  530. if err != nil {
  531. b.Fatal(err)
  532. }
  533. defer cs.Close()
  534. defer ss.Close()
  535. bench(b, cs, ss)
  536. }
  537. func BenchmarkConnTCP(b *testing.B) {
  538. cs, ss, err := getTCPConnectionPair()
  539. if err != nil {
  540. b.Fatal(err)
  541. }
  542. defer cs.Close()
  543. defer ss.Close()
  544. bench(b, cs, ss)
  545. }
  546. func getSmuxStreamPair() (*Stream, *Stream, error) {
  547. c1, c2, err := getTCPConnectionPair()
  548. if err != nil {
  549. return nil, nil, err
  550. }
  551. s, err := Server(c2, nil)
  552. if err != nil {
  553. return nil, nil, err
  554. }
  555. c, err := Client(c1, nil)
  556. if err != nil {
  557. return nil, nil, err
  558. }
  559. var ss *Stream
  560. done := make(chan error)
  561. go func() {
  562. var rerr error
  563. ss, rerr = s.AcceptStream()
  564. done <- rerr
  565. close(done)
  566. }()
  567. cs, err := c.OpenStream()
  568. if err != nil {
  569. return nil, nil, err
  570. }
  571. err = <-done
  572. if err != nil {
  573. return nil, nil, err
  574. }
  575. return cs, ss, nil
  576. }
  577. func getTCPConnectionPair() (net.Conn, net.Conn, error) {
  578. lst, err := net.Listen("tcp", "127.0.0.1:0")
  579. if err != nil {
  580. return nil, nil, err
  581. }
  582. var conn0 net.Conn
  583. var err0 error
  584. done := make(chan struct{})
  585. go func() {
  586. conn0, err0 = lst.Accept()
  587. close(done)
  588. }()
  589. conn1, err := net.Dial("tcp", lst.Addr().String())
  590. if err != nil {
  591. return nil, nil, err
  592. }
  593. <-done
  594. if err0 != nil {
  595. return nil, nil, err0
  596. }
  597. return conn0, conn1, nil
  598. }
  599. func bench(b *testing.B, rd io.Reader, wr io.Writer) {
  600. buf := make([]byte, 128*1024)
  601. buf2 := make([]byte, 128*1024)
  602. b.SetBytes(128 * 1024)
  603. b.ResetTimer()
  604. var wg sync.WaitGroup
  605. wg.Add(1)
  606. go func() {
  607. defer wg.Done()
  608. count := 0
  609. for {
  610. n, _ := rd.Read(buf2)
  611. count += n
  612. if count == 128*1024*b.N {
  613. return
  614. }
  615. }
  616. }()
  617. for i := 0; i < b.N; i++ {
  618. wr.Write(buf)
  619. }
  620. wg.Wait()
  621. }