messages_test.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. // Copyright 2011 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "bytes"
  7. "math/big"
  8. "math/rand"
  9. "reflect"
  10. "testing"
  11. "testing/quick"
  12. )
  13. var intLengthTests = []struct {
  14. val, length int
  15. }{
  16. {0, 4 + 0},
  17. {1, 4 + 1},
  18. {127, 4 + 1},
  19. {128, 4 + 2},
  20. {-1, 4 + 1},
  21. }
  22. func TestIntLength(t *testing.T) {
  23. for _, test := range intLengthTests {
  24. v := new(big.Int).SetInt64(int64(test.val))
  25. length := intLength(v)
  26. if length != test.length {
  27. t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length)
  28. }
  29. }
  30. }
  31. type msgAllTypes struct {
  32. Bool bool `sshtype:"21"`
  33. Array [16]byte
  34. Uint64 uint64
  35. Uint32 uint32
  36. Uint8 uint8
  37. String string
  38. Strings []string
  39. Bytes []byte
  40. Int *big.Int
  41. Rest []byte `ssh:"rest"`
  42. }
  43. func (t *msgAllTypes) Generate(rand *rand.Rand, size int) reflect.Value {
  44. m := &msgAllTypes{}
  45. m.Bool = rand.Intn(2) == 1
  46. randomBytes(m.Array[:], rand)
  47. m.Uint64 = uint64(rand.Int63n(1<<63 - 1))
  48. m.Uint32 = uint32(rand.Intn((1 << 31) - 1))
  49. m.Uint8 = uint8(rand.Intn(1 << 8))
  50. m.String = string(m.Array[:])
  51. m.Strings = randomNameList(rand)
  52. m.Bytes = m.Array[:]
  53. m.Int = randomInt(rand)
  54. m.Rest = m.Array[:]
  55. return reflect.ValueOf(m)
  56. }
  57. func TestMarshalUnmarshal(t *testing.T) {
  58. rand := rand.New(rand.NewSource(0))
  59. iface := &msgAllTypes{}
  60. ty := reflect.ValueOf(iface).Type()
  61. n := 100
  62. if testing.Short() {
  63. n = 5
  64. }
  65. for j := 0; j < n; j++ {
  66. v, ok := quick.Value(ty, rand)
  67. if !ok {
  68. t.Errorf("failed to create value")
  69. break
  70. }
  71. m1 := v.Elem().Interface()
  72. m2 := iface
  73. marshaled := Marshal(m1)
  74. if err := Unmarshal(marshaled, m2); err != nil {
  75. t.Errorf("Unmarshal %#v: %s", m1, err)
  76. break
  77. }
  78. if !reflect.DeepEqual(v.Interface(), m2) {
  79. t.Errorf("got: %#v\nwant:%#v\n%x", m2, m1, marshaled)
  80. break
  81. }
  82. }
  83. }
  84. func TestUnmarshalEmptyPacket(t *testing.T) {
  85. var b []byte
  86. var m channelRequestSuccessMsg
  87. if err := Unmarshal(b, &m); err == nil {
  88. t.Fatalf("unmarshal of empty slice succeeded")
  89. }
  90. }
  91. func TestUnmarshalUnexpectedPacket(t *testing.T) {
  92. type S struct {
  93. I uint32 `sshtype:"43"`
  94. S string
  95. B bool
  96. }
  97. s := S{11, "hello", true}
  98. packet := Marshal(s)
  99. packet[0] = 42
  100. roundtrip := S{}
  101. err := Unmarshal(packet, &roundtrip)
  102. if err == nil {
  103. t.Fatal("expected error, not nil")
  104. }
  105. }
  106. func TestMarshalPtr(t *testing.T) {
  107. s := struct {
  108. S string
  109. }{"hello"}
  110. m1 := Marshal(s)
  111. m2 := Marshal(&s)
  112. if !bytes.Equal(m1, m2) {
  113. t.Errorf("got %q, want %q for marshaled pointer", m2, m1)
  114. }
  115. }
  116. func TestBareMarshalUnmarshal(t *testing.T) {
  117. type S struct {
  118. I uint32
  119. S string
  120. B bool
  121. }
  122. s := S{42, "hello", true}
  123. packet := Marshal(s)
  124. roundtrip := S{}
  125. Unmarshal(packet, &roundtrip)
  126. if !reflect.DeepEqual(s, roundtrip) {
  127. t.Errorf("got %#v, want %#v", roundtrip, s)
  128. }
  129. }
  130. func TestBareMarshal(t *testing.T) {
  131. type S2 struct {
  132. I uint32
  133. }
  134. s := S2{42}
  135. packet := Marshal(s)
  136. i, rest, ok := parseUint32(packet)
  137. if len(rest) > 0 || !ok {
  138. t.Errorf("parseInt(%q): parse error", packet)
  139. }
  140. if i != s.I {
  141. t.Errorf("got %d, want %d", i, s.I)
  142. }
  143. }
  144. func TestUnmarshalShortKexInitPacket(t *testing.T) {
  145. // This used to panic.
  146. // Issue 11348
  147. packet := []byte{0x14, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0x30, 0xff, 0xff, 0xff, 0xff}
  148. kim := &kexInitMsg{}
  149. if err := Unmarshal(packet, kim); err == nil {
  150. t.Error("truncated packet unmarshaled without error")
  151. }
  152. }
  153. func TestMarshalMultiTag(t *testing.T) {
  154. var res struct {
  155. A uint32 `sshtype:"1|2"`
  156. }
  157. good1 := struct {
  158. A uint32 `sshtype:"1"`
  159. }{
  160. 1,
  161. }
  162. good2 := struct {
  163. A uint32 `sshtype:"2"`
  164. }{
  165. 1,
  166. }
  167. if e := Unmarshal(Marshal(good1), &res); e != nil {
  168. t.Errorf("error unmarshaling multipart tag: %v", e)
  169. }
  170. if e := Unmarshal(Marshal(good2), &res); e != nil {
  171. t.Errorf("error unmarshaling multipart tag: %v", e)
  172. }
  173. bad1 := struct {
  174. A uint32 `sshtype:"3"`
  175. }{
  176. 1,
  177. }
  178. if e := Unmarshal(Marshal(bad1), &res); e == nil {
  179. t.Errorf("bad struct unmarshaled without error")
  180. }
  181. }
  182. func randomBytes(out []byte, rand *rand.Rand) {
  183. for i := 0; i < len(out); i++ {
  184. out[i] = byte(rand.Int31())
  185. }
  186. }
  187. func randomNameList(rand *rand.Rand) []string {
  188. ret := make([]string, rand.Int31()&15)
  189. for i := range ret {
  190. s := make([]byte, 1+(rand.Int31()&15))
  191. for j := range s {
  192. s[j] = 'a' + uint8(rand.Int31()&15)
  193. }
  194. ret[i] = string(s)
  195. }
  196. return ret
  197. }
  198. func randomInt(rand *rand.Rand) *big.Int {
  199. return new(big.Int).SetInt64(int64(int32(rand.Uint32())))
  200. }
  201. func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  202. ki := &kexInitMsg{}
  203. randomBytes(ki.Cookie[:], rand)
  204. ki.KexAlgos = randomNameList(rand)
  205. ki.ServerHostKeyAlgos = randomNameList(rand)
  206. ki.CiphersClientServer = randomNameList(rand)
  207. ki.CiphersServerClient = randomNameList(rand)
  208. ki.MACsClientServer = randomNameList(rand)
  209. ki.MACsServerClient = randomNameList(rand)
  210. ki.CompressionClientServer = randomNameList(rand)
  211. ki.CompressionServerClient = randomNameList(rand)
  212. ki.LanguagesClientServer = randomNameList(rand)
  213. ki.LanguagesServerClient = randomNameList(rand)
  214. if rand.Int31()&1 == 1 {
  215. ki.FirstKexFollows = true
  216. }
  217. return reflect.ValueOf(ki)
  218. }
  219. func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
  220. dhi := &kexDHInitMsg{}
  221. dhi.X = randomInt(rand)
  222. return reflect.ValueOf(dhi)
  223. }
  224. var (
  225. _kexInitMsg = new(kexInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
  226. _kexDHInitMsg = new(kexDHInitMsg).Generate(rand.New(rand.NewSource(0)), 10).Elem().Interface()
  227. _kexInit = Marshal(_kexInitMsg)
  228. _kexDHInit = Marshal(_kexDHInitMsg)
  229. )
  230. func BenchmarkMarshalKexInitMsg(b *testing.B) {
  231. for i := 0; i < b.N; i++ {
  232. Marshal(_kexInitMsg)
  233. }
  234. }
  235. func BenchmarkUnmarshalKexInitMsg(b *testing.B) {
  236. m := new(kexInitMsg)
  237. for i := 0; i < b.N; i++ {
  238. Unmarshal(_kexInit, m)
  239. }
  240. }
  241. func BenchmarkMarshalKexDHInitMsg(b *testing.B) {
  242. for i := 0; i < b.N; i++ {
  243. Marshal(_kexDHInitMsg)
  244. }
  245. }
  246. func BenchmarkUnmarshalKexDHInitMsg(b *testing.B) {
  247. m := new(kexDHInitMsg)
  248. for i := 0; i < b.N; i++ {
  249. Unmarshal(_kexDHInit, m)
  250. }
  251. }