You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

132 lines
2.8 KiB

  1. package testfixtures
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "strings"
  6. )
  7. // Oracle is the Oracle database helper for this package
  8. type Oracle struct {
  9. baseHelper
  10. enabledConstraints []oracleConstraint
  11. sequences []string
  12. }
  13. type oracleConstraint struct {
  14. tableName string
  15. constraintName string
  16. }
  17. func (h *Oracle) init(db *sql.DB) error {
  18. var err error
  19. h.enabledConstraints, err = h.getEnabledConstraints(db)
  20. if err != nil {
  21. return err
  22. }
  23. h.sequences, err = h.getSequences(db)
  24. if err != nil {
  25. return err
  26. }
  27. return nil
  28. }
  29. func (*Oracle) paramType() int {
  30. return paramTypeColon
  31. }
  32. func (*Oracle) quoteKeyword(str string) string {
  33. return fmt.Sprintf("\"%s\"", strings.ToUpper(str))
  34. }
  35. func (*Oracle) databaseName(db *sql.DB) (dbName string) {
  36. db.QueryRow("SELECT user FROM DUAL").Scan(&dbName)
  37. return
  38. }
  39. func (*Oracle) getEnabledConstraints(db *sql.DB) ([]oracleConstraint, error) {
  40. constraints := make([]oracleConstraint, 0)
  41. rows, err := db.Query(`
  42. SELECT table_name, constraint_name
  43. FROM user_constraints
  44. WHERE constraint_type = 'R'
  45. AND status = 'ENABLED'
  46. `)
  47. if err != nil {
  48. return nil, err
  49. }
  50. defer rows.Close()
  51. for rows.Next() {
  52. var constraint oracleConstraint
  53. rows.Scan(&constraint.tableName, &constraint.constraintName)
  54. constraints = append(constraints, constraint)
  55. }
  56. return constraints, nil
  57. }
  58. func (*Oracle) getSequences(db *sql.DB) ([]string, error) {
  59. sequences := make([]string, 0)
  60. rows, err := db.Query("SELECT sequence_name FROM user_sequences")
  61. if err != nil {
  62. return nil, err
  63. }
  64. defer rows.Close()
  65. for rows.Next() {
  66. var sequence string
  67. rows.Scan(&sequence)
  68. sequences = append(sequences, sequence)
  69. }
  70. return sequences, nil
  71. }
  72. func (h *Oracle) resetSequences(db *sql.DB) error {
  73. for _, sequence := range h.sequences {
  74. _, err := db.Exec(fmt.Sprintf("DROP SEQUENCE %s", h.quoteKeyword(sequence)))
  75. if err != nil {
  76. return err
  77. }
  78. _, err = db.Exec(fmt.Sprintf("CREATE SEQUENCE %s START WITH %d", h.quoteKeyword(sequence), resetSequencesTo))
  79. if err != nil {
  80. return err
  81. }
  82. }
  83. return nil
  84. }
  85. func (h *Oracle) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
  86. // re-enable after load
  87. defer func() {
  88. for _, c := range h.enabledConstraints {
  89. db.Exec(fmt.Sprintf("ALTER TABLE %s ENABLE CONSTRAINT %s", h.quoteKeyword(c.tableName), h.quoteKeyword(c.constraintName)))
  90. }
  91. }()
  92. // disable foreign keys
  93. for _, c := range h.enabledConstraints {
  94. _, err := db.Exec(fmt.Sprintf("ALTER TABLE %s DISABLE CONSTRAINT %s", h.quoteKeyword(c.tableName), h.quoteKeyword(c.constraintName)))
  95. if err != nil {
  96. return err
  97. }
  98. }
  99. tx, err := db.Begin()
  100. if err != nil {
  101. return err
  102. }
  103. if err = loadFn(tx); err != nil {
  104. tx.Rollback()
  105. return err
  106. }
  107. if err = tx.Commit(); err != nil {
  108. return err
  109. }
  110. return h.resetSequences(db)
  111. }