You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

279 lines
6.6 KiB

  1. package testfixtures
  2. import (
  3. "database/sql"
  4. "errors"
  5. "fmt"
  6. "io/ioutil"
  7. "path"
  8. "path/filepath"
  9. "regexp"
  10. "strings"
  11. "gopkg.in/yaml.v2"
  12. )
  13. // Context holds the fixtures to be loaded in the database.
  14. type Context struct {
  15. db *sql.DB
  16. helper Helper
  17. fixturesFiles []*fixtureFile
  18. }
  19. type fixtureFile struct {
  20. path string
  21. fileName string
  22. content []byte
  23. insertSQLs []insertSQL
  24. }
  25. type insertSQL struct {
  26. sql string
  27. params []interface{}
  28. }
  29. var (
  30. // ErrWrongCastNotAMap is returned when a map is not a map[interface{}]interface{}
  31. ErrWrongCastNotAMap = errors.New("Could not cast record: not a map[interface{}]interface{}")
  32. // ErrFileIsNotSliceOrMap is returned the the fixture file is not a slice or map.
  33. ErrFileIsNotSliceOrMap = errors.New("The fixture file is not a slice or map")
  34. // ErrKeyIsNotString is returned when a record is not of type string
  35. ErrKeyIsNotString = errors.New("Record map key is not string")
  36. // ErrNotTestDatabase is returned when the database name doesn't contains "test"
  37. ErrNotTestDatabase = errors.New(`Loading aborted because the database name does not contains "test"`)
  38. dbnameRegexp = regexp.MustCompile("(?i)test")
  39. )
  40. // NewFolder craetes a context for all fixtures in a given folder into the database:
  41. // NewFolder(db, &PostgreSQL{}, "my/fixtures/folder")
  42. func NewFolder(db *sql.DB, helper Helper, folderName string) (*Context, error) {
  43. fixtures, err := fixturesFromFolder(folderName)
  44. if err != nil {
  45. return nil, err
  46. }
  47. c, err := newContext(db, helper, fixtures)
  48. if err != nil {
  49. return nil, err
  50. }
  51. return c, nil
  52. }
  53. // NewFiles craetes a context for all specified fixtures files into database:
  54. // NewFiles(db, &PostgreSQL{},
  55. // "fixtures/customers.yml",
  56. // "fixtures/orders.yml"
  57. // // add as many files you want
  58. // )
  59. func NewFiles(db *sql.DB, helper Helper, fileNames ...string) (*Context, error) {
  60. fixtures, err := fixturesFromFiles(fileNames...)
  61. if err != nil {
  62. return nil, err
  63. }
  64. c, err := newContext(db, helper, fixtures)
  65. if err != nil {
  66. return nil, err
  67. }
  68. return c, nil
  69. }
  70. func newContext(db *sql.DB, helper Helper, fixtures []*fixtureFile) (*Context, error) {
  71. c := &Context{
  72. db: db,
  73. helper: helper,
  74. fixturesFiles: fixtures,
  75. }
  76. if err := c.helper.init(c.db); err != nil {
  77. return nil, err
  78. }
  79. if err := c.buildInsertSQLs(); err != nil {
  80. return nil, err
  81. }
  82. return c, nil
  83. }
  84. // Load wipes and after load all fixtures in the database.
  85. // if err := fixtures.Load(); err != nil {
  86. // log.Fatal(err)
  87. // }
  88. func (c *Context) Load() error {
  89. if !skipDatabaseNameCheck {
  90. if !dbnameRegexp.MatchString(c.helper.databaseName(c.db)) {
  91. return ErrNotTestDatabase
  92. }
  93. }
  94. err := c.helper.disableReferentialIntegrity(c.db, func(tx *sql.Tx) error {
  95. for _, file := range c.fixturesFiles {
  96. if err := file.delete(tx, c.helper); err != nil {
  97. return err
  98. }
  99. err := c.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error {
  100. for _, i := range file.insertSQLs {
  101. if _, err := tx.Exec(i.sql, i.params...); err != nil {
  102. return err
  103. }
  104. }
  105. return nil
  106. })
  107. if err != nil {
  108. return err
  109. }
  110. }
  111. return nil
  112. })
  113. return err
  114. }
  115. func (c *Context) buildInsertSQLs() error {
  116. for _, f := range c.fixturesFiles {
  117. var records interface{}
  118. if err := yaml.Unmarshal(f.content, &records); err != nil {
  119. return err
  120. }
  121. switch records := records.(type) {
  122. case []interface{}:
  123. for _, record := range records {
  124. recordMap, ok := record.(map[interface{}]interface{})
  125. if !ok {
  126. return ErrWrongCastNotAMap
  127. }
  128. sql, values, err := f.buildInsertSQL(c.helper, recordMap)
  129. if err != nil {
  130. return err
  131. }
  132. f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
  133. }
  134. case map[interface{}]interface{}:
  135. for _, record := range records {
  136. recordMap, ok := record.(map[interface{}]interface{})
  137. if !ok {
  138. return ErrWrongCastNotAMap
  139. }
  140. sql, values, err := f.buildInsertSQL(c.helper, recordMap)
  141. if err != nil {
  142. return err
  143. }
  144. f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
  145. }
  146. default:
  147. return ErrFileIsNotSliceOrMap
  148. }
  149. }
  150. return nil
  151. }
  152. func (f *fixtureFile) fileNameWithoutExtension() string {
  153. return strings.Replace(f.fileName, filepath.Ext(f.fileName), "", 1)
  154. }
  155. func (f *fixtureFile) delete(tx *sql.Tx, h Helper) error {
  156. _, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", h.quoteKeyword(f.fileNameWithoutExtension())))
  157. return err
  158. }
  159. func (f *fixtureFile) buildInsertSQL(h Helper, record map[interface{}]interface{}) (sqlStr string, values []interface{}, err error) {
  160. var (
  161. sqlColumns []string
  162. sqlValues []string
  163. i = 1
  164. )
  165. for key, value := range record {
  166. keyStr, ok := key.(string)
  167. if !ok {
  168. err = ErrKeyIsNotString
  169. return
  170. }
  171. sqlColumns = append(sqlColumns, h.quoteKeyword(keyStr))
  172. switch h.paramType() {
  173. case paramTypeDollar:
  174. sqlValues = append(sqlValues, fmt.Sprintf("$%d", i))
  175. case paramTypeQuestion:
  176. sqlValues = append(sqlValues, "?")
  177. case paramTypeColon:
  178. switch {
  179. case isDateTime(value):
  180. sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'YYYY-MM-DD HH24:MI:SS')", i))
  181. case isDate(value):
  182. sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'YYYY-MM-DD')", i))
  183. case isTime(value):
  184. sqlValues = append(sqlValues, fmt.Sprintf("to_date(:%d, 'HH24:MI:SS')", i))
  185. default:
  186. sqlValues = append(sqlValues, fmt.Sprintf(":%d", i))
  187. }
  188. }
  189. i++
  190. values = append(values, value)
  191. }
  192. sqlStr = fmt.Sprintf(
  193. "INSERT INTO %s (%s) VALUES (%s)",
  194. h.quoteKeyword(f.fileNameWithoutExtension()),
  195. strings.Join(sqlColumns, ", "),
  196. strings.Join(sqlValues, ", "),
  197. )
  198. return
  199. }
  200. func fixturesFromFolder(folderName string) ([]*fixtureFile, error) {
  201. var files []*fixtureFile
  202. fileinfos, err := ioutil.ReadDir(folderName)
  203. if err != nil {
  204. return nil, err
  205. }
  206. for _, fileinfo := range fileinfos {
  207. if !fileinfo.IsDir() && filepath.Ext(fileinfo.Name()) == ".yml" {
  208. fixture := &fixtureFile{
  209. path: path.Join(folderName, fileinfo.Name()),
  210. fileName: fileinfo.Name(),
  211. }
  212. fixture.content, err = ioutil.ReadFile(fixture.path)
  213. if err != nil {
  214. return nil, err
  215. }
  216. files = append(files, fixture)
  217. }
  218. }
  219. return files, nil
  220. }
  221. func fixturesFromFiles(fileNames ...string) ([]*fixtureFile, error) {
  222. var (
  223. fixtureFiles []*fixtureFile
  224. err error
  225. )
  226. for _, f := range fileNames {
  227. fixture := &fixtureFile{
  228. path: f,
  229. fileName: filepath.Base(f),
  230. }
  231. fixture.content, err = ioutil.ReadFile(fixture.path)
  232. if err != nil {
  233. return nil, err
  234. }
  235. fixtureFiles = append(fixtureFiles, fixture)
  236. }
  237. return fixtureFiles, nil
  238. }