You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

436 lines
11 KiB

  1. // Copyright 2016 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "database/sql"
  7. "errors"
  8. "fmt"
  9. "reflect"
  10. "strings"
  11. "github.com/go-xorm/core"
  12. )
  13. // Ping test if database is ok
  14. func (session *Session) Ping() error {
  15. if session.isAutoClose {
  16. defer session.Close()
  17. }
  18. session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
  19. return session.DB().Ping()
  20. }
  21. // CreateTable create a table according a bean
  22. func (session *Session) CreateTable(bean interface{}) error {
  23. if session.isAutoClose {
  24. defer session.Close()
  25. }
  26. return session.createTable(bean)
  27. }
  28. func (session *Session) createTable(bean interface{}) error {
  29. v := rValue(bean)
  30. if err := session.statement.setRefValue(v); err != nil {
  31. return err
  32. }
  33. sqlStr := session.statement.genCreateTableSQL()
  34. _, err := session.exec(sqlStr)
  35. return err
  36. }
  37. // CreateIndexes create indexes
  38. func (session *Session) CreateIndexes(bean interface{}) error {
  39. if session.isAutoClose {
  40. defer session.Close()
  41. }
  42. return session.createIndexes(bean)
  43. }
  44. func (session *Session) createIndexes(bean interface{}) error {
  45. v := rValue(bean)
  46. if err := session.statement.setRefValue(v); err != nil {
  47. return err
  48. }
  49. sqls := session.statement.genIndexSQL()
  50. for _, sqlStr := range sqls {
  51. _, err := session.exec(sqlStr)
  52. if err != nil {
  53. return err
  54. }
  55. }
  56. return nil
  57. }
  58. // CreateUniques create uniques
  59. func (session *Session) CreateUniques(bean interface{}) error {
  60. if session.isAutoClose {
  61. defer session.Close()
  62. }
  63. return session.createUniques(bean)
  64. }
  65. func (session *Session) createUniques(bean interface{}) error {
  66. v := rValue(bean)
  67. if err := session.statement.setRefValue(v); err != nil {
  68. return err
  69. }
  70. sqls := session.statement.genUniqueSQL()
  71. for _, sqlStr := range sqls {
  72. _, err := session.exec(sqlStr)
  73. if err != nil {
  74. return err
  75. }
  76. }
  77. return nil
  78. }
  79. // DropIndexes drop indexes
  80. func (session *Session) DropIndexes(bean interface{}) error {
  81. if session.isAutoClose {
  82. defer session.Close()
  83. }
  84. return session.dropIndexes(bean)
  85. }
  86. func (session *Session) dropIndexes(bean interface{}) error {
  87. v := rValue(bean)
  88. if err := session.statement.setRefValue(v); err != nil {
  89. return err
  90. }
  91. sqls := session.statement.genDelIndexSQL()
  92. for _, sqlStr := range sqls {
  93. _, err := session.exec(sqlStr)
  94. if err != nil {
  95. return err
  96. }
  97. }
  98. return nil
  99. }
  100. // DropTable drop table will drop table if exist, if drop failed, it will return error
  101. func (session *Session) DropTable(beanOrTableName interface{}) error {
  102. if session.isAutoClose {
  103. defer session.Close()
  104. }
  105. return session.dropTable(beanOrTableName)
  106. }
  107. func (session *Session) dropTable(beanOrTableName interface{}) error {
  108. tableName, err := session.engine.tableName(beanOrTableName)
  109. if err != nil {
  110. return err
  111. }
  112. var needDrop = true
  113. if !session.engine.dialect.SupportDropIfExists() {
  114. sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
  115. results, err := session.queryBytes(sqlStr, args...)
  116. if err != nil {
  117. return err
  118. }
  119. needDrop = len(results) > 0
  120. }
  121. if needDrop {
  122. sqlStr := session.engine.Dialect().DropTableSql(tableName)
  123. _, err = session.exec(sqlStr)
  124. return err
  125. }
  126. return nil
  127. }
  128. // IsTableExist if a table is exist
  129. func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) {
  130. if session.isAutoClose {
  131. defer session.Close()
  132. }
  133. tableName, err := session.engine.tableName(beanOrTableName)
  134. if err != nil {
  135. return false, err
  136. }
  137. return session.isTableExist(tableName)
  138. }
  139. func (session *Session) isTableExist(tableName string) (bool, error) {
  140. sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
  141. results, err := session.queryBytes(sqlStr, args...)
  142. return len(results) > 0, err
  143. }
  144. // IsTableEmpty if table have any records
  145. func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
  146. v := rValue(bean)
  147. t := v.Type()
  148. if t.Kind() == reflect.String {
  149. if session.isAutoClose {
  150. defer session.Close()
  151. }
  152. return session.isTableEmpty(bean.(string))
  153. } else if t.Kind() == reflect.Struct {
  154. rows, err := session.Count(bean)
  155. return rows == 0, err
  156. }
  157. return false, errors.New("bean should be a struct or struct's point")
  158. }
  159. func (session *Session) isTableEmpty(tableName string) (bool, error) {
  160. var total int64
  161. sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(tableName))
  162. err := session.queryRow(sqlStr).Scan(&total)
  163. if err != nil {
  164. if err == sql.ErrNoRows {
  165. err = nil
  166. }
  167. return true, err
  168. }
  169. return total == 0, nil
  170. }
  171. // find if index is exist according cols
  172. func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) {
  173. indexes, err := session.engine.dialect.GetIndexes(tableName)
  174. if err != nil {
  175. return false, err
  176. }
  177. for _, index := range indexes {
  178. if sliceEq(index.Cols, cols) {
  179. if unique {
  180. return index.Type == core.UniqueType, nil
  181. }
  182. return index.Type == core.IndexType, nil
  183. }
  184. }
  185. return false, nil
  186. }
  187. func (session *Session) addColumn(colName string) error {
  188. col := session.statement.RefTable.GetColumn(colName)
  189. sql, args := session.statement.genAddColumnStr(col)
  190. _, err := session.exec(sql, args...)
  191. return err
  192. }
  193. func (session *Session) addIndex(tableName, idxName string) error {
  194. index := session.statement.RefTable.Indexes[idxName]
  195. sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
  196. _, err := session.exec(sqlStr)
  197. return err
  198. }
  199. func (session *Session) addUnique(tableName, uqeName string) error {
  200. index := session.statement.RefTable.Indexes[uqeName]
  201. sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
  202. _, err := session.exec(sqlStr)
  203. return err
  204. }
  205. // Sync2 synchronize structs to database tables
  206. func (session *Session) Sync2(beans ...interface{}) error {
  207. engine := session.engine
  208. if session.isAutoClose {
  209. session.isAutoClose = false
  210. defer session.Close()
  211. }
  212. tables, err := engine.DBMetas()
  213. if err != nil {
  214. return err
  215. }
  216. var structTables []*core.Table
  217. for _, bean := range beans {
  218. v := rValue(bean)
  219. table, err := engine.mapType(v)
  220. if err != nil {
  221. return err
  222. }
  223. structTables = append(structTables, table)
  224. var tbName = session.tbNameNoSchema(table)
  225. var oriTable *core.Table
  226. for _, tb := range tables {
  227. if strings.EqualFold(tb.Name, tbName) {
  228. oriTable = tb
  229. break
  230. }
  231. }
  232. if oriTable == nil {
  233. err = session.StoreEngine(session.statement.StoreEngine).createTable(bean)
  234. if err != nil {
  235. return err
  236. }
  237. err = session.createUniques(bean)
  238. if err != nil {
  239. return err
  240. }
  241. err = session.createIndexes(bean)
  242. if err != nil {
  243. return err
  244. }
  245. } else {
  246. for _, col := range table.Columns() {
  247. var oriCol *core.Column
  248. for _, col2 := range oriTable.Columns() {
  249. if strings.EqualFold(col.Name, col2.Name) {
  250. oriCol = col2
  251. break
  252. }
  253. }
  254. if oriCol != nil {
  255. expectedType := engine.dialect.SqlType(col)
  256. curType := engine.dialect.SqlType(oriCol)
  257. if expectedType != curType {
  258. if expectedType == core.Text &&
  259. strings.HasPrefix(curType, core.Varchar) {
  260. // currently only support mysql & postgres
  261. if engine.dialect.DBType() == core.MYSQL ||
  262. engine.dialect.DBType() == core.POSTGRES {
  263. engine.logger.Infof("Table %s column %s change type from %s to %s\n",
  264. tbName, col.Name, curType, expectedType)
  265. _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
  266. } else {
  267. engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
  268. tbName, col.Name, curType, expectedType)
  269. }
  270. } else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) {
  271. if engine.dialect.DBType() == core.MYSQL {
  272. if oriCol.Length < col.Length {
  273. engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
  274. tbName, col.Name, oriCol.Length, col.Length)
  275. _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
  276. }
  277. }
  278. } else {
  279. if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
  280. engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
  281. tbName, col.Name, curType, expectedType)
  282. }
  283. }
  284. } else if expectedType == core.Varchar {
  285. if engine.dialect.DBType() == core.MYSQL {
  286. if oriCol.Length < col.Length {
  287. engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
  288. tbName, col.Name, oriCol.Length, col.Length)
  289. _, err = session.exec(engine.dialect.ModifyColumnSql(table.Name, col))
  290. }
  291. }
  292. }
  293. if col.Default != oriCol.Default {
  294. engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s",
  295. tbName, col.Name, oriCol.Default, col.Default)
  296. }
  297. if col.Nullable != oriCol.Nullable {
  298. engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v",
  299. tbName, col.Name, oriCol.Nullable, col.Nullable)
  300. }
  301. } else {
  302. session.statement.RefTable = table
  303. session.statement.tableName = tbName
  304. err = session.addColumn(col.Name)
  305. }
  306. if err != nil {
  307. return err
  308. }
  309. }
  310. var foundIndexNames = make(map[string]bool)
  311. var addedNames = make(map[string]*core.Index)
  312. for name, index := range table.Indexes {
  313. var oriIndex *core.Index
  314. for name2, index2 := range oriTable.Indexes {
  315. if index.Equal(index2) {
  316. oriIndex = index2
  317. foundIndexNames[name2] = true
  318. break
  319. }
  320. }
  321. if oriIndex != nil {
  322. if oriIndex.Type != index.Type {
  323. sql := engine.dialect.DropIndexSql(tbName, oriIndex)
  324. _, err = session.exec(sql)
  325. if err != nil {
  326. return err
  327. }
  328. oriIndex = nil
  329. }
  330. }
  331. if oriIndex == nil {
  332. addedNames[name] = index
  333. }
  334. }
  335. for name2, index2 := range oriTable.Indexes {
  336. if _, ok := foundIndexNames[name2]; !ok {
  337. sql := engine.dialect.DropIndexSql(tbName, index2)
  338. _, err = session.exec(sql)
  339. if err != nil {
  340. return err
  341. }
  342. }
  343. }
  344. for name, index := range addedNames {
  345. if index.Type == core.UniqueType {
  346. session.statement.RefTable = table
  347. session.statement.tableName = tbName
  348. err = session.addUnique(tbName, name)
  349. } else if index.Type == core.IndexType {
  350. session.statement.RefTable = table
  351. session.statement.tableName = tbName
  352. err = session.addIndex(tbName, name)
  353. }
  354. if err != nil {
  355. return err
  356. }
  357. }
  358. }
  359. }
  360. for _, table := range tables {
  361. var oriTable *core.Table
  362. for _, structTable := range structTables {
  363. if strings.EqualFold(table.Name, session.tbNameNoSchema(structTable)) {
  364. oriTable = structTable
  365. break
  366. }
  367. }
  368. if oriTable == nil {
  369. //engine.LogWarnf("Table %s has no struct to mapping it", table.Name)
  370. continue
  371. }
  372. for _, colName := range table.ColumnsSeq() {
  373. if oriTable.GetColumn(colName) == nil {
  374. engine.logger.Warnf("Table %s has column %s but struct has not related field", table.Name, colName)
  375. }
  376. }
  377. }
  378. return nil
  379. }