orm.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. // Copyright 2014 beego Author. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // Package orm provide ORM for MySQL/PostgreSQL/sqlite
  15. // Simple Usage
  16. //
  17. // package main
  18. //
  19. // import (
  20. // "fmt"
  21. // "github.com/astaxie/beego/orm"
  22. // _ "github.com/go-sql-driver/mysql" // import your used driver
  23. // )
  24. //
  25. // // Model Struct
  26. // type User struct {
  27. // Id int `orm:"auto"`
  28. // Name string `orm:"size(100)"`
  29. // }
  30. //
  31. // func init() {
  32. // orm.RegisterDataBase("default", "mysql", "root:root@/my_db?charset=utf8", 30)
  33. // }
  34. //
  35. // func main() {
  36. // o := orm.NewOrm()
  37. // user := User{Name: "slene"}
  38. // // insert
  39. // id, err := o.Insert(&user)
  40. // // update
  41. // user.Name = "astaxie"
  42. // num, err := o.Update(&user)
  43. // // read one
  44. // u := User{Id: user.Id}
  45. // err = o.Read(&u)
  46. // // delete
  47. // num, err = o.Delete(&u)
  48. // }
  49. //
  50. // more docs: http://beego.me/docs/mvc/model/overview.md
  51. package orm
  52. import (
  53. "database/sql"
  54. "errors"
  55. "fmt"
  56. "os"
  57. "reflect"
  58. "time"
  59. )
  60. // DebugQueries define the debug
  61. const (
  62. DebugQueries = iota
  63. )
  64. // Define common vars
  65. var (
  66. Debug = false
  67. DebugLog = NewLog(os.Stdout)
  68. DefaultRowsLimit = 1000
  69. DefaultRelsDepth = 2
  70. DefaultTimeLoc = time.Local
  71. ErrTxHasBegan = errors.New("<Ormer.Begin> transaction already begin")
  72. ErrTxDone = errors.New("<Ormer.Commit/Rollback> transaction not begin")
  73. ErrMultiRows = errors.New("<QuerySeter> return multi rows")
  74. ErrNoRows = errors.New("<QuerySeter> no row found")
  75. ErrStmtClosed = errors.New("<QuerySeter> stmt already closed")
  76. ErrArgs = errors.New("<Ormer> args error may be empty")
  77. ErrNotImplement = errors.New("have not implement")
  78. )
  79. // Params stores the Params
  80. type Params map[string]interface{}
  81. // ParamsList stores paramslist
  82. type ParamsList []interface{}
  83. type orm struct {
  84. alias *alias
  85. db dbQuerier
  86. isTx bool
  87. }
  88. var _ Ormer = new(orm)
  89. // get model info and model reflect value
  90. func (o *orm) getMiInd(md interface{}, needPtr bool) (mi *modelInfo, ind reflect.Value) {
  91. val := reflect.ValueOf(md)
  92. ind = reflect.Indirect(val)
  93. typ := ind.Type()
  94. if needPtr && val.Kind() != reflect.Ptr {
  95. panic(fmt.Errorf("<Ormer> cannot use non-ptr model struct `%s`", getFullName(typ)))
  96. }
  97. name := getFullName(typ)
  98. if mi, ok := modelCache.getByFullName(name); ok {
  99. return mi, ind
  100. }
  101. panic(fmt.Errorf("<Ormer> table: `%s` not found, maybe not RegisterModel", name))
  102. }
  103. // get field info from model info by given field name
  104. func (o *orm) getFieldInfo(mi *modelInfo, name string) *fieldInfo {
  105. fi, ok := mi.fields.GetByAny(name)
  106. if !ok {
  107. panic(fmt.Errorf("<Ormer> cannot find field `%s` for model `%s`", name, mi.fullName))
  108. }
  109. return fi
  110. }
  111. // read data to model
  112. func (o *orm) Read(md interface{}, cols ...string) error {
  113. mi, ind := o.getMiInd(md, true)
  114. err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
  115. if err != nil {
  116. return err
  117. }
  118. return nil
  119. }
  120. // read data to model, like Read(), but use "SELECT FOR UPDATE" form
  121. func (o *orm) ReadForUpdate(md interface{}, cols ...string) error {
  122. mi, ind := o.getMiInd(md, true)
  123. err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, true)
  124. if err != nil {
  125. return err
  126. }
  127. return nil
  128. }
  129. // Try to read a row from the database, or insert one if it doesn't exist
  130. func (o *orm) ReadOrCreate(md interface{}, col1 string, cols ...string) (bool, int64, error) {
  131. cols = append([]string{col1}, cols...)
  132. mi, ind := o.getMiInd(md, true)
  133. err := o.alias.DbBaser.Read(o.db, mi, ind, o.alias.TZ, cols, false)
  134. if err == ErrNoRows {
  135. // Create
  136. id, err := o.Insert(md)
  137. return (err == nil), id, err
  138. }
  139. id, vid := int64(0), ind.FieldByIndex(mi.fields.pk.fieldIndex)
  140. if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
  141. id = int64(vid.Uint())
  142. } else if mi.fields.pk.rel {
  143. return o.ReadOrCreate(vid.Interface(), mi.fields.pk.relModelInfo.fields.pk.name)
  144. } else {
  145. id = vid.Int()
  146. }
  147. return false, id, err
  148. }
  149. // insert model data to database
  150. func (o *orm) Insert(md interface{}) (int64, error) {
  151. mi, ind := o.getMiInd(md, true)
  152. id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
  153. if err != nil {
  154. return id, err
  155. }
  156. o.setPk(mi, ind, id)
  157. return id, nil
  158. }
  159. // set auto pk field
  160. func (o *orm) setPk(mi *modelInfo, ind reflect.Value, id int64) {
  161. if mi.fields.pk.auto {
  162. if mi.fields.pk.fieldType&IsPositiveIntegerField > 0 {
  163. ind.FieldByIndex(mi.fields.pk.fieldIndex).SetUint(uint64(id))
  164. } else {
  165. ind.FieldByIndex(mi.fields.pk.fieldIndex).SetInt(id)
  166. }
  167. }
  168. }
  169. // insert some models to database
  170. func (o *orm) InsertMulti(bulk int, mds interface{}) (int64, error) {
  171. var cnt int64
  172. sind := reflect.Indirect(reflect.ValueOf(mds))
  173. switch sind.Kind() {
  174. case reflect.Array, reflect.Slice:
  175. if sind.Len() == 0 {
  176. return cnt, ErrArgs
  177. }
  178. default:
  179. return cnt, ErrArgs
  180. }
  181. if bulk <= 1 {
  182. for i := 0; i < sind.Len(); i++ {
  183. ind := reflect.Indirect(sind.Index(i))
  184. mi, _ := o.getMiInd(ind.Interface(), false)
  185. id, err := o.alias.DbBaser.Insert(o.db, mi, ind, o.alias.TZ)
  186. if err != nil {
  187. return cnt, err
  188. }
  189. o.setPk(mi, ind, id)
  190. cnt++
  191. }
  192. } else {
  193. mi, _ := o.getMiInd(sind.Index(0).Interface(), false)
  194. return o.alias.DbBaser.InsertMulti(o.db, mi, sind, bulk, o.alias.TZ)
  195. }
  196. return cnt, nil
  197. }
  198. // InsertOrUpdate data to database
  199. func (o *orm) InsertOrUpdate(md interface{}, colConflitAndArgs ...string) (int64, error) {
  200. mi, ind := o.getMiInd(md, true)
  201. id, err := o.alias.DbBaser.InsertOrUpdate(o.db, mi, ind, o.alias, colConflitAndArgs...)
  202. if err != nil {
  203. return id, err
  204. }
  205. o.setPk(mi, ind, id)
  206. return id, nil
  207. }
  208. // update model to database.
  209. // cols set the columns those want to update.
  210. func (o *orm) Update(md interface{}, cols ...string) (int64, error) {
  211. mi, ind := o.getMiInd(md, true)
  212. num, err := o.alias.DbBaser.Update(o.db, mi, ind, o.alias.TZ, cols)
  213. if err != nil {
  214. return num, err
  215. }
  216. return num, nil
  217. }
  218. // delete model in database
  219. // cols shows the delete conditions values read from. deafult is pk
  220. func (o *orm) Delete(md interface{}, cols ...string) (int64, error) {
  221. mi, ind := o.getMiInd(md, true)
  222. num, err := o.alias.DbBaser.Delete(o.db, mi, ind, o.alias.TZ, cols)
  223. if err != nil {
  224. return num, err
  225. }
  226. if num > 0 {
  227. o.setPk(mi, ind, 0)
  228. }
  229. return num, nil
  230. }
  231. // create a models to models queryer
  232. func (o *orm) QueryM2M(md interface{}, name string) QueryM2Mer {
  233. mi, ind := o.getMiInd(md, true)
  234. fi := o.getFieldInfo(mi, name)
  235. switch {
  236. case fi.fieldType == RelManyToMany:
  237. case fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough:
  238. default:
  239. panic(fmt.Errorf("<Ormer.QueryM2M> model `%s` . name `%s` is not a m2m field", fi.name, mi.fullName))
  240. }
  241. return newQueryM2M(md, o, mi, fi, ind)
  242. }
  243. // load related models to md model.
  244. // args are limit, offset int and order string.
  245. //
  246. // example:
  247. // orm.LoadRelated(post,"Tags")
  248. // for _,tag := range post.Tags{...}
  249. //
  250. // make sure the relation is defined in model struct tags.
  251. func (o *orm) LoadRelated(md interface{}, name string, args ...interface{}) (int64, error) {
  252. _, fi, ind, qseter := o.queryRelated(md, name)
  253. qs := qseter.(*querySet)
  254. var relDepth int
  255. var limit, offset int64
  256. var order string
  257. for i, arg := range args {
  258. switch i {
  259. case 0:
  260. if v, ok := arg.(bool); ok {
  261. if v {
  262. relDepth = DefaultRelsDepth
  263. }
  264. } else if v, ok := arg.(int); ok {
  265. relDepth = v
  266. }
  267. case 1:
  268. limit = ToInt64(arg)
  269. case 2:
  270. offset = ToInt64(arg)
  271. case 3:
  272. order, _ = arg.(string)
  273. }
  274. }
  275. switch fi.fieldType {
  276. case RelOneToOne, RelForeignKey, RelReverseOne:
  277. limit = 1
  278. offset = 0
  279. }
  280. qs.limit = limit
  281. qs.offset = offset
  282. qs.relDepth = relDepth
  283. if len(order) > 0 {
  284. qs.orders = []string{order}
  285. }
  286. find := ind.FieldByIndex(fi.fieldIndex)
  287. var nums int64
  288. var err error
  289. switch fi.fieldType {
  290. case RelOneToOne, RelForeignKey, RelReverseOne:
  291. val := reflect.New(find.Type().Elem())
  292. container := val.Interface()
  293. err = qs.One(container)
  294. if err == nil {
  295. find.Set(val)
  296. nums = 1
  297. }
  298. default:
  299. nums, err = qs.All(find.Addr().Interface())
  300. }
  301. return nums, err
  302. }
  303. // return a QuerySeter for related models to md model.
  304. // it can do all, update, delete in QuerySeter.
  305. // example:
  306. // qs := orm.QueryRelated(post,"Tag")
  307. // qs.All(&[]*Tag{})
  308. //
  309. func (o *orm) QueryRelated(md interface{}, name string) QuerySeter {
  310. // is this api needed ?
  311. _, _, _, qs := o.queryRelated(md, name)
  312. return qs
  313. }
  314. // get QuerySeter for related models to md model
  315. func (o *orm) queryRelated(md interface{}, name string) (*modelInfo, *fieldInfo, reflect.Value, QuerySeter) {
  316. mi, ind := o.getMiInd(md, true)
  317. fi := o.getFieldInfo(mi, name)
  318. _, _, exist := getExistPk(mi, ind)
  319. if exist == false {
  320. panic(ErrMissPK)
  321. }
  322. var qs *querySet
  323. switch fi.fieldType {
  324. case RelOneToOne, RelForeignKey, RelManyToMany:
  325. if !fi.inModel {
  326. break
  327. }
  328. qs = o.getRelQs(md, mi, fi)
  329. case RelReverseOne, RelReverseMany:
  330. if !fi.inModel {
  331. break
  332. }
  333. qs = o.getReverseQs(md, mi, fi)
  334. }
  335. if qs == nil {
  336. panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel/reverse field", md, name))
  337. }
  338. return mi, fi, ind, qs
  339. }
  340. // get reverse relation QuerySeter
  341. func (o *orm) getReverseQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
  342. switch fi.fieldType {
  343. case RelReverseOne, RelReverseMany:
  344. default:
  345. panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available reverse field", fi.name, mi.fullName))
  346. }
  347. var q *querySet
  348. if fi.fieldType == RelReverseMany && fi.reverseFieldInfo.mi.isThrough {
  349. q = newQuerySet(o, fi.relModelInfo).(*querySet)
  350. q.cond = NewCondition().And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
  351. } else {
  352. q = newQuerySet(o, fi.reverseFieldInfo.mi).(*querySet)
  353. q.cond = NewCondition().And(fi.reverseFieldInfo.column, md)
  354. }
  355. return q
  356. }
  357. // get relation QuerySeter
  358. func (o *orm) getRelQs(md interface{}, mi *modelInfo, fi *fieldInfo) *querySet {
  359. switch fi.fieldType {
  360. case RelOneToOne, RelForeignKey, RelManyToMany:
  361. default:
  362. panic(fmt.Errorf("<Ormer> name `%s` for model `%s` is not an available rel field", fi.name, mi.fullName))
  363. }
  364. q := newQuerySet(o, fi.relModelInfo).(*querySet)
  365. q.cond = NewCondition()
  366. if fi.fieldType == RelManyToMany {
  367. q.cond = q.cond.And(fi.reverseFieldInfoM2M.column+ExprSep+fi.reverseFieldInfo.column, md)
  368. } else {
  369. q.cond = q.cond.And(fi.reverseFieldInfo.column, md)
  370. }
  371. return q
  372. }
  373. // return a QuerySeter for table operations.
  374. // table name can be string or struct.
  375. // e.g. QueryTable("user"), QueryTable(&user{}) or QueryTable((*User)(nil)),
  376. func (o *orm) QueryTable(ptrStructOrTableName interface{}) (qs QuerySeter) {
  377. name := ""
  378. if table, ok := ptrStructOrTableName.(string); ok {
  379. name = snakeString(table)
  380. if mi, ok := modelCache.get(name); ok {
  381. qs = newQuerySet(o, mi)
  382. }
  383. } else {
  384. name = getFullName(indirectType(reflect.TypeOf(ptrStructOrTableName)))
  385. if mi, ok := modelCache.getByFullName(name); ok {
  386. qs = newQuerySet(o, mi)
  387. }
  388. }
  389. if qs == nil {
  390. panic(fmt.Errorf("<Ormer.QueryTable> table name: `%s` not exists", name))
  391. }
  392. return
  393. }
  394. // switch to another registered database driver by given name.
  395. func (o *orm) Using(name string) error {
  396. if o.isTx {
  397. panic(fmt.Errorf("<Ormer.Using> transaction has been start, cannot change db"))
  398. }
  399. if al, ok := dataBaseCache.get(name); ok {
  400. o.alias = al
  401. if Debug {
  402. o.db = newDbQueryLog(al, al.DB)
  403. } else {
  404. o.db = al.DB
  405. }
  406. } else {
  407. return fmt.Errorf("<Ormer.Using> unknown db alias name `%s`", name)
  408. }
  409. return nil
  410. }
  411. // begin transaction
  412. func (o *orm) Begin() error {
  413. if o.isTx {
  414. return ErrTxHasBegan
  415. }
  416. var tx *sql.Tx
  417. tx, err := o.db.(txer).Begin()
  418. if err != nil {
  419. return err
  420. }
  421. o.isTx = true
  422. if Debug {
  423. o.db.(*dbQueryLog).SetDB(tx)
  424. } else {
  425. o.db = tx
  426. }
  427. return nil
  428. }
  429. // commit transaction
  430. func (o *orm) Commit() error {
  431. if o.isTx == false {
  432. return ErrTxDone
  433. }
  434. err := o.db.(txEnder).Commit()
  435. if err == nil {
  436. o.isTx = false
  437. o.Using(o.alias.Name)
  438. } else if err == sql.ErrTxDone {
  439. return ErrTxDone
  440. }
  441. return err
  442. }
  443. // rollback transaction
  444. func (o *orm) Rollback() error {
  445. if o.isTx == false {
  446. return ErrTxDone
  447. }
  448. err := o.db.(txEnder).Rollback()
  449. if err == nil {
  450. o.isTx = false
  451. o.Using(o.alias.Name)
  452. } else if err == sql.ErrTxDone {
  453. return ErrTxDone
  454. }
  455. return err
  456. }
  457. // return a raw query seter for raw sql string.
  458. func (o *orm) Raw(query string, args ...interface{}) RawSeter {
  459. return newRawSet(o, query, args)
  460. }
  461. // return current using database Driver
  462. func (o *orm) Driver() Driver {
  463. return driver(o.alias.Name)
  464. }
  465. // NewOrm create new orm
  466. func NewOrm() Ormer {
  467. BootStrap() // execute only once
  468. o := new(orm)
  469. err := o.Using("default")
  470. if err != nil {
  471. panic(err)
  472. }
  473. return o
  474. }
  475. // NewOrmWithDB create a new ormer object with specify *sql.DB for query
  476. func NewOrmWithDB(driverName, aliasName string, db *sql.DB) (Ormer, error) {
  477. var al *alias
  478. if dr, ok := drivers[driverName]; ok {
  479. al = new(alias)
  480. al.DbBaser = dbBasers[dr]
  481. al.Driver = dr
  482. } else {
  483. return nil, fmt.Errorf("driver name `%s` have not registered", driverName)
  484. }
  485. al.Name = aliasName
  486. al.DriverName = driverName
  487. o := new(orm)
  488. o.alias = al
  489. if Debug {
  490. o.db = newDbQueryLog(o.alias, db)
  491. } else {
  492. o.db = db
  493. }
  494. return o, nil
  495. }