You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

211 lines
4.5 KiB

  1. package testfixtures
  2. import (
  3. "database/sql"
  4. "fmt"
  5. )
  6. // PostgreSQL is the PG helper for this package
  7. type PostgreSQL struct {
  8. baseHelper
  9. // UseAlterConstraint If true, the contraint disabling will do
  10. // using ALTER CONTRAINT sintax, only allowed in PG >= 9.4.
  11. // If false, the constraint disabling will use DISABLE TRIGGER ALL,
  12. // which requires SUPERUSER privileges.
  13. UseAlterConstraint bool
  14. tables []string
  15. sequences []string
  16. nonDeferrableConstraints []pgConstraint
  17. }
  18. type pgConstraint struct {
  19. tableName string
  20. constraintName string
  21. }
  22. func (h *PostgreSQL) init(db *sql.DB) error {
  23. var err error
  24. h.tables, err = h.getTables(db)
  25. if err != nil {
  26. return err
  27. }
  28. h.sequences, err = h.getSequences(db)
  29. if err != nil {
  30. return err
  31. }
  32. h.nonDeferrableConstraints, err = h.getNonDeferrableConstraints(db)
  33. if err != nil {
  34. return err
  35. }
  36. return nil
  37. }
  38. func (*PostgreSQL) paramType() int {
  39. return paramTypeDollar
  40. }
  41. func (*PostgreSQL) databaseName(db *sql.DB) (dbName string) {
  42. db.QueryRow("SELECT current_database()").Scan(&dbName)
  43. return
  44. }
  45. func (h *PostgreSQL) getTables(db *sql.DB) ([]string, error) {
  46. var tables []string
  47. sql := `
  48. SELECT table_name
  49. FROM information_schema.tables
  50. WHERE table_schema = 'public'
  51. AND table_type = 'BASE TABLE';
  52. `
  53. rows, err := db.Query(sql)
  54. if err != nil {
  55. return nil, err
  56. }
  57. defer rows.Close()
  58. for rows.Next() {
  59. var table string
  60. rows.Scan(&table)
  61. tables = append(tables, table)
  62. }
  63. return tables, nil
  64. }
  65. func (h *PostgreSQL) getSequences(db *sql.DB) ([]string, error) {
  66. var sequences []string
  67. sql := "SELECT relname FROM pg_class WHERE relkind = 'S'"
  68. rows, err := db.Query(sql)
  69. if err != nil {
  70. return nil, err
  71. }
  72. defer rows.Close()
  73. for rows.Next() {
  74. var sequence string
  75. if err = rows.Scan(&sequence); err != nil {
  76. return nil, err
  77. }
  78. sequences = append(sequences, sequence)
  79. }
  80. return sequences, nil
  81. }
  82. func (*PostgreSQL) getNonDeferrableConstraints(db *sql.DB) ([]pgConstraint, error) {
  83. var constraints []pgConstraint
  84. sql := `
  85. SELECT table_name, constraint_name
  86. FROM information_schema.table_constraints
  87. WHERE constraint_type = 'FOREIGN KEY'
  88. AND is_deferrable = 'NO'`
  89. rows, err := db.Query(sql)
  90. if err != nil {
  91. return nil, err
  92. }
  93. defer rows.Close()
  94. for rows.Next() {
  95. var constraint pgConstraint
  96. err = rows.Scan(&constraint.tableName, &constraint.constraintName)
  97. if err != nil {
  98. return nil, err
  99. }
  100. constraints = append(constraints, constraint)
  101. }
  102. return constraints, nil
  103. }
  104. func (h *PostgreSQL) disableTriggers(db *sql.DB, loadFn loadFunction) error {
  105. defer func() {
  106. // re-enable triggers after load
  107. var sql string
  108. for _, table := range h.tables {
  109. sql += fmt.Sprintf("ALTER TABLE %s ENABLE TRIGGER ALL;", h.quoteKeyword(table))
  110. }
  111. db.Exec(sql)
  112. }()
  113. tx, err := db.Begin()
  114. if err != nil {
  115. return err
  116. }
  117. var sql string
  118. for _, table := range h.tables {
  119. sql += fmt.Sprintf("ALTER TABLE %s DISABLE TRIGGER ALL;", h.quoteKeyword(table))
  120. }
  121. if _, err = tx.Exec(sql); err != nil {
  122. return err
  123. }
  124. if err = loadFn(tx); err != nil {
  125. tx.Rollback()
  126. return err
  127. }
  128. return tx.Commit()
  129. }
  130. func (h *PostgreSQL) makeConstraintsDeferrable(db *sql.DB, loadFn loadFunction) error {
  131. defer func() {
  132. // ensure constraint being not deferrable again after load
  133. var sql string
  134. for _, constraint := range h.nonDeferrableConstraints {
  135. sql += fmt.Sprintf("ALTER TABLE %s ALTER CONSTRAINT %s NOT DEFERRABLE;", h.quoteKeyword(constraint.tableName), h.quoteKeyword(constraint.constraintName))
  136. }
  137. db.Exec(sql)
  138. }()
  139. var sql string
  140. for _, constraint := range h.nonDeferrableConstraints {
  141. sql += fmt.Sprintf("ALTER TABLE %s ALTER CONSTRAINT %s DEFERRABLE;", h.quoteKeyword(constraint.tableName), h.quoteKeyword(constraint.constraintName))
  142. }
  143. if _, err := db.Exec(sql); err != nil {
  144. return err
  145. }
  146. tx, err := db.Begin()
  147. if err != nil {
  148. return err
  149. }
  150. if _, err = tx.Exec("SET CONSTRAINTS ALL DEFERRED"); err != nil {
  151. return nil
  152. }
  153. if err = loadFn(tx); err != nil {
  154. tx.Rollback()
  155. return err
  156. }
  157. return tx.Commit()
  158. }
  159. func (h *PostgreSQL) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
  160. // ensure sequences being reset after load
  161. defer h.resetSequences(db)
  162. if h.UseAlterConstraint {
  163. return h.makeConstraintsDeferrable(db, loadFn)
  164. } else {
  165. return h.disableTriggers(db, loadFn)
  166. }
  167. }
  168. func (h *PostgreSQL) resetSequences(db *sql.DB) error {
  169. for _, sequence := range h.sequences {
  170. _, err := db.Exec(fmt.Sprintf("SELECT SETVAL('%s', %d)", sequence, resetSequencesTo))
  171. if err != nil {
  172. return err
  173. }
  174. }
  175. return nil
  176. }