You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

338 lines
6.7 KiB

  1. // Copyright 2019 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package core
  5. import (
  6. "database/sql"
  7. "errors"
  8. "reflect"
  9. "sync"
  10. )
  11. type Rows struct {
  12. *sql.Rows
  13. db *DB
  14. }
  15. func (rs *Rows) ToMapString() ([]map[string]string, error) {
  16. cols, err := rs.Columns()
  17. if err != nil {
  18. return nil, err
  19. }
  20. var results = make([]map[string]string, 0, 10)
  21. for rs.Next() {
  22. var record = make(map[string]string, len(cols))
  23. err = rs.ScanMap(&record)
  24. if err != nil {
  25. return nil, err
  26. }
  27. results = append(results, record)
  28. }
  29. return results, nil
  30. }
  31. // scan data to a struct's pointer according field index
  32. func (rs *Rows) ScanStructByIndex(dest ...interface{}) error {
  33. if len(dest) == 0 {
  34. return errors.New("at least one struct")
  35. }
  36. vvvs := make([]reflect.Value, len(dest))
  37. for i, s := range dest {
  38. vv := reflect.ValueOf(s)
  39. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  40. return errors.New("dest should be a struct's pointer")
  41. }
  42. vvvs[i] = vv.Elem()
  43. }
  44. cols, err := rs.Columns()
  45. if err != nil {
  46. return err
  47. }
  48. newDest := make([]interface{}, len(cols))
  49. var i = 0
  50. for _, vvv := range vvvs {
  51. for j := 0; j < vvv.NumField(); j++ {
  52. newDest[i] = vvv.Field(j).Addr().Interface()
  53. i = i + 1
  54. }
  55. }
  56. return rs.Rows.Scan(newDest...)
  57. }
  58. var (
  59. fieldCache = make(map[reflect.Type]map[string]int)
  60. fieldCacheMutex sync.RWMutex
  61. )
  62. func fieldByName(v reflect.Value, name string) reflect.Value {
  63. t := v.Type()
  64. fieldCacheMutex.RLock()
  65. cache, ok := fieldCache[t]
  66. fieldCacheMutex.RUnlock()
  67. if !ok {
  68. cache = make(map[string]int)
  69. for i := 0; i < v.NumField(); i++ {
  70. cache[t.Field(i).Name] = i
  71. }
  72. fieldCacheMutex.Lock()
  73. fieldCache[t] = cache
  74. fieldCacheMutex.Unlock()
  75. }
  76. if i, ok := cache[name]; ok {
  77. return v.Field(i)
  78. }
  79. return reflect.Zero(t)
  80. }
  81. // scan data to a struct's pointer according field name
  82. func (rs *Rows) ScanStructByName(dest interface{}) error {
  83. vv := reflect.ValueOf(dest)
  84. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Struct {
  85. return errors.New("dest should be a struct's pointer")
  86. }
  87. cols, err := rs.Columns()
  88. if err != nil {
  89. return err
  90. }
  91. newDest := make([]interface{}, len(cols))
  92. var v EmptyScanner
  93. for j, name := range cols {
  94. f := fieldByName(vv.Elem(), rs.db.Mapper.Table2Obj(name))
  95. if f.IsValid() {
  96. newDest[j] = f.Addr().Interface()
  97. } else {
  98. newDest[j] = &v
  99. }
  100. }
  101. return rs.Rows.Scan(newDest...)
  102. }
  103. // scan data to a slice's pointer, slice's length should equal to columns' number
  104. func (rs *Rows) ScanSlice(dest interface{}) error {
  105. vv := reflect.ValueOf(dest)
  106. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Slice {
  107. return errors.New("dest should be a slice's pointer")
  108. }
  109. vvv := vv.Elem()
  110. cols, err := rs.Columns()
  111. if err != nil {
  112. return err
  113. }
  114. newDest := make([]interface{}, len(cols))
  115. for j := 0; j < len(cols); j++ {
  116. if j >= vvv.Len() {
  117. newDest[j] = reflect.New(vvv.Type().Elem()).Interface()
  118. } else {
  119. newDest[j] = vvv.Index(j).Addr().Interface()
  120. }
  121. }
  122. err = rs.Rows.Scan(newDest...)
  123. if err != nil {
  124. return err
  125. }
  126. srcLen := vvv.Len()
  127. for i := srcLen; i < len(cols); i++ {
  128. vvv = reflect.Append(vvv, reflect.ValueOf(newDest[i]).Elem())
  129. }
  130. return nil
  131. }
  132. // scan data to a map's pointer
  133. func (rs *Rows) ScanMap(dest interface{}) error {
  134. vv := reflect.ValueOf(dest)
  135. if vv.Kind() != reflect.Ptr || vv.Elem().Kind() != reflect.Map {
  136. return errors.New("dest should be a map's pointer")
  137. }
  138. cols, err := rs.Columns()
  139. if err != nil {
  140. return err
  141. }
  142. newDest := make([]interface{}, len(cols))
  143. vvv := vv.Elem()
  144. for i := range cols {
  145. newDest[i] = rs.db.reflectNew(vvv.Type().Elem()).Interface()
  146. }
  147. err = rs.Rows.Scan(newDest...)
  148. if err != nil {
  149. return err
  150. }
  151. for i, name := range cols {
  152. vname := reflect.ValueOf(name)
  153. vvv.SetMapIndex(vname, reflect.ValueOf(newDest[i]).Elem())
  154. }
  155. return nil
  156. }
  157. type Row struct {
  158. rows *Rows
  159. // One of these two will be non-nil:
  160. err error // deferred error for easy chaining
  161. }
  162. // ErrorRow return an error row
  163. func ErrorRow(err error) *Row {
  164. return &Row{
  165. err: err,
  166. }
  167. }
  168. // NewRow from rows
  169. func NewRow(rows *Rows, err error) *Row {
  170. return &Row{rows, err}
  171. }
  172. func (row *Row) Columns() ([]string, error) {
  173. if row.err != nil {
  174. return nil, row.err
  175. }
  176. return row.rows.Columns()
  177. }
  178. func (row *Row) Scan(dest ...interface{}) error {
  179. if row.err != nil {
  180. return row.err
  181. }
  182. defer row.rows.Close()
  183. for _, dp := range dest {
  184. if _, ok := dp.(*sql.RawBytes); ok {
  185. return errors.New("sql: RawBytes isn't allowed on Row.Scan")
  186. }
  187. }
  188. if !row.rows.Next() {
  189. if err := row.rows.Err(); err != nil {
  190. return err
  191. }
  192. return sql.ErrNoRows
  193. }
  194. err := row.rows.Scan(dest...)
  195. if err != nil {
  196. return err
  197. }
  198. // Make sure the query can be processed to completion with no errors.
  199. return row.rows.Close()
  200. }
  201. func (row *Row) ScanStructByName(dest interface{}) error {
  202. if row.err != nil {
  203. return row.err
  204. }
  205. defer row.rows.Close()
  206. if !row.rows.Next() {
  207. if err := row.rows.Err(); err != nil {
  208. return err
  209. }
  210. return sql.ErrNoRows
  211. }
  212. err := row.rows.ScanStructByName(dest)
  213. if err != nil {
  214. return err
  215. }
  216. // Make sure the query can be processed to completion with no errors.
  217. return row.rows.Close()
  218. }
  219. func (row *Row) ScanStructByIndex(dest interface{}) error {
  220. if row.err != nil {
  221. return row.err
  222. }
  223. defer row.rows.Close()
  224. if !row.rows.Next() {
  225. if err := row.rows.Err(); err != nil {
  226. return err
  227. }
  228. return sql.ErrNoRows
  229. }
  230. err := row.rows.ScanStructByIndex(dest)
  231. if err != nil {
  232. return err
  233. }
  234. // Make sure the query can be processed to completion with no errors.
  235. return row.rows.Close()
  236. }
  237. // scan data to a slice's pointer, slice's length should equal to columns' number
  238. func (row *Row) ScanSlice(dest interface{}) error {
  239. if row.err != nil {
  240. return row.err
  241. }
  242. defer row.rows.Close()
  243. if !row.rows.Next() {
  244. if err := row.rows.Err(); err != nil {
  245. return err
  246. }
  247. return sql.ErrNoRows
  248. }
  249. err := row.rows.ScanSlice(dest)
  250. if err != nil {
  251. return err
  252. }
  253. // Make sure the query can be processed to completion with no errors.
  254. return row.rows.Close()
  255. }
  256. // scan data to a map's pointer
  257. func (row *Row) ScanMap(dest interface{}) error {
  258. if row.err != nil {
  259. return row.err
  260. }
  261. defer row.rows.Close()
  262. if !row.rows.Next() {
  263. if err := row.rows.Err(); err != nil {
  264. return err
  265. }
  266. return sql.ErrNoRows
  267. }
  268. err := row.rows.ScanMap(dest)
  269. if err != nil {
  270. return err
  271. }
  272. // Make sure the query can be processed to completion with no errors.
  273. return row.rows.Close()
  274. }
  275. func (row *Row) ToMapString() (map[string]string, error) {
  276. cols, err := row.Columns()
  277. if err != nil {
  278. return nil, err
  279. }
  280. var record = make(map[string]string, len(cols))
  281. err = row.ScanMap(&record)
  282. if err != nil {
  283. return nil, err
  284. }
  285. return record, nil
  286. }