smp.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. // Copyright 2012 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // This file implements the Socialist Millionaires Protocol as described in
  5. // http://www.cypherpunks.ca/otr/Protocol-v2-3.1.0.html. The protocol
  6. // specification is required in order to understand this code and, where
  7. // possible, the variable names in the code match up with the spec.
  8. package otr
  9. import (
  10. "bytes"
  11. "crypto/sha256"
  12. "errors"
  13. "hash"
  14. "math/big"
  15. )
  16. type smpFailure string
  17. func (s smpFailure) Error() string {
  18. return string(s)
  19. }
  20. var smpFailureError = smpFailure("otr: SMP protocol failed")
  21. var smpSecretMissingError = smpFailure("otr: mutual secret needed")
  22. const smpVersion = 1
  23. const (
  24. smpState1 = iota
  25. smpState2
  26. smpState3
  27. smpState4
  28. )
  29. type smpState struct {
  30. state int
  31. a2, a3, b2, b3, pb, qb *big.Int
  32. g2a, g3a *big.Int
  33. g2, g3 *big.Int
  34. g3b, papb, qaqb, ra *big.Int
  35. saved *tlv
  36. secret *big.Int
  37. question string
  38. }
  39. func (c *Conversation) startSMP(question string) (tlvs []tlv) {
  40. if c.smp.state != smpState1 {
  41. tlvs = append(tlvs, c.generateSMPAbort())
  42. }
  43. tlvs = append(tlvs, c.generateSMP1(question))
  44. c.smp.question = ""
  45. c.smp.state = smpState2
  46. return
  47. }
  48. func (c *Conversation) resetSMP() {
  49. c.smp.state = smpState1
  50. c.smp.secret = nil
  51. c.smp.question = ""
  52. }
  53. func (c *Conversation) processSMP(in tlv) (out tlv, complete bool, err error) {
  54. data := in.data
  55. switch in.typ {
  56. case tlvTypeSMPAbort:
  57. if c.smp.state != smpState1 {
  58. err = smpFailureError
  59. }
  60. c.resetSMP()
  61. return
  62. case tlvTypeSMP1WithQuestion:
  63. // We preprocess this into a SMP1 message.
  64. nulPos := bytes.IndexByte(data, 0)
  65. if nulPos == -1 {
  66. err = errors.New("otr: SMP message with question didn't contain a NUL byte")
  67. return
  68. }
  69. c.smp.question = string(data[:nulPos])
  70. data = data[nulPos+1:]
  71. }
  72. numMPIs, data, ok := getU32(data)
  73. if !ok || numMPIs > 20 {
  74. err = errors.New("otr: corrupt SMP message")
  75. return
  76. }
  77. mpis := make([]*big.Int, numMPIs)
  78. for i := range mpis {
  79. var ok bool
  80. mpis[i], data, ok = getMPI(data)
  81. if !ok {
  82. err = errors.New("otr: corrupt SMP message")
  83. return
  84. }
  85. }
  86. switch in.typ {
  87. case tlvTypeSMP1, tlvTypeSMP1WithQuestion:
  88. if c.smp.state != smpState1 {
  89. c.resetSMP()
  90. out = c.generateSMPAbort()
  91. return
  92. }
  93. if c.smp.secret == nil {
  94. err = smpSecretMissingError
  95. return
  96. }
  97. if err = c.processSMP1(mpis); err != nil {
  98. return
  99. }
  100. c.smp.state = smpState3
  101. out = c.generateSMP2()
  102. case tlvTypeSMP2:
  103. if c.smp.state != smpState2 {
  104. c.resetSMP()
  105. out = c.generateSMPAbort()
  106. return
  107. }
  108. if out, err = c.processSMP2(mpis); err != nil {
  109. out = c.generateSMPAbort()
  110. return
  111. }
  112. c.smp.state = smpState4
  113. case tlvTypeSMP3:
  114. if c.smp.state != smpState3 {
  115. c.resetSMP()
  116. out = c.generateSMPAbort()
  117. return
  118. }
  119. if out, err = c.processSMP3(mpis); err != nil {
  120. return
  121. }
  122. c.smp.state = smpState1
  123. c.smp.secret = nil
  124. complete = true
  125. case tlvTypeSMP4:
  126. if c.smp.state != smpState4 {
  127. c.resetSMP()
  128. out = c.generateSMPAbort()
  129. return
  130. }
  131. if err = c.processSMP4(mpis); err != nil {
  132. out = c.generateSMPAbort()
  133. return
  134. }
  135. c.smp.state = smpState1
  136. c.smp.secret = nil
  137. complete = true
  138. default:
  139. panic("unknown SMP message")
  140. }
  141. return
  142. }
  143. func (c *Conversation) calcSMPSecret(mutualSecret []byte, weStarted bool) {
  144. h := sha256.New()
  145. h.Write([]byte{smpVersion})
  146. if weStarted {
  147. h.Write(c.PrivateKey.PublicKey.Fingerprint())
  148. h.Write(c.TheirPublicKey.Fingerprint())
  149. } else {
  150. h.Write(c.TheirPublicKey.Fingerprint())
  151. h.Write(c.PrivateKey.PublicKey.Fingerprint())
  152. }
  153. h.Write(c.SSID[:])
  154. h.Write(mutualSecret)
  155. c.smp.secret = new(big.Int).SetBytes(h.Sum(nil))
  156. }
  157. func (c *Conversation) generateSMP1(question string) tlv {
  158. var randBuf [16]byte
  159. c.smp.a2 = c.randMPI(randBuf[:])
  160. c.smp.a3 = c.randMPI(randBuf[:])
  161. g2a := new(big.Int).Exp(g, c.smp.a2, p)
  162. g3a := new(big.Int).Exp(g, c.smp.a3, p)
  163. h := sha256.New()
  164. r2 := c.randMPI(randBuf[:])
  165. r := new(big.Int).Exp(g, r2, p)
  166. c2 := new(big.Int).SetBytes(hashMPIs(h, 1, r))
  167. d2 := new(big.Int).Mul(c.smp.a2, c2)
  168. d2.Sub(r2, d2)
  169. d2.Mod(d2, q)
  170. if d2.Sign() < 0 {
  171. d2.Add(d2, q)
  172. }
  173. r3 := c.randMPI(randBuf[:])
  174. r.Exp(g, r3, p)
  175. c3 := new(big.Int).SetBytes(hashMPIs(h, 2, r))
  176. d3 := new(big.Int).Mul(c.smp.a3, c3)
  177. d3.Sub(r3, d3)
  178. d3.Mod(d3, q)
  179. if d3.Sign() < 0 {
  180. d3.Add(d3, q)
  181. }
  182. var ret tlv
  183. if len(question) > 0 {
  184. ret.typ = tlvTypeSMP1WithQuestion
  185. ret.data = append(ret.data, question...)
  186. ret.data = append(ret.data, 0)
  187. } else {
  188. ret.typ = tlvTypeSMP1
  189. }
  190. ret.data = appendU32(ret.data, 6)
  191. ret.data = appendMPIs(ret.data, g2a, c2, d2, g3a, c3, d3)
  192. return ret
  193. }
  194. func (c *Conversation) processSMP1(mpis []*big.Int) error {
  195. if len(mpis) != 6 {
  196. return errors.New("otr: incorrect number of arguments in SMP1 message")
  197. }
  198. g2a := mpis[0]
  199. c2 := mpis[1]
  200. d2 := mpis[2]
  201. g3a := mpis[3]
  202. c3 := mpis[4]
  203. d3 := mpis[5]
  204. h := sha256.New()
  205. r := new(big.Int).Exp(g, d2, p)
  206. s := new(big.Int).Exp(g2a, c2, p)
  207. r.Mul(r, s)
  208. r.Mod(r, p)
  209. t := new(big.Int).SetBytes(hashMPIs(h, 1, r))
  210. if c2.Cmp(t) != 0 {
  211. return errors.New("otr: ZKP c2 incorrect in SMP1 message")
  212. }
  213. r.Exp(g, d3, p)
  214. s.Exp(g3a, c3, p)
  215. r.Mul(r, s)
  216. r.Mod(r, p)
  217. t.SetBytes(hashMPIs(h, 2, r))
  218. if c3.Cmp(t) != 0 {
  219. return errors.New("otr: ZKP c3 incorrect in SMP1 message")
  220. }
  221. c.smp.g2a = g2a
  222. c.smp.g3a = g3a
  223. return nil
  224. }
  225. func (c *Conversation) generateSMP2() tlv {
  226. var randBuf [16]byte
  227. b2 := c.randMPI(randBuf[:])
  228. c.smp.b3 = c.randMPI(randBuf[:])
  229. r2 := c.randMPI(randBuf[:])
  230. r3 := c.randMPI(randBuf[:])
  231. r4 := c.randMPI(randBuf[:])
  232. r5 := c.randMPI(randBuf[:])
  233. r6 := c.randMPI(randBuf[:])
  234. g2b := new(big.Int).Exp(g, b2, p)
  235. g3b := new(big.Int).Exp(g, c.smp.b3, p)
  236. r := new(big.Int).Exp(g, r2, p)
  237. h := sha256.New()
  238. c2 := new(big.Int).SetBytes(hashMPIs(h, 3, r))
  239. d2 := new(big.Int).Mul(b2, c2)
  240. d2.Sub(r2, d2)
  241. d2.Mod(d2, q)
  242. if d2.Sign() < 0 {
  243. d2.Add(d2, q)
  244. }
  245. r.Exp(g, r3, p)
  246. c3 := new(big.Int).SetBytes(hashMPIs(h, 4, r))
  247. d3 := new(big.Int).Mul(c.smp.b3, c3)
  248. d3.Sub(r3, d3)
  249. d3.Mod(d3, q)
  250. if d3.Sign() < 0 {
  251. d3.Add(d3, q)
  252. }
  253. c.smp.g2 = new(big.Int).Exp(c.smp.g2a, b2, p)
  254. c.smp.g3 = new(big.Int).Exp(c.smp.g3a, c.smp.b3, p)
  255. c.smp.pb = new(big.Int).Exp(c.smp.g3, r4, p)
  256. c.smp.qb = new(big.Int).Exp(g, r4, p)
  257. r.Exp(c.smp.g2, c.smp.secret, p)
  258. c.smp.qb.Mul(c.smp.qb, r)
  259. c.smp.qb.Mod(c.smp.qb, p)
  260. s := new(big.Int)
  261. s.Exp(c.smp.g2, r6, p)
  262. r.Exp(g, r5, p)
  263. s.Mul(r, s)
  264. s.Mod(s, p)
  265. r.Exp(c.smp.g3, r5, p)
  266. cp := new(big.Int).SetBytes(hashMPIs(h, 5, r, s))
  267. // D5 = r5 - r4 cP mod q and D6 = r6 - y cP mod q
  268. s.Mul(r4, cp)
  269. r.Sub(r5, s)
  270. d5 := new(big.Int).Mod(r, q)
  271. if d5.Sign() < 0 {
  272. d5.Add(d5, q)
  273. }
  274. s.Mul(c.smp.secret, cp)
  275. r.Sub(r6, s)
  276. d6 := new(big.Int).Mod(r, q)
  277. if d6.Sign() < 0 {
  278. d6.Add(d6, q)
  279. }
  280. var ret tlv
  281. ret.typ = tlvTypeSMP2
  282. ret.data = appendU32(ret.data, 11)
  283. ret.data = appendMPIs(ret.data, g2b, c2, d2, g3b, c3, d3, c.smp.pb, c.smp.qb, cp, d5, d6)
  284. return ret
  285. }
  286. func (c *Conversation) processSMP2(mpis []*big.Int) (out tlv, err error) {
  287. if len(mpis) != 11 {
  288. err = errors.New("otr: incorrect number of arguments in SMP2 message")
  289. return
  290. }
  291. g2b := mpis[0]
  292. c2 := mpis[1]
  293. d2 := mpis[2]
  294. g3b := mpis[3]
  295. c3 := mpis[4]
  296. d3 := mpis[5]
  297. pb := mpis[6]
  298. qb := mpis[7]
  299. cp := mpis[8]
  300. d5 := mpis[9]
  301. d6 := mpis[10]
  302. h := sha256.New()
  303. r := new(big.Int).Exp(g, d2, p)
  304. s := new(big.Int).Exp(g2b, c2, p)
  305. r.Mul(r, s)
  306. r.Mod(r, p)
  307. s.SetBytes(hashMPIs(h, 3, r))
  308. if c2.Cmp(s) != 0 {
  309. err = errors.New("otr: ZKP c2 failed in SMP2 message")
  310. return
  311. }
  312. r.Exp(g, d3, p)
  313. s.Exp(g3b, c3, p)
  314. r.Mul(r, s)
  315. r.Mod(r, p)
  316. s.SetBytes(hashMPIs(h, 4, r))
  317. if c3.Cmp(s) != 0 {
  318. err = errors.New("otr: ZKP c3 failed in SMP2 message")
  319. return
  320. }
  321. c.smp.g2 = new(big.Int).Exp(g2b, c.smp.a2, p)
  322. c.smp.g3 = new(big.Int).Exp(g3b, c.smp.a3, p)
  323. r.Exp(g, d5, p)
  324. s.Exp(c.smp.g2, d6, p)
  325. r.Mul(r, s)
  326. s.Exp(qb, cp, p)
  327. r.Mul(r, s)
  328. r.Mod(r, p)
  329. s.Exp(c.smp.g3, d5, p)
  330. t := new(big.Int).Exp(pb, cp, p)
  331. s.Mul(s, t)
  332. s.Mod(s, p)
  333. t.SetBytes(hashMPIs(h, 5, s, r))
  334. if cp.Cmp(t) != 0 {
  335. err = errors.New("otr: ZKP cP failed in SMP2 message")
  336. return
  337. }
  338. var randBuf [16]byte
  339. r4 := c.randMPI(randBuf[:])
  340. r5 := c.randMPI(randBuf[:])
  341. r6 := c.randMPI(randBuf[:])
  342. r7 := c.randMPI(randBuf[:])
  343. pa := new(big.Int).Exp(c.smp.g3, r4, p)
  344. r.Exp(c.smp.g2, c.smp.secret, p)
  345. qa := new(big.Int).Exp(g, r4, p)
  346. qa.Mul(qa, r)
  347. qa.Mod(qa, p)
  348. r.Exp(g, r5, p)
  349. s.Exp(c.smp.g2, r6, p)
  350. r.Mul(r, s)
  351. r.Mod(r, p)
  352. s.Exp(c.smp.g3, r5, p)
  353. cp.SetBytes(hashMPIs(h, 6, s, r))
  354. r.Mul(r4, cp)
  355. d5 = new(big.Int).Sub(r5, r)
  356. d5.Mod(d5, q)
  357. if d5.Sign() < 0 {
  358. d5.Add(d5, q)
  359. }
  360. r.Mul(c.smp.secret, cp)
  361. d6 = new(big.Int).Sub(r6, r)
  362. d6.Mod(d6, q)
  363. if d6.Sign() < 0 {
  364. d6.Add(d6, q)
  365. }
  366. r.ModInverse(qb, p)
  367. qaqb := new(big.Int).Mul(qa, r)
  368. qaqb.Mod(qaqb, p)
  369. ra := new(big.Int).Exp(qaqb, c.smp.a3, p)
  370. r.Exp(qaqb, r7, p)
  371. s.Exp(g, r7, p)
  372. cr := new(big.Int).SetBytes(hashMPIs(h, 7, s, r))
  373. r.Mul(c.smp.a3, cr)
  374. d7 := new(big.Int).Sub(r7, r)
  375. d7.Mod(d7, q)
  376. if d7.Sign() < 0 {
  377. d7.Add(d7, q)
  378. }
  379. c.smp.g3b = g3b
  380. c.smp.qaqb = qaqb
  381. r.ModInverse(pb, p)
  382. c.smp.papb = new(big.Int).Mul(pa, r)
  383. c.smp.papb.Mod(c.smp.papb, p)
  384. c.smp.ra = ra
  385. out.typ = tlvTypeSMP3
  386. out.data = appendU32(out.data, 8)
  387. out.data = appendMPIs(out.data, pa, qa, cp, d5, d6, ra, cr, d7)
  388. return
  389. }
  390. func (c *Conversation) processSMP3(mpis []*big.Int) (out tlv, err error) {
  391. if len(mpis) != 8 {
  392. err = errors.New("otr: incorrect number of arguments in SMP3 message")
  393. return
  394. }
  395. pa := mpis[0]
  396. qa := mpis[1]
  397. cp := mpis[2]
  398. d5 := mpis[3]
  399. d6 := mpis[4]
  400. ra := mpis[5]
  401. cr := mpis[6]
  402. d7 := mpis[7]
  403. h := sha256.New()
  404. r := new(big.Int).Exp(g, d5, p)
  405. s := new(big.Int).Exp(c.smp.g2, d6, p)
  406. r.Mul(r, s)
  407. s.Exp(qa, cp, p)
  408. r.Mul(r, s)
  409. r.Mod(r, p)
  410. s.Exp(c.smp.g3, d5, p)
  411. t := new(big.Int).Exp(pa, cp, p)
  412. s.Mul(s, t)
  413. s.Mod(s, p)
  414. t.SetBytes(hashMPIs(h, 6, s, r))
  415. if t.Cmp(cp) != 0 {
  416. err = errors.New("otr: ZKP cP failed in SMP3 message")
  417. return
  418. }
  419. r.ModInverse(c.smp.qb, p)
  420. qaqb := new(big.Int).Mul(qa, r)
  421. qaqb.Mod(qaqb, p)
  422. r.Exp(qaqb, d7, p)
  423. s.Exp(ra, cr, p)
  424. r.Mul(r, s)
  425. r.Mod(r, p)
  426. s.Exp(g, d7, p)
  427. t.Exp(c.smp.g3a, cr, p)
  428. s.Mul(s, t)
  429. s.Mod(s, p)
  430. t.SetBytes(hashMPIs(h, 7, s, r))
  431. if t.Cmp(cr) != 0 {
  432. err = errors.New("otr: ZKP cR failed in SMP3 message")
  433. return
  434. }
  435. var randBuf [16]byte
  436. r7 := c.randMPI(randBuf[:])
  437. rb := new(big.Int).Exp(qaqb, c.smp.b3, p)
  438. r.Exp(qaqb, r7, p)
  439. s.Exp(g, r7, p)
  440. cr = new(big.Int).SetBytes(hashMPIs(h, 8, s, r))
  441. r.Mul(c.smp.b3, cr)
  442. d7 = new(big.Int).Sub(r7, r)
  443. d7.Mod(d7, q)
  444. if d7.Sign() < 0 {
  445. d7.Add(d7, q)
  446. }
  447. out.typ = tlvTypeSMP4
  448. out.data = appendU32(out.data, 3)
  449. out.data = appendMPIs(out.data, rb, cr, d7)
  450. r.ModInverse(c.smp.pb, p)
  451. r.Mul(pa, r)
  452. r.Mod(r, p)
  453. s.Exp(ra, c.smp.b3, p)
  454. if r.Cmp(s) != 0 {
  455. err = smpFailureError
  456. }
  457. return
  458. }
  459. func (c *Conversation) processSMP4(mpis []*big.Int) error {
  460. if len(mpis) != 3 {
  461. return errors.New("otr: incorrect number of arguments in SMP4 message")
  462. }
  463. rb := mpis[0]
  464. cr := mpis[1]
  465. d7 := mpis[2]
  466. h := sha256.New()
  467. r := new(big.Int).Exp(c.smp.qaqb, d7, p)
  468. s := new(big.Int).Exp(rb, cr, p)
  469. r.Mul(r, s)
  470. r.Mod(r, p)
  471. s.Exp(g, d7, p)
  472. t := new(big.Int).Exp(c.smp.g3b, cr, p)
  473. s.Mul(s, t)
  474. s.Mod(s, p)
  475. t.SetBytes(hashMPIs(h, 8, s, r))
  476. if t.Cmp(cr) != 0 {
  477. return errors.New("otr: ZKP cR failed in SMP4 message")
  478. }
  479. r.Exp(rb, c.smp.a3, p)
  480. if r.Cmp(c.smp.papb) != 0 {
  481. return smpFailureError
  482. }
  483. return nil
  484. }
  485. func (c *Conversation) generateSMPAbort() tlv {
  486. return tlv{typ: tlvTypeSMPAbort}
  487. }
  488. func hashMPIs(h hash.Hash, magic byte, mpis ...*big.Int) []byte {
  489. if h != nil {
  490. h.Reset()
  491. } else {
  492. h = sha256.New()
  493. }
  494. h.Write([]byte{magic})
  495. for _, mpi := range mpis {
  496. h.Write(appendMPI(nil, mpi))
  497. }
  498. return h.Sum(nil)
  499. }