mux_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "io"
  7. "io/ioutil"
  8. "sync"
  9. "testing"
  10. )
  11. func muxPair() (*mux, *mux) {
  12. a, b := memPipe()
  13. s := newMux(a)
  14. c := newMux(b)
  15. return s, c
  16. }
  17. // Returns both ends of a channel, and the mux for the the 2nd
  18. // channel.
  19. func channelPair(t *testing.T) (*channel, *channel, *mux) {
  20. c, s := muxPair()
  21. res := make(chan *channel, 1)
  22. go func() {
  23. newCh, ok := <-s.incomingChannels
  24. if !ok {
  25. t.Fatalf("No incoming channel")
  26. }
  27. if newCh.ChannelType() != "chan" {
  28. t.Fatalf("got type %q want chan", newCh.ChannelType())
  29. }
  30. ch, _, err := newCh.Accept()
  31. if err != nil {
  32. t.Fatalf("Accept %v", err)
  33. }
  34. res <- ch.(*channel)
  35. }()
  36. ch, err := c.openChannel("chan", nil)
  37. if err != nil {
  38. t.Fatalf("OpenChannel: %v", err)
  39. }
  40. return <-res, ch, c
  41. }
  42. // Test that stderr and stdout can be addressed from different
  43. // goroutines. This is intended for use with the race detector.
  44. func TestMuxChannelExtendedThreadSafety(t *testing.T) {
  45. writer, reader, mux := channelPair(t)
  46. defer writer.Close()
  47. defer reader.Close()
  48. defer mux.Close()
  49. var wr, rd sync.WaitGroup
  50. magic := "hello world"
  51. wr.Add(2)
  52. go func() {
  53. io.WriteString(writer, magic)
  54. wr.Done()
  55. }()
  56. go func() {
  57. io.WriteString(writer.Stderr(), magic)
  58. wr.Done()
  59. }()
  60. rd.Add(2)
  61. go func() {
  62. c, err := ioutil.ReadAll(reader)
  63. if string(c) != magic {
  64. t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
  65. }
  66. rd.Done()
  67. }()
  68. go func() {
  69. c, err := ioutil.ReadAll(reader.Stderr())
  70. if string(c) != magic {
  71. t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
  72. }
  73. rd.Done()
  74. }()
  75. wr.Wait()
  76. writer.CloseWrite()
  77. rd.Wait()
  78. }
  79. func TestMuxReadWrite(t *testing.T) {
  80. s, c, mux := channelPair(t)
  81. defer s.Close()
  82. defer c.Close()
  83. defer mux.Close()
  84. magic := "hello world"
  85. magicExt := "hello stderr"
  86. go func() {
  87. _, err := s.Write([]byte(magic))
  88. if err != nil {
  89. t.Fatalf("Write: %v", err)
  90. }
  91. _, err = s.Extended(1).Write([]byte(magicExt))
  92. if err != nil {
  93. t.Fatalf("Write: %v", err)
  94. }
  95. err = s.Close()
  96. if err != nil {
  97. t.Fatalf("Close: %v", err)
  98. }
  99. }()
  100. var buf [1024]byte
  101. n, err := c.Read(buf[:])
  102. if err != nil {
  103. t.Fatalf("server Read: %v", err)
  104. }
  105. got := string(buf[:n])
  106. if got != magic {
  107. t.Fatalf("server: got %q want %q", got, magic)
  108. }
  109. n, err = c.Extended(1).Read(buf[:])
  110. if err != nil {
  111. t.Fatalf("server Read: %v", err)
  112. }
  113. got = string(buf[:n])
  114. if got != magicExt {
  115. t.Fatalf("server: got %q want %q", got, magic)
  116. }
  117. }
  118. func TestMuxChannelOverflow(t *testing.T) {
  119. reader, writer, mux := channelPair(t)
  120. defer reader.Close()
  121. defer writer.Close()
  122. defer mux.Close()
  123. wDone := make(chan int, 1)
  124. go func() {
  125. if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
  126. t.Errorf("could not fill window: %v", err)
  127. }
  128. writer.Write(make([]byte, 1))
  129. wDone <- 1
  130. }()
  131. writer.remoteWin.waitWriterBlocked()
  132. // Send 1 byte.
  133. packet := make([]byte, 1+4+4+1)
  134. packet[0] = msgChannelData
  135. marshalUint32(packet[1:], writer.remoteId)
  136. marshalUint32(packet[5:], uint32(1))
  137. packet[9] = 42
  138. if err := writer.mux.conn.writePacket(packet); err != nil {
  139. t.Errorf("could not send packet")
  140. }
  141. if _, err := reader.SendRequest("hello", true, nil); err == nil {
  142. t.Errorf("SendRequest succeeded.")
  143. }
  144. <-wDone
  145. }
  146. func TestMuxChannelCloseWriteUnblock(t *testing.T) {
  147. reader, writer, mux := channelPair(t)
  148. defer reader.Close()
  149. defer writer.Close()
  150. defer mux.Close()
  151. wDone := make(chan int, 1)
  152. go func() {
  153. if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
  154. t.Errorf("could not fill window: %v", err)
  155. }
  156. if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
  157. t.Errorf("got %v, want EOF for unblock write", err)
  158. }
  159. wDone <- 1
  160. }()
  161. writer.remoteWin.waitWriterBlocked()
  162. reader.Close()
  163. <-wDone
  164. }
  165. func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
  166. reader, writer, mux := channelPair(t)
  167. defer reader.Close()
  168. defer writer.Close()
  169. defer mux.Close()
  170. wDone := make(chan int, 1)
  171. go func() {
  172. if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
  173. t.Errorf("could not fill window: %v", err)
  174. }
  175. if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
  176. t.Errorf("got %v, want EOF for unblock write", err)
  177. }
  178. wDone <- 1
  179. }()
  180. writer.remoteWin.waitWriterBlocked()
  181. mux.Close()
  182. <-wDone
  183. }
  184. func TestMuxReject(t *testing.T) {
  185. client, server := muxPair()
  186. defer server.Close()
  187. defer client.Close()
  188. go func() {
  189. ch, ok := <-server.incomingChannels
  190. if !ok {
  191. t.Fatalf("Accept")
  192. }
  193. if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
  194. t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
  195. }
  196. ch.Reject(RejectionReason(42), "message")
  197. }()
  198. ch, err := client.openChannel("ch", []byte("extra"))
  199. if ch != nil {
  200. t.Fatal("openChannel not rejected")
  201. }
  202. ocf, ok := err.(*OpenChannelError)
  203. if !ok {
  204. t.Errorf("got %#v want *OpenChannelError", err)
  205. } else if ocf.Reason != 42 || ocf.Message != "message" {
  206. t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
  207. }
  208. want := "ssh: rejected: unknown reason 42 (message)"
  209. if err.Error() != want {
  210. t.Errorf("got %q, want %q", err.Error(), want)
  211. }
  212. }
  213. func TestMuxChannelRequest(t *testing.T) {
  214. client, server, mux := channelPair(t)
  215. defer server.Close()
  216. defer client.Close()
  217. defer mux.Close()
  218. var received int
  219. var wg sync.WaitGroup
  220. wg.Add(1)
  221. go func() {
  222. for r := range server.incomingRequests {
  223. received++
  224. r.Reply(r.Type == "yes", nil)
  225. }
  226. wg.Done()
  227. }()
  228. _, err := client.SendRequest("yes", false, nil)
  229. if err != nil {
  230. t.Fatalf("SendRequest: %v", err)
  231. }
  232. ok, err := client.SendRequest("yes", true, nil)
  233. if err != nil {
  234. t.Fatalf("SendRequest: %v", err)
  235. }
  236. if !ok {
  237. t.Errorf("SendRequest(yes): %v", ok)
  238. }
  239. ok, err = client.SendRequest("no", true, nil)
  240. if err != nil {
  241. t.Fatalf("SendRequest: %v", err)
  242. }
  243. if ok {
  244. t.Errorf("SendRequest(no): %v", ok)
  245. }
  246. client.Close()
  247. wg.Wait()
  248. if received != 3 {
  249. t.Errorf("got %d requests, want %d", received, 3)
  250. }
  251. }
  252. func TestMuxGlobalRequest(t *testing.T) {
  253. clientMux, serverMux := muxPair()
  254. defer serverMux.Close()
  255. defer clientMux.Close()
  256. var seen bool
  257. go func() {
  258. for r := range serverMux.incomingRequests {
  259. seen = seen || r.Type == "peek"
  260. if r.WantReply {
  261. err := r.Reply(r.Type == "yes",
  262. append([]byte(r.Type), r.Payload...))
  263. if err != nil {
  264. t.Errorf("AckRequest: %v", err)
  265. }
  266. }
  267. }
  268. }()
  269. _, _, err := clientMux.SendRequest("peek", false, nil)
  270. if err != nil {
  271. t.Errorf("SendRequest: %v", err)
  272. }
  273. ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
  274. if !ok || string(data) != "yesa" || err != nil {
  275. t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
  276. ok, data, err)
  277. }
  278. if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
  279. t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
  280. ok, data, err)
  281. }
  282. if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
  283. t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
  284. ok, data, err)
  285. }
  286. if !seen {
  287. t.Errorf("never saw 'peek' request")
  288. }
  289. }
  290. func TestMuxGlobalRequestUnblock(t *testing.T) {
  291. clientMux, serverMux := muxPair()
  292. defer serverMux.Close()
  293. defer clientMux.Close()
  294. result := make(chan error, 1)
  295. go func() {
  296. _, _, err := clientMux.SendRequest("hello", true, nil)
  297. result <- err
  298. }()
  299. <-serverMux.incomingRequests
  300. serverMux.conn.Close()
  301. err := <-result
  302. if err != io.EOF {
  303. t.Errorf("want EOF, got %v", io.EOF)
  304. }
  305. }
  306. func TestMuxChannelRequestUnblock(t *testing.T) {
  307. a, b, connB := channelPair(t)
  308. defer a.Close()
  309. defer b.Close()
  310. defer connB.Close()
  311. result := make(chan error, 1)
  312. go func() {
  313. _, err := a.SendRequest("hello", true, nil)
  314. result <- err
  315. }()
  316. <-b.incomingRequests
  317. connB.conn.Close()
  318. err := <-result
  319. if err != io.EOF {
  320. t.Errorf("want EOF, got %v", err)
  321. }
  322. }
  323. func TestMuxCloseChannel(t *testing.T) {
  324. r, w, mux := channelPair(t)
  325. defer mux.Close()
  326. defer r.Close()
  327. defer w.Close()
  328. result := make(chan error, 1)
  329. go func() {
  330. var b [1024]byte
  331. _, err := r.Read(b[:])
  332. result <- err
  333. }()
  334. if err := w.Close(); err != nil {
  335. t.Errorf("w.Close: %v", err)
  336. }
  337. if _, err := w.Write([]byte("hello")); err != io.EOF {
  338. t.Errorf("got err %v, want io.EOF after Close", err)
  339. }
  340. if err := <-result; err != io.EOF {
  341. t.Errorf("got %v (%T), want io.EOF", err, err)
  342. }
  343. }
  344. func TestMuxCloseWriteChannel(t *testing.T) {
  345. r, w, mux := channelPair(t)
  346. defer mux.Close()
  347. result := make(chan error, 1)
  348. go func() {
  349. var b [1024]byte
  350. _, err := r.Read(b[:])
  351. result <- err
  352. }()
  353. if err := w.CloseWrite(); err != nil {
  354. t.Errorf("w.CloseWrite: %v", err)
  355. }
  356. if _, err := w.Write([]byte("hello")); err != io.EOF {
  357. t.Errorf("got err %v, want io.EOF after CloseWrite", err)
  358. }
  359. if err := <-result; err != io.EOF {
  360. t.Errorf("got %v (%T), want io.EOF", err, err)
  361. }
  362. }
  363. func TestMuxInvalidRecord(t *testing.T) {
  364. a, b := muxPair()
  365. defer a.Close()
  366. defer b.Close()
  367. packet := make([]byte, 1+4+4+1)
  368. packet[0] = msgChannelData
  369. marshalUint32(packet[1:], 29348723 /* invalid channel id */)
  370. marshalUint32(packet[5:], 1)
  371. packet[9] = 42
  372. a.conn.writePacket(packet)
  373. go a.SendRequest("hello", false, nil)
  374. // 'a' wrote an invalid packet, so 'b' has exited.
  375. req, ok := <-b.incomingRequests
  376. if ok {
  377. t.Errorf("got request %#v after receiving invalid packet", req)
  378. }
  379. }
  380. func TestZeroWindowAdjust(t *testing.T) {
  381. a, b, mux := channelPair(t)
  382. defer a.Close()
  383. defer b.Close()
  384. defer mux.Close()
  385. go func() {
  386. io.WriteString(a, "hello")
  387. // bogus adjust.
  388. a.sendMessage(windowAdjustMsg{})
  389. io.WriteString(a, "world")
  390. a.Close()
  391. }()
  392. want := "helloworld"
  393. c, _ := ioutil.ReadAll(b)
  394. if string(c) != want {
  395. t.Errorf("got %q want %q", c, want)
  396. }
  397. }
  398. func TestMuxMaxPacketSize(t *testing.T) {
  399. a, b, mux := channelPair(t)
  400. defer a.Close()
  401. defer b.Close()
  402. defer mux.Close()
  403. large := make([]byte, a.maxRemotePayload+1)
  404. packet := make([]byte, 1+4+4+1+len(large))
  405. packet[0] = msgChannelData
  406. marshalUint32(packet[1:], a.remoteId)
  407. marshalUint32(packet[5:], uint32(len(large)))
  408. packet[9] = 42
  409. if err := a.mux.conn.writePacket(packet); err != nil {
  410. t.Errorf("could not send packet")
  411. }
  412. go a.SendRequest("hello", false, nil)
  413. _, ok := <-b.incomingRequests
  414. if ok {
  415. t.Errorf("connection still alive after receiving large packet.")
  416. }
  417. }
  418. // Don't ship code with debug=true.
  419. func TestDebug(t *testing.T) {
  420. if debugMux {
  421. t.Error("mux debug switched on")
  422. }
  423. if debugHandshake {
  424. t.Error("handshake debug switched on")
  425. }
  426. if debugTransport {
  427. t.Error("transport debug switched on")
  428. }
  429. }