You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1996 lines
46 KiB

  1. package pq
  2. import (
  3. "bufio"
  4. "context"
  5. "crypto/md5"
  6. "crypto/sha256"
  7. "database/sql"
  8. "database/sql/driver"
  9. "encoding/binary"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "net"
  14. "os"
  15. "os/user"
  16. "path"
  17. "path/filepath"
  18. "strconv"
  19. "strings"
  20. "time"
  21. "unicode"
  22. "github.com/lib/pq/oid"
  23. "github.com/lib/pq/scram"
  24. )
  25. // Common error types
  26. var (
  27. ErrNotSupported = errors.New("pq: Unsupported command")
  28. ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
  29. ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
  30. ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less")
  31. ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly")
  32. errUnexpectedReady = errors.New("unexpected ReadyForQuery")
  33. errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
  34. errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
  35. )
  36. // Driver is the Postgres database driver.
  37. type Driver struct{}
  38. // Open opens a new connection to the database. name is a connection string.
  39. // Most users should only use it through database/sql package from the standard
  40. // library.
  41. func (d *Driver) Open(name string) (driver.Conn, error) {
  42. return Open(name)
  43. }
  44. func init() {
  45. sql.Register("postgres", &Driver{})
  46. }
  47. type parameterStatus struct {
  48. // server version in the same format as server_version_num, or 0 if
  49. // unavailable
  50. serverVersion int
  51. // the current location based on the TimeZone value of the session, if
  52. // available
  53. currentLocation *time.Location
  54. }
  55. type transactionStatus byte
  56. const (
  57. txnStatusIdle transactionStatus = 'I'
  58. txnStatusIdleInTransaction transactionStatus = 'T'
  59. txnStatusInFailedTransaction transactionStatus = 'E'
  60. )
  61. func (s transactionStatus) String() string {
  62. switch s {
  63. case txnStatusIdle:
  64. return "idle"
  65. case txnStatusIdleInTransaction:
  66. return "idle in transaction"
  67. case txnStatusInFailedTransaction:
  68. return "in a failed transaction"
  69. default:
  70. errorf("unknown transactionStatus %d", s)
  71. }
  72. panic("not reached")
  73. }
  74. // Dialer is the dialer interface. It can be used to obtain more control over
  75. // how pq creates network connections.
  76. type Dialer interface {
  77. Dial(network, address string) (net.Conn, error)
  78. DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
  79. }
  80. // DialerContext is the context-aware dialer interface.
  81. type DialerContext interface {
  82. DialContext(ctx context.Context, network, address string) (net.Conn, error)
  83. }
  84. type defaultDialer struct {
  85. d net.Dialer
  86. }
  87. func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
  88. return d.d.Dial(network, address)
  89. }
  90. func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
  91. ctx, cancel := context.WithTimeout(context.Background(), timeout)
  92. defer cancel()
  93. return d.DialContext(ctx, network, address)
  94. }
  95. func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  96. return d.d.DialContext(ctx, network, address)
  97. }
  98. type conn struct {
  99. c net.Conn
  100. buf *bufio.Reader
  101. namei int
  102. scratch [512]byte
  103. txnStatus transactionStatus
  104. txnFinish func()
  105. // Save connection arguments to use during CancelRequest.
  106. dialer Dialer
  107. opts values
  108. // Cancellation key data for use with CancelRequest messages.
  109. processID int
  110. secretKey int
  111. parameterStatus parameterStatus
  112. saveMessageType byte
  113. saveMessageBuffer []byte
  114. // If true, this connection is bad and all public-facing functions should
  115. // return ErrBadConn.
  116. bad bool
  117. // If set, this connection should never use the binary format when
  118. // receiving query results from prepared statements. Only provided for
  119. // debugging.
  120. disablePreparedBinaryResult bool
  121. // Whether to always send []byte parameters over as binary. Enables single
  122. // round-trip mode for non-prepared Query calls.
  123. binaryParameters bool
  124. // If true this connection is in the middle of a COPY
  125. inCopy bool
  126. // If not nil, notices will be synchronously sent here
  127. noticeHandler func(*Error)
  128. // If not nil, notifications will be synchronously sent here
  129. notificationHandler func(*Notification)
  130. // GSSAPI context
  131. gss GSS
  132. }
  133. // Handle driver-side settings in parsed connection string.
  134. func (cn *conn) handleDriverSettings(o values) (err error) {
  135. boolSetting := func(key string, val *bool) error {
  136. if value, ok := o[key]; ok {
  137. if value == "yes" {
  138. *val = true
  139. } else if value == "no" {
  140. *val = false
  141. } else {
  142. return fmt.Errorf("unrecognized value %q for %s", value, key)
  143. }
  144. }
  145. return nil
  146. }
  147. err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
  148. if err != nil {
  149. return err
  150. }
  151. return boolSetting("binary_parameters", &cn.binaryParameters)
  152. }
  153. func (cn *conn) handlePgpass(o values) {
  154. // if a password was supplied, do not process .pgpass
  155. if _, ok := o["password"]; ok {
  156. return
  157. }
  158. filename := os.Getenv("PGPASSFILE")
  159. if filename == "" {
  160. // XXX this code doesn't work on Windows where the default filename is
  161. // XXX %APPDATA%\postgresql\pgpass.conf
  162. // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470
  163. userHome := os.Getenv("HOME")
  164. if userHome == "" {
  165. user, err := user.Current()
  166. if err != nil {
  167. return
  168. }
  169. userHome = user.HomeDir
  170. }
  171. filename = filepath.Join(userHome, ".pgpass")
  172. }
  173. fileinfo, err := os.Stat(filename)
  174. if err != nil {
  175. return
  176. }
  177. mode := fileinfo.Mode()
  178. if mode&(0x77) != 0 {
  179. // XXX should warn about incorrect .pgpass permissions as psql does
  180. return
  181. }
  182. file, err := os.Open(filename)
  183. if err != nil {
  184. return
  185. }
  186. defer file.Close()
  187. scanner := bufio.NewScanner(io.Reader(file))
  188. hostname := o["host"]
  189. ntw, _ := network(o)
  190. port := o["port"]
  191. db := o["dbname"]
  192. username := o["user"]
  193. // From: https://github.com/tg/pgpass/blob/master/reader.go
  194. getFields := func(s string) []string {
  195. fs := make([]string, 0, 5)
  196. f := make([]rune, 0, len(s))
  197. var esc bool
  198. for _, c := range s {
  199. switch {
  200. case esc:
  201. f = append(f, c)
  202. esc = false
  203. case c == '\\':
  204. esc = true
  205. case c == ':':
  206. fs = append(fs, string(f))
  207. f = f[:0]
  208. default:
  209. f = append(f, c)
  210. }
  211. }
  212. return append(fs, string(f))
  213. }
  214. for scanner.Scan() {
  215. line := scanner.Text()
  216. if len(line) == 0 || line[0] == '#' {
  217. continue
  218. }
  219. split := getFields(line)
  220. if len(split) != 5 {
  221. continue
  222. }
  223. if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
  224. o["password"] = split[4]
  225. return
  226. }
  227. }
  228. }
  229. func (cn *conn) writeBuf(b byte) *writeBuf {
  230. cn.scratch[0] = b
  231. return &writeBuf{
  232. buf: cn.scratch[:5],
  233. pos: 1,
  234. }
  235. }
  236. // Open opens a new connection to the database. dsn is a connection string.
  237. // Most users should only use it through database/sql package from the standard
  238. // library.
  239. func Open(dsn string) (_ driver.Conn, err error) {
  240. return DialOpen(defaultDialer{}, dsn)
  241. }
  242. // DialOpen opens a new connection to the database using a dialer.
  243. func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
  244. c, err := NewConnector(dsn)
  245. if err != nil {
  246. return nil, err
  247. }
  248. c.dialer = d
  249. return c.open(context.Background())
  250. }
  251. func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
  252. // Handle any panics during connection initialization. Note that we
  253. // specifically do *not* want to use errRecover(), as that would turn any
  254. // connection errors into ErrBadConns, hiding the real error message from
  255. // the user.
  256. defer errRecoverNoErrBadConn(&err)
  257. o := c.opts
  258. cn = &conn{
  259. opts: o,
  260. dialer: c.dialer,
  261. }
  262. err = cn.handleDriverSettings(o)
  263. if err != nil {
  264. return nil, err
  265. }
  266. cn.handlePgpass(o)
  267. cn.c, err = dial(ctx, c.dialer, o)
  268. if err != nil {
  269. return nil, err
  270. }
  271. err = cn.ssl(o)
  272. if err != nil {
  273. if cn.c != nil {
  274. cn.c.Close()
  275. }
  276. return nil, err
  277. }
  278. // cn.startup panics on error. Make sure we don't leak cn.c.
  279. panicking := true
  280. defer func() {
  281. if panicking {
  282. cn.c.Close()
  283. }
  284. }()
  285. cn.buf = bufio.NewReader(cn.c)
  286. cn.startup(o)
  287. // reset the deadline, in case one was set (see dial)
  288. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  289. err = cn.c.SetDeadline(time.Time{})
  290. }
  291. panicking = false
  292. return cn, err
  293. }
  294. func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
  295. network, address := network(o)
  296. // Zero or not specified means wait indefinitely.
  297. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  298. seconds, err := strconv.ParseInt(timeout, 10, 0)
  299. if err != nil {
  300. return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
  301. }
  302. duration := time.Duration(seconds) * time.Second
  303. // connect_timeout should apply to the entire connection establishment
  304. // procedure, so we both use a timeout for the TCP connection
  305. // establishment and set a deadline for doing the initial handshake.
  306. // The deadline is then reset after startup() is done.
  307. deadline := time.Now().Add(duration)
  308. var conn net.Conn
  309. if dctx, ok := d.(DialerContext); ok {
  310. ctx, cancel := context.WithTimeout(ctx, duration)
  311. defer cancel()
  312. conn, err = dctx.DialContext(ctx, network, address)
  313. } else {
  314. conn, err = d.DialTimeout(network, address, duration)
  315. }
  316. if err != nil {
  317. return nil, err
  318. }
  319. err = conn.SetDeadline(deadline)
  320. return conn, err
  321. }
  322. if dctx, ok := d.(DialerContext); ok {
  323. return dctx.DialContext(ctx, network, address)
  324. }
  325. return d.Dial(network, address)
  326. }
  327. func network(o values) (string, string) {
  328. host := o["host"]
  329. if strings.HasPrefix(host, "/") {
  330. sockPath := path.Join(host, ".s.PGSQL."+o["port"])
  331. return "unix", sockPath
  332. }
  333. return "tcp", net.JoinHostPort(host, o["port"])
  334. }
  335. type values map[string]string
  336. // scanner implements a tokenizer for libpq-style option strings.
  337. type scanner struct {
  338. s []rune
  339. i int
  340. }
  341. // newScanner returns a new scanner initialized with the option string s.
  342. func newScanner(s string) *scanner {
  343. return &scanner{[]rune(s), 0}
  344. }
  345. // Next returns the next rune.
  346. // It returns 0, false if the end of the text has been reached.
  347. func (s *scanner) Next() (rune, bool) {
  348. if s.i >= len(s.s) {
  349. return 0, false
  350. }
  351. r := s.s[s.i]
  352. s.i++
  353. return r, true
  354. }
  355. // SkipSpaces returns the next non-whitespace rune.
  356. // It returns 0, false if the end of the text has been reached.
  357. func (s *scanner) SkipSpaces() (rune, bool) {
  358. r, ok := s.Next()
  359. for unicode.IsSpace(r) && ok {
  360. r, ok = s.Next()
  361. }
  362. return r, ok
  363. }
  364. // parseOpts parses the options from name and adds them to the values.
  365. //
  366. // The parsing code is based on conninfo_parse from libpq's fe-connect.c
  367. func parseOpts(name string, o values) error {
  368. s := newScanner(name)
  369. for {
  370. var (
  371. keyRunes, valRunes []rune
  372. r rune
  373. ok bool
  374. )
  375. if r, ok = s.SkipSpaces(); !ok {
  376. break
  377. }
  378. // Scan the key
  379. for !unicode.IsSpace(r) && r != '=' {
  380. keyRunes = append(keyRunes, r)
  381. if r, ok = s.Next(); !ok {
  382. break
  383. }
  384. }
  385. // Skip any whitespace if we're not at the = yet
  386. if r != '=' {
  387. r, ok = s.SkipSpaces()
  388. }
  389. // The current character should be =
  390. if r != '=' || !ok {
  391. return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
  392. }
  393. // Skip any whitespace after the =
  394. if r, ok = s.SkipSpaces(); !ok {
  395. // If we reach the end here, the last value is just an empty string as per libpq.
  396. o[string(keyRunes)] = ""
  397. break
  398. }
  399. if r != '\'' {
  400. for !unicode.IsSpace(r) {
  401. if r == '\\' {
  402. if r, ok = s.Next(); !ok {
  403. return fmt.Errorf(`missing character after backslash`)
  404. }
  405. }
  406. valRunes = append(valRunes, r)
  407. if r, ok = s.Next(); !ok {
  408. break
  409. }
  410. }
  411. } else {
  412. quote:
  413. for {
  414. if r, ok = s.Next(); !ok {
  415. return fmt.Errorf(`unterminated quoted string literal in connection string`)
  416. }
  417. switch r {
  418. case '\'':
  419. break quote
  420. case '\\':
  421. r, _ = s.Next()
  422. fallthrough
  423. default:
  424. valRunes = append(valRunes, r)
  425. }
  426. }
  427. }
  428. o[string(keyRunes)] = string(valRunes)
  429. }
  430. return nil
  431. }
  432. func (cn *conn) isInTransaction() bool {
  433. return cn.txnStatus == txnStatusIdleInTransaction ||
  434. cn.txnStatus == txnStatusInFailedTransaction
  435. }
  436. func (cn *conn) checkIsInTransaction(intxn bool) {
  437. if cn.isInTransaction() != intxn {
  438. cn.bad = true
  439. errorf("unexpected transaction status %v", cn.txnStatus)
  440. }
  441. }
  442. func (cn *conn) Begin() (_ driver.Tx, err error) {
  443. return cn.begin("")
  444. }
  445. func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
  446. if cn.bad {
  447. return nil, driver.ErrBadConn
  448. }
  449. defer cn.errRecover(&err)
  450. cn.checkIsInTransaction(false)
  451. _, commandTag, err := cn.simpleExec("BEGIN" + mode)
  452. if err != nil {
  453. return nil, err
  454. }
  455. if commandTag != "BEGIN" {
  456. cn.bad = true
  457. return nil, fmt.Errorf("unexpected command tag %s", commandTag)
  458. }
  459. if cn.txnStatus != txnStatusIdleInTransaction {
  460. cn.bad = true
  461. return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
  462. }
  463. return cn, nil
  464. }
  465. func (cn *conn) closeTxn() {
  466. if finish := cn.txnFinish; finish != nil {
  467. finish()
  468. }
  469. }
  470. func (cn *conn) Commit() (err error) {
  471. defer cn.closeTxn()
  472. if cn.bad {
  473. return driver.ErrBadConn
  474. }
  475. defer cn.errRecover(&err)
  476. cn.checkIsInTransaction(true)
  477. // We don't want the client to think that everything is okay if it tries
  478. // to commit a failed transaction. However, no matter what we return,
  479. // database/sql will release this connection back into the free connection
  480. // pool so we have to abort the current transaction here. Note that you
  481. // would get the same behaviour if you issued a COMMIT in a failed
  482. // transaction, so it's also the least surprising thing to do here.
  483. if cn.txnStatus == txnStatusInFailedTransaction {
  484. if err := cn.rollback(); err != nil {
  485. return err
  486. }
  487. return ErrInFailedTransaction
  488. }
  489. _, commandTag, err := cn.simpleExec("COMMIT")
  490. if err != nil {
  491. if cn.isInTransaction() {
  492. cn.bad = true
  493. }
  494. return err
  495. }
  496. if commandTag != "COMMIT" {
  497. cn.bad = true
  498. return fmt.Errorf("unexpected command tag %s", commandTag)
  499. }
  500. cn.checkIsInTransaction(false)
  501. return nil
  502. }
  503. func (cn *conn) Rollback() (err error) {
  504. defer cn.closeTxn()
  505. if cn.bad {
  506. return driver.ErrBadConn
  507. }
  508. defer cn.errRecover(&err)
  509. return cn.rollback()
  510. }
  511. func (cn *conn) rollback() (err error) {
  512. cn.checkIsInTransaction(true)
  513. _, commandTag, err := cn.simpleExec("ROLLBACK")
  514. if err != nil {
  515. if cn.isInTransaction() {
  516. cn.bad = true
  517. }
  518. return err
  519. }
  520. if commandTag != "ROLLBACK" {
  521. return fmt.Errorf("unexpected command tag %s", commandTag)
  522. }
  523. cn.checkIsInTransaction(false)
  524. return nil
  525. }
  526. func (cn *conn) gname() string {
  527. cn.namei++
  528. return strconv.FormatInt(int64(cn.namei), 10)
  529. }
  530. func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
  531. b := cn.writeBuf('Q')
  532. b.string(q)
  533. cn.send(b)
  534. for {
  535. t, r := cn.recv1()
  536. switch t {
  537. case 'C':
  538. res, commandTag = cn.parseComplete(r.string())
  539. case 'Z':
  540. cn.processReadyForQuery(r)
  541. if res == nil && err == nil {
  542. err = errUnexpectedReady
  543. }
  544. // done
  545. return
  546. case 'E':
  547. err = parseError(r)
  548. case 'I':
  549. res = emptyRows
  550. case 'T', 'D':
  551. // ignore any results
  552. default:
  553. cn.bad = true
  554. errorf("unknown response for simple query: %q", t)
  555. }
  556. }
  557. }
  558. func (cn *conn) simpleQuery(q string) (res *rows, err error) {
  559. defer cn.errRecover(&err)
  560. b := cn.writeBuf('Q')
  561. b.string(q)
  562. cn.send(b)
  563. for {
  564. t, r := cn.recv1()
  565. switch t {
  566. case 'C', 'I':
  567. // We allow queries which don't return any results through Query as
  568. // well as Exec. We still have to give database/sql a rows object
  569. // the user can close, though, to avoid connections from being
  570. // leaked. A "rows" with done=true works fine for that purpose.
  571. if err != nil {
  572. cn.bad = true
  573. errorf("unexpected message %q in simple query execution", t)
  574. }
  575. if res == nil {
  576. res = &rows{
  577. cn: cn,
  578. }
  579. }
  580. // Set the result and tag to the last command complete if there wasn't a
  581. // query already run. Although queries usually return from here and cede
  582. // control to Next, a query with zero results does not.
  583. if t == 'C' && res.colNames == nil {
  584. res.result, res.tag = cn.parseComplete(r.string())
  585. }
  586. res.done = true
  587. case 'Z':
  588. cn.processReadyForQuery(r)
  589. // done
  590. return
  591. case 'E':
  592. res = nil
  593. err = parseError(r)
  594. case 'D':
  595. if res == nil {
  596. cn.bad = true
  597. errorf("unexpected DataRow in simple query execution")
  598. }
  599. // the query didn't fail; kick off to Next
  600. cn.saveMessage(t, r)
  601. return
  602. case 'T':
  603. // res might be non-nil here if we received a previous
  604. // CommandComplete, but that's fine; just overwrite it
  605. res = &rows{cn: cn}
  606. res.rowsHeader = parsePortalRowDescribe(r)
  607. // To work around a bug in QueryRow in Go 1.2 and earlier, wait
  608. // until the first DataRow has been received.
  609. default:
  610. cn.bad = true
  611. errorf("unknown response for simple query: %q", t)
  612. }
  613. }
  614. }
  615. type noRows struct{}
  616. var emptyRows noRows
  617. var _ driver.Result = noRows{}
  618. func (noRows) LastInsertId() (int64, error) {
  619. return 0, errNoLastInsertID
  620. }
  621. func (noRows) RowsAffected() (int64, error) {
  622. return 0, errNoRowsAffected
  623. }
  624. // Decides which column formats to use for a prepared statement. The input is
  625. // an array of type oids, one element per result column.
  626. func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) {
  627. if len(colTyps) == 0 {
  628. return nil, colFmtDataAllText
  629. }
  630. colFmts = make([]format, len(colTyps))
  631. if forceText {
  632. return colFmts, colFmtDataAllText
  633. }
  634. allBinary := true
  635. allText := true
  636. for i, t := range colTyps {
  637. switch t.OID {
  638. // This is the list of types to use binary mode for when receiving them
  639. // through a prepared statement. If a type appears in this list, it
  640. // must also be implemented in binaryDecode in encode.go.
  641. case oid.T_bytea:
  642. fallthrough
  643. case oid.T_int8:
  644. fallthrough
  645. case oid.T_int4:
  646. fallthrough
  647. case oid.T_int2:
  648. fallthrough
  649. case oid.T_uuid:
  650. colFmts[i] = formatBinary
  651. allText = false
  652. default:
  653. allBinary = false
  654. }
  655. }
  656. if allBinary {
  657. return colFmts, colFmtDataAllBinary
  658. } else if allText {
  659. return colFmts, colFmtDataAllText
  660. } else {
  661. colFmtData = make([]byte, 2+len(colFmts)*2)
  662. binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
  663. for i, v := range colFmts {
  664. binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
  665. }
  666. return colFmts, colFmtData
  667. }
  668. }
  669. func (cn *conn) prepareTo(q, stmtName string) *stmt {
  670. st := &stmt{cn: cn, name: stmtName}
  671. b := cn.writeBuf('P')
  672. b.string(st.name)
  673. b.string(q)
  674. b.int16(0)
  675. b.next('D')
  676. b.byte('S')
  677. b.string(st.name)
  678. b.next('S')
  679. cn.send(b)
  680. cn.readParseResponse()
  681. st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
  682. st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
  683. cn.readReadyForQuery()
  684. return st
  685. }
  686. func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
  687. if cn.bad {
  688. return nil, driver.ErrBadConn
  689. }
  690. defer cn.errRecover(&err)
  691. if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
  692. s, err := cn.prepareCopyIn(q)
  693. if err == nil {
  694. cn.inCopy = true
  695. }
  696. return s, err
  697. }
  698. return cn.prepareTo(q, cn.gname()), nil
  699. }
  700. func (cn *conn) Close() (err error) {
  701. // Skip cn.bad return here because we always want to close a connection.
  702. defer cn.errRecover(&err)
  703. // Ensure that cn.c.Close is always run. Since error handling is done with
  704. // panics and cn.errRecover, the Close must be in a defer.
  705. defer func() {
  706. cerr := cn.c.Close()
  707. if err == nil {
  708. err = cerr
  709. }
  710. }()
  711. // Don't go through send(); ListenerConn relies on us not scribbling on the
  712. // scratch buffer of this connection.
  713. return cn.sendSimpleMessage('X')
  714. }
  715. // Implement the "Queryer" interface
  716. func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
  717. return cn.query(query, args)
  718. }
  719. func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
  720. if cn.bad {
  721. return nil, driver.ErrBadConn
  722. }
  723. if cn.inCopy {
  724. return nil, errCopyInProgress
  725. }
  726. defer cn.errRecover(&err)
  727. // Check to see if we can use the "simpleQuery" interface, which is
  728. // *much* faster than going through prepare/exec
  729. if len(args) == 0 {
  730. return cn.simpleQuery(query)
  731. }
  732. if cn.binaryParameters {
  733. cn.sendBinaryModeQuery(query, args)
  734. cn.readParseResponse()
  735. cn.readBindResponse()
  736. rows := &rows{cn: cn}
  737. rows.rowsHeader = cn.readPortalDescribeResponse()
  738. cn.postExecuteWorkaround()
  739. return rows, nil
  740. }
  741. st := cn.prepareTo(query, "")
  742. st.exec(args)
  743. return &rows{
  744. cn: cn,
  745. rowsHeader: st.rowsHeader,
  746. }, nil
  747. }
  748. // Implement the optional "Execer" interface for one-shot queries
  749. func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
  750. if cn.bad {
  751. return nil, driver.ErrBadConn
  752. }
  753. defer cn.errRecover(&err)
  754. // Check to see if we can use the "simpleExec" interface, which is
  755. // *much* faster than going through prepare/exec
  756. if len(args) == 0 {
  757. // ignore commandTag, our caller doesn't care
  758. r, _, err := cn.simpleExec(query)
  759. return r, err
  760. }
  761. if cn.binaryParameters {
  762. cn.sendBinaryModeQuery(query, args)
  763. cn.readParseResponse()
  764. cn.readBindResponse()
  765. cn.readPortalDescribeResponse()
  766. cn.postExecuteWorkaround()
  767. res, _, err = cn.readExecuteResponse("Execute")
  768. return res, err
  769. }
  770. // Use the unnamed statement to defer planning until bind
  771. // time, or else value-based selectivity estimates cannot be
  772. // used.
  773. st := cn.prepareTo(query, "")
  774. r, err := st.Exec(args)
  775. if err != nil {
  776. panic(err)
  777. }
  778. return r, err
  779. }
  780. func (cn *conn) send(m *writeBuf) {
  781. _, err := cn.c.Write(m.wrap())
  782. if err != nil {
  783. panic(err)
  784. }
  785. }
  786. func (cn *conn) sendStartupPacket(m *writeBuf) error {
  787. _, err := cn.c.Write((m.wrap())[1:])
  788. return err
  789. }
  790. // Send a message of type typ to the server on the other end of cn. The
  791. // message should have no payload. This method does not use the scratch
  792. // buffer.
  793. func (cn *conn) sendSimpleMessage(typ byte) (err error) {
  794. _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
  795. return err
  796. }
  797. // saveMessage memorizes a message and its buffer in the conn struct.
  798. // recvMessage will then return these values on the next call to it. This
  799. // method is useful in cases where you have to see what the next message is
  800. // going to be (e.g. to see whether it's an error or not) but you can't handle
  801. // the message yourself.
  802. func (cn *conn) saveMessage(typ byte, buf *readBuf) {
  803. if cn.saveMessageType != 0 {
  804. cn.bad = true
  805. errorf("unexpected saveMessageType %d", cn.saveMessageType)
  806. }
  807. cn.saveMessageType = typ
  808. cn.saveMessageBuffer = *buf
  809. }
  810. // recvMessage receives any message from the backend, or returns an error if
  811. // a problem occurred while reading the message.
  812. func (cn *conn) recvMessage(r *readBuf) (byte, error) {
  813. // workaround for a QueryRow bug, see exec
  814. if cn.saveMessageType != 0 {
  815. t := cn.saveMessageType
  816. *r = cn.saveMessageBuffer
  817. cn.saveMessageType = 0
  818. cn.saveMessageBuffer = nil
  819. return t, nil
  820. }
  821. x := cn.scratch[:5]
  822. _, err := io.ReadFull(cn.buf, x)
  823. if err != nil {
  824. return 0, err
  825. }
  826. // read the type and length of the message that follows
  827. t := x[0]
  828. n := int(binary.BigEndian.Uint32(x[1:])) - 4
  829. var y []byte
  830. if n <= len(cn.scratch) {
  831. y = cn.scratch[:n]
  832. } else {
  833. y = make([]byte, n)
  834. }
  835. _, err = io.ReadFull(cn.buf, y)
  836. if err != nil {
  837. return 0, err
  838. }
  839. *r = y
  840. return t, nil
  841. }
  842. // recv receives a message from the backend, but if an error happened while
  843. // reading the message or the received message was an ErrorResponse, it panics.
  844. // NoticeResponses are ignored. This function should generally be used only
  845. // during the startup sequence.
  846. func (cn *conn) recv() (t byte, r *readBuf) {
  847. for {
  848. var err error
  849. r = &readBuf{}
  850. t, err = cn.recvMessage(r)
  851. if err != nil {
  852. panic(err)
  853. }
  854. switch t {
  855. case 'E':
  856. panic(parseError(r))
  857. case 'N':
  858. if n := cn.noticeHandler; n != nil {
  859. n(parseError(r))
  860. }
  861. case 'A':
  862. if n := cn.notificationHandler; n != nil {
  863. n(recvNotification(r))
  864. }
  865. default:
  866. return
  867. }
  868. }
  869. }
  870. // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
  871. // the caller to avoid an allocation.
  872. func (cn *conn) recv1Buf(r *readBuf) byte {
  873. for {
  874. t, err := cn.recvMessage(r)
  875. if err != nil {
  876. panic(err)
  877. }
  878. switch t {
  879. case 'A':
  880. if n := cn.notificationHandler; n != nil {
  881. n(recvNotification(r))
  882. }
  883. case 'N':
  884. if n := cn.noticeHandler; n != nil {
  885. n(parseError(r))
  886. }
  887. case 'S':
  888. cn.processParameterStatus(r)
  889. default:
  890. return t
  891. }
  892. }
  893. }
  894. // recv1 receives a message from the backend, panicking if an error occurs
  895. // while attempting to read it. All asynchronous messages are ignored, with
  896. // the exception of ErrorResponse.
  897. func (cn *conn) recv1() (t byte, r *readBuf) {
  898. r = &readBuf{}
  899. t = cn.recv1Buf(r)
  900. return t, r
  901. }
  902. func (cn *conn) ssl(o values) error {
  903. upgrade, err := ssl(o)
  904. if err != nil {
  905. return err
  906. }
  907. if upgrade == nil {
  908. // Nothing to do
  909. return nil
  910. }
  911. w := cn.writeBuf(0)
  912. w.int32(80877103)
  913. if err = cn.sendStartupPacket(w); err != nil {
  914. return err
  915. }
  916. b := cn.scratch[:1]
  917. _, err = io.ReadFull(cn.c, b)
  918. if err != nil {
  919. return err
  920. }
  921. if b[0] != 'S' {
  922. return ErrSSLNotSupported
  923. }
  924. cn.c, err = upgrade(cn.c)
  925. return err
  926. }
  927. // isDriverSetting returns true iff a setting is purely for configuring the
  928. // driver's options and should not be sent to the server in the connection
  929. // startup packet.
  930. func isDriverSetting(key string) bool {
  931. switch key {
  932. case "host", "port":
  933. return true
  934. case "password":
  935. return true
  936. case "sslmode", "sslcert", "sslkey", "sslrootcert":
  937. return true
  938. case "fallback_application_name":
  939. return true
  940. case "connect_timeout":
  941. return true
  942. case "disable_prepared_binary_result":
  943. return true
  944. case "binary_parameters":
  945. return true
  946. case "krbsrvname":
  947. return true
  948. case "krbspn":
  949. return true
  950. default:
  951. return false
  952. }
  953. }
  954. func (cn *conn) startup(o values) {
  955. w := cn.writeBuf(0)
  956. w.int32(196608)
  957. // Send the backend the name of the database we want to connect to, and the
  958. // user we want to connect as. Additionally, we send over any run-time
  959. // parameters potentially included in the connection string. If the server
  960. // doesn't recognize any of them, it will reply with an error.
  961. for k, v := range o {
  962. if isDriverSetting(k) {
  963. // skip options which can't be run-time parameters
  964. continue
  965. }
  966. // The protocol requires us to supply the database name as "database"
  967. // instead of "dbname".
  968. if k == "dbname" {
  969. k = "database"
  970. }
  971. w.string(k)
  972. w.string(v)
  973. }
  974. w.string("")
  975. if err := cn.sendStartupPacket(w); err != nil {
  976. panic(err)
  977. }
  978. for {
  979. t, r := cn.recv()
  980. switch t {
  981. case 'K':
  982. cn.processBackendKeyData(r)
  983. case 'S':
  984. cn.processParameterStatus(r)
  985. case 'R':
  986. cn.auth(r, o)
  987. case 'Z':
  988. cn.processReadyForQuery(r)
  989. return
  990. default:
  991. errorf("unknown response for startup: %q", t)
  992. }
  993. }
  994. }
  995. func (cn *conn) auth(r *readBuf, o values) {
  996. switch code := r.int32(); code {
  997. case 0:
  998. // OK
  999. case 3:
  1000. w := cn.writeBuf('p')
  1001. w.string(o["password"])
  1002. cn.send(w)
  1003. t, r := cn.recv()
  1004. if t != 'R' {
  1005. errorf("unexpected password response: %q", t)
  1006. }
  1007. if r.int32() != 0 {
  1008. errorf("unexpected authentication response: %q", t)
  1009. }
  1010. case 5:
  1011. s := string(r.next(4))
  1012. w := cn.writeBuf('p')
  1013. w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
  1014. cn.send(w)
  1015. t, r := cn.recv()
  1016. if t != 'R' {
  1017. errorf("unexpected password response: %q", t)
  1018. }
  1019. if r.int32() != 0 {
  1020. errorf("unexpected authentication response: %q", t)
  1021. }
  1022. case 7: // GSSAPI, startup
  1023. if newGss == nil {
  1024. errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)")
  1025. }
  1026. cli, err := newGss()
  1027. if err != nil {
  1028. errorf("kerberos error: %s", err.Error())
  1029. }
  1030. var token []byte
  1031. if spn, ok := o["krbspn"]; ok {
  1032. // Use the supplied SPN if provided..
  1033. token, err = cli.GetInitTokenFromSpn(spn)
  1034. } else {
  1035. // Allow the kerberos service name to be overridden
  1036. service := "postgres"
  1037. if val, ok := o["krbsrvname"]; ok {
  1038. service = val
  1039. }
  1040. token, err = cli.GetInitToken(o["host"], service)
  1041. }
  1042. if err != nil {
  1043. errorf("failed to get Kerberos ticket: %q", err)
  1044. }
  1045. w := cn.writeBuf('p')
  1046. w.bytes(token)
  1047. cn.send(w)
  1048. // Store for GSSAPI continue message
  1049. cn.gss = cli
  1050. case 8: // GSSAPI continue
  1051. if cn.gss == nil {
  1052. errorf("GSSAPI protocol error")
  1053. }
  1054. b := []byte(*r)
  1055. done, tokOut, err := cn.gss.Continue(b)
  1056. if err == nil && !done {
  1057. w := cn.writeBuf('p')
  1058. w.bytes(tokOut)
  1059. cn.send(w)
  1060. }
  1061. // Errors fall through and read the more detailed message
  1062. // from the server..
  1063. case 10:
  1064. sc := scram.NewClient(sha256.New, o["user"], o["password"])
  1065. sc.Step(nil)
  1066. if sc.Err() != nil {
  1067. errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
  1068. }
  1069. scOut := sc.Out()
  1070. w := cn.writeBuf('p')
  1071. w.string("SCRAM-SHA-256")
  1072. w.int32(len(scOut))
  1073. w.bytes(scOut)
  1074. cn.send(w)
  1075. t, r := cn.recv()
  1076. if t != 'R' {
  1077. errorf("unexpected password response: %q", t)
  1078. }
  1079. if r.int32() != 11 {
  1080. errorf("unexpected authentication response: %q", t)
  1081. }
  1082. nextStep := r.next(len(*r))
  1083. sc.Step(nextStep)
  1084. if sc.Err() != nil {
  1085. errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
  1086. }
  1087. scOut = sc.Out()
  1088. w = cn.writeBuf('p')
  1089. w.bytes(scOut)
  1090. cn.send(w)
  1091. t, r = cn.recv()
  1092. if t != 'R' {
  1093. errorf("unexpected password response: %q", t)
  1094. }
  1095. if r.int32() != 12 {
  1096. errorf("unexpected authentication response: %q", t)
  1097. }
  1098. nextStep = r.next(len(*r))
  1099. sc.Step(nextStep)
  1100. if sc.Err() != nil {
  1101. errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
  1102. }
  1103. default:
  1104. errorf("unknown authentication response: %d", code)
  1105. }
  1106. }
  1107. type format int
  1108. const formatText format = 0
  1109. const formatBinary format = 1
  1110. // One result-column format code with the value 1 (i.e. all binary).
  1111. var colFmtDataAllBinary = []byte{0, 1, 0, 1}
  1112. // No result-column format codes (i.e. all text).
  1113. var colFmtDataAllText = []byte{0, 0}
  1114. type stmt struct {
  1115. cn *conn
  1116. name string
  1117. rowsHeader
  1118. colFmtData []byte
  1119. paramTyps []oid.Oid
  1120. closed bool
  1121. }
  1122. func (st *stmt) Close() (err error) {
  1123. if st.closed {
  1124. return nil
  1125. }
  1126. if st.cn.bad {
  1127. return driver.ErrBadConn
  1128. }
  1129. defer st.cn.errRecover(&err)
  1130. w := st.cn.writeBuf('C')
  1131. w.byte('S')
  1132. w.string(st.name)
  1133. st.cn.send(w)
  1134. st.cn.send(st.cn.writeBuf('S'))
  1135. t, _ := st.cn.recv1()
  1136. if t != '3' {
  1137. st.cn.bad = true
  1138. errorf("unexpected close response: %q", t)
  1139. }
  1140. st.closed = true
  1141. t, r := st.cn.recv1()
  1142. if t != 'Z' {
  1143. st.cn.bad = true
  1144. errorf("expected ready for query, but got: %q", t)
  1145. }
  1146. st.cn.processReadyForQuery(r)
  1147. return nil
  1148. }
  1149. func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
  1150. if st.cn.bad {
  1151. return nil, driver.ErrBadConn
  1152. }
  1153. defer st.cn.errRecover(&err)
  1154. st.exec(v)
  1155. return &rows{
  1156. cn: st.cn,
  1157. rowsHeader: st.rowsHeader,
  1158. }, nil
  1159. }
  1160. func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
  1161. if st.cn.bad {
  1162. return nil, driver.ErrBadConn
  1163. }
  1164. defer st.cn.errRecover(&err)
  1165. st.exec(v)
  1166. res, _, err = st.cn.readExecuteResponse("simple query")
  1167. return res, err
  1168. }
  1169. func (st *stmt) exec(v []driver.Value) {
  1170. if len(v) >= 65536 {
  1171. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
  1172. }
  1173. if len(v) != len(st.paramTyps) {
  1174. errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
  1175. }
  1176. cn := st.cn
  1177. w := cn.writeBuf('B')
  1178. w.byte(0) // unnamed portal
  1179. w.string(st.name)
  1180. if cn.binaryParameters {
  1181. cn.sendBinaryParameters(w, v)
  1182. } else {
  1183. w.int16(0)
  1184. w.int16(len(v))
  1185. for i, x := range v {
  1186. if x == nil {
  1187. w.int32(-1)
  1188. } else {
  1189. b := encode(&cn.parameterStatus, x, st.paramTyps[i])
  1190. w.int32(len(b))
  1191. w.bytes(b)
  1192. }
  1193. }
  1194. }
  1195. w.bytes(st.colFmtData)
  1196. w.next('E')
  1197. w.byte(0)
  1198. w.int32(0)
  1199. w.next('S')
  1200. cn.send(w)
  1201. cn.readBindResponse()
  1202. cn.postExecuteWorkaround()
  1203. }
  1204. func (st *stmt) NumInput() int {
  1205. return len(st.paramTyps)
  1206. }
  1207. // parseComplete parses the "command tag" from a CommandComplete message, and
  1208. // returns the number of rows affected (if applicable) and a string
  1209. // identifying only the command that was executed, e.g. "ALTER TABLE". If the
  1210. // command tag could not be parsed, parseComplete panics.
  1211. func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
  1212. commandsWithAffectedRows := []string{
  1213. "SELECT ",
  1214. // INSERT is handled below
  1215. "UPDATE ",
  1216. "DELETE ",
  1217. "FETCH ",
  1218. "MOVE ",
  1219. "COPY ",
  1220. }
  1221. var affectedRows *string
  1222. for _, tag := range commandsWithAffectedRows {
  1223. if strings.HasPrefix(commandTag, tag) {
  1224. t := commandTag[len(tag):]
  1225. affectedRows = &t
  1226. commandTag = tag[:len(tag)-1]
  1227. break
  1228. }
  1229. }
  1230. // INSERT also includes the oid of the inserted row in its command tag.
  1231. // Oids in user tables are deprecated, and the oid is only returned when
  1232. // exactly one row is inserted, so it's unlikely to be of value to any
  1233. // real-world application and we can ignore it.
  1234. if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
  1235. parts := strings.Split(commandTag, " ")
  1236. if len(parts) != 3 {
  1237. cn.bad = true
  1238. errorf("unexpected INSERT command tag %s", commandTag)
  1239. }
  1240. affectedRows = &parts[len(parts)-1]
  1241. commandTag = "INSERT"
  1242. }
  1243. // There should be no affected rows attached to the tag, just return it
  1244. if affectedRows == nil {
  1245. return driver.RowsAffected(0), commandTag
  1246. }
  1247. n, err := strconv.ParseInt(*affectedRows, 10, 64)
  1248. if err != nil {
  1249. cn.bad = true
  1250. errorf("could not parse commandTag: %s", err)
  1251. }
  1252. return driver.RowsAffected(n), commandTag
  1253. }
  1254. type rowsHeader struct {
  1255. colNames []string
  1256. colTyps []fieldDesc
  1257. colFmts []format
  1258. }
  1259. type rows struct {
  1260. cn *conn
  1261. finish func()
  1262. rowsHeader
  1263. done bool
  1264. rb readBuf
  1265. result driver.Result
  1266. tag string
  1267. next *rowsHeader
  1268. }
  1269. func (rs *rows) Close() error {
  1270. if finish := rs.finish; finish != nil {
  1271. defer finish()
  1272. }
  1273. // no need to look at cn.bad as Next() will
  1274. for {
  1275. err := rs.Next(nil)
  1276. switch err {
  1277. case nil:
  1278. case io.EOF:
  1279. // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row
  1280. // description, used with HasNextResultSet). We need to fetch messages until
  1281. // we hit a 'Z', which is done by waiting for done to be set.
  1282. if rs.done {
  1283. return nil
  1284. }
  1285. default:
  1286. return err
  1287. }
  1288. }
  1289. }
  1290. func (rs *rows) Columns() []string {
  1291. return rs.colNames
  1292. }
  1293. func (rs *rows) Result() driver.Result {
  1294. if rs.result == nil {
  1295. return emptyRows
  1296. }
  1297. return rs.result
  1298. }
  1299. func (rs *rows) Tag() string {
  1300. return rs.tag
  1301. }
  1302. func (rs *rows) Next(dest []driver.Value) (err error) {
  1303. if rs.done {
  1304. return io.EOF
  1305. }
  1306. conn := rs.cn
  1307. if conn.bad {
  1308. return driver.ErrBadConn
  1309. }
  1310. defer conn.errRecover(&err)
  1311. for {
  1312. t := conn.recv1Buf(&rs.rb)
  1313. switch t {
  1314. case 'E':
  1315. err = parseError(&rs.rb)
  1316. case 'C', 'I':
  1317. if t == 'C' {
  1318. rs.result, rs.tag = conn.parseComplete(rs.rb.string())
  1319. }
  1320. continue
  1321. case 'Z':
  1322. conn.processReadyForQuery(&rs.rb)
  1323. rs.done = true
  1324. if err != nil {
  1325. return err
  1326. }
  1327. return io.EOF
  1328. case 'D':
  1329. n := rs.rb.int16()
  1330. if err != nil {
  1331. conn.bad = true
  1332. errorf("unexpected DataRow after error %s", err)
  1333. }
  1334. if n < len(dest) {
  1335. dest = dest[:n]
  1336. }
  1337. for i := range dest {
  1338. l := rs.rb.int32()
  1339. if l == -1 {
  1340. dest[i] = nil
  1341. continue
  1342. }
  1343. dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i])
  1344. }
  1345. return
  1346. case 'T':
  1347. next := parsePortalRowDescribe(&rs.rb)
  1348. rs.next = &next
  1349. return io.EOF
  1350. default:
  1351. errorf("unexpected message after execute: %q", t)
  1352. }
  1353. }
  1354. }
  1355. func (rs *rows) HasNextResultSet() bool {
  1356. hasNext := rs.next != nil && !rs.done
  1357. return hasNext
  1358. }
  1359. func (rs *rows) NextResultSet() error {
  1360. if rs.next == nil {
  1361. return io.EOF
  1362. }
  1363. rs.rowsHeader = *rs.next
  1364. rs.next = nil
  1365. return nil
  1366. }
  1367. // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
  1368. // used as part of an SQL statement. For example:
  1369. //
  1370. // tblname := "my_table"
  1371. // data := "my_data"
  1372. // quoted := pq.QuoteIdentifier(tblname)
  1373. // err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
  1374. //
  1375. // Any double quotes in name will be escaped. The quoted identifier will be
  1376. // case sensitive when used in a query. If the input string contains a zero
  1377. // byte, the result will be truncated immediately before it.
  1378. func QuoteIdentifier(name string) string {
  1379. end := strings.IndexRune(name, 0)
  1380. if end > -1 {
  1381. name = name[:end]
  1382. }
  1383. return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
  1384. }
  1385. // QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
  1386. // to DDL and other statements that do not accept parameters) to be used as part
  1387. // of an SQL statement. For example:
  1388. //
  1389. // exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
  1390. // err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
  1391. //
  1392. // Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
  1393. // replaced by two backslashes (i.e. "\\") and the C-style escape identifier
  1394. // that PostgreSQL provides ('E') will be prepended to the string.
  1395. func QuoteLiteral(literal string) string {
  1396. // This follows the PostgreSQL internal algorithm for handling quoted literals
  1397. // from libpq, which can be found in the "PQEscapeStringInternal" function,
  1398. // which is found in the libpq/fe-exec.c source file:
  1399. // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c
  1400. //
  1401. // substitute any single-quotes (') with two single-quotes ('')
  1402. literal = strings.Replace(literal, `'`, `''`, -1)
  1403. // determine if the string has any backslashes (\) in it.
  1404. // if it does, replace any backslashes (\) with two backslashes (\\)
  1405. // then, we need to wrap the entire string with a PostgreSQL
  1406. // C-style escape. Per how "PQEscapeStringInternal" handles this case, we
  1407. // also add a space before the "E"
  1408. if strings.Contains(literal, `\`) {
  1409. literal = strings.Replace(literal, `\`, `\\`, -1)
  1410. literal = ` E'` + literal + `'`
  1411. } else {
  1412. // otherwise, we can just wrap the literal with a pair of single quotes
  1413. literal = `'` + literal + `'`
  1414. }
  1415. return literal
  1416. }
  1417. func md5s(s string) string {
  1418. h := md5.New()
  1419. h.Write([]byte(s))
  1420. return fmt.Sprintf("%x", h.Sum(nil))
  1421. }
  1422. func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
  1423. // Do one pass over the parameters to see if we're going to send any of
  1424. // them over in binary. If we are, create a paramFormats array at the
  1425. // same time.
  1426. var paramFormats []int
  1427. for i, x := range args {
  1428. _, ok := x.([]byte)
  1429. if ok {
  1430. if paramFormats == nil {
  1431. paramFormats = make([]int, len(args))
  1432. }
  1433. paramFormats[i] = 1
  1434. }
  1435. }
  1436. if paramFormats == nil {
  1437. b.int16(0)
  1438. } else {
  1439. b.int16(len(paramFormats))
  1440. for _, x := range paramFormats {
  1441. b.int16(x)
  1442. }
  1443. }
  1444. b.int16(len(args))
  1445. for _, x := range args {
  1446. if x == nil {
  1447. b.int32(-1)
  1448. } else {
  1449. datum := binaryEncode(&cn.parameterStatus, x)
  1450. b.int32(len(datum))
  1451. b.bytes(datum)
  1452. }
  1453. }
  1454. }
  1455. func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
  1456. if len(args) >= 65536 {
  1457. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
  1458. }
  1459. b := cn.writeBuf('P')
  1460. b.byte(0) // unnamed statement
  1461. b.string(query)
  1462. b.int16(0)
  1463. b.next('B')
  1464. b.int16(0) // unnamed portal and statement
  1465. cn.sendBinaryParameters(b, args)
  1466. b.bytes(colFmtDataAllText)
  1467. b.next('D')
  1468. b.byte('P')
  1469. b.byte(0) // unnamed portal
  1470. b.next('E')
  1471. b.byte(0)
  1472. b.int32(0)
  1473. b.next('S')
  1474. cn.send(b)
  1475. }
  1476. func (cn *conn) processParameterStatus(r *readBuf) {
  1477. var err error
  1478. param := r.string()
  1479. switch param {
  1480. case "server_version":
  1481. var major1 int
  1482. var major2 int
  1483. var minor int
  1484. _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor)
  1485. if err == nil {
  1486. cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
  1487. }
  1488. case "TimeZone":
  1489. cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
  1490. if err != nil {
  1491. cn.parameterStatus.currentLocation = nil
  1492. }
  1493. default:
  1494. // ignore
  1495. }
  1496. }
  1497. func (cn *conn) processReadyForQuery(r *readBuf) {
  1498. cn.txnStatus = transactionStatus(r.byte())
  1499. }
  1500. func (cn *conn) readReadyForQuery() {
  1501. t, r := cn.recv1()
  1502. switch t {
  1503. case 'Z':
  1504. cn.processReadyForQuery(r)
  1505. return
  1506. default:
  1507. cn.bad = true
  1508. errorf("unexpected message %q; expected ReadyForQuery", t)
  1509. }
  1510. }
  1511. func (cn *conn) processBackendKeyData(r *readBuf) {
  1512. cn.processID = r.int32()
  1513. cn.secretKey = r.int32()
  1514. }
  1515. func (cn *conn) readParseResponse() {
  1516. t, r := cn.recv1()
  1517. switch t {
  1518. case '1':
  1519. return
  1520. case 'E':
  1521. err := parseError(r)
  1522. cn.readReadyForQuery()
  1523. panic(err)
  1524. default:
  1525. cn.bad = true
  1526. errorf("unexpected Parse response %q", t)
  1527. }
  1528. }
  1529. func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) {
  1530. for {
  1531. t, r := cn.recv1()
  1532. switch t {
  1533. case 't':
  1534. nparams := r.int16()
  1535. paramTyps = make([]oid.Oid, nparams)
  1536. for i := range paramTyps {
  1537. paramTyps[i] = r.oid()
  1538. }
  1539. case 'n':
  1540. return paramTyps, nil, nil
  1541. case 'T':
  1542. colNames, colTyps = parseStatementRowDescribe(r)
  1543. return paramTyps, colNames, colTyps
  1544. case 'E':
  1545. err := parseError(r)
  1546. cn.readReadyForQuery()
  1547. panic(err)
  1548. default:
  1549. cn.bad = true
  1550. errorf("unexpected Describe statement response %q", t)
  1551. }
  1552. }
  1553. }
  1554. func (cn *conn) readPortalDescribeResponse() rowsHeader {
  1555. t, r := cn.recv1()
  1556. switch t {
  1557. case 'T':
  1558. return parsePortalRowDescribe(r)
  1559. case 'n':
  1560. return rowsHeader{}
  1561. case 'E':
  1562. err := parseError(r)
  1563. cn.readReadyForQuery()
  1564. panic(err)
  1565. default:
  1566. cn.bad = true
  1567. errorf("unexpected Describe response %q", t)
  1568. }
  1569. panic("not reached")
  1570. }
  1571. func (cn *conn) readBindResponse() {
  1572. t, r := cn.recv1()
  1573. switch t {
  1574. case '2':
  1575. return
  1576. case 'E':
  1577. err := parseError(r)
  1578. cn.readReadyForQuery()
  1579. panic(err)
  1580. default:
  1581. cn.bad = true
  1582. errorf("unexpected Bind response %q", t)
  1583. }
  1584. }
  1585. func (cn *conn) postExecuteWorkaround() {
  1586. // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
  1587. // any errors from rows.Next, which masks errors that happened during the
  1588. // execution of the query. To avoid the problem in common cases, we wait
  1589. // here for one more message from the database. If it's not an error the
  1590. // query will likely succeed (or perhaps has already, if it's a
  1591. // CommandComplete), so we push the message into the conn struct; recv1
  1592. // will return it as the next message for rows.Next or rows.Close.
  1593. // However, if it's an error, we wait until ReadyForQuery and then return
  1594. // the error to our caller.
  1595. for {
  1596. t, r := cn.recv1()
  1597. switch t {
  1598. case 'E':
  1599. err := parseError(r)
  1600. cn.readReadyForQuery()
  1601. panic(err)
  1602. case 'C', 'D', 'I':
  1603. // the query didn't fail, but we can't process this message
  1604. cn.saveMessage(t, r)
  1605. return
  1606. default:
  1607. cn.bad = true
  1608. errorf("unexpected message during extended query execution: %q", t)
  1609. }
  1610. }
  1611. }
  1612. // Only for Exec(), since we ignore the returned data
  1613. func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
  1614. for {
  1615. t, r := cn.recv1()
  1616. switch t {
  1617. case 'C':
  1618. if err != nil {
  1619. cn.bad = true
  1620. errorf("unexpected CommandComplete after error %s", err)
  1621. }
  1622. res, commandTag = cn.parseComplete(r.string())
  1623. case 'Z':
  1624. cn.processReadyForQuery(r)
  1625. if res == nil && err == nil {
  1626. err = errUnexpectedReady
  1627. }
  1628. return res, commandTag, err
  1629. case 'E':
  1630. err = parseError(r)
  1631. case 'T', 'D', 'I':
  1632. if err != nil {
  1633. cn.bad = true
  1634. errorf("unexpected %q after error %s", t, err)
  1635. }
  1636. if t == 'I' {
  1637. res = emptyRows
  1638. }
  1639. // ignore any results
  1640. default:
  1641. cn.bad = true
  1642. errorf("unknown %s response: %q", protocolState, t)
  1643. }
  1644. }
  1645. }
  1646. func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
  1647. n := r.int16()
  1648. colNames = make([]string, n)
  1649. colTyps = make([]fieldDesc, n)
  1650. for i := range colNames {
  1651. colNames[i] = r.string()
  1652. r.next(6)
  1653. colTyps[i].OID = r.oid()
  1654. colTyps[i].Len = r.int16()
  1655. colTyps[i].Mod = r.int32()
  1656. // format code not known when describing a statement; always 0
  1657. r.next(2)
  1658. }
  1659. return
  1660. }
  1661. func parsePortalRowDescribe(r *readBuf) rowsHeader {
  1662. n := r.int16()
  1663. colNames := make([]string, n)
  1664. colFmts := make([]format, n)
  1665. colTyps := make([]fieldDesc, n)
  1666. for i := range colNames {
  1667. colNames[i] = r.string()
  1668. r.next(6)
  1669. colTyps[i].OID = r.oid()
  1670. colTyps[i].Len = r.int16()
  1671. colTyps[i].Mod = r.int32()
  1672. colFmts[i] = format(r.int16())
  1673. }
  1674. return rowsHeader{
  1675. colNames: colNames,
  1676. colFmts: colFmts,
  1677. colTyps: colTyps,
  1678. }
  1679. }
  1680. // parseEnviron tries to mimic some of libpq's environment handling
  1681. //
  1682. // To ease testing, it does not directly reference os.Environ, but is
  1683. // designed to accept its output.
  1684. //
  1685. // Environment-set connection information is intended to have a higher
  1686. // precedence than a library default but lower than any explicitly
  1687. // passed information (such as in the URL or connection string).
  1688. func parseEnviron(env []string) (out map[string]string) {
  1689. out = make(map[string]string)
  1690. for _, v := range env {
  1691. parts := strings.SplitN(v, "=", 2)
  1692. accrue := func(keyname string) {
  1693. out[keyname] = parts[1]
  1694. }
  1695. unsupported := func() {
  1696. panic(fmt.Sprintf("setting %v not supported", parts[0]))
  1697. }
  1698. // The order of these is the same as is seen in the
  1699. // PostgreSQL 9.1 manual. Unsupported but well-defined
  1700. // keys cause a panic; these should be unset prior to
  1701. // execution. Options which pq expects to be set to a
  1702. // certain value are allowed, but must be set to that
  1703. // value if present (they can, of course, be absent).
  1704. switch parts[0] {
  1705. case "PGHOST":
  1706. accrue("host")
  1707. case "PGHOSTADDR":
  1708. unsupported()
  1709. case "PGPORT":
  1710. accrue("port")
  1711. case "PGDATABASE":
  1712. accrue("dbname")
  1713. case "PGUSER":
  1714. accrue("user")
  1715. case "PGPASSWORD":
  1716. accrue("password")
  1717. case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
  1718. unsupported()
  1719. case "PGOPTIONS":
  1720. accrue("options")
  1721. case "PGAPPNAME":
  1722. accrue("application_name")
  1723. case "PGSSLMODE":
  1724. accrue("sslmode")
  1725. case "PGSSLCERT":
  1726. accrue("sslcert")
  1727. case "PGSSLKEY":
  1728. accrue("sslkey")
  1729. case "PGSSLROOTCERT":
  1730. accrue("sslrootcert")
  1731. case "PGREQUIRESSL", "PGSSLCRL":
  1732. unsupported()
  1733. case "PGREQUIREPEER":
  1734. unsupported()
  1735. case "PGKRBSRVNAME", "PGGSSLIB":
  1736. unsupported()
  1737. case "PGCONNECT_TIMEOUT":
  1738. accrue("connect_timeout")
  1739. case "PGCLIENTENCODING":
  1740. accrue("client_encoding")
  1741. case "PGDATESTYLE":
  1742. accrue("datestyle")
  1743. case "PGTZ":
  1744. accrue("timezone")
  1745. case "PGGEQO":
  1746. accrue("geqo")
  1747. case "PGSYSCONFDIR", "PGLOCALEDIR":
  1748. unsupported()
  1749. }
  1750. }
  1751. return out
  1752. }
  1753. // isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
  1754. func isUTF8(name string) bool {
  1755. // Recognize all sorts of silly things as "UTF-8", like Postgres does
  1756. s := strings.Map(alnumLowerASCII, name)
  1757. return s == "utf8" || s == "unicode"
  1758. }
  1759. func alnumLowerASCII(ch rune) rune {
  1760. if 'A' <= ch && ch <= 'Z' {
  1761. return ch + ('a' - 'A')
  1762. }
  1763. if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
  1764. return ch
  1765. }
  1766. return -1 // discard
  1767. }