You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1914 lines
45 KiB

  1. package pq
  2. import (
  3. "bufio"
  4. "crypto/md5"
  5. "crypto/tls"
  6. "crypto/x509"
  7. "database/sql"
  8. "database/sql/driver"
  9. "encoding/binary"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "io/ioutil"
  14. "net"
  15. "os"
  16. "os/user"
  17. "path"
  18. "path/filepath"
  19. "strconv"
  20. "strings"
  21. "time"
  22. "unicode"
  23. "github.com/lib/pq/oid"
  24. )
  25. // Common error types
  26. var (
  27. ErrNotSupported = errors.New("pq: Unsupported command")
  28. ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
  29. ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
  30. ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.")
  31. ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.")
  32. errUnexpectedReady = errors.New("unexpected ReadyForQuery")
  33. errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
  34. errNoLastInsertId = errors.New("no LastInsertId available after the empty statement")
  35. )
  36. type drv struct{}
  37. func (d *drv) Open(name string) (driver.Conn, error) {
  38. return Open(name)
  39. }
  40. func init() {
  41. sql.Register("postgres", &drv{})
  42. }
  43. type parameterStatus struct {
  44. // server version in the same format as server_version_num, or 0 if
  45. // unavailable
  46. serverVersion int
  47. // the current location based on the TimeZone value of the session, if
  48. // available
  49. currentLocation *time.Location
  50. }
  51. type transactionStatus byte
  52. const (
  53. txnStatusIdle transactionStatus = 'I'
  54. txnStatusIdleInTransaction transactionStatus = 'T'
  55. txnStatusInFailedTransaction transactionStatus = 'E'
  56. )
  57. func (s transactionStatus) String() string {
  58. switch s {
  59. case txnStatusIdle:
  60. return "idle"
  61. case txnStatusIdleInTransaction:
  62. return "idle in transaction"
  63. case txnStatusInFailedTransaction:
  64. return "in a failed transaction"
  65. default:
  66. errorf("unknown transactionStatus %d", s)
  67. }
  68. panic("not reached")
  69. }
  70. type Dialer interface {
  71. Dial(network, address string) (net.Conn, error)
  72. DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
  73. }
  74. type defaultDialer struct{}
  75. func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) {
  76. return net.Dial(ntw, addr)
  77. }
  78. func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
  79. return net.DialTimeout(ntw, addr, timeout)
  80. }
  81. type conn struct {
  82. c net.Conn
  83. buf *bufio.Reader
  84. namei int
  85. scratch [512]byte
  86. txnStatus transactionStatus
  87. parameterStatus parameterStatus
  88. saveMessageType byte
  89. saveMessageBuffer []byte
  90. // If true, this connection is bad and all public-facing functions should
  91. // return ErrBadConn.
  92. bad bool
  93. // If set, this connection should never use the binary format when
  94. // receiving query results from prepared statements. Only provided for
  95. // debugging.
  96. disablePreparedBinaryResult bool
  97. // Whether to always send []byte parameters over as binary. Enables single
  98. // round-trip mode for non-prepared Query calls.
  99. binaryParameters bool
  100. // If true this connection is in the middle of a COPY
  101. inCopy bool
  102. }
  103. // Handle driver-side settings in parsed connection string.
  104. func (c *conn) handleDriverSettings(o values) (err error) {
  105. boolSetting := func(key string, val *bool) error {
  106. if value := o.Get(key); value != "" {
  107. if value == "yes" {
  108. *val = true
  109. } else if value == "no" {
  110. *val = false
  111. } else {
  112. return fmt.Errorf("unrecognized value %q for %s", value, key)
  113. }
  114. }
  115. return nil
  116. }
  117. err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult)
  118. if err != nil {
  119. return err
  120. }
  121. err = boolSetting("binary_parameters", &c.binaryParameters)
  122. if err != nil {
  123. return err
  124. }
  125. return nil
  126. }
  127. func (c *conn) handlePgpass(o values) {
  128. // if a password was supplied, do not process .pgpass
  129. _, ok := o["password"]
  130. if ok {
  131. return
  132. }
  133. filename := os.Getenv("PGPASSFILE")
  134. if filename == "" {
  135. // XXX this code doesn't work on Windows where the default filename is
  136. // XXX %APPDATA%\postgresql\pgpass.conf
  137. user, err := user.Current()
  138. if err != nil {
  139. return
  140. }
  141. filename = filepath.Join(user.HomeDir, ".pgpass")
  142. }
  143. fileinfo, err := os.Stat(filename)
  144. if err != nil {
  145. return
  146. }
  147. mode := fileinfo.Mode()
  148. if mode&(0x77) != 0 {
  149. // XXX should warn about incorrect .pgpass permissions as psql does
  150. return
  151. }
  152. file, err := os.Open(filename)
  153. if err != nil {
  154. return
  155. }
  156. defer file.Close()
  157. scanner := bufio.NewScanner(io.Reader(file))
  158. hostname := o.Get("host")
  159. ntw, _ := network(o)
  160. port := o.Get("port")
  161. db := o.Get("dbname")
  162. username := o.Get("user")
  163. // From: https://github.com/tg/pgpass/blob/master/reader.go
  164. getFields := func(s string) []string {
  165. fs := make([]string, 0, 5)
  166. f := make([]rune, 0, len(s))
  167. var esc bool
  168. for _, c := range s {
  169. switch {
  170. case esc:
  171. f = append(f, c)
  172. esc = false
  173. case c == '\\':
  174. esc = true
  175. case c == ':':
  176. fs = append(fs, string(f))
  177. f = f[:0]
  178. default:
  179. f = append(f, c)
  180. }
  181. }
  182. return append(fs, string(f))
  183. }
  184. for scanner.Scan() {
  185. line := scanner.Text()
  186. if len(line) == 0 || line[0] == '#' {
  187. continue
  188. }
  189. split := getFields(line)
  190. if len(split) != 5 {
  191. continue
  192. }
  193. if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
  194. o["password"] = split[4]
  195. return
  196. }
  197. }
  198. }
  199. func (c *conn) writeBuf(b byte) *writeBuf {
  200. c.scratch[0] = b
  201. return &writeBuf{
  202. buf: c.scratch[:5],
  203. pos: 1,
  204. }
  205. }
  206. func Open(name string) (_ driver.Conn, err error) {
  207. return DialOpen(defaultDialer{}, name)
  208. }
  209. func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
  210. // Handle any panics during connection initialization. Note that we
  211. // specifically do *not* want to use errRecover(), as that would turn any
  212. // connection errors into ErrBadConns, hiding the real error message from
  213. // the user.
  214. defer errRecoverNoErrBadConn(&err)
  215. o := make(values)
  216. // A number of defaults are applied here, in this order:
  217. //
  218. // * Very low precedence defaults applied in every situation
  219. // * Environment variables
  220. // * Explicitly passed connection information
  221. o.Set("host", "localhost")
  222. o.Set("port", "5432")
  223. // N.B.: Extra float digits should be set to 3, but that breaks
  224. // Postgres 8.4 and older, where the max is 2.
  225. o.Set("extra_float_digits", "2")
  226. for k, v := range parseEnviron(os.Environ()) {
  227. o.Set(k, v)
  228. }
  229. if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
  230. name, err = ParseURL(name)
  231. if err != nil {
  232. return nil, err
  233. }
  234. }
  235. if err := parseOpts(name, o); err != nil {
  236. return nil, err
  237. }
  238. // Use the "fallback" application name if necessary
  239. if fallback := o.Get("fallback_application_name"); fallback != "" {
  240. if !o.Isset("application_name") {
  241. o.Set("application_name", fallback)
  242. }
  243. }
  244. // We can't work with any client_encoding other than UTF-8 currently.
  245. // However, we have historically allowed the user to set it to UTF-8
  246. // explicitly, and there's no reason to break such programs, so allow that.
  247. // Note that the "options" setting could also set client_encoding, but
  248. // parsing its value is not worth it. Instead, we always explicitly send
  249. // client_encoding as a separate run-time parameter, which should override
  250. // anything set in options.
  251. if enc := o.Get("client_encoding"); enc != "" && !isUTF8(enc) {
  252. return nil, errors.New("client_encoding must be absent or 'UTF8'")
  253. }
  254. o.Set("client_encoding", "UTF8")
  255. // DateStyle needs a similar treatment.
  256. if datestyle := o.Get("datestyle"); datestyle != "" {
  257. if datestyle != "ISO, MDY" {
  258. panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v",
  259. "ISO, MDY", datestyle))
  260. }
  261. } else {
  262. o.Set("datestyle", "ISO, MDY")
  263. }
  264. // If a user is not provided by any other means, the last
  265. // resort is to use the current operating system provided user
  266. // name.
  267. if o.Get("user") == "" {
  268. u, err := userCurrent()
  269. if err != nil {
  270. return nil, err
  271. } else {
  272. o.Set("user", u)
  273. }
  274. }
  275. cn := &conn{}
  276. err = cn.handleDriverSettings(o)
  277. if err != nil {
  278. return nil, err
  279. }
  280. cn.handlePgpass(o)
  281. cn.c, err = dial(d, o)
  282. if err != nil {
  283. return nil, err
  284. }
  285. cn.ssl(o)
  286. cn.buf = bufio.NewReader(cn.c)
  287. cn.startup(o)
  288. // reset the deadline, in case one was set (see dial)
  289. if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" {
  290. err = cn.c.SetDeadline(time.Time{})
  291. }
  292. return cn, err
  293. }
  294. func dial(d Dialer, o values) (net.Conn, error) {
  295. ntw, addr := network(o)
  296. // SSL is not necessary or supported over UNIX domain sockets
  297. if ntw == "unix" {
  298. o["sslmode"] = "disable"
  299. }
  300. // Zero or not specified means wait indefinitely.
  301. if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" {
  302. seconds, err := strconv.ParseInt(timeout, 10, 0)
  303. if err != nil {
  304. return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
  305. }
  306. duration := time.Duration(seconds) * time.Second
  307. // connect_timeout should apply to the entire connection establishment
  308. // procedure, so we both use a timeout for the TCP connection
  309. // establishment and set a deadline for doing the initial handshake.
  310. // The deadline is then reset after startup() is done.
  311. deadline := time.Now().Add(duration)
  312. conn, err := d.DialTimeout(ntw, addr, duration)
  313. if err != nil {
  314. return nil, err
  315. }
  316. err = conn.SetDeadline(deadline)
  317. return conn, err
  318. }
  319. return d.Dial(ntw, addr)
  320. }
  321. func network(o values) (string, string) {
  322. host := o.Get("host")
  323. if strings.HasPrefix(host, "/") {
  324. sockPath := path.Join(host, ".s.PGSQL."+o.Get("port"))
  325. return "unix", sockPath
  326. }
  327. return "tcp", net.JoinHostPort(host, o.Get("port"))
  328. }
  329. type values map[string]string
  330. func (vs values) Set(k, v string) {
  331. vs[k] = v
  332. }
  333. func (vs values) Get(k string) (v string) {
  334. return vs[k]
  335. }
  336. func (vs values) Isset(k string) bool {
  337. _, ok := vs[k]
  338. return ok
  339. }
  340. // scanner implements a tokenizer for libpq-style option strings.
  341. type scanner struct {
  342. s []rune
  343. i int
  344. }
  345. // newScanner returns a new scanner initialized with the option string s.
  346. func newScanner(s string) *scanner {
  347. return &scanner{[]rune(s), 0}
  348. }
  349. // Next returns the next rune.
  350. // It returns 0, false if the end of the text has been reached.
  351. func (s *scanner) Next() (rune, bool) {
  352. if s.i >= len(s.s) {
  353. return 0, false
  354. }
  355. r := s.s[s.i]
  356. s.i++
  357. return r, true
  358. }
  359. // SkipSpaces returns the next non-whitespace rune.
  360. // It returns 0, false if the end of the text has been reached.
  361. func (s *scanner) SkipSpaces() (rune, bool) {
  362. r, ok := s.Next()
  363. for unicode.IsSpace(r) && ok {
  364. r, ok = s.Next()
  365. }
  366. return r, ok
  367. }
  368. // parseOpts parses the options from name and adds them to the values.
  369. //
  370. // The parsing code is based on conninfo_parse from libpq's fe-connect.c
  371. func parseOpts(name string, o values) error {
  372. s := newScanner(name)
  373. for {
  374. var (
  375. keyRunes, valRunes []rune
  376. r rune
  377. ok bool
  378. )
  379. if r, ok = s.SkipSpaces(); !ok {
  380. break
  381. }
  382. // Scan the key
  383. for !unicode.IsSpace(r) && r != '=' {
  384. keyRunes = append(keyRunes, r)
  385. if r, ok = s.Next(); !ok {
  386. break
  387. }
  388. }
  389. // Skip any whitespace if we're not at the = yet
  390. if r != '=' {
  391. r, ok = s.SkipSpaces()
  392. }
  393. // The current character should be =
  394. if r != '=' || !ok {
  395. return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
  396. }
  397. // Skip any whitespace after the =
  398. if r, ok = s.SkipSpaces(); !ok {
  399. // If we reach the end here, the last value is just an empty string as per libpq.
  400. o.Set(string(keyRunes), "")
  401. break
  402. }
  403. if r != '\'' {
  404. for !unicode.IsSpace(r) {
  405. if r == '\\' {
  406. if r, ok = s.Next(); !ok {
  407. return fmt.Errorf(`missing character after backslash`)
  408. }
  409. }
  410. valRunes = append(valRunes, r)
  411. if r, ok = s.Next(); !ok {
  412. break
  413. }
  414. }
  415. } else {
  416. quote:
  417. for {
  418. if r, ok = s.Next(); !ok {
  419. return fmt.Errorf(`unterminated quoted string literal in connection string`)
  420. }
  421. switch r {
  422. case '\'':
  423. break quote
  424. case '\\':
  425. r, _ = s.Next()
  426. fallthrough
  427. default:
  428. valRunes = append(valRunes, r)
  429. }
  430. }
  431. }
  432. o.Set(string(keyRunes), string(valRunes))
  433. }
  434. return nil
  435. }
  436. func (cn *conn) isInTransaction() bool {
  437. return cn.txnStatus == txnStatusIdleInTransaction ||
  438. cn.txnStatus == txnStatusInFailedTransaction
  439. }
  440. func (cn *conn) checkIsInTransaction(intxn bool) {
  441. if cn.isInTransaction() != intxn {
  442. cn.bad = true
  443. errorf("unexpected transaction status %v", cn.txnStatus)
  444. }
  445. }
  446. func (cn *conn) Begin() (_ driver.Tx, err error) {
  447. if cn.bad {
  448. return nil, driver.ErrBadConn
  449. }
  450. defer cn.errRecover(&err)
  451. cn.checkIsInTransaction(false)
  452. _, commandTag, err := cn.simpleExec("BEGIN")
  453. if err != nil {
  454. return nil, err
  455. }
  456. if commandTag != "BEGIN" {
  457. cn.bad = true
  458. return nil, fmt.Errorf("unexpected command tag %s", commandTag)
  459. }
  460. if cn.txnStatus != txnStatusIdleInTransaction {
  461. cn.bad = true
  462. return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
  463. }
  464. return cn, nil
  465. }
  466. func (cn *conn) Commit() (err error) {
  467. if cn.bad {
  468. return driver.ErrBadConn
  469. }
  470. defer cn.errRecover(&err)
  471. cn.checkIsInTransaction(true)
  472. // We don't want the client to think that everything is okay if it tries
  473. // to commit a failed transaction. However, no matter what we return,
  474. // database/sql will release this connection back into the free connection
  475. // pool so we have to abort the current transaction here. Note that you
  476. // would get the same behaviour if you issued a COMMIT in a failed
  477. // transaction, so it's also the least surprising thing to do here.
  478. if cn.txnStatus == txnStatusInFailedTransaction {
  479. if err := cn.Rollback(); err != nil {
  480. return err
  481. }
  482. return ErrInFailedTransaction
  483. }
  484. _, commandTag, err := cn.simpleExec("COMMIT")
  485. if err != nil {
  486. if cn.isInTransaction() {
  487. cn.bad = true
  488. }
  489. return err
  490. }
  491. if commandTag != "COMMIT" {
  492. cn.bad = true
  493. return fmt.Errorf("unexpected command tag %s", commandTag)
  494. }
  495. cn.checkIsInTransaction(false)
  496. return nil
  497. }
  498. func (cn *conn) Rollback() (err error) {
  499. if cn.bad {
  500. return driver.ErrBadConn
  501. }
  502. defer cn.errRecover(&err)
  503. cn.checkIsInTransaction(true)
  504. _, commandTag, err := cn.simpleExec("ROLLBACK")
  505. if err != nil {
  506. if cn.isInTransaction() {
  507. cn.bad = true
  508. }
  509. return err
  510. }
  511. if commandTag != "ROLLBACK" {
  512. return fmt.Errorf("unexpected command tag %s", commandTag)
  513. }
  514. cn.checkIsInTransaction(false)
  515. return nil
  516. }
  517. func (cn *conn) gname() string {
  518. cn.namei++
  519. return strconv.FormatInt(int64(cn.namei), 10)
  520. }
  521. func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
  522. b := cn.writeBuf('Q')
  523. b.string(q)
  524. cn.send(b)
  525. for {
  526. t, r := cn.recv1()
  527. switch t {
  528. case 'C':
  529. res, commandTag = cn.parseComplete(r.string())
  530. case 'Z':
  531. cn.processReadyForQuery(r)
  532. if res == nil && err == nil {
  533. err = errUnexpectedReady
  534. }
  535. // done
  536. return
  537. case 'E':
  538. err = parseError(r)
  539. case 'I':
  540. res = emptyRows
  541. case 'T', 'D':
  542. // ignore any results
  543. default:
  544. cn.bad = true
  545. errorf("unknown response for simple query: %q", t)
  546. }
  547. }
  548. }
  549. func (cn *conn) simpleQuery(q string) (res *rows, err error) {
  550. defer cn.errRecover(&err)
  551. b := cn.writeBuf('Q')
  552. b.string(q)
  553. cn.send(b)
  554. for {
  555. t, r := cn.recv1()
  556. switch t {
  557. case 'C', 'I':
  558. // We allow queries which don't return any results through Query as
  559. // well as Exec. We still have to give database/sql a rows object
  560. // the user can close, though, to avoid connections from being
  561. // leaked. A "rows" with done=true works fine for that purpose.
  562. if err != nil {
  563. cn.bad = true
  564. errorf("unexpected message %q in simple query execution", t)
  565. }
  566. if res == nil {
  567. res = &rows{
  568. cn: cn,
  569. }
  570. }
  571. res.done = true
  572. case 'Z':
  573. cn.processReadyForQuery(r)
  574. // done
  575. return
  576. case 'E':
  577. res = nil
  578. err = parseError(r)
  579. case 'D':
  580. if res == nil {
  581. cn.bad = true
  582. errorf("unexpected DataRow in simple query execution")
  583. }
  584. // the query didn't fail; kick off to Next
  585. cn.saveMessage(t, r)
  586. return
  587. case 'T':
  588. // res might be non-nil here if we received a previous
  589. // CommandComplete, but that's fine; just overwrite it
  590. res = &rows{cn: cn}
  591. res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
  592. // To work around a bug in QueryRow in Go 1.2 and earlier, wait
  593. // until the first DataRow has been received.
  594. default:
  595. cn.bad = true
  596. errorf("unknown response for simple query: %q", t)
  597. }
  598. }
  599. }
  600. type noRows struct{}
  601. var emptyRows noRows
  602. var _ driver.Result = noRows{}
  603. func (noRows) LastInsertId() (int64, error) {
  604. return 0, errNoLastInsertId
  605. }
  606. func (noRows) RowsAffected() (int64, error) {
  607. return 0, errNoRowsAffected
  608. }
  609. // Decides which column formats to use for a prepared statement. The input is
  610. // an array of type oids, one element per result column.
  611. func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, colFmtData []byte) {
  612. if len(colTyps) == 0 {
  613. return nil, colFmtDataAllText
  614. }
  615. colFmts = make([]format, len(colTyps))
  616. if forceText {
  617. return colFmts, colFmtDataAllText
  618. }
  619. allBinary := true
  620. allText := true
  621. for i, o := range colTyps {
  622. switch o {
  623. // This is the list of types to use binary mode for when receiving them
  624. // through a prepared statement. If a type appears in this list, it
  625. // must also be implemented in binaryDecode in encode.go.
  626. case oid.T_bytea:
  627. fallthrough
  628. case oid.T_int8:
  629. fallthrough
  630. case oid.T_int4:
  631. fallthrough
  632. case oid.T_int2:
  633. colFmts[i] = formatBinary
  634. allText = false
  635. default:
  636. allBinary = false
  637. }
  638. }
  639. if allBinary {
  640. return colFmts, colFmtDataAllBinary
  641. } else if allText {
  642. return colFmts, colFmtDataAllText
  643. } else {
  644. colFmtData = make([]byte, 2+len(colFmts)*2)
  645. binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
  646. for i, v := range colFmts {
  647. binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
  648. }
  649. return colFmts, colFmtData
  650. }
  651. }
  652. func (cn *conn) prepareTo(q, stmtName string) *stmt {
  653. st := &stmt{cn: cn, name: stmtName}
  654. b := cn.writeBuf('P')
  655. b.string(st.name)
  656. b.string(q)
  657. b.int16(0)
  658. b.next('D')
  659. b.byte('S')
  660. b.string(st.name)
  661. b.next('S')
  662. cn.send(b)
  663. cn.readParseResponse()
  664. st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
  665. st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
  666. cn.readReadyForQuery()
  667. return st
  668. }
  669. func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
  670. if cn.bad {
  671. return nil, driver.ErrBadConn
  672. }
  673. defer cn.errRecover(&err)
  674. if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
  675. s, err := cn.prepareCopyIn(q)
  676. if err == nil {
  677. cn.inCopy = true
  678. }
  679. return s, err
  680. }
  681. return cn.prepareTo(q, cn.gname()), nil
  682. }
  683. func (cn *conn) Close() (err error) {
  684. // Skip cn.bad return here because we always want to close a connection.
  685. defer cn.errRecover(&err)
  686. // Ensure that cn.c.Close is always run. Since error handling is done with
  687. // panics and cn.errRecover, the Close must be in a defer.
  688. defer func() {
  689. cerr := cn.c.Close()
  690. if err == nil {
  691. err = cerr
  692. }
  693. }()
  694. // Don't go through send(); ListenerConn relies on us not scribbling on the
  695. // scratch buffer of this connection.
  696. return cn.sendSimpleMessage('X')
  697. }
  698. // Implement the "Queryer" interface
  699. func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err error) {
  700. if cn.bad {
  701. return nil, driver.ErrBadConn
  702. }
  703. if cn.inCopy {
  704. return nil, errCopyInProgress
  705. }
  706. defer cn.errRecover(&err)
  707. // Check to see if we can use the "simpleQuery" interface, which is
  708. // *much* faster than going through prepare/exec
  709. if len(args) == 0 {
  710. return cn.simpleQuery(query)
  711. }
  712. if cn.binaryParameters {
  713. cn.sendBinaryModeQuery(query, args)
  714. cn.readParseResponse()
  715. cn.readBindResponse()
  716. rows := &rows{cn: cn}
  717. rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
  718. cn.postExecuteWorkaround()
  719. return rows, nil
  720. } else {
  721. st := cn.prepareTo(query, "")
  722. st.exec(args)
  723. return &rows{
  724. cn: cn,
  725. colNames: st.colNames,
  726. colTyps: st.colTyps,
  727. colFmts: st.colFmts,
  728. }, nil
  729. }
  730. }
  731. // Implement the optional "Execer" interface for one-shot queries
  732. func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
  733. if cn.bad {
  734. return nil, driver.ErrBadConn
  735. }
  736. defer cn.errRecover(&err)
  737. // Check to see if we can use the "simpleExec" interface, which is
  738. // *much* faster than going through prepare/exec
  739. if len(args) == 0 {
  740. // ignore commandTag, our caller doesn't care
  741. r, _, err := cn.simpleExec(query)
  742. return r, err
  743. }
  744. if cn.binaryParameters {
  745. cn.sendBinaryModeQuery(query, args)
  746. cn.readParseResponse()
  747. cn.readBindResponse()
  748. cn.readPortalDescribeResponse()
  749. cn.postExecuteWorkaround()
  750. res, _, err = cn.readExecuteResponse("Execute")
  751. return res, err
  752. } else {
  753. // Use the unnamed statement to defer planning until bind
  754. // time, or else value-based selectivity estimates cannot be
  755. // used.
  756. st := cn.prepareTo(query, "")
  757. r, err := st.Exec(args)
  758. if err != nil {
  759. panic(err)
  760. }
  761. return r, err
  762. }
  763. }
  764. func (cn *conn) send(m *writeBuf) {
  765. _, err := cn.c.Write(m.wrap())
  766. if err != nil {
  767. panic(err)
  768. }
  769. }
  770. func (cn *conn) sendStartupPacket(m *writeBuf) {
  771. // sanity check
  772. if m.buf[0] != 0 {
  773. panic("oops")
  774. }
  775. _, err := cn.c.Write((m.wrap())[1:])
  776. if err != nil {
  777. panic(err)
  778. }
  779. }
  780. // Send a message of type typ to the server on the other end of cn. The
  781. // message should have no payload. This method does not use the scratch
  782. // buffer.
  783. func (cn *conn) sendSimpleMessage(typ byte) (err error) {
  784. _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
  785. return err
  786. }
  787. // saveMessage memorizes a message and its buffer in the conn struct.
  788. // recvMessage will then return these values on the next call to it. This
  789. // method is useful in cases where you have to see what the next message is
  790. // going to be (e.g. to see whether it's an error or not) but you can't handle
  791. // the message yourself.
  792. func (cn *conn) saveMessage(typ byte, buf *readBuf) {
  793. if cn.saveMessageType != 0 {
  794. cn.bad = true
  795. errorf("unexpected saveMessageType %d", cn.saveMessageType)
  796. }
  797. cn.saveMessageType = typ
  798. cn.saveMessageBuffer = *buf
  799. }
  800. // recvMessage receives any message from the backend, or returns an error if
  801. // a problem occurred while reading the message.
  802. func (cn *conn) recvMessage(r *readBuf) (byte, error) {
  803. // workaround for a QueryRow bug, see exec
  804. if cn.saveMessageType != 0 {
  805. t := cn.saveMessageType
  806. *r = cn.saveMessageBuffer
  807. cn.saveMessageType = 0
  808. cn.saveMessageBuffer = nil
  809. return t, nil
  810. }
  811. x := cn.scratch[:5]
  812. _, err := io.ReadFull(cn.buf, x)
  813. if err != nil {
  814. return 0, err
  815. }
  816. // read the type and length of the message that follows
  817. t := x[0]
  818. n := int(binary.BigEndian.Uint32(x[1:])) - 4
  819. var y []byte
  820. if n <= len(cn.scratch) {
  821. y = cn.scratch[:n]
  822. } else {
  823. y = make([]byte, n)
  824. }
  825. _, err = io.ReadFull(cn.buf, y)
  826. if err != nil {
  827. return 0, err
  828. }
  829. *r = y
  830. return t, nil
  831. }
  832. // recv receives a message from the backend, but if an error happened while
  833. // reading the message or the received message was an ErrorResponse, it panics.
  834. // NoticeResponses are ignored. This function should generally be used only
  835. // during the startup sequence.
  836. func (cn *conn) recv() (t byte, r *readBuf) {
  837. for {
  838. var err error
  839. r = &readBuf{}
  840. t, err = cn.recvMessage(r)
  841. if err != nil {
  842. panic(err)
  843. }
  844. switch t {
  845. case 'E':
  846. panic(parseError(r))
  847. case 'N':
  848. // ignore
  849. default:
  850. return
  851. }
  852. }
  853. }
  854. // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
  855. // the caller to avoid an allocation.
  856. func (cn *conn) recv1Buf(r *readBuf) byte {
  857. for {
  858. t, err := cn.recvMessage(r)
  859. if err != nil {
  860. panic(err)
  861. }
  862. switch t {
  863. case 'A', 'N':
  864. // ignore
  865. case 'S':
  866. cn.processParameterStatus(r)
  867. default:
  868. return t
  869. }
  870. }
  871. }
  872. // recv1 receives a message from the backend, panicking if an error occurs
  873. // while attempting to read it. All asynchronous messages are ignored, with
  874. // the exception of ErrorResponse.
  875. func (cn *conn) recv1() (t byte, r *readBuf) {
  876. r = &readBuf{}
  877. t = cn.recv1Buf(r)
  878. return t, r
  879. }
  880. func (cn *conn) ssl(o values) {
  881. verifyCaOnly := false
  882. tlsConf := tls.Config{}
  883. switch mode := o.Get("sslmode"); mode {
  884. // "require" is the default.
  885. case "", "require":
  886. // We must skip TLS's own verification since it requires full
  887. // verification since Go 1.3.
  888. tlsConf.InsecureSkipVerify = true
  889. // From http://www.postgresql.org/docs/current/static/libpq-ssl.html:
  890. // Note: For backwards compatibility with earlier versions of PostgreSQL, if a
  891. // root CA file exists, the behavior of sslmode=require will be the same as
  892. // that of verify-ca, meaning the server certificate is validated against the
  893. // CA. Relying on this behavior is discouraged, and applications that need
  894. // certificate validation should always use verify-ca or verify-full.
  895. if _, err := os.Stat(o.Get("sslrootcert")); err == nil {
  896. verifyCaOnly = true
  897. } else {
  898. o.Set("sslrootcert", "")
  899. }
  900. case "verify-ca":
  901. // We must skip TLS's own verification since it requires full
  902. // verification since Go 1.3.
  903. tlsConf.InsecureSkipVerify = true
  904. verifyCaOnly = true
  905. case "verify-full":
  906. tlsConf.ServerName = o.Get("host")
  907. case "disable":
  908. return
  909. default:
  910. errorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode)
  911. }
  912. cn.setupSSLClientCertificates(&tlsConf, o)
  913. cn.setupSSLCA(&tlsConf, o)
  914. w := cn.writeBuf(0)
  915. w.int32(80877103)
  916. cn.sendStartupPacket(w)
  917. b := cn.scratch[:1]
  918. _, err := io.ReadFull(cn.c, b)
  919. if err != nil {
  920. panic(err)
  921. }
  922. if b[0] != 'S' {
  923. panic(ErrSSLNotSupported)
  924. }
  925. client := tls.Client(cn.c, &tlsConf)
  926. if verifyCaOnly {
  927. cn.verifyCA(client, &tlsConf)
  928. }
  929. cn.c = client
  930. }
  931. // verifyCA carries out a TLS handshake to the server and verifies the
  932. // presented certificate against the effective CA, i.e. the one specified in
  933. // sslrootcert or the system CA if sslrootcert was not specified.
  934. func (cn *conn) verifyCA(client *tls.Conn, tlsConf *tls.Config) {
  935. err := client.Handshake()
  936. if err != nil {
  937. panic(err)
  938. }
  939. certs := client.ConnectionState().PeerCertificates
  940. opts := x509.VerifyOptions{
  941. DNSName: client.ConnectionState().ServerName,
  942. Intermediates: x509.NewCertPool(),
  943. Roots: tlsConf.RootCAs,
  944. }
  945. for i, cert := range certs {
  946. if i == 0 {
  947. continue
  948. }
  949. opts.Intermediates.AddCert(cert)
  950. }
  951. _, err = certs[0].Verify(opts)
  952. if err != nil {
  953. panic(err)
  954. }
  955. }
  956. // This function sets up SSL client certificates based on either the "sslkey"
  957. // and "sslcert" settings (possibly set via the environment variables PGSSLKEY
  958. // and PGSSLCERT, respectively), or if they aren't set, from the .postgresql
  959. // directory in the user's home directory. If the file paths are set
  960. // explicitly, the files must exist. The key file must also not be
  961. // world-readable, or this function will panic with
  962. // ErrSSLKeyHasWorldPermissions.
  963. func (cn *conn) setupSSLClientCertificates(tlsConf *tls.Config, o values) {
  964. var missingOk bool
  965. sslkey := o.Get("sslkey")
  966. sslcert := o.Get("sslcert")
  967. if sslkey != "" && sslcert != "" {
  968. // If the user has set an sslkey and sslcert, they *must* exist.
  969. missingOk = false
  970. } else {
  971. // Automatically load certificates from ~/.postgresql.
  972. user, err := user.Current()
  973. if err != nil {
  974. // user.Current() might fail when cross-compiling. We have to
  975. // ignore the error and continue without client certificates, since
  976. // we wouldn't know where to load them from.
  977. return
  978. }
  979. sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
  980. sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
  981. missingOk = true
  982. }
  983. // Check that both files exist, and report the error or stop, depending on
  984. // which behaviour we want. Note that we don't do any more extensive
  985. // checks than this (such as checking that the paths aren't directories);
  986. // LoadX509KeyPair() will take care of the rest.
  987. keyfinfo, err := os.Stat(sslkey)
  988. if err != nil && missingOk {
  989. return
  990. } else if err != nil {
  991. panic(err)
  992. }
  993. _, err = os.Stat(sslcert)
  994. if err != nil && missingOk {
  995. return
  996. } else if err != nil {
  997. panic(err)
  998. }
  999. // If we got this far, the key file must also have the correct permissions
  1000. kmode := keyfinfo.Mode()
  1001. if kmode != kmode&0600 {
  1002. panic(ErrSSLKeyHasWorldPermissions)
  1003. }
  1004. cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
  1005. if err != nil {
  1006. panic(err)
  1007. }
  1008. tlsConf.Certificates = []tls.Certificate{cert}
  1009. }
  1010. // Sets up RootCAs in the TLS configuration if sslrootcert is set.
  1011. func (cn *conn) setupSSLCA(tlsConf *tls.Config, o values) {
  1012. if sslrootcert := o.Get("sslrootcert"); sslrootcert != "" {
  1013. tlsConf.RootCAs = x509.NewCertPool()
  1014. cert, err := ioutil.ReadFile(sslrootcert)
  1015. if err != nil {
  1016. panic(err)
  1017. }
  1018. ok := tlsConf.RootCAs.AppendCertsFromPEM(cert)
  1019. if !ok {
  1020. errorf("couldn't parse pem in sslrootcert")
  1021. }
  1022. }
  1023. }
  1024. // isDriverSetting returns true iff a setting is purely for configuring the
  1025. // driver's options and should not be sent to the server in the connection
  1026. // startup packet.
  1027. func isDriverSetting(key string) bool {
  1028. switch key {
  1029. case "host", "port":
  1030. return true
  1031. case "password":
  1032. return true
  1033. case "sslmode", "sslcert", "sslkey", "sslrootcert":
  1034. return true
  1035. case "fallback_application_name":
  1036. return true
  1037. case "connect_timeout":
  1038. return true
  1039. case "disable_prepared_binary_result":
  1040. return true
  1041. case "binary_parameters":
  1042. return true
  1043. default:
  1044. return false
  1045. }
  1046. }
  1047. func (cn *conn) startup(o values) {
  1048. w := cn.writeBuf(0)
  1049. w.int32(196608)
  1050. // Send the backend the name of the database we want to connect to, and the
  1051. // user we want to connect as. Additionally, we send over any run-time
  1052. // parameters potentially included in the connection string. If the server
  1053. // doesn't recognize any of them, it will reply with an error.
  1054. for k, v := range o {
  1055. if isDriverSetting(k) {
  1056. // skip options which can't be run-time parameters
  1057. continue
  1058. }
  1059. // The protocol requires us to supply the database name as "database"
  1060. // instead of "dbname".
  1061. if k == "dbname" {
  1062. k = "database"
  1063. }
  1064. w.string(k)
  1065. w.string(v)
  1066. }
  1067. w.string("")
  1068. cn.sendStartupPacket(w)
  1069. for {
  1070. t, r := cn.recv()
  1071. switch t {
  1072. case 'K':
  1073. case 'S':
  1074. cn.processParameterStatus(r)
  1075. case 'R':
  1076. cn.auth(r, o)
  1077. case 'Z':
  1078. cn.processReadyForQuery(r)
  1079. return
  1080. default:
  1081. errorf("unknown response for startup: %q", t)
  1082. }
  1083. }
  1084. }
  1085. func (cn *conn) auth(r *readBuf, o values) {
  1086. switch code := r.int32(); code {
  1087. case 0:
  1088. // OK
  1089. case 3:
  1090. w := cn.writeBuf('p')
  1091. w.string(o.Get("password"))
  1092. cn.send(w)
  1093. t, r := cn.recv()
  1094. if t != 'R' {
  1095. errorf("unexpected password response: %q", t)
  1096. }
  1097. if r.int32() != 0 {
  1098. errorf("unexpected authentication response: %q", t)
  1099. }
  1100. case 5:
  1101. s := string(r.next(4))
  1102. w := cn.writeBuf('p')
  1103. w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s))
  1104. cn.send(w)
  1105. t, r := cn.recv()
  1106. if t != 'R' {
  1107. errorf("unexpected password response: %q", t)
  1108. }
  1109. if r.int32() != 0 {
  1110. errorf("unexpected authentication response: %q", t)
  1111. }
  1112. default:
  1113. errorf("unknown authentication response: %d", code)
  1114. }
  1115. }
  1116. type format int
  1117. const formatText format = 0
  1118. const formatBinary format = 1
  1119. // One result-column format code with the value 1 (i.e. all binary).
  1120. var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1}
  1121. // No result-column format codes (i.e. all text).
  1122. var colFmtDataAllText []byte = []byte{0, 0}
  1123. type stmt struct {
  1124. cn *conn
  1125. name string
  1126. colNames []string
  1127. colFmts []format
  1128. colFmtData []byte
  1129. colTyps []oid.Oid
  1130. paramTyps []oid.Oid
  1131. closed bool
  1132. }
  1133. func (st *stmt) Close() (err error) {
  1134. if st.closed {
  1135. return nil
  1136. }
  1137. if st.cn.bad {
  1138. return driver.ErrBadConn
  1139. }
  1140. defer st.cn.errRecover(&err)
  1141. w := st.cn.writeBuf('C')
  1142. w.byte('S')
  1143. w.string(st.name)
  1144. st.cn.send(w)
  1145. st.cn.send(st.cn.writeBuf('S'))
  1146. t, _ := st.cn.recv1()
  1147. if t != '3' {
  1148. st.cn.bad = true
  1149. errorf("unexpected close response: %q", t)
  1150. }
  1151. st.closed = true
  1152. t, r := st.cn.recv1()
  1153. if t != 'Z' {
  1154. st.cn.bad = true
  1155. errorf("expected ready for query, but got: %q", t)
  1156. }
  1157. st.cn.processReadyForQuery(r)
  1158. return nil
  1159. }
  1160. func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
  1161. if st.cn.bad {
  1162. return nil, driver.ErrBadConn
  1163. }
  1164. defer st.cn.errRecover(&err)
  1165. st.exec(v)
  1166. return &rows{
  1167. cn: st.cn,
  1168. colNames: st.colNames,
  1169. colTyps: st.colTyps,
  1170. colFmts: st.colFmts,
  1171. }, nil
  1172. }
  1173. func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
  1174. if st.cn.bad {
  1175. return nil, driver.ErrBadConn
  1176. }
  1177. defer st.cn.errRecover(&err)
  1178. st.exec(v)
  1179. res, _, err = st.cn.readExecuteResponse("simple query")
  1180. return res, err
  1181. }
  1182. func (st *stmt) exec(v []driver.Value) {
  1183. if len(v) >= 65536 {
  1184. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
  1185. }
  1186. if len(v) != len(st.paramTyps) {
  1187. errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
  1188. }
  1189. cn := st.cn
  1190. w := cn.writeBuf('B')
  1191. w.byte(0) // unnamed portal
  1192. w.string(st.name)
  1193. if cn.binaryParameters {
  1194. cn.sendBinaryParameters(w, v)
  1195. } else {
  1196. w.int16(0)
  1197. w.int16(len(v))
  1198. for i, x := range v {
  1199. if x == nil {
  1200. w.int32(-1)
  1201. } else {
  1202. b := encode(&cn.parameterStatus, x, st.paramTyps[i])
  1203. w.int32(len(b))
  1204. w.bytes(b)
  1205. }
  1206. }
  1207. }
  1208. w.bytes(st.colFmtData)
  1209. w.next('E')
  1210. w.byte(0)
  1211. w.int32(0)
  1212. w.next('S')
  1213. cn.send(w)
  1214. cn.readBindResponse()
  1215. cn.postExecuteWorkaround()
  1216. }
  1217. func (st *stmt) NumInput() int {
  1218. return len(st.paramTyps)
  1219. }
  1220. // parseComplete parses the "command tag" from a CommandComplete message, and
  1221. // returns the number of rows affected (if applicable) and a string
  1222. // identifying only the command that was executed, e.g. "ALTER TABLE". If the
  1223. // command tag could not be parsed, parseComplete panics.
  1224. func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
  1225. commandsWithAffectedRows := []string{
  1226. "SELECT ",
  1227. // INSERT is handled below
  1228. "UPDATE ",
  1229. "DELETE ",
  1230. "FETCH ",
  1231. "MOVE ",
  1232. "COPY ",
  1233. }
  1234. var affectedRows *string
  1235. for _, tag := range commandsWithAffectedRows {
  1236. if strings.HasPrefix(commandTag, tag) {
  1237. t := commandTag[len(tag):]
  1238. affectedRows = &t
  1239. commandTag = tag[:len(tag)-1]
  1240. break
  1241. }
  1242. }
  1243. // INSERT also includes the oid of the inserted row in its command tag.
  1244. // Oids in user tables are deprecated, and the oid is only returned when
  1245. // exactly one row is inserted, so it's unlikely to be of value to any
  1246. // real-world application and we can ignore it.
  1247. if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
  1248. parts := strings.Split(commandTag, " ")
  1249. if len(parts) != 3 {
  1250. cn.bad = true
  1251. errorf("unexpected INSERT command tag %s", commandTag)
  1252. }
  1253. affectedRows = &parts[len(parts)-1]
  1254. commandTag = "INSERT"
  1255. }
  1256. // There should be no affected rows attached to the tag, just return it
  1257. if affectedRows == nil {
  1258. return driver.RowsAffected(0), commandTag
  1259. }
  1260. n, err := strconv.ParseInt(*affectedRows, 10, 64)
  1261. if err != nil {
  1262. cn.bad = true
  1263. errorf("could not parse commandTag: %s", err)
  1264. }
  1265. return driver.RowsAffected(n), commandTag
  1266. }
  1267. type rows struct {
  1268. cn *conn
  1269. colNames []string
  1270. colTyps []oid.Oid
  1271. colFmts []format
  1272. done bool
  1273. rb readBuf
  1274. }
  1275. func (rs *rows) Close() error {
  1276. // no need to look at cn.bad as Next() will
  1277. for {
  1278. err := rs.Next(nil)
  1279. switch err {
  1280. case nil:
  1281. case io.EOF:
  1282. return nil
  1283. default:
  1284. return err
  1285. }
  1286. }
  1287. }
  1288. func (rs *rows) Columns() []string {
  1289. return rs.colNames
  1290. }
  1291. func (rs *rows) Next(dest []driver.Value) (err error) {
  1292. if rs.done {
  1293. return io.EOF
  1294. }
  1295. conn := rs.cn
  1296. if conn.bad {
  1297. return driver.ErrBadConn
  1298. }
  1299. defer conn.errRecover(&err)
  1300. for {
  1301. t := conn.recv1Buf(&rs.rb)
  1302. switch t {
  1303. case 'E':
  1304. err = parseError(&rs.rb)
  1305. case 'C', 'I':
  1306. continue
  1307. case 'Z':
  1308. conn.processReadyForQuery(&rs.rb)
  1309. rs.done = true
  1310. if err != nil {
  1311. return err
  1312. }
  1313. return io.EOF
  1314. case 'D':
  1315. n := rs.rb.int16()
  1316. if err != nil {
  1317. conn.bad = true
  1318. errorf("unexpected DataRow after error %s", err)
  1319. }
  1320. if n < len(dest) {
  1321. dest = dest[:n]
  1322. }
  1323. for i := range dest {
  1324. l := rs.rb.int32()
  1325. if l == -1 {
  1326. dest[i] = nil
  1327. continue
  1328. }
  1329. dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i], rs.colFmts[i])
  1330. }
  1331. return
  1332. case 'T':
  1333. rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb)
  1334. return io.EOF
  1335. default:
  1336. errorf("unexpected message after execute: %q", t)
  1337. }
  1338. }
  1339. }
  1340. func (rs *rows) HasNextResultSet() bool {
  1341. return !rs.done
  1342. }
  1343. func (rs *rows) NextResultSet() error {
  1344. return nil
  1345. }
  1346. // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
  1347. // used as part of an SQL statement. For example:
  1348. //
  1349. // tblname := "my_table"
  1350. // data := "my_data"
  1351. // err = db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", pq.QuoteIdentifier(tblname)), data)
  1352. //
  1353. // Any double quotes in name will be escaped. The quoted identifier will be
  1354. // case sensitive when used in a query. If the input string contains a zero
  1355. // byte, the result will be truncated immediately before it.
  1356. func QuoteIdentifier(name string) string {
  1357. end := strings.IndexRune(name, 0)
  1358. if end > -1 {
  1359. name = name[:end]
  1360. }
  1361. return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
  1362. }
  1363. func md5s(s string) string {
  1364. h := md5.New()
  1365. h.Write([]byte(s))
  1366. return fmt.Sprintf("%x", h.Sum(nil))
  1367. }
  1368. func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
  1369. // Do one pass over the parameters to see if we're going to send any of
  1370. // them over in binary. If we are, create a paramFormats array at the
  1371. // same time.
  1372. var paramFormats []int
  1373. for i, x := range args {
  1374. _, ok := x.([]byte)
  1375. if ok {
  1376. if paramFormats == nil {
  1377. paramFormats = make([]int, len(args))
  1378. }
  1379. paramFormats[i] = 1
  1380. }
  1381. }
  1382. if paramFormats == nil {
  1383. b.int16(0)
  1384. } else {
  1385. b.int16(len(paramFormats))
  1386. for _, x := range paramFormats {
  1387. b.int16(x)
  1388. }
  1389. }
  1390. b.int16(len(args))
  1391. for _, x := range args {
  1392. if x == nil {
  1393. b.int32(-1)
  1394. } else {
  1395. datum := binaryEncode(&cn.parameterStatus, x)
  1396. b.int32(len(datum))
  1397. b.bytes(datum)
  1398. }
  1399. }
  1400. }
  1401. func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
  1402. if len(args) >= 65536 {
  1403. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
  1404. }
  1405. b := cn.writeBuf('P')
  1406. b.byte(0) // unnamed statement
  1407. b.string(query)
  1408. b.int16(0)
  1409. b.next('B')
  1410. b.int16(0) // unnamed portal and statement
  1411. cn.sendBinaryParameters(b, args)
  1412. b.bytes(colFmtDataAllText)
  1413. b.next('D')
  1414. b.byte('P')
  1415. b.byte(0) // unnamed portal
  1416. b.next('E')
  1417. b.byte(0)
  1418. b.int32(0)
  1419. b.next('S')
  1420. cn.send(b)
  1421. }
  1422. func (c *conn) processParameterStatus(r *readBuf) {
  1423. var err error
  1424. param := r.string()
  1425. switch param {
  1426. case "server_version":
  1427. var major1 int
  1428. var major2 int
  1429. var minor int
  1430. _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor)
  1431. if err == nil {
  1432. c.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
  1433. }
  1434. case "TimeZone":
  1435. c.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
  1436. if err != nil {
  1437. c.parameterStatus.currentLocation = nil
  1438. }
  1439. default:
  1440. // ignore
  1441. }
  1442. }
  1443. func (c *conn) processReadyForQuery(r *readBuf) {
  1444. c.txnStatus = transactionStatus(r.byte())
  1445. }
  1446. func (cn *conn) readReadyForQuery() {
  1447. t, r := cn.recv1()
  1448. switch t {
  1449. case 'Z':
  1450. cn.processReadyForQuery(r)
  1451. return
  1452. default:
  1453. cn.bad = true
  1454. errorf("unexpected message %q; expected ReadyForQuery", t)
  1455. }
  1456. }
  1457. func (cn *conn) readParseResponse() {
  1458. t, r := cn.recv1()
  1459. switch t {
  1460. case '1':
  1461. return
  1462. case 'E':
  1463. err := parseError(r)
  1464. cn.readReadyForQuery()
  1465. panic(err)
  1466. default:
  1467. cn.bad = true
  1468. errorf("unexpected Parse response %q", t)
  1469. }
  1470. }
  1471. func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []oid.Oid) {
  1472. for {
  1473. t, r := cn.recv1()
  1474. switch t {
  1475. case 't':
  1476. nparams := r.int16()
  1477. paramTyps = make([]oid.Oid, nparams)
  1478. for i := range paramTyps {
  1479. paramTyps[i] = r.oid()
  1480. }
  1481. case 'n':
  1482. return paramTyps, nil, nil
  1483. case 'T':
  1484. colNames, colTyps = parseStatementRowDescribe(r)
  1485. return paramTyps, colNames, colTyps
  1486. case 'E':
  1487. err := parseError(r)
  1488. cn.readReadyForQuery()
  1489. panic(err)
  1490. default:
  1491. cn.bad = true
  1492. errorf("unexpected Describe statement response %q", t)
  1493. }
  1494. }
  1495. }
  1496. func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []oid.Oid) {
  1497. t, r := cn.recv1()
  1498. switch t {
  1499. case 'T':
  1500. return parsePortalRowDescribe(r)
  1501. case 'n':
  1502. return nil, nil, nil
  1503. case 'E':
  1504. err := parseError(r)
  1505. cn.readReadyForQuery()
  1506. panic(err)
  1507. default:
  1508. cn.bad = true
  1509. errorf("unexpected Describe response %q", t)
  1510. }
  1511. panic("not reached")
  1512. }
  1513. func (cn *conn) readBindResponse() {
  1514. t, r := cn.recv1()
  1515. switch t {
  1516. case '2':
  1517. return
  1518. case 'E':
  1519. err := parseError(r)
  1520. cn.readReadyForQuery()
  1521. panic(err)
  1522. default:
  1523. cn.bad = true
  1524. errorf("unexpected Bind response %q", t)
  1525. }
  1526. }
  1527. func (cn *conn) postExecuteWorkaround() {
  1528. // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
  1529. // any errors from rows.Next, which masks errors that happened during the
  1530. // execution of the query. To avoid the problem in common cases, we wait
  1531. // here for one more message from the database. If it's not an error the
  1532. // query will likely succeed (or perhaps has already, if it's a
  1533. // CommandComplete), so we push the message into the conn struct; recv1
  1534. // will return it as the next message for rows.Next or rows.Close.
  1535. // However, if it's an error, we wait until ReadyForQuery and then return
  1536. // the error to our caller.
  1537. for {
  1538. t, r := cn.recv1()
  1539. switch t {
  1540. case 'E':
  1541. err := parseError(r)
  1542. cn.readReadyForQuery()
  1543. panic(err)
  1544. case 'C', 'D', 'I':
  1545. // the query didn't fail, but we can't process this message
  1546. cn.saveMessage(t, r)
  1547. return
  1548. default:
  1549. cn.bad = true
  1550. errorf("unexpected message during extended query execution: %q", t)
  1551. }
  1552. }
  1553. }
  1554. // Only for Exec(), since we ignore the returned data
  1555. func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
  1556. for {
  1557. t, r := cn.recv1()
  1558. switch t {
  1559. case 'C':
  1560. if err != nil {
  1561. cn.bad = true
  1562. errorf("unexpected CommandComplete after error %s", err)
  1563. }
  1564. res, commandTag = cn.parseComplete(r.string())
  1565. case 'Z':
  1566. cn.processReadyForQuery(r)
  1567. if res == nil && err == nil {
  1568. err = errUnexpectedReady
  1569. }
  1570. return res, commandTag, err
  1571. case 'E':
  1572. err = parseError(r)
  1573. case 'T', 'D', 'I':
  1574. if err != nil {
  1575. cn.bad = true
  1576. errorf("unexpected %q after error %s", t, err)
  1577. }
  1578. if t == 'I' {
  1579. res = emptyRows
  1580. }
  1581. // ignore any results
  1582. default:
  1583. cn.bad = true
  1584. errorf("unknown %s response: %q", protocolState, t)
  1585. }
  1586. }
  1587. }
  1588. func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []oid.Oid) {
  1589. n := r.int16()
  1590. colNames = make([]string, n)
  1591. colTyps = make([]oid.Oid, n)
  1592. for i := range colNames {
  1593. colNames[i] = r.string()
  1594. r.next(6)
  1595. colTyps[i] = r.oid()
  1596. r.next(6)
  1597. // format code not known when describing a statement; always 0
  1598. r.next(2)
  1599. }
  1600. return
  1601. }
  1602. func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []oid.Oid) {
  1603. n := r.int16()
  1604. colNames = make([]string, n)
  1605. colFmts = make([]format, n)
  1606. colTyps = make([]oid.Oid, n)
  1607. for i := range colNames {
  1608. colNames[i] = r.string()
  1609. r.next(6)
  1610. colTyps[i] = r.oid()
  1611. r.next(6)
  1612. colFmts[i] = format(r.int16())
  1613. }
  1614. return
  1615. }
  1616. // parseEnviron tries to mimic some of libpq's environment handling
  1617. //
  1618. // To ease testing, it does not directly reference os.Environ, but is
  1619. // designed to accept its output.
  1620. //
  1621. // Environment-set connection information is intended to have a higher
  1622. // precedence than a library default but lower than any explicitly
  1623. // passed information (such as in the URL or connection string).
  1624. func parseEnviron(env []string) (out map[string]string) {
  1625. out = make(map[string]string)
  1626. for _, v := range env {
  1627. parts := strings.SplitN(v, "=", 2)
  1628. accrue := func(keyname string) {
  1629. out[keyname] = parts[1]
  1630. }
  1631. unsupported := func() {
  1632. panic(fmt.Sprintf("setting %v not supported", parts[0]))
  1633. }
  1634. // The order of these is the same as is seen in the
  1635. // PostgreSQL 9.1 manual. Unsupported but well-defined
  1636. // keys cause a panic; these should be unset prior to
  1637. // execution. Options which pq expects to be set to a
  1638. // certain value are allowed, but must be set to that
  1639. // value if present (they can, of course, be absent).
  1640. switch parts[0] {
  1641. case "PGHOST":
  1642. accrue("host")
  1643. case "PGHOSTADDR":
  1644. unsupported()
  1645. case "PGPORT":
  1646. accrue("port")
  1647. case "PGDATABASE":
  1648. accrue("dbname")
  1649. case "PGUSER":
  1650. accrue("user")
  1651. case "PGPASSWORD":
  1652. accrue("password")
  1653. case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
  1654. unsupported()
  1655. case "PGOPTIONS":
  1656. accrue("options")
  1657. case "PGAPPNAME":
  1658. accrue("application_name")
  1659. case "PGSSLMODE":
  1660. accrue("sslmode")
  1661. case "PGSSLCERT":
  1662. accrue("sslcert")
  1663. case "PGSSLKEY":
  1664. accrue("sslkey")
  1665. case "PGSSLROOTCERT":
  1666. accrue("sslrootcert")
  1667. case "PGREQUIRESSL", "PGSSLCRL":
  1668. unsupported()
  1669. case "PGREQUIREPEER":
  1670. unsupported()
  1671. case "PGKRBSRVNAME", "PGGSSLIB":
  1672. unsupported()
  1673. case "PGCONNECT_TIMEOUT":
  1674. accrue("connect_timeout")
  1675. case "PGCLIENTENCODING":
  1676. accrue("client_encoding")
  1677. case "PGDATESTYLE":
  1678. accrue("datestyle")
  1679. case "PGTZ":
  1680. accrue("timezone")
  1681. case "PGGEQO":
  1682. accrue("geqo")
  1683. case "PGSYSCONFDIR", "PGLOCALEDIR":
  1684. unsupported()
  1685. }
  1686. }
  1687. return out
  1688. }
  1689. // isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
  1690. func isUTF8(name string) bool {
  1691. // Recognize all sorts of silly things as "UTF-8", like Postgres does
  1692. s := strings.Map(alnumLowerASCII, name)
  1693. return s == "utf8" || s == "unicode"
  1694. }
  1695. func alnumLowerASCII(ch rune) rune {
  1696. if 'A' <= ch && ch <= 'Z' {
  1697. return ch + ('a' - 'A')
  1698. }
  1699. if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
  1700. return ch
  1701. }
  1702. return -1 // discard
  1703. }