You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
2.5 KiB

  1. package testfixtures
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "os"
  6. "path"
  7. "unicode/utf8"
  8. "gopkg.in/yaml.v2"
  9. )
  10. // TableInfo is settings for generating a fixture for table.
  11. type TableInfo struct {
  12. Name string // Table name
  13. Where string // A condition for extracting records. If this value is empty, extracts all records.
  14. }
  15. func (ti *TableInfo) whereClause() string {
  16. if ti.Where == "" {
  17. return ""
  18. }
  19. return fmt.Sprintf(" WHERE %s", ti.Where)
  20. }
  21. // GenerateFixtures generates fixtures for the current contents of a database, and saves
  22. // them to the specified directory
  23. func GenerateFixtures(db *sql.DB, helper Helper, dir string) error {
  24. tables, err := helper.tableNames(db)
  25. if err != nil {
  26. return err
  27. }
  28. for _, table := range tables {
  29. filename := path.Join(dir, table+".yml")
  30. if err := generateFixturesForTable(db, helper, &TableInfo{Name: table}, filename); err != nil {
  31. return err
  32. }
  33. }
  34. return nil
  35. }
  36. // GenerateFixturesForTables generates fixtures for the current contents of specified tables in a database, and saves
  37. // them to the specified directory
  38. func GenerateFixturesForTables(db *sql.DB, tables []*TableInfo, helper Helper, dir string) error {
  39. for _, table := range tables {
  40. filename := path.Join(dir, table.Name+".yml")
  41. if err := generateFixturesForTable(db, helper, table, filename); err != nil {
  42. return err
  43. }
  44. }
  45. return nil
  46. }
  47. func generateFixturesForTable(db *sql.DB, h Helper, table *TableInfo, filename string) error {
  48. query := fmt.Sprintf("SELECT * FROM %s%s", h.quoteKeyword(table.Name), table.whereClause())
  49. rows, err := db.Query(query)
  50. if err != nil {
  51. return err
  52. }
  53. defer rows.Close()
  54. columns, err := rows.Columns()
  55. if err != nil {
  56. return err
  57. }
  58. fixtures := make([]interface{}, 0, 10)
  59. for rows.Next() {
  60. entries := make([]interface{}, len(columns))
  61. entryPtrs := make([]interface{}, len(entries))
  62. for i := range entries {
  63. entryPtrs[i] = &entries[i]
  64. }
  65. if err := rows.Scan(entryPtrs...); err != nil {
  66. return err
  67. }
  68. entryMap := make(map[string]interface{}, len(entries))
  69. for i, column := range columns {
  70. entryMap[column] = convertValue(entries[i])
  71. }
  72. fixtures = append(fixtures, entryMap)
  73. }
  74. if err = rows.Err(); err != nil {
  75. return err
  76. }
  77. f, err := os.Create(filename)
  78. if err != nil {
  79. return err
  80. }
  81. defer f.Close()
  82. marshaled, err := yaml.Marshal(fixtures)
  83. if err != nil {
  84. return err
  85. }
  86. _, err = f.Write(marshaled)
  87. return err
  88. }
  89. func convertValue(value interface{}) interface{} {
  90. switch v := value.(type) {
  91. case []byte:
  92. if utf8.Valid(v) {
  93. return string(v)
  94. }
  95. }
  96. return value
  97. }