You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1923 lines
45 KiB

  1. package pq
  2. import (
  3. "bufio"
  4. "context"
  5. "crypto/md5"
  6. "crypto/sha256"
  7. "database/sql"
  8. "database/sql/driver"
  9. "encoding/binary"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "net"
  14. "os"
  15. "os/user"
  16. "path"
  17. "path/filepath"
  18. "strconv"
  19. "strings"
  20. "time"
  21. "unicode"
  22. "github.com/lib/pq/oid"
  23. "github.com/lib/pq/scram"
  24. )
  25. // Common error types
  26. var (
  27. ErrNotSupported = errors.New("pq: Unsupported command")
  28. ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
  29. ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
  30. ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less")
  31. ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly")
  32. errUnexpectedReady = errors.New("unexpected ReadyForQuery")
  33. errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
  34. errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
  35. )
  36. // Driver is the Postgres database driver.
  37. type Driver struct{}
  38. // Open opens a new connection to the database. name is a connection string.
  39. // Most users should only use it through database/sql package from the standard
  40. // library.
  41. func (d *Driver) Open(name string) (driver.Conn, error) {
  42. return Open(name)
  43. }
  44. func init() {
  45. sql.Register("postgres", &Driver{})
  46. }
  47. type parameterStatus struct {
  48. // server version in the same format as server_version_num, or 0 if
  49. // unavailable
  50. serverVersion int
  51. // the current location based on the TimeZone value of the session, if
  52. // available
  53. currentLocation *time.Location
  54. }
  55. type transactionStatus byte
  56. const (
  57. txnStatusIdle transactionStatus = 'I'
  58. txnStatusIdleInTransaction transactionStatus = 'T'
  59. txnStatusInFailedTransaction transactionStatus = 'E'
  60. )
  61. func (s transactionStatus) String() string {
  62. switch s {
  63. case txnStatusIdle:
  64. return "idle"
  65. case txnStatusIdleInTransaction:
  66. return "idle in transaction"
  67. case txnStatusInFailedTransaction:
  68. return "in a failed transaction"
  69. default:
  70. errorf("unknown transactionStatus %d", s)
  71. }
  72. panic("not reached")
  73. }
  74. // Dialer is the dialer interface. It can be used to obtain more control over
  75. // how pq creates network connections.
  76. type Dialer interface {
  77. Dial(network, address string) (net.Conn, error)
  78. DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
  79. }
  80. // DialerContext is the context-aware dialer interface.
  81. type DialerContext interface {
  82. DialContext(ctx context.Context, network, address string) (net.Conn, error)
  83. }
  84. type defaultDialer struct {
  85. d net.Dialer
  86. }
  87. func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
  88. return d.d.Dial(network, address)
  89. }
  90. func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
  91. ctx, cancel := context.WithTimeout(context.Background(), timeout)
  92. defer cancel()
  93. return d.DialContext(ctx, network, address)
  94. }
  95. func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  96. return d.d.DialContext(ctx, network, address)
  97. }
  98. type conn struct {
  99. c net.Conn
  100. buf *bufio.Reader
  101. namei int
  102. scratch [512]byte
  103. txnStatus transactionStatus
  104. txnFinish func()
  105. // Save connection arguments to use during CancelRequest.
  106. dialer Dialer
  107. opts values
  108. // Cancellation key data for use with CancelRequest messages.
  109. processID int
  110. secretKey int
  111. parameterStatus parameterStatus
  112. saveMessageType byte
  113. saveMessageBuffer []byte
  114. // If true, this connection is bad and all public-facing functions should
  115. // return ErrBadConn.
  116. bad bool
  117. // If set, this connection should never use the binary format when
  118. // receiving query results from prepared statements. Only provided for
  119. // debugging.
  120. disablePreparedBinaryResult bool
  121. // Whether to always send []byte parameters over as binary. Enables single
  122. // round-trip mode for non-prepared Query calls.
  123. binaryParameters bool
  124. // If true this connection is in the middle of a COPY
  125. inCopy bool
  126. }
  127. // Handle driver-side settings in parsed connection string.
  128. func (cn *conn) handleDriverSettings(o values) (err error) {
  129. boolSetting := func(key string, val *bool) error {
  130. if value, ok := o[key]; ok {
  131. if value == "yes" {
  132. *val = true
  133. } else if value == "no" {
  134. *val = false
  135. } else {
  136. return fmt.Errorf("unrecognized value %q for %s", value, key)
  137. }
  138. }
  139. return nil
  140. }
  141. err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
  142. if err != nil {
  143. return err
  144. }
  145. return boolSetting("binary_parameters", &cn.binaryParameters)
  146. }
  147. func (cn *conn) handlePgpass(o values) {
  148. // if a password was supplied, do not process .pgpass
  149. if _, ok := o["password"]; ok {
  150. return
  151. }
  152. filename := os.Getenv("PGPASSFILE")
  153. if filename == "" {
  154. // XXX this code doesn't work on Windows where the default filename is
  155. // XXX %APPDATA%\postgresql\pgpass.conf
  156. // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470
  157. userHome := os.Getenv("HOME")
  158. if userHome == "" {
  159. user, err := user.Current()
  160. if err != nil {
  161. return
  162. }
  163. userHome = user.HomeDir
  164. }
  165. filename = filepath.Join(userHome, ".pgpass")
  166. }
  167. fileinfo, err := os.Stat(filename)
  168. if err != nil {
  169. return
  170. }
  171. mode := fileinfo.Mode()
  172. if mode&(0x77) != 0 {
  173. // XXX should warn about incorrect .pgpass permissions as psql does
  174. return
  175. }
  176. file, err := os.Open(filename)
  177. if err != nil {
  178. return
  179. }
  180. defer file.Close()
  181. scanner := bufio.NewScanner(io.Reader(file))
  182. hostname := o["host"]
  183. ntw, _ := network(o)
  184. port := o["port"]
  185. db := o["dbname"]
  186. username := o["user"]
  187. // From: https://github.com/tg/pgpass/blob/master/reader.go
  188. getFields := func(s string) []string {
  189. fs := make([]string, 0, 5)
  190. f := make([]rune, 0, len(s))
  191. var esc bool
  192. for _, c := range s {
  193. switch {
  194. case esc:
  195. f = append(f, c)
  196. esc = false
  197. case c == '\\':
  198. esc = true
  199. case c == ':':
  200. fs = append(fs, string(f))
  201. f = f[:0]
  202. default:
  203. f = append(f, c)
  204. }
  205. }
  206. return append(fs, string(f))
  207. }
  208. for scanner.Scan() {
  209. line := scanner.Text()
  210. if len(line) == 0 || line[0] == '#' {
  211. continue
  212. }
  213. split := getFields(line)
  214. if len(split) != 5 {
  215. continue
  216. }
  217. if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
  218. o["password"] = split[4]
  219. return
  220. }
  221. }
  222. }
  223. func (cn *conn) writeBuf(b byte) *writeBuf {
  224. cn.scratch[0] = b
  225. return &writeBuf{
  226. buf: cn.scratch[:5],
  227. pos: 1,
  228. }
  229. }
  230. // Open opens a new connection to the database. dsn is a connection string.
  231. // Most users should only use it through database/sql package from the standard
  232. // library.
  233. func Open(dsn string) (_ driver.Conn, err error) {
  234. return DialOpen(defaultDialer{}, dsn)
  235. }
  236. // DialOpen opens a new connection to the database using a dialer.
  237. func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
  238. c, err := NewConnector(dsn)
  239. if err != nil {
  240. return nil, err
  241. }
  242. c.dialer = d
  243. return c.open(context.Background())
  244. }
  245. func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
  246. // Handle any panics during connection initialization. Note that we
  247. // specifically do *not* want to use errRecover(), as that would turn any
  248. // connection errors into ErrBadConns, hiding the real error message from
  249. // the user.
  250. defer errRecoverNoErrBadConn(&err)
  251. o := c.opts
  252. cn = &conn{
  253. opts: o,
  254. dialer: c.dialer,
  255. }
  256. err = cn.handleDriverSettings(o)
  257. if err != nil {
  258. return nil, err
  259. }
  260. cn.handlePgpass(o)
  261. cn.c, err = dial(ctx, c.dialer, o)
  262. if err != nil {
  263. return nil, err
  264. }
  265. err = cn.ssl(o)
  266. if err != nil {
  267. if cn.c != nil {
  268. cn.c.Close()
  269. }
  270. return nil, err
  271. }
  272. // cn.startup panics on error. Make sure we don't leak cn.c.
  273. panicking := true
  274. defer func() {
  275. if panicking {
  276. cn.c.Close()
  277. }
  278. }()
  279. cn.buf = bufio.NewReader(cn.c)
  280. cn.startup(o)
  281. // reset the deadline, in case one was set (see dial)
  282. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  283. err = cn.c.SetDeadline(time.Time{})
  284. }
  285. panicking = false
  286. return cn, err
  287. }
  288. func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
  289. network, address := network(o)
  290. // SSL is not necessary or supported over UNIX domain sockets
  291. if network == "unix" {
  292. o["sslmode"] = "disable"
  293. }
  294. // Zero or not specified means wait indefinitely.
  295. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  296. seconds, err := strconv.ParseInt(timeout, 10, 0)
  297. if err != nil {
  298. return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
  299. }
  300. duration := time.Duration(seconds) * time.Second
  301. // connect_timeout should apply to the entire connection establishment
  302. // procedure, so we both use a timeout for the TCP connection
  303. // establishment and set a deadline for doing the initial handshake.
  304. // The deadline is then reset after startup() is done.
  305. deadline := time.Now().Add(duration)
  306. var conn net.Conn
  307. if dctx, ok := d.(DialerContext); ok {
  308. ctx, cancel := context.WithTimeout(ctx, duration)
  309. defer cancel()
  310. conn, err = dctx.DialContext(ctx, network, address)
  311. } else {
  312. conn, err = d.DialTimeout(network, address, duration)
  313. }
  314. if err != nil {
  315. return nil, err
  316. }
  317. err = conn.SetDeadline(deadline)
  318. return conn, err
  319. }
  320. if dctx, ok := d.(DialerContext); ok {
  321. return dctx.DialContext(ctx, network, address)
  322. }
  323. return d.Dial(network, address)
  324. }
  325. func network(o values) (string, string) {
  326. host := o["host"]
  327. if strings.HasPrefix(host, "/") {
  328. sockPath := path.Join(host, ".s.PGSQL."+o["port"])
  329. return "unix", sockPath
  330. }
  331. return "tcp", net.JoinHostPort(host, o["port"])
  332. }
  333. type values map[string]string
  334. // scanner implements a tokenizer for libpq-style option strings.
  335. type scanner struct {
  336. s []rune
  337. i int
  338. }
  339. // newScanner returns a new scanner initialized with the option string s.
  340. func newScanner(s string) *scanner {
  341. return &scanner{[]rune(s), 0}
  342. }
  343. // Next returns the next rune.
  344. // It returns 0, false if the end of the text has been reached.
  345. func (s *scanner) Next() (rune, bool) {
  346. if s.i >= len(s.s) {
  347. return 0, false
  348. }
  349. r := s.s[s.i]
  350. s.i++
  351. return r, true
  352. }
  353. // SkipSpaces returns the next non-whitespace rune.
  354. // It returns 0, false if the end of the text has been reached.
  355. func (s *scanner) SkipSpaces() (rune, bool) {
  356. r, ok := s.Next()
  357. for unicode.IsSpace(r) && ok {
  358. r, ok = s.Next()
  359. }
  360. return r, ok
  361. }
  362. // parseOpts parses the options from name and adds them to the values.
  363. //
  364. // The parsing code is based on conninfo_parse from libpq's fe-connect.c
  365. func parseOpts(name string, o values) error {
  366. s := newScanner(name)
  367. for {
  368. var (
  369. keyRunes, valRunes []rune
  370. r rune
  371. ok bool
  372. )
  373. if r, ok = s.SkipSpaces(); !ok {
  374. break
  375. }
  376. // Scan the key
  377. for !unicode.IsSpace(r) && r != '=' {
  378. keyRunes = append(keyRunes, r)
  379. if r, ok = s.Next(); !ok {
  380. break
  381. }
  382. }
  383. // Skip any whitespace if we're not at the = yet
  384. if r != '=' {
  385. r, ok = s.SkipSpaces()
  386. }
  387. // The current character should be =
  388. if r != '=' || !ok {
  389. return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
  390. }
  391. // Skip any whitespace after the =
  392. if r, ok = s.SkipSpaces(); !ok {
  393. // If we reach the end here, the last value is just an empty string as per libpq.
  394. o[string(keyRunes)] = ""
  395. break
  396. }
  397. if r != '\'' {
  398. for !unicode.IsSpace(r) {
  399. if r == '\\' {
  400. if r, ok = s.Next(); !ok {
  401. return fmt.Errorf(`missing character after backslash`)
  402. }
  403. }
  404. valRunes = append(valRunes, r)
  405. if r, ok = s.Next(); !ok {
  406. break
  407. }
  408. }
  409. } else {
  410. quote:
  411. for {
  412. if r, ok = s.Next(); !ok {
  413. return fmt.Errorf(`unterminated quoted string literal in connection string`)
  414. }
  415. switch r {
  416. case '\'':
  417. break quote
  418. case '\\':
  419. r, _ = s.Next()
  420. fallthrough
  421. default:
  422. valRunes = append(valRunes, r)
  423. }
  424. }
  425. }
  426. o[string(keyRunes)] = string(valRunes)
  427. }
  428. return nil
  429. }
  430. func (cn *conn) isInTransaction() bool {
  431. return cn.txnStatus == txnStatusIdleInTransaction ||
  432. cn.txnStatus == txnStatusInFailedTransaction
  433. }
  434. func (cn *conn) checkIsInTransaction(intxn bool) {
  435. if cn.isInTransaction() != intxn {
  436. cn.bad = true
  437. errorf("unexpected transaction status %v", cn.txnStatus)
  438. }
  439. }
  440. func (cn *conn) Begin() (_ driver.Tx, err error) {
  441. return cn.begin("")
  442. }
  443. func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
  444. if cn.bad {
  445. return nil, driver.ErrBadConn
  446. }
  447. defer cn.errRecover(&err)
  448. cn.checkIsInTransaction(false)
  449. _, commandTag, err := cn.simpleExec("BEGIN" + mode)
  450. if err != nil {
  451. return nil, err
  452. }
  453. if commandTag != "BEGIN" {
  454. cn.bad = true
  455. return nil, fmt.Errorf("unexpected command tag %s", commandTag)
  456. }
  457. if cn.txnStatus != txnStatusIdleInTransaction {
  458. cn.bad = true
  459. return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
  460. }
  461. return cn, nil
  462. }
  463. func (cn *conn) closeTxn() {
  464. if finish := cn.txnFinish; finish != nil {
  465. finish()
  466. }
  467. }
  468. func (cn *conn) Commit() (err error) {
  469. defer cn.closeTxn()
  470. if cn.bad {
  471. return driver.ErrBadConn
  472. }
  473. defer cn.errRecover(&err)
  474. cn.checkIsInTransaction(true)
  475. // We don't want the client to think that everything is okay if it tries
  476. // to commit a failed transaction. However, no matter what we return,
  477. // database/sql will release this connection back into the free connection
  478. // pool so we have to abort the current transaction here. Note that you
  479. // would get the same behaviour if you issued a COMMIT in a failed
  480. // transaction, so it's also the least surprising thing to do here.
  481. if cn.txnStatus == txnStatusInFailedTransaction {
  482. if err := cn.rollback(); err != nil {
  483. return err
  484. }
  485. return ErrInFailedTransaction
  486. }
  487. _, commandTag, err := cn.simpleExec("COMMIT")
  488. if err != nil {
  489. if cn.isInTransaction() {
  490. cn.bad = true
  491. }
  492. return err
  493. }
  494. if commandTag != "COMMIT" {
  495. cn.bad = true
  496. return fmt.Errorf("unexpected command tag %s", commandTag)
  497. }
  498. cn.checkIsInTransaction(false)
  499. return nil
  500. }
  501. func (cn *conn) Rollback() (err error) {
  502. defer cn.closeTxn()
  503. if cn.bad {
  504. return driver.ErrBadConn
  505. }
  506. defer cn.errRecover(&err)
  507. return cn.rollback()
  508. }
  509. func (cn *conn) rollback() (err error) {
  510. cn.checkIsInTransaction(true)
  511. _, commandTag, err := cn.simpleExec("ROLLBACK")
  512. if err != nil {
  513. if cn.isInTransaction() {
  514. cn.bad = true
  515. }
  516. return err
  517. }
  518. if commandTag != "ROLLBACK" {
  519. return fmt.Errorf("unexpected command tag %s", commandTag)
  520. }
  521. cn.checkIsInTransaction(false)
  522. return nil
  523. }
  524. func (cn *conn) gname() string {
  525. cn.namei++
  526. return strconv.FormatInt(int64(cn.namei), 10)
  527. }
  528. func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
  529. b := cn.writeBuf('Q')
  530. b.string(q)
  531. cn.send(b)
  532. for {
  533. t, r := cn.recv1()
  534. switch t {
  535. case 'C':
  536. res, commandTag = cn.parseComplete(r.string())
  537. case 'Z':
  538. cn.processReadyForQuery(r)
  539. if res == nil && err == nil {
  540. err = errUnexpectedReady
  541. }
  542. // done
  543. return
  544. case 'E':
  545. err = parseError(r)
  546. case 'I':
  547. res = emptyRows
  548. case 'T', 'D':
  549. // ignore any results
  550. default:
  551. cn.bad = true
  552. errorf("unknown response for simple query: %q", t)
  553. }
  554. }
  555. }
  556. func (cn *conn) simpleQuery(q string) (res *rows, err error) {
  557. defer cn.errRecover(&err)
  558. b := cn.writeBuf('Q')
  559. b.string(q)
  560. cn.send(b)
  561. for {
  562. t, r := cn.recv1()
  563. switch t {
  564. case 'C', 'I':
  565. // We allow queries which don't return any results through Query as
  566. // well as Exec. We still have to give database/sql a rows object
  567. // the user can close, though, to avoid connections from being
  568. // leaked. A "rows" with done=true works fine for that purpose.
  569. if err != nil {
  570. cn.bad = true
  571. errorf("unexpected message %q in simple query execution", t)
  572. }
  573. if res == nil {
  574. res = &rows{
  575. cn: cn,
  576. }
  577. }
  578. // Set the result and tag to the last command complete if there wasn't a
  579. // query already run. Although queries usually return from here and cede
  580. // control to Next, a query with zero results does not.
  581. if t == 'C' && res.colNames == nil {
  582. res.result, res.tag = cn.parseComplete(r.string())
  583. }
  584. res.done = true
  585. case 'Z':
  586. cn.processReadyForQuery(r)
  587. // done
  588. return
  589. case 'E':
  590. res = nil
  591. err = parseError(r)
  592. case 'D':
  593. if res == nil {
  594. cn.bad = true
  595. errorf("unexpected DataRow in simple query execution")
  596. }
  597. // the query didn't fail; kick off to Next
  598. cn.saveMessage(t, r)
  599. return
  600. case 'T':
  601. // res might be non-nil here if we received a previous
  602. // CommandComplete, but that's fine; just overwrite it
  603. res = &rows{cn: cn}
  604. res.rowsHeader = parsePortalRowDescribe(r)
  605. // To work around a bug in QueryRow in Go 1.2 and earlier, wait
  606. // until the first DataRow has been received.
  607. default:
  608. cn.bad = true
  609. errorf("unknown response for simple query: %q", t)
  610. }
  611. }
  612. }
  613. type noRows struct{}
  614. var emptyRows noRows
  615. var _ driver.Result = noRows{}
  616. func (noRows) LastInsertId() (int64, error) {
  617. return 0, errNoLastInsertID
  618. }
  619. func (noRows) RowsAffected() (int64, error) {
  620. return 0, errNoRowsAffected
  621. }
  622. // Decides which column formats to use for a prepared statement. The input is
  623. // an array of type oids, one element per result column.
  624. func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) {
  625. if len(colTyps) == 0 {
  626. return nil, colFmtDataAllText
  627. }
  628. colFmts = make([]format, len(colTyps))
  629. if forceText {
  630. return colFmts, colFmtDataAllText
  631. }
  632. allBinary := true
  633. allText := true
  634. for i, t := range colTyps {
  635. switch t.OID {
  636. // This is the list of types to use binary mode for when receiving them
  637. // through a prepared statement. If a type appears in this list, it
  638. // must also be implemented in binaryDecode in encode.go.
  639. case oid.T_bytea:
  640. fallthrough
  641. case oid.T_int8:
  642. fallthrough
  643. case oid.T_int4:
  644. fallthrough
  645. case oid.T_int2:
  646. fallthrough
  647. case oid.T_uuid:
  648. colFmts[i] = formatBinary
  649. allText = false
  650. default:
  651. allBinary = false
  652. }
  653. }
  654. if allBinary {
  655. return colFmts, colFmtDataAllBinary
  656. } else if allText {
  657. return colFmts, colFmtDataAllText
  658. } else {
  659. colFmtData = make([]byte, 2+len(colFmts)*2)
  660. binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
  661. for i, v := range colFmts {
  662. binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
  663. }
  664. return colFmts, colFmtData
  665. }
  666. }
  667. func (cn *conn) prepareTo(q, stmtName string) *stmt {
  668. st := &stmt{cn: cn, name: stmtName}
  669. b := cn.writeBuf('P')
  670. b.string(st.name)
  671. b.string(q)
  672. b.int16(0)
  673. b.next('D')
  674. b.byte('S')
  675. b.string(st.name)
  676. b.next('S')
  677. cn.send(b)
  678. cn.readParseResponse()
  679. st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
  680. st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
  681. cn.readReadyForQuery()
  682. return st
  683. }
  684. func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
  685. if cn.bad {
  686. return nil, driver.ErrBadConn
  687. }
  688. defer cn.errRecover(&err)
  689. if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
  690. s, err := cn.prepareCopyIn(q)
  691. if err == nil {
  692. cn.inCopy = true
  693. }
  694. return s, err
  695. }
  696. return cn.prepareTo(q, cn.gname()), nil
  697. }
  698. func (cn *conn) Close() (err error) {
  699. // Skip cn.bad return here because we always want to close a connection.
  700. defer cn.errRecover(&err)
  701. // Ensure that cn.c.Close is always run. Since error handling is done with
  702. // panics and cn.errRecover, the Close must be in a defer.
  703. defer func() {
  704. cerr := cn.c.Close()
  705. if err == nil {
  706. err = cerr
  707. }
  708. }()
  709. // Don't go through send(); ListenerConn relies on us not scribbling on the
  710. // scratch buffer of this connection.
  711. return cn.sendSimpleMessage('X')
  712. }
  713. // Implement the "Queryer" interface
  714. func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
  715. return cn.query(query, args)
  716. }
  717. func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
  718. if cn.bad {
  719. return nil, driver.ErrBadConn
  720. }
  721. if cn.inCopy {
  722. return nil, errCopyInProgress
  723. }
  724. defer cn.errRecover(&err)
  725. // Check to see if we can use the "simpleQuery" interface, which is
  726. // *much* faster than going through prepare/exec
  727. if len(args) == 0 {
  728. return cn.simpleQuery(query)
  729. }
  730. if cn.binaryParameters {
  731. cn.sendBinaryModeQuery(query, args)
  732. cn.readParseResponse()
  733. cn.readBindResponse()
  734. rows := &rows{cn: cn}
  735. rows.rowsHeader = cn.readPortalDescribeResponse()
  736. cn.postExecuteWorkaround()
  737. return rows, nil
  738. }
  739. st := cn.prepareTo(query, "")
  740. st.exec(args)
  741. return &rows{
  742. cn: cn,
  743. rowsHeader: st.rowsHeader,
  744. }, nil
  745. }
  746. // Implement the optional "Execer" interface for one-shot queries
  747. func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
  748. if cn.bad {
  749. return nil, driver.ErrBadConn
  750. }
  751. defer cn.errRecover(&err)
  752. // Check to see if we can use the "simpleExec" interface, which is
  753. // *much* faster than going through prepare/exec
  754. if len(args) == 0 {
  755. // ignore commandTag, our caller doesn't care
  756. r, _, err := cn.simpleExec(query)
  757. return r, err
  758. }
  759. if cn.binaryParameters {
  760. cn.sendBinaryModeQuery(query, args)
  761. cn.readParseResponse()
  762. cn.readBindResponse()
  763. cn.readPortalDescribeResponse()
  764. cn.postExecuteWorkaround()
  765. res, _, err = cn.readExecuteResponse("Execute")
  766. return res, err
  767. }
  768. // Use the unnamed statement to defer planning until bind
  769. // time, or else value-based selectivity estimates cannot be
  770. // used.
  771. st := cn.prepareTo(query, "")
  772. r, err := st.Exec(args)
  773. if err != nil {
  774. panic(err)
  775. }
  776. return r, err
  777. }
  778. func (cn *conn) send(m *writeBuf) {
  779. _, err := cn.c.Write(m.wrap())
  780. if err != nil {
  781. panic(err)
  782. }
  783. }
  784. func (cn *conn) sendStartupPacket(m *writeBuf) error {
  785. _, err := cn.c.Write((m.wrap())[1:])
  786. return err
  787. }
  788. // Send a message of type typ to the server on the other end of cn. The
  789. // message should have no payload. This method does not use the scratch
  790. // buffer.
  791. func (cn *conn) sendSimpleMessage(typ byte) (err error) {
  792. _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
  793. return err
  794. }
  795. // saveMessage memorizes a message and its buffer in the conn struct.
  796. // recvMessage will then return these values on the next call to it. This
  797. // method is useful in cases where you have to see what the next message is
  798. // going to be (e.g. to see whether it's an error or not) but you can't handle
  799. // the message yourself.
  800. func (cn *conn) saveMessage(typ byte, buf *readBuf) {
  801. if cn.saveMessageType != 0 {
  802. cn.bad = true
  803. errorf("unexpected saveMessageType %d", cn.saveMessageType)
  804. }
  805. cn.saveMessageType = typ
  806. cn.saveMessageBuffer = *buf
  807. }
  808. // recvMessage receives any message from the backend, or returns an error if
  809. // a problem occurred while reading the message.
  810. func (cn *conn) recvMessage(r *readBuf) (byte, error) {
  811. // workaround for a QueryRow bug, see exec
  812. if cn.saveMessageType != 0 {
  813. t := cn.saveMessageType
  814. *r = cn.saveMessageBuffer
  815. cn.saveMessageType = 0
  816. cn.saveMessageBuffer = nil
  817. return t, nil
  818. }
  819. x := cn.scratch[:5]
  820. _, err := io.ReadFull(cn.buf, x)
  821. if err != nil {
  822. return 0, err
  823. }
  824. // read the type and length of the message that follows
  825. t := x[0]
  826. n := int(binary.BigEndian.Uint32(x[1:])) - 4
  827. var y []byte
  828. if n <= len(cn.scratch) {
  829. y = cn.scratch[:n]
  830. } else {
  831. y = make([]byte, n)
  832. }
  833. _, err = io.ReadFull(cn.buf, y)
  834. if err != nil {
  835. return 0, err
  836. }
  837. *r = y
  838. return t, nil
  839. }
  840. // recv receives a message from the backend, but if an error happened while
  841. // reading the message or the received message was an ErrorResponse, it panics.
  842. // NoticeResponses are ignored. This function should generally be used only
  843. // during the startup sequence.
  844. func (cn *conn) recv() (t byte, r *readBuf) {
  845. for {
  846. var err error
  847. r = &readBuf{}
  848. t, err = cn.recvMessage(r)
  849. if err != nil {
  850. panic(err)
  851. }
  852. switch t {
  853. case 'E':
  854. panic(parseError(r))
  855. case 'N':
  856. // ignore
  857. default:
  858. return
  859. }
  860. }
  861. }
  862. // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
  863. // the caller to avoid an allocation.
  864. func (cn *conn) recv1Buf(r *readBuf) byte {
  865. for {
  866. t, err := cn.recvMessage(r)
  867. if err != nil {
  868. panic(err)
  869. }
  870. switch t {
  871. case 'A', 'N':
  872. // ignore
  873. case 'S':
  874. cn.processParameterStatus(r)
  875. default:
  876. return t
  877. }
  878. }
  879. }
  880. // recv1 receives a message from the backend, panicking if an error occurs
  881. // while attempting to read it. All asynchronous messages are ignored, with
  882. // the exception of ErrorResponse.
  883. func (cn *conn) recv1() (t byte, r *readBuf) {
  884. r = &readBuf{}
  885. t = cn.recv1Buf(r)
  886. return t, r
  887. }
  888. func (cn *conn) ssl(o values) error {
  889. upgrade, err := ssl(o)
  890. if err != nil {
  891. return err
  892. }
  893. if upgrade == nil {
  894. // Nothing to do
  895. return nil
  896. }
  897. w := cn.writeBuf(0)
  898. w.int32(80877103)
  899. if err = cn.sendStartupPacket(w); err != nil {
  900. return err
  901. }
  902. b := cn.scratch[:1]
  903. _, err = io.ReadFull(cn.c, b)
  904. if err != nil {
  905. return err
  906. }
  907. if b[0] != 'S' {
  908. return ErrSSLNotSupported
  909. }
  910. cn.c, err = upgrade(cn.c)
  911. return err
  912. }
  913. // isDriverSetting returns true iff a setting is purely for configuring the
  914. // driver's options and should not be sent to the server in the connection
  915. // startup packet.
  916. func isDriverSetting(key string) bool {
  917. switch key {
  918. case "host", "port":
  919. return true
  920. case "password":
  921. return true
  922. case "sslmode", "sslcert", "sslkey", "sslrootcert":
  923. return true
  924. case "fallback_application_name":
  925. return true
  926. case "connect_timeout":
  927. return true
  928. case "disable_prepared_binary_result":
  929. return true
  930. case "binary_parameters":
  931. return true
  932. default:
  933. return false
  934. }
  935. }
  936. func (cn *conn) startup(o values) {
  937. w := cn.writeBuf(0)
  938. w.int32(196608)
  939. // Send the backend the name of the database we want to connect to, and the
  940. // user we want to connect as. Additionally, we send over any run-time
  941. // parameters potentially included in the connection string. If the server
  942. // doesn't recognize any of them, it will reply with an error.
  943. for k, v := range o {
  944. if isDriverSetting(k) {
  945. // skip options which can't be run-time parameters
  946. continue
  947. }
  948. // The protocol requires us to supply the database name as "database"
  949. // instead of "dbname".
  950. if k == "dbname" {
  951. k = "database"
  952. }
  953. w.string(k)
  954. w.string(v)
  955. }
  956. w.string("")
  957. if err := cn.sendStartupPacket(w); err != nil {
  958. panic(err)
  959. }
  960. for {
  961. t, r := cn.recv()
  962. switch t {
  963. case 'K':
  964. cn.processBackendKeyData(r)
  965. case 'S':
  966. cn.processParameterStatus(r)
  967. case 'R':
  968. cn.auth(r, o)
  969. case 'Z':
  970. cn.processReadyForQuery(r)
  971. return
  972. default:
  973. errorf("unknown response for startup: %q", t)
  974. }
  975. }
  976. }
  977. func (cn *conn) auth(r *readBuf, o values) {
  978. switch code := r.int32(); code {
  979. case 0:
  980. // OK
  981. case 3:
  982. w := cn.writeBuf('p')
  983. w.string(o["password"])
  984. cn.send(w)
  985. t, r := cn.recv()
  986. if t != 'R' {
  987. errorf("unexpected password response: %q", t)
  988. }
  989. if r.int32() != 0 {
  990. errorf("unexpected authentication response: %q", t)
  991. }
  992. case 5:
  993. s := string(r.next(4))
  994. w := cn.writeBuf('p')
  995. w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
  996. cn.send(w)
  997. t, r := cn.recv()
  998. if t != 'R' {
  999. errorf("unexpected password response: %q", t)
  1000. }
  1001. if r.int32() != 0 {
  1002. errorf("unexpected authentication response: %q", t)
  1003. }
  1004. case 10:
  1005. sc := scram.NewClient(sha256.New, o["user"], o["password"])
  1006. sc.Step(nil)
  1007. if sc.Err() != nil {
  1008. errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
  1009. }
  1010. scOut := sc.Out()
  1011. w := cn.writeBuf('p')
  1012. w.string("SCRAM-SHA-256")
  1013. w.int32(len(scOut))
  1014. w.bytes(scOut)
  1015. cn.send(w)
  1016. t, r := cn.recv()
  1017. if t != 'R' {
  1018. errorf("unexpected password response: %q", t)
  1019. }
  1020. if r.int32() != 11 {
  1021. errorf("unexpected authentication response: %q", t)
  1022. }
  1023. nextStep := r.next(len(*r))
  1024. sc.Step(nextStep)
  1025. if sc.Err() != nil {
  1026. errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
  1027. }
  1028. scOut = sc.Out()
  1029. w = cn.writeBuf('p')
  1030. w.bytes(scOut)
  1031. cn.send(w)
  1032. t, r = cn.recv()
  1033. if t != 'R' {
  1034. errorf("unexpected password response: %q", t)
  1035. }
  1036. if r.int32() != 12 {
  1037. errorf("unexpected authentication response: %q", t)
  1038. }
  1039. nextStep = r.next(len(*r))
  1040. sc.Step(nextStep)
  1041. if sc.Err() != nil {
  1042. errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
  1043. }
  1044. default:
  1045. errorf("unknown authentication response: %d", code)
  1046. }
  1047. }
  1048. type format int
  1049. const formatText format = 0
  1050. const formatBinary format = 1
  1051. // One result-column format code with the value 1 (i.e. all binary).
  1052. var colFmtDataAllBinary = []byte{0, 1, 0, 1}
  1053. // No result-column format codes (i.e. all text).
  1054. var colFmtDataAllText = []byte{0, 0}
  1055. type stmt struct {
  1056. cn *conn
  1057. name string
  1058. rowsHeader
  1059. colFmtData []byte
  1060. paramTyps []oid.Oid
  1061. closed bool
  1062. }
  1063. func (st *stmt) Close() (err error) {
  1064. if st.closed {
  1065. return nil
  1066. }
  1067. if st.cn.bad {
  1068. return driver.ErrBadConn
  1069. }
  1070. defer st.cn.errRecover(&err)
  1071. w := st.cn.writeBuf('C')
  1072. w.byte('S')
  1073. w.string(st.name)
  1074. st.cn.send(w)
  1075. st.cn.send(st.cn.writeBuf('S'))
  1076. t, _ := st.cn.recv1()
  1077. if t != '3' {
  1078. st.cn.bad = true
  1079. errorf("unexpected close response: %q", t)
  1080. }
  1081. st.closed = true
  1082. t, r := st.cn.recv1()
  1083. if t != 'Z' {
  1084. st.cn.bad = true
  1085. errorf("expected ready for query, but got: %q", t)
  1086. }
  1087. st.cn.processReadyForQuery(r)
  1088. return nil
  1089. }
  1090. func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
  1091. if st.cn.bad {
  1092. return nil, driver.ErrBadConn
  1093. }
  1094. defer st.cn.errRecover(&err)
  1095. st.exec(v)
  1096. return &rows{
  1097. cn: st.cn,
  1098. rowsHeader: st.rowsHeader,
  1099. }, nil
  1100. }
  1101. func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
  1102. if st.cn.bad {
  1103. return nil, driver.ErrBadConn
  1104. }
  1105. defer st.cn.errRecover(&err)
  1106. st.exec(v)
  1107. res, _, err = st.cn.readExecuteResponse("simple query")
  1108. return res, err
  1109. }
  1110. func (st *stmt) exec(v []driver.Value) {
  1111. if len(v) >= 65536 {
  1112. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
  1113. }
  1114. if len(v) != len(st.paramTyps) {
  1115. errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
  1116. }
  1117. cn := st.cn
  1118. w := cn.writeBuf('B')
  1119. w.byte(0) // unnamed portal
  1120. w.string(st.name)
  1121. if cn.binaryParameters {
  1122. cn.sendBinaryParameters(w, v)
  1123. } else {
  1124. w.int16(0)
  1125. w.int16(len(v))
  1126. for i, x := range v {
  1127. if x == nil {
  1128. w.int32(-1)
  1129. } else {
  1130. b := encode(&cn.parameterStatus, x, st.paramTyps[i])
  1131. w.int32(len(b))
  1132. w.bytes(b)
  1133. }
  1134. }
  1135. }
  1136. w.bytes(st.colFmtData)
  1137. w.next('E')
  1138. w.byte(0)
  1139. w.int32(0)
  1140. w.next('S')
  1141. cn.send(w)
  1142. cn.readBindResponse()
  1143. cn.postExecuteWorkaround()
  1144. }
  1145. func (st *stmt) NumInput() int {
  1146. return len(st.paramTyps)
  1147. }
  1148. // parseComplete parses the "command tag" from a CommandComplete message, and
  1149. // returns the number of rows affected (if applicable) and a string
  1150. // identifying only the command that was executed, e.g. "ALTER TABLE". If the
  1151. // command tag could not be parsed, parseComplete panics.
  1152. func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
  1153. commandsWithAffectedRows := []string{
  1154. "SELECT ",
  1155. // INSERT is handled below
  1156. "UPDATE ",
  1157. "DELETE ",
  1158. "FETCH ",
  1159. "MOVE ",
  1160. "COPY ",
  1161. }
  1162. var affectedRows *string
  1163. for _, tag := range commandsWithAffectedRows {
  1164. if strings.HasPrefix(commandTag, tag) {
  1165. t := commandTag[len(tag):]
  1166. affectedRows = &t
  1167. commandTag = tag[:len(tag)-1]
  1168. break
  1169. }
  1170. }
  1171. // INSERT also includes the oid of the inserted row in its command tag.
  1172. // Oids in user tables are deprecated, and the oid is only returned when
  1173. // exactly one row is inserted, so it's unlikely to be of value to any
  1174. // real-world application and we can ignore it.
  1175. if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
  1176. parts := strings.Split(commandTag, " ")
  1177. if len(parts) != 3 {
  1178. cn.bad = true
  1179. errorf("unexpected INSERT command tag %s", commandTag)
  1180. }
  1181. affectedRows = &parts[len(parts)-1]
  1182. commandTag = "INSERT"
  1183. }
  1184. // There should be no affected rows attached to the tag, just return it
  1185. if affectedRows == nil {
  1186. return driver.RowsAffected(0), commandTag
  1187. }
  1188. n, err := strconv.ParseInt(*affectedRows, 10, 64)
  1189. if err != nil {
  1190. cn.bad = true
  1191. errorf("could not parse commandTag: %s", err)
  1192. }
  1193. return driver.RowsAffected(n), commandTag
  1194. }
  1195. type rowsHeader struct {
  1196. colNames []string
  1197. colTyps []fieldDesc
  1198. colFmts []format
  1199. }
  1200. type rows struct {
  1201. cn *conn
  1202. finish func()
  1203. rowsHeader
  1204. done bool
  1205. rb readBuf
  1206. result driver.Result
  1207. tag string
  1208. next *rowsHeader
  1209. }
  1210. func (rs *rows) Close() error {
  1211. if finish := rs.finish; finish != nil {
  1212. defer finish()
  1213. }
  1214. // no need to look at cn.bad as Next() will
  1215. for {
  1216. err := rs.Next(nil)
  1217. switch err {
  1218. case nil:
  1219. case io.EOF:
  1220. // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row
  1221. // description, used with HasNextResultSet). We need to fetch messages until
  1222. // we hit a 'Z', which is done by waiting for done to be set.
  1223. if rs.done {
  1224. return nil
  1225. }
  1226. default:
  1227. return err
  1228. }
  1229. }
  1230. }
  1231. func (rs *rows) Columns() []string {
  1232. return rs.colNames
  1233. }
  1234. func (rs *rows) Result() driver.Result {
  1235. if rs.result == nil {
  1236. return emptyRows
  1237. }
  1238. return rs.result
  1239. }
  1240. func (rs *rows) Tag() string {
  1241. return rs.tag
  1242. }
  1243. func (rs *rows) Next(dest []driver.Value) (err error) {
  1244. if rs.done {
  1245. return io.EOF
  1246. }
  1247. conn := rs.cn
  1248. if conn.bad {
  1249. return driver.ErrBadConn
  1250. }
  1251. defer conn.errRecover(&err)
  1252. for {
  1253. t := conn.recv1Buf(&rs.rb)
  1254. switch t {
  1255. case 'E':
  1256. err = parseError(&rs.rb)
  1257. case 'C', 'I':
  1258. if t == 'C' {
  1259. rs.result, rs.tag = conn.parseComplete(rs.rb.string())
  1260. }
  1261. continue
  1262. case 'Z':
  1263. conn.processReadyForQuery(&rs.rb)
  1264. rs.done = true
  1265. if err != nil {
  1266. return err
  1267. }
  1268. return io.EOF
  1269. case 'D':
  1270. n := rs.rb.int16()
  1271. if err != nil {
  1272. conn.bad = true
  1273. errorf("unexpected DataRow after error %s", err)
  1274. }
  1275. if n < len(dest) {
  1276. dest = dest[:n]
  1277. }
  1278. for i := range dest {
  1279. l := rs.rb.int32()
  1280. if l == -1 {
  1281. dest[i] = nil
  1282. continue
  1283. }
  1284. dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i])
  1285. }
  1286. return
  1287. case 'T':
  1288. next := parsePortalRowDescribe(&rs.rb)
  1289. rs.next = &next
  1290. return io.EOF
  1291. default:
  1292. errorf("unexpected message after execute: %q", t)
  1293. }
  1294. }
  1295. }
  1296. func (rs *rows) HasNextResultSet() bool {
  1297. hasNext := rs.next != nil && !rs.done
  1298. return hasNext
  1299. }
  1300. func (rs *rows) NextResultSet() error {
  1301. if rs.next == nil {
  1302. return io.EOF
  1303. }
  1304. rs.rowsHeader = *rs.next
  1305. rs.next = nil
  1306. return nil
  1307. }
  1308. // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
  1309. // used as part of an SQL statement. For example:
  1310. //
  1311. // tblname := "my_table"
  1312. // data := "my_data"
  1313. // quoted := pq.QuoteIdentifier(tblname)
  1314. // err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
  1315. //
  1316. // Any double quotes in name will be escaped. The quoted identifier will be
  1317. // case sensitive when used in a query. If the input string contains a zero
  1318. // byte, the result will be truncated immediately before it.
  1319. func QuoteIdentifier(name string) string {
  1320. end := strings.IndexRune(name, 0)
  1321. if end > -1 {
  1322. name = name[:end]
  1323. }
  1324. return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
  1325. }
  1326. // QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
  1327. // to DDL and other statements that do not accept parameters) to be used as part
  1328. // of an SQL statement. For example:
  1329. //
  1330. // exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
  1331. // err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
  1332. //
  1333. // Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
  1334. // replaced by two backslashes (i.e. "\\") and the C-style escape identifier
  1335. // that PostgreSQL provides ('E') will be prepended to the string.
  1336. func QuoteLiteral(literal string) string {
  1337. // This follows the PostgreSQL internal algorithm for handling quoted literals
  1338. // from libpq, which can be found in the "PQEscapeStringInternal" function,
  1339. // which is found in the libpq/fe-exec.c source file:
  1340. // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c
  1341. //
  1342. // substitute any single-quotes (') with two single-quotes ('')
  1343. literal = strings.Replace(literal, `'`, `''`, -1)
  1344. // determine if the string has any backslashes (\) in it.
  1345. // if it does, replace any backslashes (\) with two backslashes (\\)
  1346. // then, we need to wrap the entire string with a PostgreSQL
  1347. // C-style escape. Per how "PQEscapeStringInternal" handles this case, we
  1348. // also add a space before the "E"
  1349. if strings.Contains(literal, `\`) {
  1350. literal = strings.Replace(literal, `\`, `\\`, -1)
  1351. literal = ` E'` + literal + `'`
  1352. } else {
  1353. // otherwise, we can just wrap the literal with a pair of single quotes
  1354. literal = `'` + literal + `'`
  1355. }
  1356. return literal
  1357. }
  1358. func md5s(s string) string {
  1359. h := md5.New()
  1360. h.Write([]byte(s))
  1361. return fmt.Sprintf("%x", h.Sum(nil))
  1362. }
  1363. func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
  1364. // Do one pass over the parameters to see if we're going to send any of
  1365. // them over in binary. If we are, create a paramFormats array at the
  1366. // same time.
  1367. var paramFormats []int
  1368. for i, x := range args {
  1369. _, ok := x.([]byte)
  1370. if ok {
  1371. if paramFormats == nil {
  1372. paramFormats = make([]int, len(args))
  1373. }
  1374. paramFormats[i] = 1
  1375. }
  1376. }
  1377. if paramFormats == nil {
  1378. b.int16(0)
  1379. } else {
  1380. b.int16(len(paramFormats))
  1381. for _, x := range paramFormats {
  1382. b.int16(x)
  1383. }
  1384. }
  1385. b.int16(len(args))
  1386. for _, x := range args {
  1387. if x == nil {
  1388. b.int32(-1)
  1389. } else {
  1390. datum := binaryEncode(&cn.parameterStatus, x)
  1391. b.int32(len(datum))
  1392. b.bytes(datum)
  1393. }
  1394. }
  1395. }
  1396. func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
  1397. if len(args) >= 65536 {
  1398. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
  1399. }
  1400. b := cn.writeBuf('P')
  1401. b.byte(0) // unnamed statement
  1402. b.string(query)
  1403. b.int16(0)
  1404. b.next('B')
  1405. b.int16(0) // unnamed portal and statement
  1406. cn.sendBinaryParameters(b, args)
  1407. b.bytes(colFmtDataAllText)
  1408. b.next('D')
  1409. b.byte('P')
  1410. b.byte(0) // unnamed portal
  1411. b.next('E')
  1412. b.byte(0)
  1413. b.int32(0)
  1414. b.next('S')
  1415. cn.send(b)
  1416. }
  1417. func (cn *conn) processParameterStatus(r *readBuf) {
  1418. var err error
  1419. param := r.string()
  1420. switch param {
  1421. case "server_version":
  1422. var major1 int
  1423. var major2 int
  1424. var minor int
  1425. _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor)
  1426. if err == nil {
  1427. cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
  1428. }
  1429. case "TimeZone":
  1430. cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
  1431. if err != nil {
  1432. cn.parameterStatus.currentLocation = nil
  1433. }
  1434. default:
  1435. // ignore
  1436. }
  1437. }
  1438. func (cn *conn) processReadyForQuery(r *readBuf) {
  1439. cn.txnStatus = transactionStatus(r.byte())
  1440. }
  1441. func (cn *conn) readReadyForQuery() {
  1442. t, r := cn.recv1()
  1443. switch t {
  1444. case 'Z':
  1445. cn.processReadyForQuery(r)
  1446. return
  1447. default:
  1448. cn.bad = true
  1449. errorf("unexpected message %q; expected ReadyForQuery", t)
  1450. }
  1451. }
  1452. func (cn *conn) processBackendKeyData(r *readBuf) {
  1453. cn.processID = r.int32()
  1454. cn.secretKey = r.int32()
  1455. }
  1456. func (cn *conn) readParseResponse() {
  1457. t, r := cn.recv1()
  1458. switch t {
  1459. case '1':
  1460. return
  1461. case 'E':
  1462. err := parseError(r)
  1463. cn.readReadyForQuery()
  1464. panic(err)
  1465. default:
  1466. cn.bad = true
  1467. errorf("unexpected Parse response %q", t)
  1468. }
  1469. }
  1470. func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) {
  1471. for {
  1472. t, r := cn.recv1()
  1473. switch t {
  1474. case 't':
  1475. nparams := r.int16()
  1476. paramTyps = make([]oid.Oid, nparams)
  1477. for i := range paramTyps {
  1478. paramTyps[i] = r.oid()
  1479. }
  1480. case 'n':
  1481. return paramTyps, nil, nil
  1482. case 'T':
  1483. colNames, colTyps = parseStatementRowDescribe(r)
  1484. return paramTyps, colNames, colTyps
  1485. case 'E':
  1486. err := parseError(r)
  1487. cn.readReadyForQuery()
  1488. panic(err)
  1489. default:
  1490. cn.bad = true
  1491. errorf("unexpected Describe statement response %q", t)
  1492. }
  1493. }
  1494. }
  1495. func (cn *conn) readPortalDescribeResponse() rowsHeader {
  1496. t, r := cn.recv1()
  1497. switch t {
  1498. case 'T':
  1499. return parsePortalRowDescribe(r)
  1500. case 'n':
  1501. return rowsHeader{}
  1502. case 'E':
  1503. err := parseError(r)
  1504. cn.readReadyForQuery()
  1505. panic(err)
  1506. default:
  1507. cn.bad = true
  1508. errorf("unexpected Describe response %q", t)
  1509. }
  1510. panic("not reached")
  1511. }
  1512. func (cn *conn) readBindResponse() {
  1513. t, r := cn.recv1()
  1514. switch t {
  1515. case '2':
  1516. return
  1517. case 'E':
  1518. err := parseError(r)
  1519. cn.readReadyForQuery()
  1520. panic(err)
  1521. default:
  1522. cn.bad = true
  1523. errorf("unexpected Bind response %q", t)
  1524. }
  1525. }
  1526. func (cn *conn) postExecuteWorkaround() {
  1527. // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
  1528. // any errors from rows.Next, which masks errors that happened during the
  1529. // execution of the query. To avoid the problem in common cases, we wait
  1530. // here for one more message from the database. If it's not an error the
  1531. // query will likely succeed (or perhaps has already, if it's a
  1532. // CommandComplete), so we push the message into the conn struct; recv1
  1533. // will return it as the next message for rows.Next or rows.Close.
  1534. // However, if it's an error, we wait until ReadyForQuery and then return
  1535. // the error to our caller.
  1536. for {
  1537. t, r := cn.recv1()
  1538. switch t {
  1539. case 'E':
  1540. err := parseError(r)
  1541. cn.readReadyForQuery()
  1542. panic(err)
  1543. case 'C', 'D', 'I':
  1544. // the query didn't fail, but we can't process this message
  1545. cn.saveMessage(t, r)
  1546. return
  1547. default:
  1548. cn.bad = true
  1549. errorf("unexpected message during extended query execution: %q", t)
  1550. }
  1551. }
  1552. }
  1553. // Only for Exec(), since we ignore the returned data
  1554. func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
  1555. for {
  1556. t, r := cn.recv1()
  1557. switch t {
  1558. case 'C':
  1559. if err != nil {
  1560. cn.bad = true
  1561. errorf("unexpected CommandComplete after error %s", err)
  1562. }
  1563. res, commandTag = cn.parseComplete(r.string())
  1564. case 'Z':
  1565. cn.processReadyForQuery(r)
  1566. if res == nil && err == nil {
  1567. err = errUnexpectedReady
  1568. }
  1569. return res, commandTag, err
  1570. case 'E':
  1571. err = parseError(r)
  1572. case 'T', 'D', 'I':
  1573. if err != nil {
  1574. cn.bad = true
  1575. errorf("unexpected %q after error %s", t, err)
  1576. }
  1577. if t == 'I' {
  1578. res = emptyRows
  1579. }
  1580. // ignore any results
  1581. default:
  1582. cn.bad = true
  1583. errorf("unknown %s response: %q", protocolState, t)
  1584. }
  1585. }
  1586. }
  1587. func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
  1588. n := r.int16()
  1589. colNames = make([]string, n)
  1590. colTyps = make([]fieldDesc, n)
  1591. for i := range colNames {
  1592. colNames[i] = r.string()
  1593. r.next(6)
  1594. colTyps[i].OID = r.oid()
  1595. colTyps[i].Len = r.int16()
  1596. colTyps[i].Mod = r.int32()
  1597. // format code not known when describing a statement; always 0
  1598. r.next(2)
  1599. }
  1600. return
  1601. }
  1602. func parsePortalRowDescribe(r *readBuf) rowsHeader {
  1603. n := r.int16()
  1604. colNames := make([]string, n)
  1605. colFmts := make([]format, n)
  1606. colTyps := make([]fieldDesc, n)
  1607. for i := range colNames {
  1608. colNames[i] = r.string()
  1609. r.next(6)
  1610. colTyps[i].OID = r.oid()
  1611. colTyps[i].Len = r.int16()
  1612. colTyps[i].Mod = r.int32()
  1613. colFmts[i] = format(r.int16())
  1614. }
  1615. return rowsHeader{
  1616. colNames: colNames,
  1617. colFmts: colFmts,
  1618. colTyps: colTyps,
  1619. }
  1620. }
  1621. // parseEnviron tries to mimic some of libpq's environment handling
  1622. //
  1623. // To ease testing, it does not directly reference os.Environ, but is
  1624. // designed to accept its output.
  1625. //
  1626. // Environment-set connection information is intended to have a higher
  1627. // precedence than a library default but lower than any explicitly
  1628. // passed information (such as in the URL or connection string).
  1629. func parseEnviron(env []string) (out map[string]string) {
  1630. out = make(map[string]string)
  1631. for _, v := range env {
  1632. parts := strings.SplitN(v, "=", 2)
  1633. accrue := func(keyname string) {
  1634. out[keyname] = parts[1]
  1635. }
  1636. unsupported := func() {
  1637. panic(fmt.Sprintf("setting %v not supported", parts[0]))
  1638. }
  1639. // The order of these is the same as is seen in the
  1640. // PostgreSQL 9.1 manual. Unsupported but well-defined
  1641. // keys cause a panic; these should be unset prior to
  1642. // execution. Options which pq expects to be set to a
  1643. // certain value are allowed, but must be set to that
  1644. // value if present (they can, of course, be absent).
  1645. switch parts[0] {
  1646. case "PGHOST":
  1647. accrue("host")
  1648. case "PGHOSTADDR":
  1649. unsupported()
  1650. case "PGPORT":
  1651. accrue("port")
  1652. case "PGDATABASE":
  1653. accrue("dbname")
  1654. case "PGUSER":
  1655. accrue("user")
  1656. case "PGPASSWORD":
  1657. accrue("password")
  1658. case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
  1659. unsupported()
  1660. case "PGOPTIONS":
  1661. accrue("options")
  1662. case "PGAPPNAME":
  1663. accrue("application_name")
  1664. case "PGSSLMODE":
  1665. accrue("sslmode")
  1666. case "PGSSLCERT":
  1667. accrue("sslcert")
  1668. case "PGSSLKEY":
  1669. accrue("sslkey")
  1670. case "PGSSLROOTCERT":
  1671. accrue("sslrootcert")
  1672. case "PGREQUIRESSL", "PGSSLCRL":
  1673. unsupported()
  1674. case "PGREQUIREPEER":
  1675. unsupported()
  1676. case "PGKRBSRVNAME", "PGGSSLIB":
  1677. unsupported()
  1678. case "PGCONNECT_TIMEOUT":
  1679. accrue("connect_timeout")
  1680. case "PGCLIENTENCODING":
  1681. accrue("client_encoding")
  1682. case "PGDATESTYLE":
  1683. accrue("datestyle")
  1684. case "PGTZ":
  1685. accrue("timezone")
  1686. case "PGGEQO":
  1687. accrue("geqo")
  1688. case "PGSYSCONFDIR", "PGLOCALEDIR":
  1689. unsupported()
  1690. }
  1691. }
  1692. return out
  1693. }
  1694. // isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
  1695. func isUTF8(name string) bool {
  1696. // Recognize all sorts of silly things as "UTF-8", like Postgres does
  1697. s := strings.Map(alnumLowerASCII, name)
  1698. return s == "utf8" || s == "unicode"
  1699. }
  1700. func alnumLowerASCII(ch rune) rune {
  1701. if 'A' <= ch && ch <= 'Z' {
  1702. return ch + ('a' - 'A')
  1703. }
  1704. if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
  1705. return ch
  1706. }
  1707. return -1 // discard
  1708. }