You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

727 lines
17 KiB

  1. package pq
  2. import (
  3. "bytes"
  4. "database/sql"
  5. "database/sql/driver"
  6. "encoding/hex"
  7. "fmt"
  8. "reflect"
  9. "strconv"
  10. "strings"
  11. )
  12. var typeByteSlice = reflect.TypeOf([]byte{})
  13. var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
  14. var typeSqlScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
  15. // Array returns the optimal driver.Valuer and sql.Scanner for an array or
  16. // slice of any dimension.
  17. //
  18. // For example:
  19. // db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401}))
  20. //
  21. // var x []sql.NullInt64
  22. // db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x))
  23. //
  24. // Scanning multi-dimensional arrays is not supported. Arrays where the lower
  25. // bound is not one (such as `[0:0]={1}') are not supported.
  26. func Array(a interface{}) interface {
  27. driver.Valuer
  28. sql.Scanner
  29. } {
  30. switch a := a.(type) {
  31. case []bool:
  32. return (*BoolArray)(&a)
  33. case []float64:
  34. return (*Float64Array)(&a)
  35. case []int64:
  36. return (*Int64Array)(&a)
  37. case []string:
  38. return (*StringArray)(&a)
  39. case *[]bool:
  40. return (*BoolArray)(a)
  41. case *[]float64:
  42. return (*Float64Array)(a)
  43. case *[]int64:
  44. return (*Int64Array)(a)
  45. case *[]string:
  46. return (*StringArray)(a)
  47. }
  48. return GenericArray{a}
  49. }
  50. // ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner
  51. // to override the array delimiter used by GenericArray.
  52. type ArrayDelimiter interface {
  53. // ArrayDelimiter returns the delimiter character(s) for this element's type.
  54. ArrayDelimiter() string
  55. }
  56. // BoolArray represents a one-dimensional array of the PostgreSQL boolean type.
  57. type BoolArray []bool
  58. // Scan implements the sql.Scanner interface.
  59. func (a *BoolArray) Scan(src interface{}) error {
  60. switch src := src.(type) {
  61. case []byte:
  62. return a.scanBytes(src)
  63. case string:
  64. return a.scanBytes([]byte(src))
  65. }
  66. return fmt.Errorf("pq: cannot convert %T to BoolArray", src)
  67. }
  68. func (a *BoolArray) scanBytes(src []byte) error {
  69. elems, err := scanLinearArray(src, []byte{','}, "BoolArray")
  70. if err != nil {
  71. return err
  72. }
  73. if len(elems) == 0 {
  74. *a = (*a)[:0]
  75. } else {
  76. b := make(BoolArray, len(elems))
  77. for i, v := range elems {
  78. if len(v) != 1 {
  79. return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v)
  80. }
  81. switch v[0] {
  82. case 't':
  83. b[i] = true
  84. case 'f':
  85. b[i] = false
  86. default:
  87. return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v)
  88. }
  89. }
  90. *a = b
  91. }
  92. return nil
  93. }
  94. // Value implements the driver.Valuer interface.
  95. func (a BoolArray) Value() (driver.Value, error) {
  96. if a == nil {
  97. return nil, nil
  98. }
  99. if n := len(a); n > 0 {
  100. // There will be exactly two curly brackets, N bytes of values,
  101. // and N-1 bytes of delimiters.
  102. b := make([]byte, 1+2*n)
  103. for i := 0; i < n; i++ {
  104. b[2*i] = ','
  105. if a[i] {
  106. b[1+2*i] = 't'
  107. } else {
  108. b[1+2*i] = 'f'
  109. }
  110. }
  111. b[0] = '{'
  112. b[2*n] = '}'
  113. return string(b), nil
  114. }
  115. return "{}", nil
  116. }
  117. // ByteaArray represents a one-dimensional array of the PostgreSQL bytea type.
  118. type ByteaArray [][]byte
  119. // Scan implements the sql.Scanner interface.
  120. func (a *ByteaArray) Scan(src interface{}) error {
  121. switch src := src.(type) {
  122. case []byte:
  123. return a.scanBytes(src)
  124. case string:
  125. return a.scanBytes([]byte(src))
  126. }
  127. return fmt.Errorf("pq: cannot convert %T to ByteaArray", src)
  128. }
  129. func (a *ByteaArray) scanBytes(src []byte) error {
  130. elems, err := scanLinearArray(src, []byte{','}, "ByteaArray")
  131. if err != nil {
  132. return err
  133. }
  134. if len(elems) == 0 {
  135. *a = (*a)[:0]
  136. } else {
  137. b := make(ByteaArray, len(elems))
  138. for i, v := range elems {
  139. b[i], err = parseBytea(v)
  140. if err != nil {
  141. return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error())
  142. }
  143. }
  144. *a = b
  145. }
  146. return nil
  147. }
  148. // Value implements the driver.Valuer interface. It uses the "hex" format which
  149. // is only supported on PostgreSQL 9.0 or newer.
  150. func (a ByteaArray) Value() (driver.Value, error) {
  151. if a == nil {
  152. return nil, nil
  153. }
  154. if n := len(a); n > 0 {
  155. // There will be at least two curly brackets, 2*N bytes of quotes,
  156. // 3*N bytes of hex formatting, and N-1 bytes of delimiters.
  157. size := 1 + 6*n
  158. for _, x := range a {
  159. size += hex.EncodedLen(len(x))
  160. }
  161. b := make([]byte, size)
  162. for i, s := 0, b; i < n; i++ {
  163. o := copy(s, `,"\\x`)
  164. o += hex.Encode(s[o:], a[i])
  165. s[o] = '"'
  166. s = s[o+1:]
  167. }
  168. b[0] = '{'
  169. b[size-1] = '}'
  170. return string(b), nil
  171. }
  172. return "{}", nil
  173. }
  174. // Float64Array represents a one-dimensional array of the PostgreSQL double
  175. // precision type.
  176. type Float64Array []float64
  177. // Scan implements the sql.Scanner interface.
  178. func (a *Float64Array) Scan(src interface{}) error {
  179. switch src := src.(type) {
  180. case []byte:
  181. return a.scanBytes(src)
  182. case string:
  183. return a.scanBytes([]byte(src))
  184. }
  185. return fmt.Errorf("pq: cannot convert %T to Float64Array", src)
  186. }
  187. func (a *Float64Array) scanBytes(src []byte) error {
  188. elems, err := scanLinearArray(src, []byte{','}, "Float64Array")
  189. if err != nil {
  190. return err
  191. }
  192. if len(elems) == 0 {
  193. *a = (*a)[:0]
  194. } else {
  195. b := make(Float64Array, len(elems))
  196. for i, v := range elems {
  197. if b[i], err = strconv.ParseFloat(string(v), 64); err != nil {
  198. return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
  199. }
  200. }
  201. *a = b
  202. }
  203. return nil
  204. }
  205. // Value implements the driver.Valuer interface.
  206. func (a Float64Array) Value() (driver.Value, error) {
  207. if a == nil {
  208. return nil, nil
  209. }
  210. if n := len(a); n > 0 {
  211. // There will be at least two curly brackets, N bytes of values,
  212. // and N-1 bytes of delimiters.
  213. b := make([]byte, 1, 1+2*n)
  214. b[0] = '{'
  215. b = strconv.AppendFloat(b, a[0], 'f', -1, 64)
  216. for i := 1; i < n; i++ {
  217. b = append(b, ',')
  218. b = strconv.AppendFloat(b, a[i], 'f', -1, 64)
  219. }
  220. return string(append(b, '}')), nil
  221. }
  222. return "{}", nil
  223. }
  224. // GenericArray implements the driver.Valuer and sql.Scanner interfaces for
  225. // an array or slice of any dimension.
  226. type GenericArray struct{ A interface{} }
  227. func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) {
  228. var assign func([]byte, reflect.Value) error
  229. var del = ","
  230. // TODO calculate the assign function for other types
  231. // TODO repeat this section on the element type of arrays or slices (multidimensional)
  232. {
  233. if reflect.PtrTo(rt).Implements(typeSqlScanner) {
  234. // dest is always addressable because it is an element of a slice.
  235. assign = func(src []byte, dest reflect.Value) (err error) {
  236. ss := dest.Addr().Interface().(sql.Scanner)
  237. if src == nil {
  238. err = ss.Scan(nil)
  239. } else {
  240. err = ss.Scan(src)
  241. }
  242. return
  243. }
  244. goto FoundType
  245. }
  246. assign = func([]byte, reflect.Value) error {
  247. return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt)
  248. }
  249. }
  250. FoundType:
  251. if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok {
  252. del = ad.ArrayDelimiter()
  253. }
  254. return rt, assign, del
  255. }
  256. // Scan implements the sql.Scanner interface.
  257. func (a GenericArray) Scan(src interface{}) error {
  258. dpv := reflect.ValueOf(a.A)
  259. switch {
  260. case dpv.Kind() != reflect.Ptr:
  261. return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A)
  262. case dpv.IsNil():
  263. return fmt.Errorf("pq: destination %T is nil", a.A)
  264. }
  265. dv := dpv.Elem()
  266. switch dv.Kind() {
  267. case reflect.Slice:
  268. case reflect.Array:
  269. default:
  270. return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A)
  271. }
  272. switch src := src.(type) {
  273. case []byte:
  274. return a.scanBytes(src, dv)
  275. case string:
  276. return a.scanBytes([]byte(src), dv)
  277. }
  278. return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type())
  279. }
  280. func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error {
  281. dtype, assign, del := a.evaluateDestination(dv.Type().Elem())
  282. dims, elems, err := parseArray(src, []byte(del))
  283. if err != nil {
  284. return err
  285. }
  286. // TODO allow multidimensional
  287. if len(dims) > 1 {
  288. return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented",
  289. strings.Replace(fmt.Sprint(dims), " ", "][", -1))
  290. }
  291. // Treat a zero-dimensional array like an array with a single dimension of zero.
  292. if len(dims) == 0 {
  293. dims = append(dims, 0)
  294. }
  295. for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() {
  296. switch rt.Kind() {
  297. case reflect.Slice:
  298. case reflect.Array:
  299. if rt.Len() != dims[i] {
  300. return fmt.Errorf("pq: cannot convert ARRAY%s to %s",
  301. strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type())
  302. }
  303. default:
  304. // TODO handle multidimensional
  305. }
  306. }
  307. values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems))
  308. for i, e := range elems {
  309. if err := assign(e, values.Index(i)); err != nil {
  310. return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
  311. }
  312. }
  313. // TODO handle multidimensional
  314. switch dv.Kind() {
  315. case reflect.Slice:
  316. dv.Set(values.Slice(0, dims[0]))
  317. case reflect.Array:
  318. for i := 0; i < dims[0]; i++ {
  319. dv.Index(i).Set(values.Index(i))
  320. }
  321. }
  322. return nil
  323. }
  324. // Value implements the driver.Valuer interface.
  325. func (a GenericArray) Value() (driver.Value, error) {
  326. if a.A == nil {
  327. return nil, nil
  328. }
  329. rv := reflect.ValueOf(a.A)
  330. if k := rv.Kind(); k != reflect.Array && k != reflect.Slice {
  331. return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A)
  332. }
  333. if n := rv.Len(); n > 0 {
  334. // There will be at least two curly brackets, N bytes of values,
  335. // and N-1 bytes of delimiters.
  336. b := make([]byte, 0, 1+2*n)
  337. b, _, err := appendArray(b, rv, n)
  338. return string(b), err
  339. }
  340. return "{}", nil
  341. }
  342. // Int64Array represents a one-dimensional array of the PostgreSQL integer types.
  343. type Int64Array []int64
  344. // Scan implements the sql.Scanner interface.
  345. func (a *Int64Array) Scan(src interface{}) error {
  346. switch src := src.(type) {
  347. case []byte:
  348. return a.scanBytes(src)
  349. case string:
  350. return a.scanBytes([]byte(src))
  351. }
  352. return fmt.Errorf("pq: cannot convert %T to Int64Array", src)
  353. }
  354. func (a *Int64Array) scanBytes(src []byte) error {
  355. elems, err := scanLinearArray(src, []byte{','}, "Int64Array")
  356. if err != nil {
  357. return err
  358. }
  359. if len(elems) == 0 {
  360. *a = (*a)[:0]
  361. } else {
  362. b := make(Int64Array, len(elems))
  363. for i, v := range elems {
  364. if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil {
  365. return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
  366. }
  367. }
  368. *a = b
  369. }
  370. return nil
  371. }
  372. // Value implements the driver.Valuer interface.
  373. func (a Int64Array) Value() (driver.Value, error) {
  374. if a == nil {
  375. return nil, nil
  376. }
  377. if n := len(a); n > 0 {
  378. // There will be at least two curly brackets, N bytes of values,
  379. // and N-1 bytes of delimiters.
  380. b := make([]byte, 1, 1+2*n)
  381. b[0] = '{'
  382. b = strconv.AppendInt(b, a[0], 10)
  383. for i := 1; i < n; i++ {
  384. b = append(b, ',')
  385. b = strconv.AppendInt(b, a[i], 10)
  386. }
  387. return string(append(b, '}')), nil
  388. }
  389. return "{}", nil
  390. }
  391. // StringArray represents a one-dimensional array of the PostgreSQL character types.
  392. type StringArray []string
  393. // Scan implements the sql.Scanner interface.
  394. func (a *StringArray) Scan(src interface{}) error {
  395. switch src := src.(type) {
  396. case []byte:
  397. return a.scanBytes(src)
  398. case string:
  399. return a.scanBytes([]byte(src))
  400. }
  401. return fmt.Errorf("pq: cannot convert %T to StringArray", src)
  402. }
  403. func (a *StringArray) scanBytes(src []byte) error {
  404. elems, err := scanLinearArray(src, []byte{','}, "StringArray")
  405. if err != nil {
  406. return err
  407. }
  408. if len(elems) == 0 {
  409. *a = (*a)[:0]
  410. } else {
  411. b := make(StringArray, len(elems))
  412. for i, v := range elems {
  413. if b[i] = string(v); v == nil {
  414. return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i)
  415. }
  416. }
  417. *a = b
  418. }
  419. return nil
  420. }
  421. // Value implements the driver.Valuer interface.
  422. func (a StringArray) Value() (driver.Value, error) {
  423. if a == nil {
  424. return nil, nil
  425. }
  426. if n := len(a); n > 0 {
  427. // There will be at least two curly brackets, 2*N bytes of quotes,
  428. // and N-1 bytes of delimiters.
  429. b := make([]byte, 1, 1+3*n)
  430. b[0] = '{'
  431. b = appendArrayQuotedBytes(b, []byte(a[0]))
  432. for i := 1; i < n; i++ {
  433. b = append(b, ',')
  434. b = appendArrayQuotedBytes(b, []byte(a[i]))
  435. }
  436. return string(append(b, '}')), nil
  437. }
  438. return "{}", nil
  439. }
  440. // appendArray appends rv to the buffer, returning the extended buffer and
  441. // the delimiter used between elements.
  442. //
  443. // It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice.
  444. func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) {
  445. var del string
  446. var err error
  447. b = append(b, '{')
  448. if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil {
  449. return b, del, err
  450. }
  451. for i := 1; i < n; i++ {
  452. b = append(b, del...)
  453. if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil {
  454. return b, del, err
  455. }
  456. }
  457. return append(b, '}'), del, nil
  458. }
  459. // appendArrayElement appends rv to the buffer, returning the extended buffer
  460. // and the delimiter to use before the next element.
  461. //
  462. // When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted
  463. // using driver.DefaultParameterConverter and the resulting []byte or string
  464. // is double-quoted.
  465. //
  466. // See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
  467. func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) {
  468. if k := rv.Kind(); k == reflect.Array || k == reflect.Slice {
  469. if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) {
  470. if n := rv.Len(); n > 0 {
  471. return appendArray(b, rv, n)
  472. }
  473. return b, "", nil
  474. }
  475. }
  476. var del string = ","
  477. var err error
  478. var iv interface{} = rv.Interface()
  479. if ad, ok := iv.(ArrayDelimiter); ok {
  480. del = ad.ArrayDelimiter()
  481. }
  482. if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil {
  483. return b, del, err
  484. }
  485. switch v := iv.(type) {
  486. case nil:
  487. return append(b, "NULL"...), del, nil
  488. case []byte:
  489. return appendArrayQuotedBytes(b, v), del, nil
  490. case string:
  491. return appendArrayQuotedBytes(b, []byte(v)), del, nil
  492. }
  493. b, err = appendValue(b, iv)
  494. return b, del, err
  495. }
  496. func appendArrayQuotedBytes(b, v []byte) []byte {
  497. b = append(b, '"')
  498. for {
  499. i := bytes.IndexAny(v, `"\`)
  500. if i < 0 {
  501. b = append(b, v...)
  502. break
  503. }
  504. if i > 0 {
  505. b = append(b, v[:i]...)
  506. }
  507. b = append(b, '\\', v[i])
  508. v = v[i+1:]
  509. }
  510. return append(b, '"')
  511. }
  512. func appendValue(b []byte, v driver.Value) ([]byte, error) {
  513. return append(b, encode(nil, v, 0)...), nil
  514. }
  515. // parseArray extracts the dimensions and elements of an array represented in
  516. // text format. Only representations emitted by the backend are supported.
  517. // Notably, whitespace around brackets and delimiters is significant, and NULL
  518. // is case-sensitive.
  519. //
  520. // See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
  521. func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) {
  522. var depth, i int
  523. if len(src) < 1 || src[0] != '{' {
  524. return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0)
  525. }
  526. Open:
  527. for i < len(src) {
  528. switch src[i] {
  529. case '{':
  530. depth++
  531. i++
  532. case '}':
  533. elems = make([][]byte, 0)
  534. goto Close
  535. default:
  536. break Open
  537. }
  538. }
  539. dims = make([]int, i)
  540. Element:
  541. for i < len(src) {
  542. switch src[i] {
  543. case '{':
  544. depth++
  545. dims[depth-1] = 0
  546. i++
  547. case '"':
  548. var elem = []byte{}
  549. var escape bool
  550. for i++; i < len(src); i++ {
  551. if escape {
  552. elem = append(elem, src[i])
  553. escape = false
  554. } else {
  555. switch src[i] {
  556. default:
  557. elem = append(elem, src[i])
  558. case '\\':
  559. escape = true
  560. case '"':
  561. elems = append(elems, elem)
  562. i++
  563. break Element
  564. }
  565. }
  566. }
  567. default:
  568. for start := i; i < len(src); i++ {
  569. if bytes.HasPrefix(src[i:], del) || src[i] == '}' {
  570. elem := src[start:i]
  571. if len(elem) == 0 {
  572. return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
  573. }
  574. if bytes.Equal(elem, []byte("NULL")) {
  575. elem = nil
  576. }
  577. elems = append(elems, elem)
  578. break Element
  579. }
  580. }
  581. }
  582. }
  583. for i < len(src) {
  584. if bytes.HasPrefix(src[i:], del) {
  585. dims[depth-1]++
  586. i += len(del)
  587. goto Element
  588. } else if src[i] == '}' {
  589. dims[depth-1]++
  590. depth--
  591. i++
  592. } else {
  593. return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
  594. }
  595. }
  596. Close:
  597. for i < len(src) {
  598. if src[i] == '}' && depth > 0 {
  599. depth--
  600. i++
  601. } else {
  602. return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
  603. }
  604. }
  605. if depth > 0 {
  606. err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i)
  607. }
  608. if err == nil {
  609. for _, d := range dims {
  610. if (len(elems) % d) != 0 {
  611. err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions")
  612. }
  613. }
  614. }
  615. return
  616. }
  617. func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) {
  618. dims, elems, err := parseArray(src, del)
  619. if err != nil {
  620. return nil, err
  621. }
  622. if len(dims) > 1 {
  623. return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ)
  624. }
  625. return elems, err
  626. }