You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

305 lines
6.6 KiB

  1. package testfixtures
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "io/ioutil"
  6. "path"
  7. "path/filepath"
  8. "regexp"
  9. "strings"
  10. "gopkg.in/yaml.v2"
  11. )
  12. // Context holds the fixtures to be loaded in the database.
  13. type Context struct {
  14. db *sql.DB
  15. helper Helper
  16. fixturesFiles []*fixtureFile
  17. }
  18. type fixtureFile struct {
  19. path string
  20. fileName string
  21. content []byte
  22. insertSQLs []insertSQL
  23. }
  24. type insertSQL struct {
  25. sql string
  26. params []interface{}
  27. }
  28. var (
  29. dbnameRegexp = regexp.MustCompile("(?i)test")
  30. )
  31. // NewFolder creates a context for all fixtures in a given folder into the database:
  32. // NewFolder(db, &PostgreSQL{}, "my/fixtures/folder")
  33. func NewFolder(db *sql.DB, helper Helper, folderName string) (*Context, error) {
  34. fixtures, err := fixturesFromFolder(folderName)
  35. if err != nil {
  36. return nil, err
  37. }
  38. c, err := newContext(db, helper, fixtures)
  39. if err != nil {
  40. return nil, err
  41. }
  42. return c, nil
  43. }
  44. // NewFiles creates a context for all specified fixtures files into database:
  45. // NewFiles(db, &PostgreSQL{},
  46. // "fixtures/customers.yml",
  47. // "fixtures/orders.yml"
  48. // // add as many files you want
  49. // )
  50. func NewFiles(db *sql.DB, helper Helper, fileNames ...string) (*Context, error) {
  51. fixtures, err := fixturesFromFiles(fileNames...)
  52. if err != nil {
  53. return nil, err
  54. }
  55. c, err := newContext(db, helper, fixtures)
  56. if err != nil {
  57. return nil, err
  58. }
  59. return c, nil
  60. }
  61. func newContext(db *sql.DB, helper Helper, fixtures []*fixtureFile) (*Context, error) {
  62. c := &Context{
  63. db: db,
  64. helper: helper,
  65. fixturesFiles: fixtures,
  66. }
  67. if err := c.helper.init(c.db); err != nil {
  68. return nil, err
  69. }
  70. if err := c.buildInsertSQLs(); err != nil {
  71. return nil, err
  72. }
  73. return c, nil
  74. }
  75. // DetectTestDatabase returns nil if databaseName matches regexp
  76. // if err := fixtures.DetectTestDatabase(); err != nil {
  77. // log.Fatal(err)
  78. // }
  79. func (c *Context) DetectTestDatabase() error {
  80. dbName, err := c.helper.databaseName(c.db)
  81. if err != nil {
  82. return err
  83. }
  84. if !dbnameRegexp.MatchString(dbName) {
  85. return ErrNotTestDatabase
  86. }
  87. return nil
  88. }
  89. // Load wipes and after load all fixtures in the database.
  90. // if err := fixtures.Load(); err != nil {
  91. // log.Fatal(err)
  92. // }
  93. func (c *Context) Load() error {
  94. if !skipDatabaseNameCheck {
  95. if err := c.DetectTestDatabase(); err != nil {
  96. return err
  97. }
  98. }
  99. err := c.helper.disableReferentialIntegrity(c.db, func(tx *sql.Tx) error {
  100. for _, file := range c.fixturesFiles {
  101. modified, err := c.helper.isTableModified(tx, file.fileNameWithoutExtension())
  102. if err != nil {
  103. return err
  104. }
  105. if !modified {
  106. continue
  107. }
  108. if err := file.delete(tx, c.helper); err != nil {
  109. return err
  110. }
  111. err = c.helper.whileInsertOnTable(tx, file.fileNameWithoutExtension(), func() error {
  112. for j, i := range file.insertSQLs {
  113. if _, err := tx.Exec(i.sql, i.params...); err != nil {
  114. return &InsertError{
  115. Err: err,
  116. File: file.fileName,
  117. Index: j,
  118. SQL: i.sql,
  119. Params: i.params,
  120. }
  121. }
  122. }
  123. return nil
  124. })
  125. if err != nil {
  126. return err
  127. }
  128. }
  129. return nil
  130. })
  131. if err != nil {
  132. return err
  133. }
  134. return c.helper.afterLoad(c.db)
  135. }
  136. func (c *Context) buildInsertSQLs() error {
  137. for _, f := range c.fixturesFiles {
  138. var records interface{}
  139. if err := yaml.Unmarshal(f.content, &records); err != nil {
  140. return err
  141. }
  142. switch records := records.(type) {
  143. case []interface{}:
  144. for _, record := range records {
  145. recordMap, ok := record.(map[interface{}]interface{})
  146. if !ok {
  147. return ErrWrongCastNotAMap
  148. }
  149. sql, values, err := f.buildInsertSQL(c.helper, recordMap)
  150. if err != nil {
  151. return err
  152. }
  153. f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
  154. }
  155. case map[interface{}]interface{}:
  156. for _, record := range records {
  157. recordMap, ok := record.(map[interface{}]interface{})
  158. if !ok {
  159. return ErrWrongCastNotAMap
  160. }
  161. sql, values, err := f.buildInsertSQL(c.helper, recordMap)
  162. if err != nil {
  163. return err
  164. }
  165. f.insertSQLs = append(f.insertSQLs, insertSQL{sql, values})
  166. }
  167. default:
  168. return ErrFileIsNotSliceOrMap
  169. }
  170. }
  171. return nil
  172. }
  173. func (f *fixtureFile) fileNameWithoutExtension() string {
  174. return strings.Replace(f.fileName, filepath.Ext(f.fileName), "", 1)
  175. }
  176. func (f *fixtureFile) delete(tx *sql.Tx, h Helper) error {
  177. _, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", h.quoteKeyword(f.fileNameWithoutExtension())))
  178. return err
  179. }
  180. func (f *fixtureFile) buildInsertSQL(h Helper, record map[interface{}]interface{}) (sqlStr string, values []interface{}, err error) {
  181. var (
  182. sqlColumns []string
  183. sqlValues []string
  184. i = 1
  185. )
  186. for key, value := range record {
  187. keyStr, ok := key.(string)
  188. if !ok {
  189. err = ErrKeyIsNotString
  190. return
  191. }
  192. sqlColumns = append(sqlColumns, h.quoteKeyword(keyStr))
  193. // if string, try convert to SQL or time
  194. // if map or array, convert to json
  195. switch v := value.(type) {
  196. case string:
  197. if strings.HasPrefix(v, "RAW=") {
  198. sqlValues = append(sqlValues, strings.TrimPrefix(v, "RAW="))
  199. continue
  200. }
  201. if t, err := tryStrToDate(v); err == nil {
  202. value = t
  203. }
  204. case []interface{}, map[interface{}]interface{}:
  205. value = recursiveToJSON(v)
  206. }
  207. switch h.paramType() {
  208. case paramTypeDollar:
  209. sqlValues = append(sqlValues, fmt.Sprintf("$%d", i))
  210. case paramTypeQuestion:
  211. sqlValues = append(sqlValues, "?")
  212. case paramTypeColon:
  213. sqlValues = append(sqlValues, fmt.Sprintf(":%d", i))
  214. }
  215. values = append(values, value)
  216. i++
  217. }
  218. sqlStr = fmt.Sprintf(
  219. "INSERT INTO %s (%s) VALUES (%s)",
  220. h.quoteKeyword(f.fileNameWithoutExtension()),
  221. strings.Join(sqlColumns, ", "),
  222. strings.Join(sqlValues, ", "),
  223. )
  224. return
  225. }
  226. func fixturesFromFolder(folderName string) ([]*fixtureFile, error) {
  227. var files []*fixtureFile
  228. fileinfos, err := ioutil.ReadDir(folderName)
  229. if err != nil {
  230. return nil, err
  231. }
  232. for _, fileinfo := range fileinfos {
  233. if !fileinfo.IsDir() && filepath.Ext(fileinfo.Name()) == ".yml" {
  234. fixture := &fixtureFile{
  235. path: path.Join(folderName, fileinfo.Name()),
  236. fileName: fileinfo.Name(),
  237. }
  238. fixture.content, err = ioutil.ReadFile(fixture.path)
  239. if err != nil {
  240. return nil, err
  241. }
  242. files = append(files, fixture)
  243. }
  244. }
  245. return files, nil
  246. }
  247. func fixturesFromFiles(fileNames ...string) ([]*fixtureFile, error) {
  248. var (
  249. fixtureFiles []*fixtureFile
  250. err error
  251. )
  252. for _, f := range fileNames {
  253. fixture := &fixtureFile{
  254. path: f,
  255. fileName: filepath.Base(f),
  256. }
  257. fixture.content, err = ioutil.ReadFile(fixture.path)
  258. if err != nil {
  259. return nil, err
  260. }
  261. fixtureFiles = append(fixtureFiles, fixture)
  262. }
  263. return fixtureFiles, nil
  264. }