You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

756 lines
17 KiB

  1. package pq
  2. import (
  3. "bytes"
  4. "database/sql"
  5. "database/sql/driver"
  6. "encoding/hex"
  7. "fmt"
  8. "reflect"
  9. "strconv"
  10. "strings"
  11. )
  12. var typeByteSlice = reflect.TypeOf([]byte{})
  13. var typeDriverValuer = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
  14. var typeSQLScanner = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
  15. // Array returns the optimal driver.Valuer and sql.Scanner for an array or
  16. // slice of any dimension.
  17. //
  18. // For example:
  19. // db.Query(`SELECT * FROM t WHERE id = ANY($1)`, pq.Array([]int{235, 401}))
  20. //
  21. // var x []sql.NullInt64
  22. // db.QueryRow('SELECT ARRAY[235, 401]').Scan(pq.Array(&x))
  23. //
  24. // Scanning multi-dimensional arrays is not supported. Arrays where the lower
  25. // bound is not one (such as `[0:0]={1}') are not supported.
  26. func Array(a interface{}) interface {
  27. driver.Valuer
  28. sql.Scanner
  29. } {
  30. switch a := a.(type) {
  31. case []bool:
  32. return (*BoolArray)(&a)
  33. case []float64:
  34. return (*Float64Array)(&a)
  35. case []int64:
  36. return (*Int64Array)(&a)
  37. case []string:
  38. return (*StringArray)(&a)
  39. case *[]bool:
  40. return (*BoolArray)(a)
  41. case *[]float64:
  42. return (*Float64Array)(a)
  43. case *[]int64:
  44. return (*Int64Array)(a)
  45. case *[]string:
  46. return (*StringArray)(a)
  47. }
  48. return GenericArray{a}
  49. }
  50. // ArrayDelimiter may be optionally implemented by driver.Valuer or sql.Scanner
  51. // to override the array delimiter used by GenericArray.
  52. type ArrayDelimiter interface {
  53. // ArrayDelimiter returns the delimiter character(s) for this element's type.
  54. ArrayDelimiter() string
  55. }
  56. // BoolArray represents a one-dimensional array of the PostgreSQL boolean type.
  57. type BoolArray []bool
  58. // Scan implements the sql.Scanner interface.
  59. func (a *BoolArray) Scan(src interface{}) error {
  60. switch src := src.(type) {
  61. case []byte:
  62. return a.scanBytes(src)
  63. case string:
  64. return a.scanBytes([]byte(src))
  65. case nil:
  66. *a = nil
  67. return nil
  68. }
  69. return fmt.Errorf("pq: cannot convert %T to BoolArray", src)
  70. }
  71. func (a *BoolArray) scanBytes(src []byte) error {
  72. elems, err := scanLinearArray(src, []byte{','}, "BoolArray")
  73. if err != nil {
  74. return err
  75. }
  76. if *a != nil && len(elems) == 0 {
  77. *a = (*a)[:0]
  78. } else {
  79. b := make(BoolArray, len(elems))
  80. for i, v := range elems {
  81. if len(v) != 1 {
  82. return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v)
  83. }
  84. switch v[0] {
  85. case 't':
  86. b[i] = true
  87. case 'f':
  88. b[i] = false
  89. default:
  90. return fmt.Errorf("pq: could not parse boolean array index %d: invalid boolean %q", i, v)
  91. }
  92. }
  93. *a = b
  94. }
  95. return nil
  96. }
  97. // Value implements the driver.Valuer interface.
  98. func (a BoolArray) Value() (driver.Value, error) {
  99. if a == nil {
  100. return nil, nil
  101. }
  102. if n := len(a); n > 0 {
  103. // There will be exactly two curly brackets, N bytes of values,
  104. // and N-1 bytes of delimiters.
  105. b := make([]byte, 1+2*n)
  106. for i := 0; i < n; i++ {
  107. b[2*i] = ','
  108. if a[i] {
  109. b[1+2*i] = 't'
  110. } else {
  111. b[1+2*i] = 'f'
  112. }
  113. }
  114. b[0] = '{'
  115. b[2*n] = '}'
  116. return string(b), nil
  117. }
  118. return "{}", nil
  119. }
  120. // ByteaArray represents a one-dimensional array of the PostgreSQL bytea type.
  121. type ByteaArray [][]byte
  122. // Scan implements the sql.Scanner interface.
  123. func (a *ByteaArray) Scan(src interface{}) error {
  124. switch src := src.(type) {
  125. case []byte:
  126. return a.scanBytes(src)
  127. case string:
  128. return a.scanBytes([]byte(src))
  129. case nil:
  130. *a = nil
  131. return nil
  132. }
  133. return fmt.Errorf("pq: cannot convert %T to ByteaArray", src)
  134. }
  135. func (a *ByteaArray) scanBytes(src []byte) error {
  136. elems, err := scanLinearArray(src, []byte{','}, "ByteaArray")
  137. if err != nil {
  138. return err
  139. }
  140. if *a != nil && len(elems) == 0 {
  141. *a = (*a)[:0]
  142. } else {
  143. b := make(ByteaArray, len(elems))
  144. for i, v := range elems {
  145. b[i], err = parseBytea(v)
  146. if err != nil {
  147. return fmt.Errorf("could not parse bytea array index %d: %s", i, err.Error())
  148. }
  149. }
  150. *a = b
  151. }
  152. return nil
  153. }
  154. // Value implements the driver.Valuer interface. It uses the "hex" format which
  155. // is only supported on PostgreSQL 9.0 or newer.
  156. func (a ByteaArray) Value() (driver.Value, error) {
  157. if a == nil {
  158. return nil, nil
  159. }
  160. if n := len(a); n > 0 {
  161. // There will be at least two curly brackets, 2*N bytes of quotes,
  162. // 3*N bytes of hex formatting, and N-1 bytes of delimiters.
  163. size := 1 + 6*n
  164. for _, x := range a {
  165. size += hex.EncodedLen(len(x))
  166. }
  167. b := make([]byte, size)
  168. for i, s := 0, b; i < n; i++ {
  169. o := copy(s, `,"\\x`)
  170. o += hex.Encode(s[o:], a[i])
  171. s[o] = '"'
  172. s = s[o+1:]
  173. }
  174. b[0] = '{'
  175. b[size-1] = '}'
  176. return string(b), nil
  177. }
  178. return "{}", nil
  179. }
  180. // Float64Array represents a one-dimensional array of the PostgreSQL double
  181. // precision type.
  182. type Float64Array []float64
  183. // Scan implements the sql.Scanner interface.
  184. func (a *Float64Array) Scan(src interface{}) error {
  185. switch src := src.(type) {
  186. case []byte:
  187. return a.scanBytes(src)
  188. case string:
  189. return a.scanBytes([]byte(src))
  190. case nil:
  191. *a = nil
  192. return nil
  193. }
  194. return fmt.Errorf("pq: cannot convert %T to Float64Array", src)
  195. }
  196. func (a *Float64Array) scanBytes(src []byte) error {
  197. elems, err := scanLinearArray(src, []byte{','}, "Float64Array")
  198. if err != nil {
  199. return err
  200. }
  201. if *a != nil && len(elems) == 0 {
  202. *a = (*a)[:0]
  203. } else {
  204. b := make(Float64Array, len(elems))
  205. for i, v := range elems {
  206. if b[i], err = strconv.ParseFloat(string(v), 64); err != nil {
  207. return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
  208. }
  209. }
  210. *a = b
  211. }
  212. return nil
  213. }
  214. // Value implements the driver.Valuer interface.
  215. func (a Float64Array) Value() (driver.Value, error) {
  216. if a == nil {
  217. return nil, nil
  218. }
  219. if n := len(a); n > 0 {
  220. // There will be at least two curly brackets, N bytes of values,
  221. // and N-1 bytes of delimiters.
  222. b := make([]byte, 1, 1+2*n)
  223. b[0] = '{'
  224. b = strconv.AppendFloat(b, a[0], 'f', -1, 64)
  225. for i := 1; i < n; i++ {
  226. b = append(b, ',')
  227. b = strconv.AppendFloat(b, a[i], 'f', -1, 64)
  228. }
  229. return string(append(b, '}')), nil
  230. }
  231. return "{}", nil
  232. }
  233. // GenericArray implements the driver.Valuer and sql.Scanner interfaces for
  234. // an array or slice of any dimension.
  235. type GenericArray struct{ A interface{} }
  236. func (GenericArray) evaluateDestination(rt reflect.Type) (reflect.Type, func([]byte, reflect.Value) error, string) {
  237. var assign func([]byte, reflect.Value) error
  238. var del = ","
  239. // TODO calculate the assign function for other types
  240. // TODO repeat this section on the element type of arrays or slices (multidimensional)
  241. {
  242. if reflect.PtrTo(rt).Implements(typeSQLScanner) {
  243. // dest is always addressable because it is an element of a slice.
  244. assign = func(src []byte, dest reflect.Value) (err error) {
  245. ss := dest.Addr().Interface().(sql.Scanner)
  246. if src == nil {
  247. err = ss.Scan(nil)
  248. } else {
  249. err = ss.Scan(src)
  250. }
  251. return
  252. }
  253. goto FoundType
  254. }
  255. assign = func([]byte, reflect.Value) error {
  256. return fmt.Errorf("pq: scanning to %s is not implemented; only sql.Scanner", rt)
  257. }
  258. }
  259. FoundType:
  260. if ad, ok := reflect.Zero(rt).Interface().(ArrayDelimiter); ok {
  261. del = ad.ArrayDelimiter()
  262. }
  263. return rt, assign, del
  264. }
  265. // Scan implements the sql.Scanner interface.
  266. func (a GenericArray) Scan(src interface{}) error {
  267. dpv := reflect.ValueOf(a.A)
  268. switch {
  269. case dpv.Kind() != reflect.Ptr:
  270. return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A)
  271. case dpv.IsNil():
  272. return fmt.Errorf("pq: destination %T is nil", a.A)
  273. }
  274. dv := dpv.Elem()
  275. switch dv.Kind() {
  276. case reflect.Slice:
  277. case reflect.Array:
  278. default:
  279. return fmt.Errorf("pq: destination %T is not a pointer to array or slice", a.A)
  280. }
  281. switch src := src.(type) {
  282. case []byte:
  283. return a.scanBytes(src, dv)
  284. case string:
  285. return a.scanBytes([]byte(src), dv)
  286. case nil:
  287. if dv.Kind() == reflect.Slice {
  288. dv.Set(reflect.Zero(dv.Type()))
  289. return nil
  290. }
  291. }
  292. return fmt.Errorf("pq: cannot convert %T to %s", src, dv.Type())
  293. }
  294. func (a GenericArray) scanBytes(src []byte, dv reflect.Value) error {
  295. dtype, assign, del := a.evaluateDestination(dv.Type().Elem())
  296. dims, elems, err := parseArray(src, []byte(del))
  297. if err != nil {
  298. return err
  299. }
  300. // TODO allow multidimensional
  301. if len(dims) > 1 {
  302. return fmt.Errorf("pq: scanning from multidimensional ARRAY%s is not implemented",
  303. strings.Replace(fmt.Sprint(dims), " ", "][", -1))
  304. }
  305. // Treat a zero-dimensional array like an array with a single dimension of zero.
  306. if len(dims) == 0 {
  307. dims = append(dims, 0)
  308. }
  309. for i, rt := 0, dv.Type(); i < len(dims); i, rt = i+1, rt.Elem() {
  310. switch rt.Kind() {
  311. case reflect.Slice:
  312. case reflect.Array:
  313. if rt.Len() != dims[i] {
  314. return fmt.Errorf("pq: cannot convert ARRAY%s to %s",
  315. strings.Replace(fmt.Sprint(dims), " ", "][", -1), dv.Type())
  316. }
  317. default:
  318. // TODO handle multidimensional
  319. }
  320. }
  321. values := reflect.MakeSlice(reflect.SliceOf(dtype), len(elems), len(elems))
  322. for i, e := range elems {
  323. if err := assign(e, values.Index(i)); err != nil {
  324. return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
  325. }
  326. }
  327. // TODO handle multidimensional
  328. switch dv.Kind() {
  329. case reflect.Slice:
  330. dv.Set(values.Slice(0, dims[0]))
  331. case reflect.Array:
  332. for i := 0; i < dims[0]; i++ {
  333. dv.Index(i).Set(values.Index(i))
  334. }
  335. }
  336. return nil
  337. }
  338. // Value implements the driver.Valuer interface.
  339. func (a GenericArray) Value() (driver.Value, error) {
  340. if a.A == nil {
  341. return nil, nil
  342. }
  343. rv := reflect.ValueOf(a.A)
  344. switch rv.Kind() {
  345. case reflect.Slice:
  346. if rv.IsNil() {
  347. return nil, nil
  348. }
  349. case reflect.Array:
  350. default:
  351. return nil, fmt.Errorf("pq: Unable to convert %T to array", a.A)
  352. }
  353. if n := rv.Len(); n > 0 {
  354. // There will be at least two curly brackets, N bytes of values,
  355. // and N-1 bytes of delimiters.
  356. b := make([]byte, 0, 1+2*n)
  357. b, _, err := appendArray(b, rv, n)
  358. return string(b), err
  359. }
  360. return "{}", nil
  361. }
  362. // Int64Array represents a one-dimensional array of the PostgreSQL integer types.
  363. type Int64Array []int64
  364. // Scan implements the sql.Scanner interface.
  365. func (a *Int64Array) Scan(src interface{}) error {
  366. switch src := src.(type) {
  367. case []byte:
  368. return a.scanBytes(src)
  369. case string:
  370. return a.scanBytes([]byte(src))
  371. case nil:
  372. *a = nil
  373. return nil
  374. }
  375. return fmt.Errorf("pq: cannot convert %T to Int64Array", src)
  376. }
  377. func (a *Int64Array) scanBytes(src []byte) error {
  378. elems, err := scanLinearArray(src, []byte{','}, "Int64Array")
  379. if err != nil {
  380. return err
  381. }
  382. if *a != nil && len(elems) == 0 {
  383. *a = (*a)[:0]
  384. } else {
  385. b := make(Int64Array, len(elems))
  386. for i, v := range elems {
  387. if b[i], err = strconv.ParseInt(string(v), 10, 64); err != nil {
  388. return fmt.Errorf("pq: parsing array element index %d: %v", i, err)
  389. }
  390. }
  391. *a = b
  392. }
  393. return nil
  394. }
  395. // Value implements the driver.Valuer interface.
  396. func (a Int64Array) Value() (driver.Value, error) {
  397. if a == nil {
  398. return nil, nil
  399. }
  400. if n := len(a); n > 0 {
  401. // There will be at least two curly brackets, N bytes of values,
  402. // and N-1 bytes of delimiters.
  403. b := make([]byte, 1, 1+2*n)
  404. b[0] = '{'
  405. b = strconv.AppendInt(b, a[0], 10)
  406. for i := 1; i < n; i++ {
  407. b = append(b, ',')
  408. b = strconv.AppendInt(b, a[i], 10)
  409. }
  410. return string(append(b, '}')), nil
  411. }
  412. return "{}", nil
  413. }
  414. // StringArray represents a one-dimensional array of the PostgreSQL character types.
  415. type StringArray []string
  416. // Scan implements the sql.Scanner interface.
  417. func (a *StringArray) Scan(src interface{}) error {
  418. switch src := src.(type) {
  419. case []byte:
  420. return a.scanBytes(src)
  421. case string:
  422. return a.scanBytes([]byte(src))
  423. case nil:
  424. *a = nil
  425. return nil
  426. }
  427. return fmt.Errorf("pq: cannot convert %T to StringArray", src)
  428. }
  429. func (a *StringArray) scanBytes(src []byte) error {
  430. elems, err := scanLinearArray(src, []byte{','}, "StringArray")
  431. if err != nil {
  432. return err
  433. }
  434. if *a != nil && len(elems) == 0 {
  435. *a = (*a)[:0]
  436. } else {
  437. b := make(StringArray, len(elems))
  438. for i, v := range elems {
  439. if b[i] = string(v); v == nil {
  440. return fmt.Errorf("pq: parsing array element index %d: cannot convert nil to string", i)
  441. }
  442. }
  443. *a = b
  444. }
  445. return nil
  446. }
  447. // Value implements the driver.Valuer interface.
  448. func (a StringArray) Value() (driver.Value, error) {
  449. if a == nil {
  450. return nil, nil
  451. }
  452. if n := len(a); n > 0 {
  453. // There will be at least two curly brackets, 2*N bytes of quotes,
  454. // and N-1 bytes of delimiters.
  455. b := make([]byte, 1, 1+3*n)
  456. b[0] = '{'
  457. b = appendArrayQuotedBytes(b, []byte(a[0]))
  458. for i := 1; i < n; i++ {
  459. b = append(b, ',')
  460. b = appendArrayQuotedBytes(b, []byte(a[i]))
  461. }
  462. return string(append(b, '}')), nil
  463. }
  464. return "{}", nil
  465. }
  466. // appendArray appends rv to the buffer, returning the extended buffer and
  467. // the delimiter used between elements.
  468. //
  469. // It panics when n <= 0 or rv's Kind is not reflect.Array nor reflect.Slice.
  470. func appendArray(b []byte, rv reflect.Value, n int) ([]byte, string, error) {
  471. var del string
  472. var err error
  473. b = append(b, '{')
  474. if b, del, err = appendArrayElement(b, rv.Index(0)); err != nil {
  475. return b, del, err
  476. }
  477. for i := 1; i < n; i++ {
  478. b = append(b, del...)
  479. if b, del, err = appendArrayElement(b, rv.Index(i)); err != nil {
  480. return b, del, err
  481. }
  482. }
  483. return append(b, '}'), del, nil
  484. }
  485. // appendArrayElement appends rv to the buffer, returning the extended buffer
  486. // and the delimiter to use before the next element.
  487. //
  488. // When rv's Kind is neither reflect.Array nor reflect.Slice, it is converted
  489. // using driver.DefaultParameterConverter and the resulting []byte or string
  490. // is double-quoted.
  491. //
  492. // See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
  493. func appendArrayElement(b []byte, rv reflect.Value) ([]byte, string, error) {
  494. if k := rv.Kind(); k == reflect.Array || k == reflect.Slice {
  495. if t := rv.Type(); t != typeByteSlice && !t.Implements(typeDriverValuer) {
  496. if n := rv.Len(); n > 0 {
  497. return appendArray(b, rv, n)
  498. }
  499. return b, "", nil
  500. }
  501. }
  502. var del = ","
  503. var err error
  504. var iv interface{} = rv.Interface()
  505. if ad, ok := iv.(ArrayDelimiter); ok {
  506. del = ad.ArrayDelimiter()
  507. }
  508. if iv, err = driver.DefaultParameterConverter.ConvertValue(iv); err != nil {
  509. return b, del, err
  510. }
  511. switch v := iv.(type) {
  512. case nil:
  513. return append(b, "NULL"...), del, nil
  514. case []byte:
  515. return appendArrayQuotedBytes(b, v), del, nil
  516. case string:
  517. return appendArrayQuotedBytes(b, []byte(v)), del, nil
  518. }
  519. b, err = appendValue(b, iv)
  520. return b, del, err
  521. }
  522. func appendArrayQuotedBytes(b, v []byte) []byte {
  523. b = append(b, '"')
  524. for {
  525. i := bytes.IndexAny(v, `"\`)
  526. if i < 0 {
  527. b = append(b, v...)
  528. break
  529. }
  530. if i > 0 {
  531. b = append(b, v[:i]...)
  532. }
  533. b = append(b, '\\', v[i])
  534. v = v[i+1:]
  535. }
  536. return append(b, '"')
  537. }
  538. func appendValue(b []byte, v driver.Value) ([]byte, error) {
  539. return append(b, encode(nil, v, 0)...), nil
  540. }
  541. // parseArray extracts the dimensions and elements of an array represented in
  542. // text format. Only representations emitted by the backend are supported.
  543. // Notably, whitespace around brackets and delimiters is significant, and NULL
  544. // is case-sensitive.
  545. //
  546. // See http://www.postgresql.org/docs/current/static/arrays.html#ARRAYS-IO
  547. func parseArray(src, del []byte) (dims []int, elems [][]byte, err error) {
  548. var depth, i int
  549. if len(src) < 1 || src[0] != '{' {
  550. return nil, nil, fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '{', 0)
  551. }
  552. Open:
  553. for i < len(src) {
  554. switch src[i] {
  555. case '{':
  556. depth++
  557. i++
  558. case '}':
  559. elems = make([][]byte, 0)
  560. goto Close
  561. default:
  562. break Open
  563. }
  564. }
  565. dims = make([]int, i)
  566. Element:
  567. for i < len(src) {
  568. switch src[i] {
  569. case '{':
  570. if depth == len(dims) {
  571. break Element
  572. }
  573. depth++
  574. dims[depth-1] = 0
  575. i++
  576. case '"':
  577. var elem = []byte{}
  578. var escape bool
  579. for i++; i < len(src); i++ {
  580. if escape {
  581. elem = append(elem, src[i])
  582. escape = false
  583. } else {
  584. switch src[i] {
  585. default:
  586. elem = append(elem, src[i])
  587. case '\\':
  588. escape = true
  589. case '"':
  590. elems = append(elems, elem)
  591. i++
  592. break Element
  593. }
  594. }
  595. }
  596. default:
  597. for start := i; i < len(src); i++ {
  598. if bytes.HasPrefix(src[i:], del) || src[i] == '}' {
  599. elem := src[start:i]
  600. if len(elem) == 0 {
  601. return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
  602. }
  603. if bytes.Equal(elem, []byte("NULL")) {
  604. elem = nil
  605. }
  606. elems = append(elems, elem)
  607. break Element
  608. }
  609. }
  610. }
  611. }
  612. for i < len(src) {
  613. if bytes.HasPrefix(src[i:], del) && depth > 0 {
  614. dims[depth-1]++
  615. i += len(del)
  616. goto Element
  617. } else if src[i] == '}' && depth > 0 {
  618. dims[depth-1]++
  619. depth--
  620. i++
  621. } else {
  622. return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
  623. }
  624. }
  625. Close:
  626. for i < len(src) {
  627. if src[i] == '}' && depth > 0 {
  628. depth--
  629. i++
  630. } else {
  631. return nil, nil, fmt.Errorf("pq: unable to parse array; unexpected %q at offset %d", src[i], i)
  632. }
  633. }
  634. if depth > 0 {
  635. err = fmt.Errorf("pq: unable to parse array; expected %q at offset %d", '}', i)
  636. }
  637. if err == nil {
  638. for _, d := range dims {
  639. if (len(elems) % d) != 0 {
  640. err = fmt.Errorf("pq: multidimensional arrays must have elements with matching dimensions")
  641. }
  642. }
  643. }
  644. return
  645. }
  646. func scanLinearArray(src, del []byte, typ string) (elems [][]byte, err error) {
  647. dims, elems, err := parseArray(src, del)
  648. if err != nil {
  649. return nil, err
  650. }
  651. if len(dims) > 1 {
  652. return nil, fmt.Errorf("pq: cannot convert ARRAY%s to %s", strings.Replace(fmt.Sprint(dims), " ", "][", -1), typ)
  653. }
  654. return elems, err
  655. }