db_tables.go 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. // Copyright 2014 beego Author. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package orm
  15. import (
  16. "fmt"
  17. "strings"
  18. "time"
  19. )
  20. // table info struct.
  21. type dbTable struct {
  22. id int
  23. index string
  24. name string
  25. names []string
  26. sel bool
  27. inner bool
  28. mi *modelInfo
  29. fi *fieldInfo
  30. jtl *dbTable
  31. }
  32. // tables collection struct, contains some tables.
  33. type dbTables struct {
  34. tablesM map[string]*dbTable
  35. tables []*dbTable
  36. mi *modelInfo
  37. base dbBaser
  38. skipEnd bool
  39. }
  40. // set table info to collection.
  41. // if not exist, create new.
  42. func (t *dbTables) set(names []string, mi *modelInfo, fi *fieldInfo, inner bool) *dbTable {
  43. name := strings.Join(names, ExprSep)
  44. if j, ok := t.tablesM[name]; ok {
  45. j.name = name
  46. j.mi = mi
  47. j.fi = fi
  48. j.inner = inner
  49. } else {
  50. i := len(t.tables) + 1
  51. jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
  52. t.tablesM[name] = jt
  53. t.tables = append(t.tables, jt)
  54. }
  55. return t.tablesM[name]
  56. }
  57. // add table info to collection.
  58. func (t *dbTables) add(names []string, mi *modelInfo, fi *fieldInfo, inner bool) (*dbTable, bool) {
  59. name := strings.Join(names, ExprSep)
  60. if _, ok := t.tablesM[name]; ok == false {
  61. i := len(t.tables) + 1
  62. jt := &dbTable{i, fmt.Sprintf("T%d", i), name, names, false, inner, mi, fi, nil}
  63. t.tablesM[name] = jt
  64. t.tables = append(t.tables, jt)
  65. return jt, true
  66. }
  67. return t.tablesM[name], false
  68. }
  69. // get table info in collection.
  70. func (t *dbTables) get(name string) (*dbTable, bool) {
  71. j, ok := t.tablesM[name]
  72. return j, ok
  73. }
  74. // get related fields info in recursive depth loop.
  75. // loop once, depth decreases one.
  76. func (t *dbTables) loopDepth(depth int, prefix string, fi *fieldInfo, related []string) []string {
  77. if depth < 0 || fi.fieldType == RelManyToMany {
  78. return related
  79. }
  80. if prefix == "" {
  81. prefix = fi.name
  82. } else {
  83. prefix = prefix + ExprSep + fi.name
  84. }
  85. related = append(related, prefix)
  86. depth--
  87. for _, fi := range fi.relModelInfo.fields.fieldsRel {
  88. related = t.loopDepth(depth, prefix, fi, related)
  89. }
  90. return related
  91. }
  92. // parse related fields.
  93. func (t *dbTables) parseRelated(rels []string, depth int) {
  94. relsNum := len(rels)
  95. related := make([]string, relsNum)
  96. copy(related, rels)
  97. relDepth := depth
  98. if relsNum != 0 {
  99. relDepth = 0
  100. }
  101. relDepth--
  102. for _, fi := range t.mi.fields.fieldsRel {
  103. related = t.loopDepth(relDepth, "", fi, related)
  104. }
  105. for i, s := range related {
  106. var (
  107. exs = strings.Split(s, ExprSep)
  108. names = make([]string, 0, len(exs))
  109. mmi = t.mi
  110. cancel = true
  111. jtl *dbTable
  112. )
  113. inner := true
  114. for _, ex := range exs {
  115. if fi, ok := mmi.fields.GetByAny(ex); ok && fi.rel && fi.fieldType != RelManyToMany {
  116. names = append(names, fi.name)
  117. mmi = fi.relModelInfo
  118. if fi.null || t.skipEnd {
  119. inner = false
  120. }
  121. jt := t.set(names, mmi, fi, inner)
  122. jt.jtl = jtl
  123. if fi.reverse {
  124. cancel = false
  125. }
  126. if cancel {
  127. jt.sel = depth > 0
  128. if i < relsNum {
  129. jt.sel = true
  130. }
  131. }
  132. jtl = jt
  133. } else {
  134. panic(fmt.Errorf("unknown model/table name `%s`", ex))
  135. }
  136. }
  137. }
  138. }
  139. // generate join string.
  140. func (t *dbTables) getJoinSQL() (join string) {
  141. Q := t.base.TableQuote()
  142. for _, jt := range t.tables {
  143. if jt.inner {
  144. join += "INNER JOIN "
  145. } else {
  146. join += "LEFT OUTER JOIN "
  147. }
  148. var (
  149. table string
  150. t1, t2 string
  151. c1, c2 string
  152. )
  153. t1 = "T0"
  154. if jt.jtl != nil {
  155. t1 = jt.jtl.index
  156. }
  157. t2 = jt.index
  158. table = jt.mi.table
  159. switch {
  160. case jt.fi.fieldType == RelManyToMany || jt.fi.fieldType == RelReverseMany || jt.fi.reverse && jt.fi.reverseFieldInfo.fieldType == RelManyToMany:
  161. c1 = jt.fi.mi.fields.pk.column
  162. for _, ffi := range jt.mi.fields.fieldsRel {
  163. if jt.fi.mi == ffi.relModelInfo {
  164. c2 = ffi.column
  165. break
  166. }
  167. }
  168. default:
  169. c1 = jt.fi.column
  170. c2 = jt.fi.relModelInfo.fields.pk.column
  171. if jt.fi.reverse {
  172. c1 = jt.mi.fields.pk.column
  173. c2 = jt.fi.reverseFieldInfo.column
  174. }
  175. }
  176. join += fmt.Sprintf("%s%s%s %s ON %s.%s%s%s = %s.%s%s%s ", Q, table, Q, t2,
  177. t2, Q, c2, Q, t1, Q, c1, Q)
  178. }
  179. return
  180. }
  181. // parse orm model struct field tag expression.
  182. func (t *dbTables) parseExprs(mi *modelInfo, exprs []string) (index, name string, info *fieldInfo, success bool) {
  183. var (
  184. jtl *dbTable
  185. fi *fieldInfo
  186. fiN *fieldInfo
  187. mmi = mi
  188. )
  189. num := len(exprs) - 1
  190. var names []string
  191. inner := true
  192. loopFor:
  193. for i, ex := range exprs {
  194. var ok, okN bool
  195. if fiN != nil {
  196. fi = fiN
  197. ok = true
  198. fiN = nil
  199. }
  200. if i == 0 {
  201. fi, ok = mmi.fields.GetByAny(ex)
  202. }
  203. _ = okN
  204. if ok {
  205. isRel := fi.rel || fi.reverse
  206. names = append(names, fi.name)
  207. switch {
  208. case fi.rel:
  209. mmi = fi.relModelInfo
  210. if fi.fieldType == RelManyToMany {
  211. mmi = fi.relThroughModelInfo
  212. }
  213. case fi.reverse:
  214. mmi = fi.reverseFieldInfo.mi
  215. }
  216. if i < num {
  217. fiN, okN = mmi.fields.GetByAny(exprs[i+1])
  218. }
  219. if isRel && (fi.mi.isThrough == false || num != i) {
  220. if fi.null || t.skipEnd {
  221. inner = false
  222. }
  223. if t.skipEnd && okN || !t.skipEnd {
  224. if t.skipEnd && okN && fiN.pk {
  225. goto loopEnd
  226. }
  227. jt, _ := t.add(names, mmi, fi, inner)
  228. jt.jtl = jtl
  229. jtl = jt
  230. }
  231. }
  232. if num != i {
  233. continue
  234. }
  235. loopEnd:
  236. if i == 0 || jtl == nil {
  237. index = "T0"
  238. } else {
  239. index = jtl.index
  240. }
  241. info = fi
  242. if jtl == nil {
  243. name = fi.name
  244. } else {
  245. name = jtl.name + ExprSep + fi.name
  246. }
  247. switch {
  248. case fi.rel:
  249. case fi.reverse:
  250. switch fi.reverseFieldInfo.fieldType {
  251. case RelOneToOne, RelForeignKey:
  252. index = jtl.index
  253. info = fi.reverseFieldInfo.mi.fields.pk
  254. name = info.name
  255. }
  256. }
  257. break loopFor
  258. } else {
  259. index = ""
  260. name = ""
  261. info = nil
  262. success = false
  263. return
  264. }
  265. }
  266. success = index != "" && info != nil
  267. return
  268. }
  269. // generate condition sql.
  270. func (t *dbTables) getCondSQL(cond *Condition, sub bool, tz *time.Location) (where string, params []interface{}) {
  271. if cond == nil || cond.IsEmpty() {
  272. return
  273. }
  274. Q := t.base.TableQuote()
  275. mi := t.mi
  276. for i, p := range cond.params {
  277. if i > 0 {
  278. if p.isOr {
  279. where += "OR "
  280. } else {
  281. where += "AND "
  282. }
  283. }
  284. if p.isNot {
  285. where += "NOT "
  286. }
  287. if p.isCond {
  288. w, ps := t.getCondSQL(p.cond, true, tz)
  289. if w != "" {
  290. w = fmt.Sprintf("( %s) ", w)
  291. }
  292. where += w
  293. params = append(params, ps...)
  294. } else {
  295. exprs := p.exprs
  296. num := len(exprs) - 1
  297. operator := ""
  298. if operators[exprs[num]] {
  299. operator = exprs[num]
  300. exprs = exprs[:num]
  301. }
  302. index, _, fi, suc := t.parseExprs(mi, exprs)
  303. if suc == false {
  304. panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(p.exprs, ExprSep)))
  305. }
  306. if operator == "" {
  307. operator = "exact"
  308. }
  309. operSQL, args := t.base.GenerateOperatorSQL(mi, fi, operator, p.args, tz)
  310. leftCol := fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q)
  311. t.base.GenerateOperatorLeftCol(fi, operator, &leftCol)
  312. where += fmt.Sprintf("%s %s ", leftCol, operSQL)
  313. params = append(params, args...)
  314. }
  315. }
  316. if sub == false && where != "" {
  317. where = "WHERE " + where
  318. }
  319. return
  320. }
  321. // generate group sql.
  322. func (t *dbTables) getGroupSQL(groups []string) (groupSQL string) {
  323. if len(groups) == 0 {
  324. return
  325. }
  326. Q := t.base.TableQuote()
  327. groupSqls := make([]string, 0, len(groups))
  328. for _, group := range groups {
  329. exprs := strings.Split(group, ExprSep)
  330. index, _, fi, suc := t.parseExprs(t.mi, exprs)
  331. if suc == false {
  332. panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
  333. }
  334. groupSqls = append(groupSqls, fmt.Sprintf("%s.%s%s%s", index, Q, fi.column, Q))
  335. }
  336. groupSQL = fmt.Sprintf("GROUP BY %s ", strings.Join(groupSqls, ", "))
  337. return
  338. }
  339. // generate order sql.
  340. func (t *dbTables) getOrderSQL(orders []string) (orderSQL string) {
  341. if len(orders) == 0 {
  342. return
  343. }
  344. Q := t.base.TableQuote()
  345. orderSqls := make([]string, 0, len(orders))
  346. for _, order := range orders {
  347. asc := "ASC"
  348. if order[0] == '-' {
  349. asc = "DESC"
  350. order = order[1:]
  351. }
  352. exprs := strings.Split(order, ExprSep)
  353. index, _, fi, suc := t.parseExprs(t.mi, exprs)
  354. if suc == false {
  355. panic(fmt.Errorf("unknown field/column name `%s`", strings.Join(exprs, ExprSep)))
  356. }
  357. orderSqls = append(orderSqls, fmt.Sprintf("%s.%s%s%s %s", index, Q, fi.column, Q, asc))
  358. }
  359. orderSQL = fmt.Sprintf("ORDER BY %s ", strings.Join(orderSqls, ", "))
  360. return
  361. }
  362. // generate limit sql.
  363. func (t *dbTables) getLimitSQL(mi *modelInfo, offset int64, limit int64) (limits string) {
  364. if limit == 0 {
  365. limit = int64(DefaultRowsLimit)
  366. }
  367. if limit < 0 {
  368. // no limit
  369. if offset > 0 {
  370. maxLimit := t.base.MaxLimit()
  371. if maxLimit == 0 {
  372. limits = fmt.Sprintf("OFFSET %d", offset)
  373. } else {
  374. limits = fmt.Sprintf("LIMIT %d OFFSET %d", maxLimit, offset)
  375. }
  376. }
  377. } else if offset <= 0 {
  378. limits = fmt.Sprintf("LIMIT %d", limit)
  379. } else {
  380. limits = fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset)
  381. }
  382. return
  383. }
  384. // crete new tables collection.
  385. func newDbTables(mi *modelInfo, base dbBaser) *dbTables {
  386. tables := &dbTables{}
  387. tables.tablesM = make(map[string]*dbTable)
  388. tables.mi = mi
  389. tables.base = base
  390. return tables
  391. }