You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

165 lines
4.4 KiB

  1. // Copyright 2019 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package core
  5. import (
  6. "context"
  7. "database/sql"
  8. "errors"
  9. "reflect"
  10. )
  11. type Stmt struct {
  12. *sql.Stmt
  13. db *DB
  14. names map[string]int
  15. }
  16. func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
  17. names := make(map[string]int)
  18. var i int
  19. query = re.ReplaceAllStringFunc(query, func(src string) string {
  20. names[src[1:]] = i
  21. i += 1
  22. return "?"
  23. })
  24. stmt, err := db.DB.PrepareContext(ctx, query)
  25. if err != nil {
  26. return nil, err
  27. }
  28. return &Stmt{stmt, db, names}, nil
  29. }
  30. func (db *DB) Prepare(query string) (*Stmt, error) {
  31. return db.PrepareContext(context.Background(), query)
  32. }
  33. func (s *Stmt) ExecMapContext(ctx context.Context, mp interface{}) (sql.Result, error) {
  34. vv := reflect.ValueOf(mp)
  35. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  36. return nil, errors.New("mp should be a map's pointer")
  37. }
  38. args := make([]interface{}, len(s.names))
  39. for k, i := range s.names {
  40. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  41. }
  42. return s.Stmt.ExecContext(ctx, args...)
  43. }
  44. func (s *Stmt) ExecMap(mp interface{}) (sql.Result, error) {
  45. return s.ExecMapContext(context.Background(), mp)
  46. }
  47. func (s *Stmt) ExecStructContext(ctx context.Context, st interface{}) (sql.Result, error) {
  48. vv := reflect.ValueOf(st)
  49. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  50. return nil, errors.New("mp should be a map's pointer")
  51. }
  52. args := make([]interface{}, len(s.names))
  53. for k, i := range s.names {
  54. args[i] = vv.Elem().FieldByName(k).Interface()
  55. }
  56. return s.Stmt.ExecContext(ctx, args...)
  57. }
  58. func (s *Stmt) ExecStruct(st interface{}) (sql.Result, error) {
  59. return s.ExecStructContext(context.Background(), st)
  60. }
  61. func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
  62. rows, err := s.Stmt.QueryContext(ctx, args...)
  63. if err != nil {
  64. return nil, err
  65. }
  66. return &Rows{rows, s.db}, nil
  67. }
  68. func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
  69. return s.QueryContext(context.Background(), args...)
  70. }
  71. func (s *Stmt) QueryMapContext(ctx context.Context, mp interface{}) (*Rows, error) {
  72. vv := reflect.ValueOf(mp)
  73. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  74. return nil, errors.New("mp should be a map's pointer")
  75. }
  76. args := make([]interface{}, len(s.names))
  77. for k, i := range s.names {
  78. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  79. }
  80. return s.QueryContext(ctx, args...)
  81. }
  82. func (s *Stmt) QueryMap(mp interface{}) (*Rows, error) {
  83. return s.QueryMapContext(context.Background(), mp)
  84. }
  85. func (s *Stmt) QueryStructContext(ctx context.Context, st interface{}) (*Rows, error) {
  86. vv := reflect.ValueOf(st)
  87. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  88. return nil, errors.New("mp should be a map's pointer")
  89. }
  90. args := make([]interface{}, len(s.names))
  91. for k, i := range s.names {
  92. args[i] = vv.Elem().FieldByName(k).Interface()
  93. }
  94. return s.Query(args...)
  95. }
  96. func (s *Stmt) QueryStruct(st interface{}) (*Rows, error) {
  97. return s.QueryStructContext(context.Background(), st)
  98. }
  99. func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
  100. rows, err := s.QueryContext(ctx, args...)
  101. return &Row{rows, err}
  102. }
  103. func (s *Stmt) QueryRow(args ...interface{}) *Row {
  104. return s.QueryRowContext(context.Background(), args...)
  105. }
  106. func (s *Stmt) QueryRowMapContext(ctx context.Context, mp interface{}) *Row {
  107. vv := reflect.ValueOf(mp)
  108. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  109. return &Row{nil, errors.New("mp should be a map's pointer")}
  110. }
  111. args := make([]interface{}, len(s.names))
  112. for k, i := range s.names {
  113. args[i] = vv.Elem().MapIndex(reflect.ValueOf(k)).Interface()
  114. }
  115. return s.QueryRowContext(ctx, args...)
  116. }
  117. func (s *Stmt) QueryRowMap(mp interface{}) *Row {
  118. return s.QueryRowMapContext(context.Background(), mp)
  119. }
  120. func (s *Stmt) QueryRowStructContext(ctx context.Context, st interface{}) *Row {
  121. vv := reflect.ValueOf(st)
  122. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  123. return &Row{nil, errors.New("st should be a struct's pointer")}
  124. }
  125. args := make([]interface{}, len(s.names))
  126. for k, i := range s.names {
  127. args[i] = vv.Elem().FieldByName(k).Interface()
  128. }
  129. return s.QueryRowContext(ctx, args...)
  130. }
  131. func (s *Stmt) QueryRowStruct(st interface{}) *Row {
  132. return s.QueryRowStructContext(context.Background(), st)
  133. }