matrix.go 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. /**
  2. * Matrix Algebra over an 8-bit Galois Field
  3. *
  4. * Copyright 2015, Klaus Post
  5. * Copyright 2015, Backblaze, Inc.
  6. */
  7. package reedsolomon
  8. import (
  9. "errors"
  10. "fmt"
  11. "strconv"
  12. "strings"
  13. )
  14. // byte[row][col]
  15. type matrix [][]byte
  16. // newMatrix returns a matrix of zeros.
  17. func newMatrix(rows, cols int) (matrix, error) {
  18. if rows <= 0 {
  19. return nil, errInvalidRowSize
  20. }
  21. if cols <= 0 {
  22. return nil, errInvalidColSize
  23. }
  24. m := matrix(make([][]byte, rows))
  25. for i := range m {
  26. m[i] = make([]byte, cols)
  27. }
  28. return m, nil
  29. }
  30. // NewMatrixData initializes a matrix with the given row-major data.
  31. // Note that data is not copied from input.
  32. func newMatrixData(data [][]byte) (matrix, error) {
  33. m := matrix(data)
  34. err := m.Check()
  35. if err != nil {
  36. return nil, err
  37. }
  38. return m, nil
  39. }
  40. // IdentityMatrix returns an identity matrix of the given size.
  41. func identityMatrix(size int) (matrix, error) {
  42. m, err := newMatrix(size, size)
  43. if err != nil {
  44. return nil, err
  45. }
  46. for i := range m {
  47. m[i][i] = 1
  48. }
  49. return m, nil
  50. }
  51. // errInvalidRowSize will be returned if attempting to create a matrix with negative or zero row number.
  52. var errInvalidRowSize = errors.New("invalid row size")
  53. // errInvalidColSize will be returned if attempting to create a matrix with negative or zero column number.
  54. var errInvalidColSize = errors.New("invalid column size")
  55. // errColSizeMismatch is returned if the size of matrix columns mismatch.
  56. var errColSizeMismatch = errors.New("column size is not the same for all rows")
  57. func (m matrix) Check() error {
  58. rows := len(m)
  59. if rows <= 0 {
  60. return errInvalidRowSize
  61. }
  62. cols := len(m[0])
  63. if cols <= 0 {
  64. return errInvalidColSize
  65. }
  66. for _, col := range m {
  67. if len(col) != cols {
  68. return errColSizeMismatch
  69. }
  70. }
  71. return nil
  72. }
  73. // String returns a human-readable string of the matrix contents.
  74. //
  75. // Example: [[1, 2], [3, 4]]
  76. func (m matrix) String() string {
  77. rowOut := make([]string, 0, len(m))
  78. for _, row := range m {
  79. colOut := make([]string, 0, len(row))
  80. for _, col := range row {
  81. colOut = append(colOut, strconv.Itoa(int(col)))
  82. }
  83. rowOut = append(rowOut, "["+strings.Join(colOut, ", ")+"]")
  84. }
  85. return "[" + strings.Join(rowOut, ", ") + "]"
  86. }
  87. // Multiply multiplies this matrix (the one on the left) by another
  88. // matrix (the one on the right) and returns a new matrix with the result.
  89. func (m matrix) Multiply(right matrix) (matrix, error) {
  90. if len(m[0]) != len(right) {
  91. return nil, fmt.Errorf("columns on left (%d) is different than rows on right (%d)", len(m[0]), len(right))
  92. }
  93. result, _ := newMatrix(len(m), len(right[0]))
  94. for r, row := range result {
  95. for c := range row {
  96. var value byte
  97. for i := range m[0] {
  98. value ^= galMultiply(m[r][i], right[i][c])
  99. }
  100. result[r][c] = value
  101. }
  102. }
  103. return result, nil
  104. }
  105. // Augment returns the concatenation of this matrix and the matrix on the right.
  106. func (m matrix) Augment(right matrix) (matrix, error) {
  107. if len(m) != len(right) {
  108. return nil, errMatrixSize
  109. }
  110. result, _ := newMatrix(len(m), len(m[0])+len(right[0]))
  111. for r, row := range m {
  112. for c := range row {
  113. result[r][c] = m[r][c]
  114. }
  115. cols := len(m[0])
  116. for c := range right[0] {
  117. result[r][cols+c] = right[r][c]
  118. }
  119. }
  120. return result, nil
  121. }
  122. // errMatrixSize is returned if matrix dimensions are doesn't match.
  123. var errMatrixSize = errors.New("matrix sizes does not match")
  124. func (m matrix) SameSize(n matrix) error {
  125. if len(m) != len(n) {
  126. return errMatrixSize
  127. }
  128. for i := range m {
  129. if len(m[i]) != len(n[i]) {
  130. return errMatrixSize
  131. }
  132. }
  133. return nil
  134. }
  135. // Returns a part of this matrix. Data is copied.
  136. func (m matrix) SubMatrix(rmin, cmin, rmax, cmax int) (matrix, error) {
  137. result, err := newMatrix(rmax-rmin, cmax-cmin)
  138. if err != nil {
  139. return nil, err
  140. }
  141. // OPTME: If used heavily, use copy function to copy slice
  142. for r := rmin; r < rmax; r++ {
  143. for c := cmin; c < cmax; c++ {
  144. result[r-rmin][c-cmin] = m[r][c]
  145. }
  146. }
  147. return result, nil
  148. }
  149. // SwapRows Exchanges two rows in the matrix.
  150. func (m matrix) SwapRows(r1, r2 int) error {
  151. if r1 < 0 || len(m) <= r1 || r2 < 0 || len(m) <= r2 {
  152. return errInvalidRowSize
  153. }
  154. m[r2], m[r1] = m[r1], m[r2]
  155. return nil
  156. }
  157. // IsSquare will return true if the matrix is square
  158. // and nil if the matrix is square
  159. func (m matrix) IsSquare() bool {
  160. return len(m) == len(m[0])
  161. }
  162. // errSingular is returned if the matrix is singular and cannot be inversed
  163. var errSingular = errors.New("matrix is singular")
  164. // errNotSquare is returned if attempting to inverse a non-square matrix.
  165. var errNotSquare = errors.New("only square matrices can be inverted")
  166. // Invert returns the inverse of this matrix.
  167. // Returns ErrSingular when the matrix is singular and doesn't have an inverse.
  168. // The matrix must be square, otherwise ErrNotSquare is returned.
  169. func (m matrix) Invert() (matrix, error) {
  170. if !m.IsSquare() {
  171. return nil, errNotSquare
  172. }
  173. size := len(m)
  174. work, _ := identityMatrix(size)
  175. work, _ = m.Augment(work)
  176. err := work.gaussianElimination()
  177. if err != nil {
  178. return nil, err
  179. }
  180. return work.SubMatrix(0, size, size, size*2)
  181. }
  182. func (m matrix) gaussianElimination() error {
  183. rows := len(m)
  184. columns := len(m[0])
  185. // Clear out the part below the main diagonal and scale the main
  186. // diagonal to be 1.
  187. for r := 0; r < rows; r++ {
  188. // If the element on the diagonal is 0, find a row below
  189. // that has a non-zero and swap them.
  190. if m[r][r] == 0 {
  191. for rowBelow := r + 1; rowBelow < rows; rowBelow++ {
  192. if m[rowBelow][r] != 0 {
  193. m.SwapRows(r, rowBelow)
  194. break
  195. }
  196. }
  197. }
  198. // If we couldn't find one, the matrix is singular.
  199. if m[r][r] == 0 {
  200. return errSingular
  201. }
  202. // Scale to 1.
  203. if m[r][r] != 1 {
  204. scale := galDivide(1, m[r][r])
  205. for c := 0; c < columns; c++ {
  206. m[r][c] = galMultiply(m[r][c], scale)
  207. }
  208. }
  209. // Make everything below the 1 be a 0 by subtracting
  210. // a multiple of it. (Subtraction and addition are
  211. // both exclusive or in the Galois field.)
  212. for rowBelow := r + 1; rowBelow < rows; rowBelow++ {
  213. if m[rowBelow][r] != 0 {
  214. scale := m[rowBelow][r]
  215. for c := 0; c < columns; c++ {
  216. m[rowBelow][c] ^= galMultiply(scale, m[r][c])
  217. }
  218. }
  219. }
  220. }
  221. // Now clear the part above the main diagonal.
  222. for d := 0; d < rows; d++ {
  223. for rowAbove := 0; rowAbove < d; rowAbove++ {
  224. if m[rowAbove][d] != 0 {
  225. scale := m[rowAbove][d]
  226. for c := 0; c < columns; c++ {
  227. m[rowAbove][c] ^= galMultiply(scale, m[d][c])
  228. }
  229. }
  230. }
  231. }
  232. return nil
  233. }
  234. // Create a Vandermonde matrix, which is guaranteed to have the
  235. // property that any subset of rows that forms a square matrix
  236. // is invertible.
  237. func vandermonde(rows, cols int) (matrix, error) {
  238. result, err := newMatrix(rows, cols)
  239. if err != nil {
  240. return nil, err
  241. }
  242. for r, row := range result {
  243. for c := range row {
  244. result[r][c] = galExp(byte(r), c)
  245. }
  246. }
  247. return result, nil
  248. }