You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

430 lines
11 KiB

  1. // Copyright 2016 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "database/sql"
  7. "fmt"
  8. "strings"
  9. "xorm.io/core"
  10. )
  11. // Ping test if database is ok
  12. func (session *Session) Ping() error {
  13. if session.isAutoClose {
  14. defer session.Close()
  15. }
  16. session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
  17. return session.DB().PingContext(session.ctx)
  18. }
  19. // CreateTable create a table according a bean
  20. func (session *Session) CreateTable(bean interface{}) error {
  21. if session.isAutoClose {
  22. defer session.Close()
  23. }
  24. return session.createTable(bean)
  25. }
  26. func (session *Session) createTable(bean interface{}) error {
  27. if err := session.statement.setRefBean(bean); err != nil {
  28. return err
  29. }
  30. sqlStr := session.statement.genCreateTableSQL()
  31. _, err := session.exec(sqlStr)
  32. return err
  33. }
  34. // CreateIndexes create indexes
  35. func (session *Session) CreateIndexes(bean interface{}) error {
  36. if session.isAutoClose {
  37. defer session.Close()
  38. }
  39. return session.createIndexes(bean)
  40. }
  41. func (session *Session) createIndexes(bean interface{}) error {
  42. if err := session.statement.setRefBean(bean); err != nil {
  43. return err
  44. }
  45. sqls := session.statement.genIndexSQL()
  46. for _, sqlStr := range sqls {
  47. _, err := session.exec(sqlStr)
  48. if err != nil {
  49. return err
  50. }
  51. }
  52. return nil
  53. }
  54. // CreateUniques create uniques
  55. func (session *Session) CreateUniques(bean interface{}) error {
  56. if session.isAutoClose {
  57. defer session.Close()
  58. }
  59. return session.createUniques(bean)
  60. }
  61. func (session *Session) createUniques(bean interface{}) error {
  62. if err := session.statement.setRefBean(bean); err != nil {
  63. return err
  64. }
  65. sqls := session.statement.genUniqueSQL()
  66. for _, sqlStr := range sqls {
  67. _, err := session.exec(sqlStr)
  68. if err != nil {
  69. return err
  70. }
  71. }
  72. return nil
  73. }
  74. // DropIndexes drop indexes
  75. func (session *Session) DropIndexes(bean interface{}) error {
  76. if session.isAutoClose {
  77. defer session.Close()
  78. }
  79. return session.dropIndexes(bean)
  80. }
  81. func (session *Session) dropIndexes(bean interface{}) error {
  82. if err := session.statement.setRefBean(bean); err != nil {
  83. return err
  84. }
  85. sqls := session.statement.genDelIndexSQL()
  86. for _, sqlStr := range sqls {
  87. _, err := session.exec(sqlStr)
  88. if err != nil {
  89. return err
  90. }
  91. }
  92. return nil
  93. }
  94. // DropTable drop table will drop table if exist, if drop failed, it will return error
  95. func (session *Session) DropTable(beanOrTableName interface{}) error {
  96. if session.isAutoClose {
  97. defer session.Close()
  98. }
  99. return session.dropTable(beanOrTableName)
  100. }
  101. func (session *Session) dropTable(beanOrTableName interface{}) error {
  102. tableName := session.engine.TableName(beanOrTableName)
  103. var needDrop = true
  104. if !session.engine.dialect.SupportDropIfExists() {
  105. sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
  106. results, err := session.queryBytes(sqlStr, args...)
  107. if err != nil {
  108. return err
  109. }
  110. needDrop = len(results) > 0
  111. }
  112. if needDrop {
  113. sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true))
  114. _, err := session.exec(sqlStr)
  115. return err
  116. }
  117. return nil
  118. }
  119. // IsTableExist if a table is exist
  120. func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) {
  121. if session.isAutoClose {
  122. defer session.Close()
  123. }
  124. tableName := session.engine.TableName(beanOrTableName)
  125. return session.isTableExist(tableName)
  126. }
  127. func (session *Session) isTableExist(tableName string) (bool, error) {
  128. sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
  129. results, err := session.queryBytes(sqlStr, args...)
  130. return len(results) > 0, err
  131. }
  132. // IsTableEmpty if table have any records
  133. func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
  134. if session.isAutoClose {
  135. defer session.Close()
  136. }
  137. return session.isTableEmpty(session.engine.TableName(bean))
  138. }
  139. func (session *Session) isTableEmpty(tableName string) (bool, error) {
  140. var total int64
  141. sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true)))
  142. err := session.queryRow(sqlStr).Scan(&total)
  143. if err != nil {
  144. if err == sql.ErrNoRows {
  145. err = nil
  146. }
  147. return true, err
  148. }
  149. return total == 0, nil
  150. }
  151. // find if index is exist according cols
  152. func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) {
  153. indexes, err := session.engine.dialect.GetIndexes(tableName)
  154. if err != nil {
  155. return false, err
  156. }
  157. for _, index := range indexes {
  158. if sliceEq(index.Cols, cols) {
  159. if unique {
  160. return index.Type == core.UniqueType, nil
  161. }
  162. return index.Type == core.IndexType, nil
  163. }
  164. }
  165. return false, nil
  166. }
  167. func (session *Session) addColumn(colName string) error {
  168. col := session.statement.RefTable.GetColumn(colName)
  169. sql, args := session.statement.genAddColumnStr(col)
  170. _, err := session.exec(sql, args...)
  171. return err
  172. }
  173. func (session *Session) addIndex(tableName, idxName string) error {
  174. index := session.statement.RefTable.Indexes[idxName]
  175. sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
  176. _, err := session.exec(sqlStr)
  177. return err
  178. }
  179. func (session *Session) addUnique(tableName, uqeName string) error {
  180. index := session.statement.RefTable.Indexes[uqeName]
  181. sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
  182. _, err := session.exec(sqlStr)
  183. return err
  184. }
  185. // Sync2 synchronize structs to database tables
  186. func (session *Session) Sync2(beans ...interface{}) error {
  187. engine := session.engine
  188. if session.isAutoClose {
  189. session.isAutoClose = false
  190. defer session.Close()
  191. }
  192. tables, err := engine.dialect.GetTables()
  193. if err != nil {
  194. return err
  195. }
  196. session.autoResetStatement = false
  197. defer func() {
  198. session.autoResetStatement = true
  199. session.resetStatement()
  200. }()
  201. for _, bean := range beans {
  202. v := rValue(bean)
  203. table, err := engine.mapType(v)
  204. if err != nil {
  205. return err
  206. }
  207. var tbName string
  208. if len(session.statement.AltTableName) > 0 {
  209. tbName = session.statement.AltTableName
  210. } else {
  211. tbName = engine.TableName(bean)
  212. }
  213. tbNameWithSchema := engine.tbNameWithSchema(tbName)
  214. var oriTable *core.Table
  215. for _, tb := range tables {
  216. if strings.EqualFold(engine.tbNameWithSchema(tb.Name), engine.tbNameWithSchema(tbName)) {
  217. oriTable = tb
  218. break
  219. }
  220. }
  221. // this is a new table
  222. if oriTable == nil {
  223. err = session.StoreEngine(session.statement.StoreEngine).createTable(bean)
  224. if err != nil {
  225. return err
  226. }
  227. err = session.createUniques(bean)
  228. if err != nil {
  229. return err
  230. }
  231. err = session.createIndexes(bean)
  232. if err != nil {
  233. return err
  234. }
  235. continue
  236. }
  237. // this will modify an old table
  238. if err = engine.loadTableInfo(oriTable); err != nil {
  239. return err
  240. }
  241. // check columns
  242. for _, col := range table.Columns() {
  243. var oriCol *core.Column
  244. for _, col2 := range oriTable.Columns() {
  245. if strings.EqualFold(col.Name, col2.Name) {
  246. oriCol = col2
  247. break
  248. }
  249. }
  250. // column is not exist on table
  251. if oriCol == nil {
  252. session.statement.RefTable = table
  253. session.statement.tableName = tbNameWithSchema
  254. if err = session.addColumn(col.Name); err != nil {
  255. return err
  256. }
  257. continue
  258. }
  259. err = nil
  260. expectedType := engine.dialect.SqlType(col)
  261. curType := engine.dialect.SqlType(oriCol)
  262. if expectedType != curType {
  263. if expectedType == core.Text &&
  264. strings.HasPrefix(curType, core.Varchar) {
  265. // currently only support mysql & postgres
  266. if engine.dialect.DBType() == core.MYSQL ||
  267. engine.dialect.DBType() == core.POSTGRES {
  268. engine.logger.Infof("Table %s column %s change type from %s to %s\n",
  269. tbNameWithSchema, col.Name, curType, expectedType)
  270. _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
  271. } else {
  272. engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
  273. tbNameWithSchema, col.Name, curType, expectedType)
  274. }
  275. } else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) {
  276. if engine.dialect.DBType() == core.MYSQL {
  277. if oriCol.Length < col.Length {
  278. engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
  279. tbNameWithSchema, col.Name, oriCol.Length, col.Length)
  280. _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
  281. }
  282. }
  283. } else {
  284. if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
  285. engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
  286. tbNameWithSchema, col.Name, curType, expectedType)
  287. }
  288. }
  289. } else if expectedType == core.Varchar {
  290. if engine.dialect.DBType() == core.MYSQL {
  291. if oriCol.Length < col.Length {
  292. engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
  293. tbNameWithSchema, col.Name, oriCol.Length, col.Length)
  294. _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
  295. }
  296. }
  297. }
  298. if col.Default != oriCol.Default {
  299. if (col.SQLType.Name == core.Bool || col.SQLType.Name == core.Boolean) &&
  300. ((strings.EqualFold(col.Default, "true") && oriCol.Default == "1") ||
  301. (strings.EqualFold(col.Default, "false") && oriCol.Default == "0")) {
  302. } else {
  303. engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s",
  304. tbName, col.Name, oriCol.Default, col.Default)
  305. }
  306. }
  307. if col.Nullable != oriCol.Nullable {
  308. engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v",
  309. tbName, col.Name, oriCol.Nullable, col.Nullable)
  310. }
  311. if err != nil {
  312. return err
  313. }
  314. }
  315. var foundIndexNames = make(map[string]bool)
  316. var addedNames = make(map[string]*core.Index)
  317. for name, index := range table.Indexes {
  318. var oriIndex *core.Index
  319. for name2, index2 := range oriTable.Indexes {
  320. if index.Equal(index2) {
  321. oriIndex = index2
  322. foundIndexNames[name2] = true
  323. break
  324. }
  325. }
  326. if oriIndex != nil {
  327. if oriIndex.Type != index.Type {
  328. sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex)
  329. _, err = session.exec(sql)
  330. if err != nil {
  331. return err
  332. }
  333. oriIndex = nil
  334. }
  335. }
  336. if oriIndex == nil {
  337. addedNames[name] = index
  338. }
  339. }
  340. for name2, index2 := range oriTable.Indexes {
  341. if _, ok := foundIndexNames[name2]; !ok {
  342. sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2)
  343. _, err = session.exec(sql)
  344. if err != nil {
  345. return err
  346. }
  347. }
  348. }
  349. for name, index := range addedNames {
  350. if index.Type == core.UniqueType {
  351. session.statement.RefTable = table
  352. session.statement.tableName = tbNameWithSchema
  353. err = session.addUnique(tbNameWithSchema, name)
  354. } else if index.Type == core.IndexType {
  355. session.statement.RefTable = table
  356. session.statement.tableName = tbNameWithSchema
  357. err = session.addIndex(tbNameWithSchema, name)
  358. }
  359. if err != nil {
  360. return err
  361. }
  362. }
  363. // check all the columns which removed from struct fields but left on database tables.
  364. for _, colName := range oriTable.ColumnsSeq() {
  365. if table.GetColumn(colName) == nil {
  366. engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(oriTable.Name, true), colName)
  367. }
  368. }
  369. }
  370. return nil
  371. }