You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

421 lines
11 KiB

  1. // Copyright 2016 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "database/sql"
  7. "fmt"
  8. "strings"
  9. "xorm.io/core"
  10. )
  11. // Ping test if database is ok
  12. func (session *Session) Ping() error {
  13. if session.isAutoClose {
  14. defer session.Close()
  15. }
  16. session.engine.logger.Infof("PING DATABASE %v", session.engine.DriverName())
  17. return session.DB().PingContext(session.ctx)
  18. }
  19. // CreateTable create a table according a bean
  20. func (session *Session) CreateTable(bean interface{}) error {
  21. if session.isAutoClose {
  22. defer session.Close()
  23. }
  24. return session.createTable(bean)
  25. }
  26. func (session *Session) createTable(bean interface{}) error {
  27. if err := session.statement.setRefBean(bean); err != nil {
  28. return err
  29. }
  30. sqlStr := session.statement.genCreateTableSQL()
  31. _, err := session.exec(sqlStr)
  32. return err
  33. }
  34. // CreateIndexes create indexes
  35. func (session *Session) CreateIndexes(bean interface{}) error {
  36. if session.isAutoClose {
  37. defer session.Close()
  38. }
  39. return session.createIndexes(bean)
  40. }
  41. func (session *Session) createIndexes(bean interface{}) error {
  42. if err := session.statement.setRefBean(bean); err != nil {
  43. return err
  44. }
  45. sqls := session.statement.genIndexSQL()
  46. for _, sqlStr := range sqls {
  47. _, err := session.exec(sqlStr)
  48. if err != nil {
  49. return err
  50. }
  51. }
  52. return nil
  53. }
  54. // CreateUniques create uniques
  55. func (session *Session) CreateUniques(bean interface{}) error {
  56. if session.isAutoClose {
  57. defer session.Close()
  58. }
  59. return session.createUniques(bean)
  60. }
  61. func (session *Session) createUniques(bean interface{}) error {
  62. if err := session.statement.setRefBean(bean); err != nil {
  63. return err
  64. }
  65. sqls := session.statement.genUniqueSQL()
  66. for _, sqlStr := range sqls {
  67. _, err := session.exec(sqlStr)
  68. if err != nil {
  69. return err
  70. }
  71. }
  72. return nil
  73. }
  74. // DropIndexes drop indexes
  75. func (session *Session) DropIndexes(bean interface{}) error {
  76. if session.isAutoClose {
  77. defer session.Close()
  78. }
  79. return session.dropIndexes(bean)
  80. }
  81. func (session *Session) dropIndexes(bean interface{}) error {
  82. if err := session.statement.setRefBean(bean); err != nil {
  83. return err
  84. }
  85. sqls := session.statement.genDelIndexSQL()
  86. for _, sqlStr := range sqls {
  87. _, err := session.exec(sqlStr)
  88. if err != nil {
  89. return err
  90. }
  91. }
  92. return nil
  93. }
  94. // DropTable drop table will drop table if exist, if drop failed, it will return error
  95. func (session *Session) DropTable(beanOrTableName interface{}) error {
  96. if session.isAutoClose {
  97. defer session.Close()
  98. }
  99. return session.dropTable(beanOrTableName)
  100. }
  101. func (session *Session) dropTable(beanOrTableName interface{}) error {
  102. tableName := session.engine.TableName(beanOrTableName)
  103. var needDrop = true
  104. if !session.engine.dialect.SupportDropIfExists() {
  105. sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
  106. results, err := session.queryBytes(sqlStr, args...)
  107. if err != nil {
  108. return err
  109. }
  110. needDrop = len(results) > 0
  111. }
  112. if needDrop {
  113. sqlStr := session.engine.Dialect().DropTableSql(session.engine.TableName(tableName, true))
  114. _, err := session.exec(sqlStr)
  115. return err
  116. }
  117. return nil
  118. }
  119. // IsTableExist if a table is exist
  120. func (session *Session) IsTableExist(beanOrTableName interface{}) (bool, error) {
  121. if session.isAutoClose {
  122. defer session.Close()
  123. }
  124. tableName := session.engine.TableName(beanOrTableName)
  125. return session.isTableExist(tableName)
  126. }
  127. func (session *Session) isTableExist(tableName string) (bool, error) {
  128. sqlStr, args := session.engine.dialect.TableCheckSql(tableName)
  129. results, err := session.queryBytes(sqlStr, args...)
  130. return len(results) > 0, err
  131. }
  132. // IsTableEmpty if table have any records
  133. func (session *Session) IsTableEmpty(bean interface{}) (bool, error) {
  134. if session.isAutoClose {
  135. defer session.Close()
  136. }
  137. return session.isTableEmpty(session.engine.TableName(bean))
  138. }
  139. func (session *Session) isTableEmpty(tableName string) (bool, error) {
  140. var total int64
  141. sqlStr := fmt.Sprintf("select count(*) from %s", session.engine.Quote(session.engine.TableName(tableName, true)))
  142. err := session.queryRow(sqlStr).Scan(&total)
  143. if err != nil {
  144. if err == sql.ErrNoRows {
  145. err = nil
  146. }
  147. return true, err
  148. }
  149. return total == 0, nil
  150. }
  151. // find if index is exist according cols
  152. func (session *Session) isIndexExist2(tableName string, cols []string, unique bool) (bool, error) {
  153. indexes, err := session.engine.dialect.GetIndexes(tableName)
  154. if err != nil {
  155. return false, err
  156. }
  157. for _, index := range indexes {
  158. if sliceEq(index.Cols, cols) {
  159. if unique {
  160. return index.Type == core.UniqueType, nil
  161. }
  162. return index.Type == core.IndexType, nil
  163. }
  164. }
  165. return false, nil
  166. }
  167. func (session *Session) addColumn(colName string) error {
  168. col := session.statement.RefTable.GetColumn(colName)
  169. sql, args := session.statement.genAddColumnStr(col)
  170. _, err := session.exec(sql, args...)
  171. return err
  172. }
  173. func (session *Session) addIndex(tableName, idxName string) error {
  174. index := session.statement.RefTable.Indexes[idxName]
  175. sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
  176. _, err := session.exec(sqlStr)
  177. return err
  178. }
  179. func (session *Session) addUnique(tableName, uqeName string) error {
  180. index := session.statement.RefTable.Indexes[uqeName]
  181. sqlStr := session.engine.dialect.CreateIndexSql(tableName, index)
  182. _, err := session.exec(sqlStr)
  183. return err
  184. }
  185. // Sync2 synchronize structs to database tables
  186. func (session *Session) Sync2(beans ...interface{}) error {
  187. engine := session.engine
  188. if session.isAutoClose {
  189. session.isAutoClose = false
  190. defer session.Close()
  191. }
  192. tables, err := engine.DBMetas()
  193. if err != nil {
  194. return err
  195. }
  196. session.autoResetStatement = false
  197. defer func() {
  198. session.autoResetStatement = true
  199. session.resetStatement()
  200. }()
  201. var structTables []*core.Table
  202. for _, bean := range beans {
  203. v := rValue(bean)
  204. table, err := engine.mapType(v)
  205. if err != nil {
  206. return err
  207. }
  208. structTables = append(structTables, table)
  209. tbName := engine.TableName(bean)
  210. tbNameWithSchema := engine.TableName(tbName, true)
  211. var oriTable *core.Table
  212. for _, tb := range tables {
  213. if strings.EqualFold(tb.Name, tbName) {
  214. oriTable = tb
  215. break
  216. }
  217. }
  218. if oriTable == nil {
  219. err = session.StoreEngine(session.statement.StoreEngine).createTable(bean)
  220. if err != nil {
  221. return err
  222. }
  223. err = session.createUniques(bean)
  224. if err != nil {
  225. return err
  226. }
  227. err = session.createIndexes(bean)
  228. if err != nil {
  229. return err
  230. }
  231. } else {
  232. for _, col := range table.Columns() {
  233. var oriCol *core.Column
  234. for _, col2 := range oriTable.Columns() {
  235. if strings.EqualFold(col.Name, col2.Name) {
  236. oriCol = col2
  237. break
  238. }
  239. }
  240. if oriCol != nil {
  241. expectedType := engine.dialect.SqlType(col)
  242. curType := engine.dialect.SqlType(oriCol)
  243. if expectedType != curType {
  244. if expectedType == core.Text &&
  245. strings.HasPrefix(curType, core.Varchar) {
  246. // currently only support mysql & postgres
  247. if engine.dialect.DBType() == core.MYSQL ||
  248. engine.dialect.DBType() == core.POSTGRES {
  249. engine.logger.Infof("Table %s column %s change type from %s to %s\n",
  250. tbNameWithSchema, col.Name, curType, expectedType)
  251. _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
  252. } else {
  253. engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s\n",
  254. tbNameWithSchema, col.Name, curType, expectedType)
  255. }
  256. } else if strings.HasPrefix(curType, core.Varchar) && strings.HasPrefix(expectedType, core.Varchar) {
  257. if engine.dialect.DBType() == core.MYSQL {
  258. if oriCol.Length < col.Length {
  259. engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
  260. tbNameWithSchema, col.Name, oriCol.Length, col.Length)
  261. _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
  262. }
  263. }
  264. } else {
  265. if !(strings.HasPrefix(curType, expectedType) && curType[len(expectedType)] == '(') {
  266. engine.logger.Warnf("Table %s column %s db type is %s, struct type is %s",
  267. tbNameWithSchema, col.Name, curType, expectedType)
  268. }
  269. }
  270. } else if expectedType == core.Varchar {
  271. if engine.dialect.DBType() == core.MYSQL {
  272. if oriCol.Length < col.Length {
  273. engine.logger.Infof("Table %s column %s change type from varchar(%d) to varchar(%d)\n",
  274. tbNameWithSchema, col.Name, oriCol.Length, col.Length)
  275. _, err = session.exec(engine.dialect.ModifyColumnSql(tbNameWithSchema, col))
  276. }
  277. }
  278. }
  279. if col.Default != oriCol.Default {
  280. engine.logger.Warnf("Table %s Column %s db default is %s, struct default is %s",
  281. tbName, col.Name, oriCol.Default, col.Default)
  282. }
  283. if col.Nullable != oriCol.Nullable {
  284. engine.logger.Warnf("Table %s Column %s db nullable is %v, struct nullable is %v",
  285. tbName, col.Name, oriCol.Nullable, col.Nullable)
  286. }
  287. } else {
  288. session.statement.RefTable = table
  289. session.statement.tableName = tbNameWithSchema
  290. err = session.addColumn(col.Name)
  291. }
  292. if err != nil {
  293. return err
  294. }
  295. }
  296. var foundIndexNames = make(map[string]bool)
  297. var addedNames = make(map[string]*core.Index)
  298. for name, index := range table.Indexes {
  299. var oriIndex *core.Index
  300. for name2, index2 := range oriTable.Indexes {
  301. if index.Equal(index2) {
  302. oriIndex = index2
  303. foundIndexNames[name2] = true
  304. break
  305. }
  306. }
  307. if oriIndex != nil {
  308. if oriIndex.Type != index.Type {
  309. sql := engine.dialect.DropIndexSql(tbNameWithSchema, oriIndex)
  310. _, err = session.exec(sql)
  311. if err != nil {
  312. return err
  313. }
  314. oriIndex = nil
  315. }
  316. }
  317. if oriIndex == nil {
  318. addedNames[name] = index
  319. }
  320. }
  321. for name2, index2 := range oriTable.Indexes {
  322. if _, ok := foundIndexNames[name2]; !ok {
  323. sql := engine.dialect.DropIndexSql(tbNameWithSchema, index2)
  324. _, err = session.exec(sql)
  325. if err != nil {
  326. return err
  327. }
  328. }
  329. }
  330. for name, index := range addedNames {
  331. if index.Type == core.UniqueType {
  332. session.statement.RefTable = table
  333. session.statement.tableName = tbNameWithSchema
  334. err = session.addUnique(tbNameWithSchema, name)
  335. } else if index.Type == core.IndexType {
  336. session.statement.RefTable = table
  337. session.statement.tableName = tbNameWithSchema
  338. err = session.addIndex(tbNameWithSchema, name)
  339. }
  340. if err != nil {
  341. return err
  342. }
  343. }
  344. }
  345. }
  346. for _, table := range tables {
  347. var oriTable *core.Table
  348. for _, structTable := range structTables {
  349. if strings.EqualFold(table.Name, session.tbNameNoSchema(structTable)) {
  350. oriTable = structTable
  351. break
  352. }
  353. }
  354. if oriTable == nil {
  355. //engine.LogWarnf("Table %s has no struct to mapping it", table.Name)
  356. continue
  357. }
  358. for _, colName := range table.ColumnsSeq() {
  359. if oriTable.GetColumn(colName) == nil {
  360. engine.logger.Warnf("Table %s has column %s but struct has not related field", engine.TableName(table.Name, true), colName)
  361. }
  362. }
  363. }
  364. return nil
  365. }