cmd_utils.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. // Copyright 2014 beego Author. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package orm
  15. import (
  16. "fmt"
  17. "os"
  18. "strings"
  19. )
  20. type dbIndex struct {
  21. Table string
  22. Name string
  23. SQL string
  24. }
  25. // create database drop sql.
  26. func getDbDropSQL(al *alias) (sqls []string) {
  27. if len(modelCache.cache) == 0 {
  28. fmt.Println("no Model found, need register your model")
  29. os.Exit(2)
  30. }
  31. Q := al.DbBaser.TableQuote()
  32. for _, mi := range modelCache.allOrdered() {
  33. sqls = append(sqls, fmt.Sprintf(`DROP TABLE IF EXISTS %s%s%s`, Q, mi.table, Q))
  34. }
  35. return sqls
  36. }
  37. // get database column type string.
  38. func getColumnTyp(al *alias, fi *fieldInfo) (col string) {
  39. T := al.DbBaser.DbTypes()
  40. fieldType := fi.fieldType
  41. fieldSize := fi.size
  42. checkColumn:
  43. switch fieldType {
  44. case TypeBooleanField:
  45. col = T["bool"]
  46. case TypeCharField:
  47. if al.Driver == DRPostgres && fi.toText {
  48. col = T["string-text"]
  49. } else {
  50. col = fmt.Sprintf(T["string"], fieldSize)
  51. }
  52. case TypeTextField:
  53. col = T["string-text"]
  54. case TypeTimeField:
  55. col = T["time.Time-clock"]
  56. case TypeDateField:
  57. col = T["time.Time-date"]
  58. case TypeDateTimeField:
  59. col = T["time.Time"]
  60. case TypeBitField:
  61. col = T["int8"]
  62. case TypeSmallIntegerField:
  63. col = T["int16"]
  64. case TypeIntegerField:
  65. col = T["int32"]
  66. case TypeBigIntegerField:
  67. if al.Driver == DRSqlite {
  68. fieldType = TypeIntegerField
  69. goto checkColumn
  70. }
  71. col = T["int64"]
  72. case TypePositiveBitField:
  73. col = T["uint8"]
  74. case TypePositiveSmallIntegerField:
  75. col = T["uint16"]
  76. case TypePositiveIntegerField:
  77. col = T["uint32"]
  78. case TypePositiveBigIntegerField:
  79. col = T["uint64"]
  80. case TypeFloatField:
  81. col = T["float64"]
  82. case TypeDecimalField:
  83. s := T["float64-decimal"]
  84. if strings.Index(s, "%d") == -1 {
  85. col = s
  86. } else {
  87. col = fmt.Sprintf(s, fi.digits, fi.decimals)
  88. }
  89. case TypeJSONField:
  90. if al.Driver != DRPostgres {
  91. fieldType = TypeCharField
  92. goto checkColumn
  93. }
  94. col = T["json"]
  95. case TypeJsonbField:
  96. if al.Driver != DRPostgres {
  97. fieldType = TypeCharField
  98. goto checkColumn
  99. }
  100. col = T["jsonb"]
  101. case RelForeignKey, RelOneToOne:
  102. fieldType = fi.relModelInfo.fields.pk.fieldType
  103. fieldSize = fi.relModelInfo.fields.pk.size
  104. goto checkColumn
  105. }
  106. return
  107. }
  108. // create alter sql string.
  109. func getColumnAddQuery(al *alias, fi *fieldInfo) string {
  110. Q := al.DbBaser.TableQuote()
  111. typ := getColumnTyp(al, fi)
  112. if fi.null == false {
  113. typ += " " + "NOT NULL"
  114. }
  115. return fmt.Sprintf("ALTER TABLE %s%s%s ADD COLUMN %s%s%s %s %s",
  116. Q, fi.mi.table, Q,
  117. Q, fi.column, Q,
  118. typ, getColumnDefault(fi),
  119. )
  120. }
  121. // create database creation string.
  122. func getDbCreateSQL(al *alias) (sqls []string, tableIndexes map[string][]dbIndex) {
  123. if len(modelCache.cache) == 0 {
  124. fmt.Println("no Model found, need register your model")
  125. os.Exit(2)
  126. }
  127. Q := al.DbBaser.TableQuote()
  128. T := al.DbBaser.DbTypes()
  129. sep := fmt.Sprintf("%s, %s", Q, Q)
  130. tableIndexes = make(map[string][]dbIndex)
  131. for _, mi := range modelCache.allOrdered() {
  132. sql := fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
  133. sql += fmt.Sprintf("-- Table Structure for `%s`\n", mi.fullName)
  134. sql += fmt.Sprintf("-- %s\n", strings.Repeat("-", 50))
  135. sql += fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s%s%s (\n", Q, mi.table, Q)
  136. columns := make([]string, 0, len(mi.fields.fieldsDB))
  137. sqlIndexes := [][]string{}
  138. for _, fi := range mi.fields.fieldsDB {
  139. column := fmt.Sprintf(" %s%s%s ", Q, fi.column, Q)
  140. col := getColumnTyp(al, fi)
  141. if fi.auto {
  142. switch al.Driver {
  143. case DRSqlite, DRPostgres:
  144. column += T["auto"]
  145. default:
  146. column += col + " " + T["auto"]
  147. }
  148. } else if fi.pk {
  149. column += col + " " + T["pk"]
  150. } else {
  151. column += col
  152. if fi.null == false {
  153. column += " " + "NOT NULL"
  154. }
  155. //if fi.initial.String() != "" {
  156. // column += " DEFAULT " + fi.initial.String()
  157. //}
  158. // Append attribute DEFAULT
  159. column += getColumnDefault(fi)
  160. if fi.unique {
  161. column += " " + "UNIQUE"
  162. }
  163. if fi.index {
  164. sqlIndexes = append(sqlIndexes, []string{fi.column})
  165. }
  166. }
  167. if strings.Index(column, "%COL%") != -1 {
  168. column = strings.Replace(column, "%COL%", fi.column, -1)
  169. }
  170. columns = append(columns, column)
  171. }
  172. if mi.model != nil {
  173. allnames := getTableUnique(mi.addrField)
  174. if !mi.manual && len(mi.uniques) > 0 {
  175. allnames = append(allnames, mi.uniques)
  176. }
  177. for _, names := range allnames {
  178. cols := make([]string, 0, len(names))
  179. for _, name := range names {
  180. if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
  181. cols = append(cols, fi.column)
  182. } else {
  183. panic(fmt.Errorf("cannot found column `%s` when parse UNIQUE in `%s.TableUnique`", name, mi.fullName))
  184. }
  185. }
  186. column := fmt.Sprintf(" UNIQUE (%s%s%s)", Q, strings.Join(cols, sep), Q)
  187. columns = append(columns, column)
  188. }
  189. }
  190. sql += strings.Join(columns, ",\n")
  191. sql += "\n)"
  192. if al.Driver == DRMySQL {
  193. var engine string
  194. if mi.model != nil {
  195. engine = getTableEngine(mi.addrField)
  196. }
  197. if engine == "" {
  198. engine = al.Engine
  199. }
  200. sql += " ENGINE=" + engine
  201. }
  202. sql += ";"
  203. sqls = append(sqls, sql)
  204. if mi.model != nil {
  205. for _, names := range getTableIndex(mi.addrField) {
  206. cols := make([]string, 0, len(names))
  207. for _, name := range names {
  208. if fi, ok := mi.fields.GetByAny(name); ok && fi.dbcol {
  209. cols = append(cols, fi.column)
  210. } else {
  211. panic(fmt.Errorf("cannot found column `%s` when parse INDEX in `%s.TableIndex`", name, mi.fullName))
  212. }
  213. }
  214. sqlIndexes = append(sqlIndexes, cols)
  215. }
  216. }
  217. for _, names := range sqlIndexes {
  218. name := mi.table + "_" + strings.Join(names, "_")
  219. cols := strings.Join(names, sep)
  220. sql := fmt.Sprintf("CREATE INDEX %s%s%s ON %s%s%s (%s%s%s);", Q, name, Q, Q, mi.table, Q, Q, cols, Q)
  221. index := dbIndex{}
  222. index.Table = mi.table
  223. index.Name = name
  224. index.SQL = sql
  225. tableIndexes[mi.table] = append(tableIndexes[mi.table], index)
  226. }
  227. }
  228. return
  229. }
  230. // Get string value for the attribute "DEFAULT" for the CREATE, ALTER commands
  231. func getColumnDefault(fi *fieldInfo) string {
  232. var (
  233. v, t, d string
  234. )
  235. // Skip default attribute if field is in relations
  236. if fi.rel || fi.reverse {
  237. return v
  238. }
  239. t = " DEFAULT '%s' "
  240. // These defaults will be useful if there no config value orm:"default" and NOT NULL is on
  241. switch fi.fieldType {
  242. case TypeTimeField, TypeDateField, TypeDateTimeField, TypeTextField:
  243. return v
  244. case TypeBitField, TypeSmallIntegerField, TypeIntegerField,
  245. TypeBigIntegerField, TypePositiveBitField, TypePositiveSmallIntegerField,
  246. TypePositiveIntegerField, TypePositiveBigIntegerField, TypeFloatField,
  247. TypeDecimalField:
  248. t = " DEFAULT %s "
  249. d = "0"
  250. case TypeBooleanField:
  251. t = " DEFAULT %s "
  252. d = "FALSE"
  253. case TypeJSONField, TypeJsonbField:
  254. d = "{}"
  255. }
  256. if fi.colDefault {
  257. if !fi.initial.Exist() {
  258. v = fmt.Sprintf(t, "")
  259. } else {
  260. v = fmt.Sprintf(t, fi.initial.String())
  261. }
  262. } else {
  263. if !fi.null {
  264. v = fmt.Sprintf(t, d)
  265. }
  266. }
  267. return v
  268. }