You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
2.3 KiB

  1. package testfixtures
  2. import (
  3. "database/sql"
  4. "fmt"
  5. )
  6. // SQLServer is the helper for SQL Server for this package.
  7. // SQL Server >= 2008 is required.
  8. type SQLServer struct {
  9. baseHelper
  10. tables []string
  11. }
  12. func (h *SQLServer) init(db *sql.DB) error {
  13. var err error
  14. h.tables, err = h.getTables(db)
  15. if err != nil {
  16. return err
  17. }
  18. return nil
  19. }
  20. func (*SQLServer) paramType() int {
  21. return paramTypeQuestion
  22. }
  23. func (*SQLServer) quoteKeyword(str string) string {
  24. return fmt.Sprintf("[%s]", str)
  25. }
  26. func (*SQLServer) databaseName(db *sql.DB) (dbname string) {
  27. db.QueryRow("SELECT DB_NAME()").Scan(&dbname)
  28. return
  29. }
  30. func (*SQLServer) getTables(db *sql.DB) ([]string, error) {
  31. rows, err := db.Query("SELECT table_name FROM information_schema.tables")
  32. if err != nil {
  33. return nil, err
  34. }
  35. tables := make([]string, 0)
  36. defer rows.Close()
  37. for rows.Next() {
  38. var table string
  39. rows.Scan(&table)
  40. tables = append(tables, table)
  41. }
  42. return tables, nil
  43. }
  44. func (*SQLServer) tableHasIdentityColumn(tx *sql.Tx, tableName string) bool {
  45. sql := `
  46. SELECT COUNT(*)
  47. FROM SYS.IDENTITY_COLUMNS
  48. WHERE OBJECT_NAME(OBJECT_ID) = ?
  49. `
  50. var count int
  51. tx.QueryRow(sql, tableName).Scan(&count)
  52. return count > 0
  53. }
  54. func (h *SQLServer) whileInsertOnTable(tx *sql.Tx, tableName string, fn func() error) error {
  55. if h.tableHasIdentityColumn(tx, tableName) {
  56. defer tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s OFF", h.quoteKeyword(tableName)))
  57. _, err := tx.Exec(fmt.Sprintf("SET IDENTITY_INSERT %s ON", h.quoteKeyword(tableName)))
  58. if err != nil {
  59. return err
  60. }
  61. }
  62. return fn()
  63. }
  64. func (h *SQLServer) disableReferentialIntegrity(db *sql.DB, loadFn loadFunction) error {
  65. // ensure the triggers are re-enable after all
  66. defer func() {
  67. sql := ""
  68. for _, table := range h.tables {
  69. sql += fmt.Sprintf("ALTER TABLE %s WITH CHECK CHECK CONSTRAINT ALL;", h.quoteKeyword(table))
  70. }
  71. if _, err := db.Exec(sql); err != nil {
  72. fmt.Printf("Error on re-enabling constraints: %v\n", err)
  73. }
  74. }()
  75. sql := ""
  76. for _, table := range h.tables {
  77. sql += fmt.Sprintf("ALTER TABLE %s NOCHECK CONSTRAINT ALL;", h.quoteKeyword(table))
  78. }
  79. if _, err := db.Exec(sql); err != nil {
  80. return err
  81. }
  82. tx, err := db.Begin()
  83. if err != nil {
  84. return err
  85. }
  86. if err = loadFn(tx); err != nil {
  87. tx.Rollback()
  88. return err
  89. }
  90. return tx.Commit()
  91. }