1
0

fec.go 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. package kcp
  2. import (
  3. "encoding/binary"
  4. "sync/atomic"
  5. "github.com/klauspost/reedsolomon"
  6. )
  7. const (
  8. fecHeaderSize = 6
  9. fecHeaderSizePlus2 = fecHeaderSize + 2 // plus 2B data size
  10. typeData = 0xf1
  11. typeFEC = 0xf2
  12. )
  13. type (
  14. // fecPacket is a decoded FEC packet
  15. fecPacket struct {
  16. seqid uint32
  17. flag uint16
  18. data []byte
  19. }
  20. // fecDecoder for decoding incoming packets
  21. fecDecoder struct {
  22. rxlimit int // queue size limit
  23. dataShards int
  24. parityShards int
  25. shardSize int
  26. rx []fecPacket // ordered receive queue
  27. // caches
  28. decodeCache [][]byte
  29. flagCache []bool
  30. // RS decoder
  31. codec reedsolomon.Encoder
  32. }
  33. )
  34. func newFECDecoder(rxlimit, dataShards, parityShards int) *fecDecoder {
  35. if dataShards <= 0 || parityShards <= 0 {
  36. return nil
  37. }
  38. if rxlimit < dataShards+parityShards {
  39. return nil
  40. }
  41. fec := new(fecDecoder)
  42. fec.rxlimit = rxlimit
  43. fec.dataShards = dataShards
  44. fec.parityShards = parityShards
  45. fec.shardSize = dataShards + parityShards
  46. enc, err := reedsolomon.New(dataShards, parityShards, reedsolomon.WithMaxGoroutines(1))
  47. if err != nil {
  48. return nil
  49. }
  50. fec.codec = enc
  51. fec.decodeCache = make([][]byte, fec.shardSize)
  52. fec.flagCache = make([]bool, fec.shardSize)
  53. return fec
  54. }
  55. // decodeBytes a fec packet
  56. func (dec *fecDecoder) decodeBytes(data []byte) fecPacket {
  57. var pkt fecPacket
  58. pkt.seqid = binary.LittleEndian.Uint32(data)
  59. pkt.flag = binary.LittleEndian.Uint16(data[4:])
  60. // allocate memory & copy
  61. buf := xmitBuf.Get().([]byte)[:len(data)-6]
  62. copy(buf, data[6:])
  63. pkt.data = buf
  64. return pkt
  65. }
  66. // decode a fec packet
  67. func (dec *fecDecoder) decode(pkt fecPacket) (recovered [][]byte) {
  68. // insertion
  69. n := len(dec.rx) - 1
  70. insertIdx := 0
  71. for i := n; i >= 0; i-- {
  72. if pkt.seqid == dec.rx[i].seqid { // de-duplicate
  73. xmitBuf.Put(pkt.data)
  74. return nil
  75. } else if _itimediff(pkt.seqid, dec.rx[i].seqid) > 0 { // insertion
  76. insertIdx = i + 1
  77. break
  78. }
  79. }
  80. // insert into ordered rx queue
  81. if insertIdx == n+1 {
  82. dec.rx = append(dec.rx, pkt)
  83. } else {
  84. dec.rx = append(dec.rx, fecPacket{})
  85. copy(dec.rx[insertIdx+1:], dec.rx[insertIdx:]) // shift right
  86. dec.rx[insertIdx] = pkt
  87. }
  88. // shard range for current packet
  89. shardBegin := pkt.seqid - pkt.seqid%uint32(dec.shardSize)
  90. shardEnd := shardBegin + uint32(dec.shardSize) - 1
  91. // max search range in ordered queue for current shard
  92. searchBegin := insertIdx - int(pkt.seqid%uint32(dec.shardSize))
  93. if searchBegin < 0 {
  94. searchBegin = 0
  95. }
  96. searchEnd := searchBegin + dec.shardSize - 1
  97. if searchEnd >= len(dec.rx) {
  98. searchEnd = len(dec.rx) - 1
  99. }
  100. // re-construct datashards
  101. if searchEnd-searchBegin+1 >= dec.dataShards {
  102. var numshard, numDataShard, first, maxlen int
  103. // zero cache
  104. shards := dec.decodeCache
  105. shardsflag := dec.flagCache
  106. for k := range dec.decodeCache {
  107. shards[k] = nil
  108. shardsflag[k] = false
  109. }
  110. // shard assembly
  111. for i := searchBegin; i <= searchEnd; i++ {
  112. seqid := dec.rx[i].seqid
  113. if _itimediff(seqid, shardEnd) > 0 {
  114. break
  115. } else if _itimediff(seqid, shardBegin) >= 0 {
  116. shards[seqid%uint32(dec.shardSize)] = dec.rx[i].data
  117. shardsflag[seqid%uint32(dec.shardSize)] = true
  118. numshard++
  119. if dec.rx[i].flag == typeData {
  120. numDataShard++
  121. }
  122. if numshard == 1 {
  123. first = i
  124. }
  125. if len(dec.rx[i].data) > maxlen {
  126. maxlen = len(dec.rx[i].data)
  127. }
  128. }
  129. }
  130. if numDataShard == dec.dataShards {
  131. // case 1: no lost data shards
  132. dec.rx = dec.freeRange(first, numshard, dec.rx)
  133. } else if numshard >= dec.dataShards {
  134. // case 2: data shard lost, but recoverable from parity shard
  135. for k := range shards {
  136. if shards[k] != nil {
  137. dlen := len(shards[k])
  138. shards[k] = shards[k][:maxlen]
  139. xorBytes(shards[k][dlen:], shards[k][dlen:], shards[k][dlen:])
  140. }
  141. }
  142. if err := dec.codec.Reconstruct(shards); err == nil {
  143. for k := range shards[:dec.dataShards] {
  144. if !shardsflag[k] {
  145. recovered = append(recovered, shards[k])
  146. }
  147. }
  148. }
  149. dec.rx = dec.freeRange(first, numshard, dec.rx)
  150. }
  151. }
  152. // keep rxlimit
  153. if len(dec.rx) > dec.rxlimit {
  154. if dec.rx[0].flag == typeData { // record unrecoverable data
  155. atomic.AddUint64(&DefaultSnmp.FECShortShards, 1)
  156. }
  157. dec.rx = dec.freeRange(0, 1, dec.rx)
  158. }
  159. return
  160. }
  161. // free a range of fecPacket, and zero for GC recycling
  162. func (dec *fecDecoder) freeRange(first, n int, q []fecPacket) []fecPacket {
  163. for i := first; i < first+n; i++ { // free
  164. xmitBuf.Put(q[i].data)
  165. }
  166. copy(q[first:], q[first+n:])
  167. for i := 0; i < n; i++ { // dereference data
  168. q[len(q)-1-i].data = nil
  169. }
  170. return q[:len(q)-n]
  171. }
  172. type (
  173. // fecEncoder for encoding outgoing packets
  174. fecEncoder struct {
  175. dataShards int
  176. parityShards int
  177. shardSize int
  178. paws uint32 // Protect Against Wrapped Sequence numbers
  179. next uint32 // next seqid
  180. shardCount int // count the number of datashards collected
  181. maxSize int // record maximum data length in datashard
  182. headerOffset int // FEC header offset
  183. payloadOffset int // FEC payload offset
  184. // caches
  185. shardCache [][]byte
  186. encodeCache [][]byte
  187. // RS encoder
  188. codec reedsolomon.Encoder
  189. }
  190. )
  191. func newFECEncoder(dataShards, parityShards, offset int) *fecEncoder {
  192. if dataShards <= 0 || parityShards <= 0 {
  193. return nil
  194. }
  195. fec := new(fecEncoder)
  196. fec.dataShards = dataShards
  197. fec.parityShards = parityShards
  198. fec.shardSize = dataShards + parityShards
  199. fec.paws = (0xffffffff/uint32(fec.shardSize) - 1) * uint32(fec.shardSize)
  200. fec.headerOffset = offset
  201. fec.payloadOffset = fec.headerOffset + fecHeaderSize
  202. enc, err := reedsolomon.New(dataShards, parityShards, reedsolomon.WithMaxGoroutines(1))
  203. if err != nil {
  204. return nil
  205. }
  206. fec.codec = enc
  207. // caches
  208. fec.encodeCache = make([][]byte, fec.shardSize)
  209. fec.shardCache = make([][]byte, fec.shardSize)
  210. for k := range fec.shardCache {
  211. fec.shardCache[k] = make([]byte, mtuLimit)
  212. }
  213. return fec
  214. }
  215. // encode the packet, output parity shards if we have enough datashards
  216. // the content of returned parityshards will change in next encode
  217. func (enc *fecEncoder) encode(b []byte) (ps [][]byte) {
  218. enc.markData(b[enc.headerOffset:])
  219. binary.LittleEndian.PutUint16(b[enc.payloadOffset:], uint16(len(b[enc.payloadOffset:])))
  220. // copy data to fec datashards
  221. sz := len(b)
  222. enc.shardCache[enc.shardCount] = enc.shardCache[enc.shardCount][:sz]
  223. copy(enc.shardCache[enc.shardCount], b)
  224. enc.shardCount++
  225. // record max datashard length
  226. if sz > enc.maxSize {
  227. enc.maxSize = sz
  228. }
  229. // calculate Reed-Solomon Erasure Code
  230. if enc.shardCount == enc.dataShards {
  231. // bzero each datashard's tail
  232. for i := 0; i < enc.dataShards; i++ {
  233. shard := enc.shardCache[i]
  234. slen := len(shard)
  235. xorBytes(shard[slen:enc.maxSize], shard[slen:enc.maxSize], shard[slen:enc.maxSize])
  236. }
  237. // construct equal-sized slice with stripped header
  238. cache := enc.encodeCache
  239. for k := range cache {
  240. cache[k] = enc.shardCache[k][enc.payloadOffset:enc.maxSize]
  241. }
  242. // rs encode
  243. if err := enc.codec.Encode(cache); err == nil {
  244. ps = enc.shardCache[enc.dataShards:]
  245. for k := range ps {
  246. enc.markFEC(ps[k][enc.headerOffset:])
  247. ps[k] = ps[k][:enc.maxSize]
  248. }
  249. }
  250. // reset counters to zero
  251. enc.shardCount = 0
  252. enc.maxSize = 0
  253. }
  254. return
  255. }
  256. func (enc *fecEncoder) markData(data []byte) {
  257. binary.LittleEndian.PutUint32(data, enc.next)
  258. binary.LittleEndian.PutUint16(data[4:], typeData)
  259. enc.next++
  260. }
  261. func (enc *fecEncoder) markFEC(data []byte) {
  262. binary.LittleEndian.PutUint32(data, enc.next)
  263. binary.LittleEndian.PutUint16(data[4:], typeFEC)
  264. enc.next = (enc.next + 1) % enc.paws
  265. }