123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572 |
- // Copyright 2012 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- // This file implements the Socialist Millionaires Protocol as described in
- // http://www.cypherpunks.ca/otr/Protocol-v2-3.1.0.html. The protocol
- // specification is required in order to understand this code and, where
- // possible, the variable names in the code match up with the spec.
- package otr
- import (
- "bytes"
- "crypto/sha256"
- "errors"
- "hash"
- "math/big"
- )
- type smpFailure string
- func (s smpFailure) Error() string {
- return string(s)
- }
- var smpFailureError = smpFailure("otr: SMP protocol failed")
- var smpSecretMissingError = smpFailure("otr: mutual secret needed")
- const smpVersion = 1
- const (
- smpState1 = iota
- smpState2
- smpState3
- smpState4
- )
- type smpState struct {
- state int
- a2, a3, b2, b3, pb, qb *big.Int
- g2a, g3a *big.Int
- g2, g3 *big.Int
- g3b, papb, qaqb, ra *big.Int
- saved *tlv
- secret *big.Int
- question string
- }
- func (c *Conversation) startSMP(question string) (tlvs []tlv) {
- if c.smp.state != smpState1 {
- tlvs = append(tlvs, c.generateSMPAbort())
- }
- tlvs = append(tlvs, c.generateSMP1(question))
- c.smp.question = ""
- c.smp.state = smpState2
- return
- }
- func (c *Conversation) resetSMP() {
- c.smp.state = smpState1
- c.smp.secret = nil
- c.smp.question = ""
- }
- func (c *Conversation) processSMP(in tlv) (out tlv, complete bool, err error) {
- data := in.data
- switch in.typ {
- case tlvTypeSMPAbort:
- if c.smp.state != smpState1 {
- err = smpFailureError
- }
- c.resetSMP()
- return
- case tlvTypeSMP1WithQuestion:
- // We preprocess this into a SMP1 message.
- nulPos := bytes.IndexByte(data, 0)
- if nulPos == -1 {
- err = errors.New("otr: SMP message with question didn't contain a NUL byte")
- return
- }
- c.smp.question = string(data[:nulPos])
- data = data[nulPos+1:]
- }
- numMPIs, data, ok := getU32(data)
- if !ok || numMPIs > 20 {
- err = errors.New("otr: corrupt SMP message")
- return
- }
- mpis := make([]*big.Int, numMPIs)
- for i := range mpis {
- var ok bool
- mpis[i], data, ok = getMPI(data)
- if !ok {
- err = errors.New("otr: corrupt SMP message")
- return
- }
- }
- switch in.typ {
- case tlvTypeSMP1, tlvTypeSMP1WithQuestion:
- if c.smp.state != smpState1 {
- c.resetSMP()
- out = c.generateSMPAbort()
- return
- }
- if c.smp.secret == nil {
- err = smpSecretMissingError
- return
- }
- if err = c.processSMP1(mpis); err != nil {
- return
- }
- c.smp.state = smpState3
- out = c.generateSMP2()
- case tlvTypeSMP2:
- if c.smp.state != smpState2 {
- c.resetSMP()
- out = c.generateSMPAbort()
- return
- }
- if out, err = c.processSMP2(mpis); err != nil {
- out = c.generateSMPAbort()
- return
- }
- c.smp.state = smpState4
- case tlvTypeSMP3:
- if c.smp.state != smpState3 {
- c.resetSMP()
- out = c.generateSMPAbort()
- return
- }
- if out, err = c.processSMP3(mpis); err != nil {
- return
- }
- c.smp.state = smpState1
- c.smp.secret = nil
- complete = true
- case tlvTypeSMP4:
- if c.smp.state != smpState4 {
- c.resetSMP()
- out = c.generateSMPAbort()
- return
- }
- if err = c.processSMP4(mpis); err != nil {
- out = c.generateSMPAbort()
- return
- }
- c.smp.state = smpState1
- c.smp.secret = nil
- complete = true
- default:
- panic("unknown SMP message")
- }
- return
- }
- func (c *Conversation) calcSMPSecret(mutualSecret []byte, weStarted bool) {
- h := sha256.New()
- h.Write([]byte{smpVersion})
- if weStarted {
- h.Write(c.PrivateKey.PublicKey.Fingerprint())
- h.Write(c.TheirPublicKey.Fingerprint())
- } else {
- h.Write(c.TheirPublicKey.Fingerprint())
- h.Write(c.PrivateKey.PublicKey.Fingerprint())
- }
- h.Write(c.SSID[:])
- h.Write(mutualSecret)
- c.smp.secret = new(big.Int).SetBytes(h.Sum(nil))
- }
- func (c *Conversation) generateSMP1(question string) tlv {
- var randBuf [16]byte
- c.smp.a2 = c.randMPI(randBuf[:])
- c.smp.a3 = c.randMPI(randBuf[:])
- g2a := new(big.Int).Exp(g, c.smp.a2, p)
- g3a := new(big.Int).Exp(g, c.smp.a3, p)
- h := sha256.New()
- r2 := c.randMPI(randBuf[:])
- r := new(big.Int).Exp(g, r2, p)
- c2 := new(big.Int).SetBytes(hashMPIs(h, 1, r))
- d2 := new(big.Int).Mul(c.smp.a2, c2)
- d2.Sub(r2, d2)
- d2.Mod(d2, q)
- if d2.Sign() < 0 {
- d2.Add(d2, q)
- }
- r3 := c.randMPI(randBuf[:])
- r.Exp(g, r3, p)
- c3 := new(big.Int).SetBytes(hashMPIs(h, 2, r))
- d3 := new(big.Int).Mul(c.smp.a3, c3)
- d3.Sub(r3, d3)
- d3.Mod(d3, q)
- if d3.Sign() < 0 {
- d3.Add(d3, q)
- }
- var ret tlv
- if len(question) > 0 {
- ret.typ = tlvTypeSMP1WithQuestion
- ret.data = append(ret.data, question...)
- ret.data = append(ret.data, 0)
- } else {
- ret.typ = tlvTypeSMP1
- }
- ret.data = appendU32(ret.data, 6)
- ret.data = appendMPIs(ret.data, g2a, c2, d2, g3a, c3, d3)
- return ret
- }
- func (c *Conversation) processSMP1(mpis []*big.Int) error {
- if len(mpis) != 6 {
- return errors.New("otr: incorrect number of arguments in SMP1 message")
- }
- g2a := mpis[0]
- c2 := mpis[1]
- d2 := mpis[2]
- g3a := mpis[3]
- c3 := mpis[4]
- d3 := mpis[5]
- h := sha256.New()
- r := new(big.Int).Exp(g, d2, p)
- s := new(big.Int).Exp(g2a, c2, p)
- r.Mul(r, s)
- r.Mod(r, p)
- t := new(big.Int).SetBytes(hashMPIs(h, 1, r))
- if c2.Cmp(t) != 0 {
- return errors.New("otr: ZKP c2 incorrect in SMP1 message")
- }
- r.Exp(g, d3, p)
- s.Exp(g3a, c3, p)
- r.Mul(r, s)
- r.Mod(r, p)
- t.SetBytes(hashMPIs(h, 2, r))
- if c3.Cmp(t) != 0 {
- return errors.New("otr: ZKP c3 incorrect in SMP1 message")
- }
- c.smp.g2a = g2a
- c.smp.g3a = g3a
- return nil
- }
- func (c *Conversation) generateSMP2() tlv {
- var randBuf [16]byte
- b2 := c.randMPI(randBuf[:])
- c.smp.b3 = c.randMPI(randBuf[:])
- r2 := c.randMPI(randBuf[:])
- r3 := c.randMPI(randBuf[:])
- r4 := c.randMPI(randBuf[:])
- r5 := c.randMPI(randBuf[:])
- r6 := c.randMPI(randBuf[:])
- g2b := new(big.Int).Exp(g, b2, p)
- g3b := new(big.Int).Exp(g, c.smp.b3, p)
- r := new(big.Int).Exp(g, r2, p)
- h := sha256.New()
- c2 := new(big.Int).SetBytes(hashMPIs(h, 3, r))
- d2 := new(big.Int).Mul(b2, c2)
- d2.Sub(r2, d2)
- d2.Mod(d2, q)
- if d2.Sign() < 0 {
- d2.Add(d2, q)
- }
- r.Exp(g, r3, p)
- c3 := new(big.Int).SetBytes(hashMPIs(h, 4, r))
- d3 := new(big.Int).Mul(c.smp.b3, c3)
- d3.Sub(r3, d3)
- d3.Mod(d3, q)
- if d3.Sign() < 0 {
- d3.Add(d3, q)
- }
- c.smp.g2 = new(big.Int).Exp(c.smp.g2a, b2, p)
- c.smp.g3 = new(big.Int).Exp(c.smp.g3a, c.smp.b3, p)
- c.smp.pb = new(big.Int).Exp(c.smp.g3, r4, p)
- c.smp.qb = new(big.Int).Exp(g, r4, p)
- r.Exp(c.smp.g2, c.smp.secret, p)
- c.smp.qb.Mul(c.smp.qb, r)
- c.smp.qb.Mod(c.smp.qb, p)
- s := new(big.Int)
- s.Exp(c.smp.g2, r6, p)
- r.Exp(g, r5, p)
- s.Mul(r, s)
- s.Mod(s, p)
- r.Exp(c.smp.g3, r5, p)
- cp := new(big.Int).SetBytes(hashMPIs(h, 5, r, s))
- // D5 = r5 - r4 cP mod q and D6 = r6 - y cP mod q
- s.Mul(r4, cp)
- r.Sub(r5, s)
- d5 := new(big.Int).Mod(r, q)
- if d5.Sign() < 0 {
- d5.Add(d5, q)
- }
- s.Mul(c.smp.secret, cp)
- r.Sub(r6, s)
- d6 := new(big.Int).Mod(r, q)
- if d6.Sign() < 0 {
- d6.Add(d6, q)
- }
- var ret tlv
- ret.typ = tlvTypeSMP2
- ret.data = appendU32(ret.data, 11)
- ret.data = appendMPIs(ret.data, g2b, c2, d2, g3b, c3, d3, c.smp.pb, c.smp.qb, cp, d5, d6)
- return ret
- }
- func (c *Conversation) processSMP2(mpis []*big.Int) (out tlv, err error) {
- if len(mpis) != 11 {
- err = errors.New("otr: incorrect number of arguments in SMP2 message")
- return
- }
- g2b := mpis[0]
- c2 := mpis[1]
- d2 := mpis[2]
- g3b := mpis[3]
- c3 := mpis[4]
- d3 := mpis[5]
- pb := mpis[6]
- qb := mpis[7]
- cp := mpis[8]
- d5 := mpis[9]
- d6 := mpis[10]
- h := sha256.New()
- r := new(big.Int).Exp(g, d2, p)
- s := new(big.Int).Exp(g2b, c2, p)
- r.Mul(r, s)
- r.Mod(r, p)
- s.SetBytes(hashMPIs(h, 3, r))
- if c2.Cmp(s) != 0 {
- err = errors.New("otr: ZKP c2 failed in SMP2 message")
- return
- }
- r.Exp(g, d3, p)
- s.Exp(g3b, c3, p)
- r.Mul(r, s)
- r.Mod(r, p)
- s.SetBytes(hashMPIs(h, 4, r))
- if c3.Cmp(s) != 0 {
- err = errors.New("otr: ZKP c3 failed in SMP2 message")
- return
- }
- c.smp.g2 = new(big.Int).Exp(g2b, c.smp.a2, p)
- c.smp.g3 = new(big.Int).Exp(g3b, c.smp.a3, p)
- r.Exp(g, d5, p)
- s.Exp(c.smp.g2, d6, p)
- r.Mul(r, s)
- s.Exp(qb, cp, p)
- r.Mul(r, s)
- r.Mod(r, p)
- s.Exp(c.smp.g3, d5, p)
- t := new(big.Int).Exp(pb, cp, p)
- s.Mul(s, t)
- s.Mod(s, p)
- t.SetBytes(hashMPIs(h, 5, s, r))
- if cp.Cmp(t) != 0 {
- err = errors.New("otr: ZKP cP failed in SMP2 message")
- return
- }
- var randBuf [16]byte
- r4 := c.randMPI(randBuf[:])
- r5 := c.randMPI(randBuf[:])
- r6 := c.randMPI(randBuf[:])
- r7 := c.randMPI(randBuf[:])
- pa := new(big.Int).Exp(c.smp.g3, r4, p)
- r.Exp(c.smp.g2, c.smp.secret, p)
- qa := new(big.Int).Exp(g, r4, p)
- qa.Mul(qa, r)
- qa.Mod(qa, p)
- r.Exp(g, r5, p)
- s.Exp(c.smp.g2, r6, p)
- r.Mul(r, s)
- r.Mod(r, p)
- s.Exp(c.smp.g3, r5, p)
- cp.SetBytes(hashMPIs(h, 6, s, r))
- r.Mul(r4, cp)
- d5 = new(big.Int).Sub(r5, r)
- d5.Mod(d5, q)
- if d5.Sign() < 0 {
- d5.Add(d5, q)
- }
- r.Mul(c.smp.secret, cp)
- d6 = new(big.Int).Sub(r6, r)
- d6.Mod(d6, q)
- if d6.Sign() < 0 {
- d6.Add(d6, q)
- }
- r.ModInverse(qb, p)
- qaqb := new(big.Int).Mul(qa, r)
- qaqb.Mod(qaqb, p)
- ra := new(big.Int).Exp(qaqb, c.smp.a3, p)
- r.Exp(qaqb, r7, p)
- s.Exp(g, r7, p)
- cr := new(big.Int).SetBytes(hashMPIs(h, 7, s, r))
- r.Mul(c.smp.a3, cr)
- d7 := new(big.Int).Sub(r7, r)
- d7.Mod(d7, q)
- if d7.Sign() < 0 {
- d7.Add(d7, q)
- }
- c.smp.g3b = g3b
- c.smp.qaqb = qaqb
- r.ModInverse(pb, p)
- c.smp.papb = new(big.Int).Mul(pa, r)
- c.smp.papb.Mod(c.smp.papb, p)
- c.smp.ra = ra
- out.typ = tlvTypeSMP3
- out.data = appendU32(out.data, 8)
- out.data = appendMPIs(out.data, pa, qa, cp, d5, d6, ra, cr, d7)
- return
- }
- func (c *Conversation) processSMP3(mpis []*big.Int) (out tlv, err error) {
- if len(mpis) != 8 {
- err = errors.New("otr: incorrect number of arguments in SMP3 message")
- return
- }
- pa := mpis[0]
- qa := mpis[1]
- cp := mpis[2]
- d5 := mpis[3]
- d6 := mpis[4]
- ra := mpis[5]
- cr := mpis[6]
- d7 := mpis[7]
- h := sha256.New()
- r := new(big.Int).Exp(g, d5, p)
- s := new(big.Int).Exp(c.smp.g2, d6, p)
- r.Mul(r, s)
- s.Exp(qa, cp, p)
- r.Mul(r, s)
- r.Mod(r, p)
- s.Exp(c.smp.g3, d5, p)
- t := new(big.Int).Exp(pa, cp, p)
- s.Mul(s, t)
- s.Mod(s, p)
- t.SetBytes(hashMPIs(h, 6, s, r))
- if t.Cmp(cp) != 0 {
- err = errors.New("otr: ZKP cP failed in SMP3 message")
- return
- }
- r.ModInverse(c.smp.qb, p)
- qaqb := new(big.Int).Mul(qa, r)
- qaqb.Mod(qaqb, p)
- r.Exp(qaqb, d7, p)
- s.Exp(ra, cr, p)
- r.Mul(r, s)
- r.Mod(r, p)
- s.Exp(g, d7, p)
- t.Exp(c.smp.g3a, cr, p)
- s.Mul(s, t)
- s.Mod(s, p)
- t.SetBytes(hashMPIs(h, 7, s, r))
- if t.Cmp(cr) != 0 {
- err = errors.New("otr: ZKP cR failed in SMP3 message")
- return
- }
- var randBuf [16]byte
- r7 := c.randMPI(randBuf[:])
- rb := new(big.Int).Exp(qaqb, c.smp.b3, p)
- r.Exp(qaqb, r7, p)
- s.Exp(g, r7, p)
- cr = new(big.Int).SetBytes(hashMPIs(h, 8, s, r))
- r.Mul(c.smp.b3, cr)
- d7 = new(big.Int).Sub(r7, r)
- d7.Mod(d7, q)
- if d7.Sign() < 0 {
- d7.Add(d7, q)
- }
- out.typ = tlvTypeSMP4
- out.data = appendU32(out.data, 3)
- out.data = appendMPIs(out.data, rb, cr, d7)
- r.ModInverse(c.smp.pb, p)
- r.Mul(pa, r)
- r.Mod(r, p)
- s.Exp(ra, c.smp.b3, p)
- if r.Cmp(s) != 0 {
- err = smpFailureError
- }
- return
- }
- func (c *Conversation) processSMP4(mpis []*big.Int) error {
- if len(mpis) != 3 {
- return errors.New("otr: incorrect number of arguments in SMP4 message")
- }
- rb := mpis[0]
- cr := mpis[1]
- d7 := mpis[2]
- h := sha256.New()
- r := new(big.Int).Exp(c.smp.qaqb, d7, p)
- s := new(big.Int).Exp(rb, cr, p)
- r.Mul(r, s)
- r.Mod(r, p)
- s.Exp(g, d7, p)
- t := new(big.Int).Exp(c.smp.g3b, cr, p)
- s.Mul(s, t)
- s.Mod(s, p)
- t.SetBytes(hashMPIs(h, 8, s, r))
- if t.Cmp(cr) != 0 {
- return errors.New("otr: ZKP cR failed in SMP4 message")
- }
- r.Exp(rb, c.smp.a3, p)
- if r.Cmp(c.smp.papb) != 0 {
- return smpFailureError
- }
- return nil
- }
- func (c *Conversation) generateSMPAbort() tlv {
- return tlv{typ: tlvTypeSMPAbort}
- }
- func hashMPIs(h hash.Hash, magic byte, mpis ...*big.Int) []byte {
- if h != nil {
- h.Reset()
- } else {
- h = sha256.New()
- }
- h.Write([]byte{magic})
- for _, mpi := range mpis {
- h.Write(appendMPI(nil, mpi))
- }
- return h.Sum(nil)
- }
|