You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1824 lines
42 KiB

  1. package pq
  2. import (
  3. "bufio"
  4. "crypto/md5"
  5. "database/sql"
  6. "database/sql/driver"
  7. "encoding/binary"
  8. "errors"
  9. "fmt"
  10. "io"
  11. "net"
  12. "os"
  13. "os/user"
  14. "path"
  15. "path/filepath"
  16. "strconv"
  17. "strings"
  18. "time"
  19. "unicode"
  20. "github.com/lib/pq/oid"
  21. )
  22. // Common error types
  23. var (
  24. ErrNotSupported = errors.New("pq: Unsupported command")
  25. ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
  26. ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
  27. ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less")
  28. ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly")
  29. errUnexpectedReady = errors.New("unexpected ReadyForQuery")
  30. errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
  31. errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
  32. )
  33. type Driver struct{}
  34. func (d *Driver) Open(name string) (driver.Conn, error) {
  35. return Open(name)
  36. }
  37. func init() {
  38. sql.Register("postgres", &Driver{})
  39. }
  40. type parameterStatus struct {
  41. // server version in the same format as server_version_num, or 0 if
  42. // unavailable
  43. serverVersion int
  44. // the current location based on the TimeZone value of the session, if
  45. // available
  46. currentLocation *time.Location
  47. }
  48. type transactionStatus byte
  49. const (
  50. txnStatusIdle transactionStatus = 'I'
  51. txnStatusIdleInTransaction transactionStatus = 'T'
  52. txnStatusInFailedTransaction transactionStatus = 'E'
  53. )
  54. func (s transactionStatus) String() string {
  55. switch s {
  56. case txnStatusIdle:
  57. return "idle"
  58. case txnStatusIdleInTransaction:
  59. return "idle in transaction"
  60. case txnStatusInFailedTransaction:
  61. return "in a failed transaction"
  62. default:
  63. errorf("unknown transactionStatus %d", s)
  64. }
  65. panic("not reached")
  66. }
  67. type Dialer interface {
  68. Dial(network, address string) (net.Conn, error)
  69. DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
  70. }
  71. type defaultDialer struct{}
  72. func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) {
  73. return net.Dial(ntw, addr)
  74. }
  75. func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
  76. return net.DialTimeout(ntw, addr, timeout)
  77. }
  78. type conn struct {
  79. c net.Conn
  80. buf *bufio.Reader
  81. namei int
  82. scratch [512]byte
  83. txnStatus transactionStatus
  84. txnFinish func()
  85. // Save connection arguments to use during CancelRequest.
  86. dialer Dialer
  87. opts values
  88. // Cancellation key data for use with CancelRequest messages.
  89. processID int
  90. secretKey int
  91. parameterStatus parameterStatus
  92. saveMessageType byte
  93. saveMessageBuffer []byte
  94. // If true, this connection is bad and all public-facing functions should
  95. // return ErrBadConn.
  96. bad bool
  97. // If set, this connection should never use the binary format when
  98. // receiving query results from prepared statements. Only provided for
  99. // debugging.
  100. disablePreparedBinaryResult bool
  101. // Whether to always send []byte parameters over as binary. Enables single
  102. // round-trip mode for non-prepared Query calls.
  103. binaryParameters bool
  104. // If true this connection is in the middle of a COPY
  105. inCopy bool
  106. }
  107. // Handle driver-side settings in parsed connection string.
  108. func (cn *conn) handleDriverSettings(o values) (err error) {
  109. boolSetting := func(key string, val *bool) error {
  110. if value, ok := o[key]; ok {
  111. if value == "yes" {
  112. *val = true
  113. } else if value == "no" {
  114. *val = false
  115. } else {
  116. return fmt.Errorf("unrecognized value %q for %s", value, key)
  117. }
  118. }
  119. return nil
  120. }
  121. err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
  122. if err != nil {
  123. return err
  124. }
  125. return boolSetting("binary_parameters", &cn.binaryParameters)
  126. }
  127. func (cn *conn) handlePgpass(o values) {
  128. // if a password was supplied, do not process .pgpass
  129. if _, ok := o["password"]; ok {
  130. return
  131. }
  132. filename := os.Getenv("PGPASSFILE")
  133. if filename == "" {
  134. // XXX this code doesn't work on Windows where the default filename is
  135. // XXX %APPDATA%\postgresql\pgpass.conf
  136. // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470
  137. userHome := os.Getenv("HOME")
  138. if userHome == "" {
  139. user, err := user.Current()
  140. if err != nil {
  141. return
  142. }
  143. userHome = user.HomeDir
  144. }
  145. filename = filepath.Join(userHome, ".pgpass")
  146. }
  147. fileinfo, err := os.Stat(filename)
  148. if err != nil {
  149. return
  150. }
  151. mode := fileinfo.Mode()
  152. if mode&(0x77) != 0 {
  153. // XXX should warn about incorrect .pgpass permissions as psql does
  154. return
  155. }
  156. file, err := os.Open(filename)
  157. if err != nil {
  158. return
  159. }
  160. defer file.Close()
  161. scanner := bufio.NewScanner(io.Reader(file))
  162. hostname := o["host"]
  163. ntw, _ := network(o)
  164. port := o["port"]
  165. db := o["dbname"]
  166. username := o["user"]
  167. // From: https://github.com/tg/pgpass/blob/master/reader.go
  168. getFields := func(s string) []string {
  169. fs := make([]string, 0, 5)
  170. f := make([]rune, 0, len(s))
  171. var esc bool
  172. for _, c := range s {
  173. switch {
  174. case esc:
  175. f = append(f, c)
  176. esc = false
  177. case c == '\\':
  178. esc = true
  179. case c == ':':
  180. fs = append(fs, string(f))
  181. f = f[:0]
  182. default:
  183. f = append(f, c)
  184. }
  185. }
  186. return append(fs, string(f))
  187. }
  188. for scanner.Scan() {
  189. line := scanner.Text()
  190. if len(line) == 0 || line[0] == '#' {
  191. continue
  192. }
  193. split := getFields(line)
  194. if len(split) != 5 {
  195. continue
  196. }
  197. if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
  198. o["password"] = split[4]
  199. return
  200. }
  201. }
  202. }
  203. func (cn *conn) writeBuf(b byte) *writeBuf {
  204. cn.scratch[0] = b
  205. return &writeBuf{
  206. buf: cn.scratch[:5],
  207. pos: 1,
  208. }
  209. }
  210. func Open(name string) (_ driver.Conn, err error) {
  211. return DialOpen(defaultDialer{}, name)
  212. }
  213. func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
  214. // Handle any panics during connection initialization. Note that we
  215. // specifically do *not* want to use errRecover(), as that would turn any
  216. // connection errors into ErrBadConns, hiding the real error message from
  217. // the user.
  218. defer errRecoverNoErrBadConn(&err)
  219. o := make(values)
  220. // A number of defaults are applied here, in this order:
  221. //
  222. // * Very low precedence defaults applied in every situation
  223. // * Environment variables
  224. // * Explicitly passed connection information
  225. o["host"] = "localhost"
  226. o["port"] = "5432"
  227. // N.B.: Extra float digits should be set to 3, but that breaks
  228. // Postgres 8.4 and older, where the max is 2.
  229. o["extra_float_digits"] = "2"
  230. for k, v := range parseEnviron(os.Environ()) {
  231. o[k] = v
  232. }
  233. if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
  234. name, err = ParseURL(name)
  235. if err != nil {
  236. return nil, err
  237. }
  238. }
  239. if err := parseOpts(name, o); err != nil {
  240. return nil, err
  241. }
  242. // Use the "fallback" application name if necessary
  243. if fallback, ok := o["fallback_application_name"]; ok {
  244. if _, ok := o["application_name"]; !ok {
  245. o["application_name"] = fallback
  246. }
  247. }
  248. // We can't work with any client_encoding other than UTF-8 currently.
  249. // However, we have historically allowed the user to set it to UTF-8
  250. // explicitly, and there's no reason to break such programs, so allow that.
  251. // Note that the "options" setting could also set client_encoding, but
  252. // parsing its value is not worth it. Instead, we always explicitly send
  253. // client_encoding as a separate run-time parameter, which should override
  254. // anything set in options.
  255. if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
  256. return nil, errors.New("client_encoding must be absent or 'UTF8'")
  257. }
  258. o["client_encoding"] = "UTF8"
  259. // DateStyle needs a similar treatment.
  260. if datestyle, ok := o["datestyle"]; ok {
  261. if datestyle != "ISO, MDY" {
  262. panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v",
  263. "ISO, MDY", datestyle))
  264. }
  265. } else {
  266. o["datestyle"] = "ISO, MDY"
  267. }
  268. // If a user is not provided by any other means, the last
  269. // resort is to use the current operating system provided user
  270. // name.
  271. if _, ok := o["user"]; !ok {
  272. u, err := userCurrent()
  273. if err != nil {
  274. return nil, err
  275. }
  276. o["user"] = u
  277. }
  278. cn := &conn{
  279. opts: o,
  280. dialer: d,
  281. }
  282. err = cn.handleDriverSettings(o)
  283. if err != nil {
  284. return nil, err
  285. }
  286. cn.handlePgpass(o)
  287. cn.c, err = dial(d, o)
  288. if err != nil {
  289. return nil, err
  290. }
  291. cn.ssl(o)
  292. cn.buf = bufio.NewReader(cn.c)
  293. cn.startup(o)
  294. // reset the deadline, in case one was set (see dial)
  295. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  296. err = cn.c.SetDeadline(time.Time{})
  297. }
  298. return cn, err
  299. }
  300. func dial(d Dialer, o values) (net.Conn, error) {
  301. ntw, addr := network(o)
  302. // SSL is not necessary or supported over UNIX domain sockets
  303. if ntw == "unix" {
  304. o["sslmode"] = "disable"
  305. }
  306. // Zero or not specified means wait indefinitely.
  307. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  308. seconds, err := strconv.ParseInt(timeout, 10, 0)
  309. if err != nil {
  310. return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
  311. }
  312. duration := time.Duration(seconds) * time.Second
  313. // connect_timeout should apply to the entire connection establishment
  314. // procedure, so we both use a timeout for the TCP connection
  315. // establishment and set a deadline for doing the initial handshake.
  316. // The deadline is then reset after startup() is done.
  317. deadline := time.Now().Add(duration)
  318. conn, err := d.DialTimeout(ntw, addr, duration)
  319. if err != nil {
  320. return nil, err
  321. }
  322. err = conn.SetDeadline(deadline)
  323. return conn, err
  324. }
  325. return d.Dial(ntw, addr)
  326. }
  327. func network(o values) (string, string) {
  328. host := o["host"]
  329. if strings.HasPrefix(host, "/") {
  330. sockPath := path.Join(host, ".s.PGSQL."+o["port"])
  331. return "unix", sockPath
  332. }
  333. return "tcp", net.JoinHostPort(host, o["port"])
  334. }
  335. type values map[string]string
  336. // scanner implements a tokenizer for libpq-style option strings.
  337. type scanner struct {
  338. s []rune
  339. i int
  340. }
  341. // newScanner returns a new scanner initialized with the option string s.
  342. func newScanner(s string) *scanner {
  343. return &scanner{[]rune(s), 0}
  344. }
  345. // Next returns the next rune.
  346. // It returns 0, false if the end of the text has been reached.
  347. func (s *scanner) Next() (rune, bool) {
  348. if s.i >= len(s.s) {
  349. return 0, false
  350. }
  351. r := s.s[s.i]
  352. s.i++
  353. return r, true
  354. }
  355. // SkipSpaces returns the next non-whitespace rune.
  356. // It returns 0, false if the end of the text has been reached.
  357. func (s *scanner) SkipSpaces() (rune, bool) {
  358. r, ok := s.Next()
  359. for unicode.IsSpace(r) && ok {
  360. r, ok = s.Next()
  361. }
  362. return r, ok
  363. }
  364. // parseOpts parses the options from name and adds them to the values.
  365. //
  366. // The parsing code is based on conninfo_parse from libpq's fe-connect.c
  367. func parseOpts(name string, o values) error {
  368. s := newScanner(name)
  369. for {
  370. var (
  371. keyRunes, valRunes []rune
  372. r rune
  373. ok bool
  374. )
  375. if r, ok = s.SkipSpaces(); !ok {
  376. break
  377. }
  378. // Scan the key
  379. for !unicode.IsSpace(r) && r != '=' {
  380. keyRunes = append(keyRunes, r)
  381. if r, ok = s.Next(); !ok {
  382. break
  383. }
  384. }
  385. // Skip any whitespace if we're not at the = yet
  386. if r != '=' {
  387. r, ok = s.SkipSpaces()
  388. }
  389. // The current character should be =
  390. if r != '=' || !ok {
  391. return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
  392. }
  393. // Skip any whitespace after the =
  394. if r, ok = s.SkipSpaces(); !ok {
  395. // If we reach the end here, the last value is just an empty string as per libpq.
  396. o[string(keyRunes)] = ""
  397. break
  398. }
  399. if r != '\'' {
  400. for !unicode.IsSpace(r) {
  401. if r == '\\' {
  402. if r, ok = s.Next(); !ok {
  403. return fmt.Errorf(`missing character after backslash`)
  404. }
  405. }
  406. valRunes = append(valRunes, r)
  407. if r, ok = s.Next(); !ok {
  408. break
  409. }
  410. }
  411. } else {
  412. quote:
  413. for {
  414. if r, ok = s.Next(); !ok {
  415. return fmt.Errorf(`unterminated quoted string literal in connection string`)
  416. }
  417. switch r {
  418. case '\'':
  419. break quote
  420. case '\\':
  421. r, _ = s.Next()
  422. fallthrough
  423. default:
  424. valRunes = append(valRunes, r)
  425. }
  426. }
  427. }
  428. o[string(keyRunes)] = string(valRunes)
  429. }
  430. return nil
  431. }
  432. func (cn *conn) isInTransaction() bool {
  433. return cn.txnStatus == txnStatusIdleInTransaction ||
  434. cn.txnStatus == txnStatusInFailedTransaction
  435. }
  436. func (cn *conn) checkIsInTransaction(intxn bool) {
  437. if cn.isInTransaction() != intxn {
  438. cn.bad = true
  439. errorf("unexpected transaction status %v", cn.txnStatus)
  440. }
  441. }
  442. func (cn *conn) Begin() (_ driver.Tx, err error) {
  443. return cn.begin("")
  444. }
  445. func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
  446. if cn.bad {
  447. return nil, driver.ErrBadConn
  448. }
  449. defer cn.errRecover(&err)
  450. cn.checkIsInTransaction(false)
  451. _, commandTag, err := cn.simpleExec("BEGIN" + mode)
  452. if err != nil {
  453. return nil, err
  454. }
  455. if commandTag != "BEGIN" {
  456. cn.bad = true
  457. return nil, fmt.Errorf("unexpected command tag %s", commandTag)
  458. }
  459. if cn.txnStatus != txnStatusIdleInTransaction {
  460. cn.bad = true
  461. return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
  462. }
  463. return cn, nil
  464. }
  465. func (cn *conn) closeTxn() {
  466. if finish := cn.txnFinish; finish != nil {
  467. finish()
  468. }
  469. }
  470. func (cn *conn) Commit() (err error) {
  471. defer cn.closeTxn()
  472. if cn.bad {
  473. return driver.ErrBadConn
  474. }
  475. defer cn.errRecover(&err)
  476. cn.checkIsInTransaction(true)
  477. // We don't want the client to think that everything is okay if it tries
  478. // to commit a failed transaction. However, no matter what we return,
  479. // database/sql will release this connection back into the free connection
  480. // pool so we have to abort the current transaction here. Note that you
  481. // would get the same behaviour if you issued a COMMIT in a failed
  482. // transaction, so it's also the least surprising thing to do here.
  483. if cn.txnStatus == txnStatusInFailedTransaction {
  484. if err := cn.Rollback(); err != nil {
  485. return err
  486. }
  487. return ErrInFailedTransaction
  488. }
  489. _, commandTag, err := cn.simpleExec("COMMIT")
  490. if err != nil {
  491. if cn.isInTransaction() {
  492. cn.bad = true
  493. }
  494. return err
  495. }
  496. if commandTag != "COMMIT" {
  497. cn.bad = true
  498. return fmt.Errorf("unexpected command tag %s", commandTag)
  499. }
  500. cn.checkIsInTransaction(false)
  501. return nil
  502. }
  503. func (cn *conn) Rollback() (err error) {
  504. defer cn.closeTxn()
  505. if cn.bad {
  506. return driver.ErrBadConn
  507. }
  508. defer cn.errRecover(&err)
  509. cn.checkIsInTransaction(true)
  510. _, commandTag, err := cn.simpleExec("ROLLBACK")
  511. if err != nil {
  512. if cn.isInTransaction() {
  513. cn.bad = true
  514. }
  515. return err
  516. }
  517. if commandTag != "ROLLBACK" {
  518. return fmt.Errorf("unexpected command tag %s", commandTag)
  519. }
  520. cn.checkIsInTransaction(false)
  521. return nil
  522. }
  523. func (cn *conn) gname() string {
  524. cn.namei++
  525. return strconv.FormatInt(int64(cn.namei), 10)
  526. }
  527. func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
  528. b := cn.writeBuf('Q')
  529. b.string(q)
  530. cn.send(b)
  531. for {
  532. t, r := cn.recv1()
  533. switch t {
  534. case 'C':
  535. res, commandTag = cn.parseComplete(r.string())
  536. case 'Z':
  537. cn.processReadyForQuery(r)
  538. if res == nil && err == nil {
  539. err = errUnexpectedReady
  540. }
  541. // done
  542. return
  543. case 'E':
  544. err = parseError(r)
  545. case 'I':
  546. res = emptyRows
  547. case 'T', 'D':
  548. // ignore any results
  549. default:
  550. cn.bad = true
  551. errorf("unknown response for simple query: %q", t)
  552. }
  553. }
  554. }
  555. func (cn *conn) simpleQuery(q string) (res *rows, err error) {
  556. defer cn.errRecover(&err)
  557. b := cn.writeBuf('Q')
  558. b.string(q)
  559. cn.send(b)
  560. for {
  561. t, r := cn.recv1()
  562. switch t {
  563. case 'C', 'I':
  564. // We allow queries which don't return any results through Query as
  565. // well as Exec. We still have to give database/sql a rows object
  566. // the user can close, though, to avoid connections from being
  567. // leaked. A "rows" with done=true works fine for that purpose.
  568. if err != nil {
  569. cn.bad = true
  570. errorf("unexpected message %q in simple query execution", t)
  571. }
  572. if res == nil {
  573. res = &rows{
  574. cn: cn,
  575. }
  576. }
  577. // Set the result and tag to the last command complete if there wasn't a
  578. // query already run. Although queries usually return from here and cede
  579. // control to Next, a query with zero results does not.
  580. if t == 'C' && res.colNames == nil {
  581. res.result, res.tag = cn.parseComplete(r.string())
  582. }
  583. res.done = true
  584. case 'Z':
  585. cn.processReadyForQuery(r)
  586. // done
  587. return
  588. case 'E':
  589. res = nil
  590. err = parseError(r)
  591. case 'D':
  592. if res == nil {
  593. cn.bad = true
  594. errorf("unexpected DataRow in simple query execution")
  595. }
  596. // the query didn't fail; kick off to Next
  597. cn.saveMessage(t, r)
  598. return
  599. case 'T':
  600. // res might be non-nil here if we received a previous
  601. // CommandComplete, but that's fine; just overwrite it
  602. res = &rows{cn: cn}
  603. res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
  604. // To work around a bug in QueryRow in Go 1.2 and earlier, wait
  605. // until the first DataRow has been received.
  606. default:
  607. cn.bad = true
  608. errorf("unknown response for simple query: %q", t)
  609. }
  610. }
  611. }
  612. type noRows struct{}
  613. var emptyRows noRows
  614. var _ driver.Result = noRows{}
  615. func (noRows) LastInsertId() (int64, error) {
  616. return 0, errNoLastInsertID
  617. }
  618. func (noRows) RowsAffected() (int64, error) {
  619. return 0, errNoRowsAffected
  620. }
  621. // Decides which column formats to use for a prepared statement. The input is
  622. // an array of type oids, one element per result column.
  623. func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) {
  624. if len(colTyps) == 0 {
  625. return nil, colFmtDataAllText
  626. }
  627. colFmts = make([]format, len(colTyps))
  628. if forceText {
  629. return colFmts, colFmtDataAllText
  630. }
  631. allBinary := true
  632. allText := true
  633. for i, t := range colTyps {
  634. switch t.OID {
  635. // This is the list of types to use binary mode for when receiving them
  636. // through a prepared statement. If a type appears in this list, it
  637. // must also be implemented in binaryDecode in encode.go.
  638. case oid.T_bytea:
  639. fallthrough
  640. case oid.T_int8:
  641. fallthrough
  642. case oid.T_int4:
  643. fallthrough
  644. case oid.T_int2:
  645. fallthrough
  646. case oid.T_uuid:
  647. colFmts[i] = formatBinary
  648. allText = false
  649. default:
  650. allBinary = false
  651. }
  652. }
  653. if allBinary {
  654. return colFmts, colFmtDataAllBinary
  655. } else if allText {
  656. return colFmts, colFmtDataAllText
  657. } else {
  658. colFmtData = make([]byte, 2+len(colFmts)*2)
  659. binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
  660. for i, v := range colFmts {
  661. binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
  662. }
  663. return colFmts, colFmtData
  664. }
  665. }
  666. func (cn *conn) prepareTo(q, stmtName string) *stmt {
  667. st := &stmt{cn: cn, name: stmtName}
  668. b := cn.writeBuf('P')
  669. b.string(st.name)
  670. b.string(q)
  671. b.int16(0)
  672. b.next('D')
  673. b.byte('S')
  674. b.string(st.name)
  675. b.next('S')
  676. cn.send(b)
  677. cn.readParseResponse()
  678. st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
  679. st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
  680. cn.readReadyForQuery()
  681. return st
  682. }
  683. func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
  684. if cn.bad {
  685. return nil, driver.ErrBadConn
  686. }
  687. defer cn.errRecover(&err)
  688. if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
  689. s, err := cn.prepareCopyIn(q)
  690. if err == nil {
  691. cn.inCopy = true
  692. }
  693. return s, err
  694. }
  695. return cn.prepareTo(q, cn.gname()), nil
  696. }
  697. func (cn *conn) Close() (err error) {
  698. // Skip cn.bad return here because we always want to close a connection.
  699. defer cn.errRecover(&err)
  700. // Ensure that cn.c.Close is always run. Since error handling is done with
  701. // panics and cn.errRecover, the Close must be in a defer.
  702. defer func() {
  703. cerr := cn.c.Close()
  704. if err == nil {
  705. err = cerr
  706. }
  707. }()
  708. // Don't go through send(); ListenerConn relies on us not scribbling on the
  709. // scratch buffer of this connection.
  710. return cn.sendSimpleMessage('X')
  711. }
  712. // Implement the "Queryer" interface
  713. func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
  714. return cn.query(query, args)
  715. }
  716. func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
  717. if cn.bad {
  718. return nil, driver.ErrBadConn
  719. }
  720. if cn.inCopy {
  721. return nil, errCopyInProgress
  722. }
  723. defer cn.errRecover(&err)
  724. // Check to see if we can use the "simpleQuery" interface, which is
  725. // *much* faster than going through prepare/exec
  726. if len(args) == 0 {
  727. return cn.simpleQuery(query)
  728. }
  729. if cn.binaryParameters {
  730. cn.sendBinaryModeQuery(query, args)
  731. cn.readParseResponse()
  732. cn.readBindResponse()
  733. rows := &rows{cn: cn}
  734. rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
  735. cn.postExecuteWorkaround()
  736. return rows, nil
  737. }
  738. st := cn.prepareTo(query, "")
  739. st.exec(args)
  740. return &rows{
  741. cn: cn,
  742. colNames: st.colNames,
  743. colTyps: st.colTyps,
  744. colFmts: st.colFmts,
  745. }, nil
  746. }
  747. // Implement the optional "Execer" interface for one-shot queries
  748. func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
  749. if cn.bad {
  750. return nil, driver.ErrBadConn
  751. }
  752. defer cn.errRecover(&err)
  753. // Check to see if we can use the "simpleExec" interface, which is
  754. // *much* faster than going through prepare/exec
  755. if len(args) == 0 {
  756. // ignore commandTag, our caller doesn't care
  757. r, _, err := cn.simpleExec(query)
  758. return r, err
  759. }
  760. if cn.binaryParameters {
  761. cn.sendBinaryModeQuery(query, args)
  762. cn.readParseResponse()
  763. cn.readBindResponse()
  764. cn.readPortalDescribeResponse()
  765. cn.postExecuteWorkaround()
  766. res, _, err = cn.readExecuteResponse("Execute")
  767. return res, err
  768. }
  769. // Use the unnamed statement to defer planning until bind
  770. // time, or else value-based selectivity estimates cannot be
  771. // used.
  772. st := cn.prepareTo(query, "")
  773. r, err := st.Exec(args)
  774. if err != nil {
  775. panic(err)
  776. }
  777. return r, err
  778. }
  779. func (cn *conn) send(m *writeBuf) {
  780. _, err := cn.c.Write(m.wrap())
  781. if err != nil {
  782. panic(err)
  783. }
  784. }
  785. func (cn *conn) sendStartupPacket(m *writeBuf) error {
  786. _, err := cn.c.Write((m.wrap())[1:])
  787. return err
  788. }
  789. // Send a message of type typ to the server on the other end of cn. The
  790. // message should have no payload. This method does not use the scratch
  791. // buffer.
  792. func (cn *conn) sendSimpleMessage(typ byte) (err error) {
  793. _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
  794. return err
  795. }
  796. // saveMessage memorizes a message and its buffer in the conn struct.
  797. // recvMessage will then return these values on the next call to it. This
  798. // method is useful in cases where you have to see what the next message is
  799. // going to be (e.g. to see whether it's an error or not) but you can't handle
  800. // the message yourself.
  801. func (cn *conn) saveMessage(typ byte, buf *readBuf) {
  802. if cn.saveMessageType != 0 {
  803. cn.bad = true
  804. errorf("unexpected saveMessageType %d", cn.saveMessageType)
  805. }
  806. cn.saveMessageType = typ
  807. cn.saveMessageBuffer = *buf
  808. }
  809. // recvMessage receives any message from the backend, or returns an error if
  810. // a problem occurred while reading the message.
  811. func (cn *conn) recvMessage(r *readBuf) (byte, error) {
  812. // workaround for a QueryRow bug, see exec
  813. if cn.saveMessageType != 0 {
  814. t := cn.saveMessageType
  815. *r = cn.saveMessageBuffer
  816. cn.saveMessageType = 0
  817. cn.saveMessageBuffer = nil
  818. return t, nil
  819. }
  820. x := cn.scratch[:5]
  821. _, err := io.ReadFull(cn.buf, x)
  822. if err != nil {
  823. return 0, err
  824. }
  825. // read the type and length of the message that follows
  826. t := x[0]
  827. n := int(binary.BigEndian.Uint32(x[1:])) - 4
  828. var y []byte
  829. if n <= len(cn.scratch) {
  830. y = cn.scratch[:n]
  831. } else {
  832. y = make([]byte, n)
  833. }
  834. _, err = io.ReadFull(cn.buf, y)
  835. if err != nil {
  836. return 0, err
  837. }
  838. *r = y
  839. return t, nil
  840. }
  841. // recv receives a message from the backend, but if an error happened while
  842. // reading the message or the received message was an ErrorResponse, it panics.
  843. // NoticeResponses are ignored. This function should generally be used only
  844. // during the startup sequence.
  845. func (cn *conn) recv() (t byte, r *readBuf) {
  846. for {
  847. var err error
  848. r = &readBuf{}
  849. t, err = cn.recvMessage(r)
  850. if err != nil {
  851. panic(err)
  852. }
  853. switch t {
  854. case 'E':
  855. panic(parseError(r))
  856. case 'N':
  857. // ignore
  858. default:
  859. return
  860. }
  861. }
  862. }
  863. // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
  864. // the caller to avoid an allocation.
  865. func (cn *conn) recv1Buf(r *readBuf) byte {
  866. for {
  867. t, err := cn.recvMessage(r)
  868. if err != nil {
  869. panic(err)
  870. }
  871. switch t {
  872. case 'A', 'N':
  873. // ignore
  874. case 'S':
  875. cn.processParameterStatus(r)
  876. default:
  877. return t
  878. }
  879. }
  880. }
  881. // recv1 receives a message from the backend, panicking if an error occurs
  882. // while attempting to read it. All asynchronous messages are ignored, with
  883. // the exception of ErrorResponse.
  884. func (cn *conn) recv1() (t byte, r *readBuf) {
  885. r = &readBuf{}
  886. t = cn.recv1Buf(r)
  887. return t, r
  888. }
  889. func (cn *conn) ssl(o values) {
  890. upgrade := ssl(o)
  891. if upgrade == nil {
  892. // Nothing to do
  893. return
  894. }
  895. w := cn.writeBuf(0)
  896. w.int32(80877103)
  897. if err := cn.sendStartupPacket(w); err != nil {
  898. panic(err)
  899. }
  900. b := cn.scratch[:1]
  901. _, err := io.ReadFull(cn.c, b)
  902. if err != nil {
  903. panic(err)
  904. }
  905. if b[0] != 'S' {
  906. panic(ErrSSLNotSupported)
  907. }
  908. cn.c = upgrade(cn.c)
  909. }
  910. // isDriverSetting returns true iff a setting is purely for configuring the
  911. // driver's options and should not be sent to the server in the connection
  912. // startup packet.
  913. func isDriverSetting(key string) bool {
  914. switch key {
  915. case "host", "port":
  916. return true
  917. case "password":
  918. return true
  919. case "sslmode", "sslcert", "sslkey", "sslrootcert":
  920. return true
  921. case "fallback_application_name":
  922. return true
  923. case "connect_timeout":
  924. return true
  925. case "disable_prepared_binary_result":
  926. return true
  927. case "binary_parameters":
  928. return true
  929. default:
  930. return false
  931. }
  932. }
  933. func (cn *conn) startup(o values) {
  934. w := cn.writeBuf(0)
  935. w.int32(196608)
  936. // Send the backend the name of the database we want to connect to, and the
  937. // user we want to connect as. Additionally, we send over any run-time
  938. // parameters potentially included in the connection string. If the server
  939. // doesn't recognize any of them, it will reply with an error.
  940. for k, v := range o {
  941. if isDriverSetting(k) {
  942. // skip options which can't be run-time parameters
  943. continue
  944. }
  945. // The protocol requires us to supply the database name as "database"
  946. // instead of "dbname".
  947. if k == "dbname" {
  948. k = "database"
  949. }
  950. w.string(k)
  951. w.string(v)
  952. }
  953. w.string("")
  954. if err := cn.sendStartupPacket(w); err != nil {
  955. panic(err)
  956. }
  957. for {
  958. t, r := cn.recv()
  959. switch t {
  960. case 'K':
  961. cn.processBackendKeyData(r)
  962. case 'S':
  963. cn.processParameterStatus(r)
  964. case 'R':
  965. cn.auth(r, o)
  966. case 'Z':
  967. cn.processReadyForQuery(r)
  968. return
  969. default:
  970. errorf("unknown response for startup: %q", t)
  971. }
  972. }
  973. }
  974. func (cn *conn) auth(r *readBuf, o values) {
  975. switch code := r.int32(); code {
  976. case 0:
  977. // OK
  978. case 3:
  979. w := cn.writeBuf('p')
  980. w.string(o["password"])
  981. cn.send(w)
  982. t, r := cn.recv()
  983. if t != 'R' {
  984. errorf("unexpected password response: %q", t)
  985. }
  986. if r.int32() != 0 {
  987. errorf("unexpected authentication response: %q", t)
  988. }
  989. case 5:
  990. s := string(r.next(4))
  991. w := cn.writeBuf('p')
  992. w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
  993. cn.send(w)
  994. t, r := cn.recv()
  995. if t != 'R' {
  996. errorf("unexpected password response: %q", t)
  997. }
  998. if r.int32() != 0 {
  999. errorf("unexpected authentication response: %q", t)
  1000. }
  1001. default:
  1002. errorf("unknown authentication response: %d", code)
  1003. }
  1004. }
  1005. type format int
  1006. const formatText format = 0
  1007. const formatBinary format = 1
  1008. // One result-column format code with the value 1 (i.e. all binary).
  1009. var colFmtDataAllBinary = []byte{0, 1, 0, 1}
  1010. // No result-column format codes (i.e. all text).
  1011. var colFmtDataAllText = []byte{0, 0}
  1012. type stmt struct {
  1013. cn *conn
  1014. name string
  1015. colNames []string
  1016. colFmts []format
  1017. colFmtData []byte
  1018. colTyps []fieldDesc
  1019. paramTyps []oid.Oid
  1020. closed bool
  1021. }
  1022. func (st *stmt) Close() (err error) {
  1023. if st.closed {
  1024. return nil
  1025. }
  1026. if st.cn.bad {
  1027. return driver.ErrBadConn
  1028. }
  1029. defer st.cn.errRecover(&err)
  1030. w := st.cn.writeBuf('C')
  1031. w.byte('S')
  1032. w.string(st.name)
  1033. st.cn.send(w)
  1034. st.cn.send(st.cn.writeBuf('S'))
  1035. t, _ := st.cn.recv1()
  1036. if t != '3' {
  1037. st.cn.bad = true
  1038. errorf("unexpected close response: %q", t)
  1039. }
  1040. st.closed = true
  1041. t, r := st.cn.recv1()
  1042. if t != 'Z' {
  1043. st.cn.bad = true
  1044. errorf("expected ready for query, but got: %q", t)
  1045. }
  1046. st.cn.processReadyForQuery(r)
  1047. return nil
  1048. }
  1049. func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
  1050. if st.cn.bad {
  1051. return nil, driver.ErrBadConn
  1052. }
  1053. defer st.cn.errRecover(&err)
  1054. st.exec(v)
  1055. return &rows{
  1056. cn: st.cn,
  1057. colNames: st.colNames,
  1058. colTyps: st.colTyps,
  1059. colFmts: st.colFmts,
  1060. }, nil
  1061. }
  1062. func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
  1063. if st.cn.bad {
  1064. return nil, driver.ErrBadConn
  1065. }
  1066. defer st.cn.errRecover(&err)
  1067. st.exec(v)
  1068. res, _, err = st.cn.readExecuteResponse("simple query")
  1069. return res, err
  1070. }
  1071. func (st *stmt) exec(v []driver.Value) {
  1072. if len(v) >= 65536 {
  1073. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
  1074. }
  1075. if len(v) != len(st.paramTyps) {
  1076. errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
  1077. }
  1078. cn := st.cn
  1079. w := cn.writeBuf('B')
  1080. w.byte(0) // unnamed portal
  1081. w.string(st.name)
  1082. if cn.binaryParameters {
  1083. cn.sendBinaryParameters(w, v)
  1084. } else {
  1085. w.int16(0)
  1086. w.int16(len(v))
  1087. for i, x := range v {
  1088. if x == nil {
  1089. w.int32(-1)
  1090. } else {
  1091. b := encode(&cn.parameterStatus, x, st.paramTyps[i])
  1092. w.int32(len(b))
  1093. w.bytes(b)
  1094. }
  1095. }
  1096. }
  1097. w.bytes(st.colFmtData)
  1098. w.next('E')
  1099. w.byte(0)
  1100. w.int32(0)
  1101. w.next('S')
  1102. cn.send(w)
  1103. cn.readBindResponse()
  1104. cn.postExecuteWorkaround()
  1105. }
  1106. func (st *stmt) NumInput() int {
  1107. return len(st.paramTyps)
  1108. }
  1109. // parseComplete parses the "command tag" from a CommandComplete message, and
  1110. // returns the number of rows affected (if applicable) and a string
  1111. // identifying only the command that was executed, e.g. "ALTER TABLE". If the
  1112. // command tag could not be parsed, parseComplete panics.
  1113. func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
  1114. commandsWithAffectedRows := []string{
  1115. "SELECT ",
  1116. // INSERT is handled below
  1117. "UPDATE ",
  1118. "DELETE ",
  1119. "FETCH ",
  1120. "MOVE ",
  1121. "COPY ",
  1122. }
  1123. var affectedRows *string
  1124. for _, tag := range commandsWithAffectedRows {
  1125. if strings.HasPrefix(commandTag, tag) {
  1126. t := commandTag[len(tag):]
  1127. affectedRows = &t
  1128. commandTag = tag[:len(tag)-1]
  1129. break
  1130. }
  1131. }
  1132. // INSERT also includes the oid of the inserted row in its command tag.
  1133. // Oids in user tables are deprecated, and the oid is only returned when
  1134. // exactly one row is inserted, so it's unlikely to be of value to any
  1135. // real-world application and we can ignore it.
  1136. if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
  1137. parts := strings.Split(commandTag, " ")
  1138. if len(parts) != 3 {
  1139. cn.bad = true
  1140. errorf("unexpected INSERT command tag %s", commandTag)
  1141. }
  1142. affectedRows = &parts[len(parts)-1]
  1143. commandTag = "INSERT"
  1144. }
  1145. // There should be no affected rows attached to the tag, just return it
  1146. if affectedRows == nil {
  1147. return driver.RowsAffected(0), commandTag
  1148. }
  1149. n, err := strconv.ParseInt(*affectedRows, 10, 64)
  1150. if err != nil {
  1151. cn.bad = true
  1152. errorf("could not parse commandTag: %s", err)
  1153. }
  1154. return driver.RowsAffected(n), commandTag
  1155. }
  1156. type rows struct {
  1157. cn *conn
  1158. finish func()
  1159. colNames []string
  1160. colTyps []fieldDesc
  1161. colFmts []format
  1162. done bool
  1163. rb readBuf
  1164. result driver.Result
  1165. tag string
  1166. }
  1167. func (rs *rows) Close() error {
  1168. if finish := rs.finish; finish != nil {
  1169. defer finish()
  1170. }
  1171. // no need to look at cn.bad as Next() will
  1172. for {
  1173. err := rs.Next(nil)
  1174. switch err {
  1175. case nil:
  1176. case io.EOF:
  1177. // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row
  1178. // description, used with HasNextResultSet). We need to fetch messages until
  1179. // we hit a 'Z', which is done by waiting for done to be set.
  1180. if rs.done {
  1181. return nil
  1182. }
  1183. default:
  1184. return err
  1185. }
  1186. }
  1187. }
  1188. func (rs *rows) Columns() []string {
  1189. return rs.colNames
  1190. }
  1191. func (rs *rows) Result() driver.Result {
  1192. if rs.result == nil {
  1193. return emptyRows
  1194. }
  1195. return rs.result
  1196. }
  1197. func (rs *rows) Tag() string {
  1198. return rs.tag
  1199. }
  1200. func (rs *rows) Next(dest []driver.Value) (err error) {
  1201. if rs.done {
  1202. return io.EOF
  1203. }
  1204. conn := rs.cn
  1205. if conn.bad {
  1206. return driver.ErrBadConn
  1207. }
  1208. defer conn.errRecover(&err)
  1209. for {
  1210. t := conn.recv1Buf(&rs.rb)
  1211. switch t {
  1212. case 'E':
  1213. err = parseError(&rs.rb)
  1214. case 'C', 'I':
  1215. if t == 'C' {
  1216. rs.result, rs.tag = conn.parseComplete(rs.rb.string())
  1217. }
  1218. continue
  1219. case 'Z':
  1220. conn.processReadyForQuery(&rs.rb)
  1221. rs.done = true
  1222. if err != nil {
  1223. return err
  1224. }
  1225. return io.EOF
  1226. case 'D':
  1227. n := rs.rb.int16()
  1228. if err != nil {
  1229. conn.bad = true
  1230. errorf("unexpected DataRow after error %s", err)
  1231. }
  1232. if n < len(dest) {
  1233. dest = dest[:n]
  1234. }
  1235. for i := range dest {
  1236. l := rs.rb.int32()
  1237. if l == -1 {
  1238. dest[i] = nil
  1239. continue
  1240. }
  1241. dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i])
  1242. }
  1243. return
  1244. case 'T':
  1245. rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb)
  1246. return io.EOF
  1247. default:
  1248. errorf("unexpected message after execute: %q", t)
  1249. }
  1250. }
  1251. }
  1252. func (rs *rows) HasNextResultSet() bool {
  1253. return !rs.done
  1254. }
  1255. func (rs *rows) NextResultSet() error {
  1256. return nil
  1257. }
  1258. // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
  1259. // used as part of an SQL statement. For example:
  1260. //
  1261. // tblname := "my_table"
  1262. // data := "my_data"
  1263. // err = db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", pq.QuoteIdentifier(tblname)), data)
  1264. //
  1265. // Any double quotes in name will be escaped. The quoted identifier will be
  1266. // case sensitive when used in a query. If the input string contains a zero
  1267. // byte, the result will be truncated immediately before it.
  1268. func QuoteIdentifier(name string) string {
  1269. end := strings.IndexRune(name, 0)
  1270. if end > -1 {
  1271. name = name[:end]
  1272. }
  1273. return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
  1274. }
  1275. func md5s(s string) string {
  1276. h := md5.New()
  1277. h.Write([]byte(s))
  1278. return fmt.Sprintf("%x", h.Sum(nil))
  1279. }
  1280. func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
  1281. // Do one pass over the parameters to see if we're going to send any of
  1282. // them over in binary. If we are, create a paramFormats array at the
  1283. // same time.
  1284. var paramFormats []int
  1285. for i, x := range args {
  1286. _, ok := x.([]byte)
  1287. if ok {
  1288. if paramFormats == nil {
  1289. paramFormats = make([]int, len(args))
  1290. }
  1291. paramFormats[i] = 1
  1292. }
  1293. }
  1294. if paramFormats == nil {
  1295. b.int16(0)
  1296. } else {
  1297. b.int16(len(paramFormats))
  1298. for _, x := range paramFormats {
  1299. b.int16(x)
  1300. }
  1301. }
  1302. b.int16(len(args))
  1303. for _, x := range args {
  1304. if x == nil {
  1305. b.int32(-1)
  1306. } else {
  1307. datum := binaryEncode(&cn.parameterStatus, x)
  1308. b.int32(len(datum))
  1309. b.bytes(datum)
  1310. }
  1311. }
  1312. }
  1313. func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
  1314. if len(args) >= 65536 {
  1315. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
  1316. }
  1317. b := cn.writeBuf('P')
  1318. b.byte(0) // unnamed statement
  1319. b.string(query)
  1320. b.int16(0)
  1321. b.next('B')
  1322. b.int16(0) // unnamed portal and statement
  1323. cn.sendBinaryParameters(b, args)
  1324. b.bytes(colFmtDataAllText)
  1325. b.next('D')
  1326. b.byte('P')
  1327. b.byte(0) // unnamed portal
  1328. b.next('E')
  1329. b.byte(0)
  1330. b.int32(0)
  1331. b.next('S')
  1332. cn.send(b)
  1333. }
  1334. func (cn *conn) processParameterStatus(r *readBuf) {
  1335. var err error
  1336. param := r.string()
  1337. switch param {
  1338. case "server_version":
  1339. var major1 int
  1340. var major2 int
  1341. var minor int
  1342. _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor)
  1343. if err == nil {
  1344. cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
  1345. }
  1346. case "TimeZone":
  1347. cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
  1348. if err != nil {
  1349. cn.parameterStatus.currentLocation = nil
  1350. }
  1351. default:
  1352. // ignore
  1353. }
  1354. }
  1355. func (cn *conn) processReadyForQuery(r *readBuf) {
  1356. cn.txnStatus = transactionStatus(r.byte())
  1357. }
  1358. func (cn *conn) readReadyForQuery() {
  1359. t, r := cn.recv1()
  1360. switch t {
  1361. case 'Z':
  1362. cn.processReadyForQuery(r)
  1363. return
  1364. default:
  1365. cn.bad = true
  1366. errorf("unexpected message %q; expected ReadyForQuery", t)
  1367. }
  1368. }
  1369. func (cn *conn) processBackendKeyData(r *readBuf) {
  1370. cn.processID = r.int32()
  1371. cn.secretKey = r.int32()
  1372. }
  1373. func (cn *conn) readParseResponse() {
  1374. t, r := cn.recv1()
  1375. switch t {
  1376. case '1':
  1377. return
  1378. case 'E':
  1379. err := parseError(r)
  1380. cn.readReadyForQuery()
  1381. panic(err)
  1382. default:
  1383. cn.bad = true
  1384. errorf("unexpected Parse response %q", t)
  1385. }
  1386. }
  1387. func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) {
  1388. for {
  1389. t, r := cn.recv1()
  1390. switch t {
  1391. case 't':
  1392. nparams := r.int16()
  1393. paramTyps = make([]oid.Oid, nparams)
  1394. for i := range paramTyps {
  1395. paramTyps[i] = r.oid()
  1396. }
  1397. case 'n':
  1398. return paramTyps, nil, nil
  1399. case 'T':
  1400. colNames, colTyps = parseStatementRowDescribe(r)
  1401. return paramTyps, colNames, colTyps
  1402. case 'E':
  1403. err := parseError(r)
  1404. cn.readReadyForQuery()
  1405. panic(err)
  1406. default:
  1407. cn.bad = true
  1408. errorf("unexpected Describe statement response %q", t)
  1409. }
  1410. }
  1411. }
  1412. func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) {
  1413. t, r := cn.recv1()
  1414. switch t {
  1415. case 'T':
  1416. return parsePortalRowDescribe(r)
  1417. case 'n':
  1418. return nil, nil, nil
  1419. case 'E':
  1420. err := parseError(r)
  1421. cn.readReadyForQuery()
  1422. panic(err)
  1423. default:
  1424. cn.bad = true
  1425. errorf("unexpected Describe response %q", t)
  1426. }
  1427. panic("not reached")
  1428. }
  1429. func (cn *conn) readBindResponse() {
  1430. t, r := cn.recv1()
  1431. switch t {
  1432. case '2':
  1433. return
  1434. case 'E':
  1435. err := parseError(r)
  1436. cn.readReadyForQuery()
  1437. panic(err)
  1438. default:
  1439. cn.bad = true
  1440. errorf("unexpected Bind response %q", t)
  1441. }
  1442. }
  1443. func (cn *conn) postExecuteWorkaround() {
  1444. // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
  1445. // any errors from rows.Next, which masks errors that happened during the
  1446. // execution of the query. To avoid the problem in common cases, we wait
  1447. // here for one more message from the database. If it's not an error the
  1448. // query will likely succeed (or perhaps has already, if it's a
  1449. // CommandComplete), so we push the message into the conn struct; recv1
  1450. // will return it as the next message for rows.Next or rows.Close.
  1451. // However, if it's an error, we wait until ReadyForQuery and then return
  1452. // the error to our caller.
  1453. for {
  1454. t, r := cn.recv1()
  1455. switch t {
  1456. case 'E':
  1457. err := parseError(r)
  1458. cn.readReadyForQuery()
  1459. panic(err)
  1460. case 'C', 'D', 'I':
  1461. // the query didn't fail, but we can't process this message
  1462. cn.saveMessage(t, r)
  1463. return
  1464. default:
  1465. cn.bad = true
  1466. errorf("unexpected message during extended query execution: %q", t)
  1467. }
  1468. }
  1469. }
  1470. // Only for Exec(), since we ignore the returned data
  1471. func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
  1472. for {
  1473. t, r := cn.recv1()
  1474. switch t {
  1475. case 'C':
  1476. if err != nil {
  1477. cn.bad = true
  1478. errorf("unexpected CommandComplete after error %s", err)
  1479. }
  1480. res, commandTag = cn.parseComplete(r.string())
  1481. case 'Z':
  1482. cn.processReadyForQuery(r)
  1483. if res == nil && err == nil {
  1484. err = errUnexpectedReady
  1485. }
  1486. return res, commandTag, err
  1487. case 'E':
  1488. err = parseError(r)
  1489. case 'T', 'D', 'I':
  1490. if err != nil {
  1491. cn.bad = true
  1492. errorf("unexpected %q after error %s", t, err)
  1493. }
  1494. if t == 'I' {
  1495. res = emptyRows
  1496. }
  1497. // ignore any results
  1498. default:
  1499. cn.bad = true
  1500. errorf("unknown %s response: %q", protocolState, t)
  1501. }
  1502. }
  1503. }
  1504. func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
  1505. n := r.int16()
  1506. colNames = make([]string, n)
  1507. colTyps = make([]fieldDesc, n)
  1508. for i := range colNames {
  1509. colNames[i] = r.string()
  1510. r.next(6)
  1511. colTyps[i].OID = r.oid()
  1512. colTyps[i].Len = r.int16()
  1513. colTyps[i].Mod = r.int32()
  1514. // format code not known when describing a statement; always 0
  1515. r.next(2)
  1516. }
  1517. return
  1518. }
  1519. func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) {
  1520. n := r.int16()
  1521. colNames = make([]string, n)
  1522. colFmts = make([]format, n)
  1523. colTyps = make([]fieldDesc, n)
  1524. for i := range colNames {
  1525. colNames[i] = r.string()
  1526. r.next(6)
  1527. colTyps[i].OID = r.oid()
  1528. colTyps[i].Len = r.int16()
  1529. colTyps[i].Mod = r.int32()
  1530. colFmts[i] = format(r.int16())
  1531. }
  1532. return
  1533. }
  1534. // parseEnviron tries to mimic some of libpq's environment handling
  1535. //
  1536. // To ease testing, it does not directly reference os.Environ, but is
  1537. // designed to accept its output.
  1538. //
  1539. // Environment-set connection information is intended to have a higher
  1540. // precedence than a library default but lower than any explicitly
  1541. // passed information (such as in the URL or connection string).
  1542. func parseEnviron(env []string) (out map[string]string) {
  1543. out = make(map[string]string)
  1544. for _, v := range env {
  1545. parts := strings.SplitN(v, "=", 2)
  1546. accrue := func(keyname string) {
  1547. out[keyname] = parts[1]
  1548. }
  1549. unsupported := func() {
  1550. panic(fmt.Sprintf("setting %v not supported", parts[0]))
  1551. }
  1552. // The order of these is the same as is seen in the
  1553. // PostgreSQL 9.1 manual. Unsupported but well-defined
  1554. // keys cause a panic; these should be unset prior to
  1555. // execution. Options which pq expects to be set to a
  1556. // certain value are allowed, but must be set to that
  1557. // value if present (they can, of course, be absent).
  1558. switch parts[0] {
  1559. case "PGHOST":
  1560. accrue("host")
  1561. case "PGHOSTADDR":
  1562. unsupported()
  1563. case "PGPORT":
  1564. accrue("port")
  1565. case "PGDATABASE":
  1566. accrue("dbname")
  1567. case "PGUSER":
  1568. accrue("user")
  1569. case "PGPASSWORD":
  1570. accrue("password")
  1571. case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
  1572. unsupported()
  1573. case "PGOPTIONS":
  1574. accrue("options")
  1575. case "PGAPPNAME":
  1576. accrue("application_name")
  1577. case "PGSSLMODE":
  1578. accrue("sslmode")
  1579. case "PGSSLCERT":
  1580. accrue("sslcert")
  1581. case "PGSSLKEY":
  1582. accrue("sslkey")
  1583. case "PGSSLROOTCERT":
  1584. accrue("sslrootcert")
  1585. case "PGREQUIRESSL", "PGSSLCRL":
  1586. unsupported()
  1587. case "PGREQUIREPEER":
  1588. unsupported()
  1589. case "PGKRBSRVNAME", "PGGSSLIB":
  1590. unsupported()
  1591. case "PGCONNECT_TIMEOUT":
  1592. accrue("connect_timeout")
  1593. case "PGCLIENTENCODING":
  1594. accrue("client_encoding")
  1595. case "PGDATESTYLE":
  1596. accrue("datestyle")
  1597. case "PGTZ":
  1598. accrue("timezone")
  1599. case "PGGEQO":
  1600. accrue("geqo")
  1601. case "PGSYSCONFDIR", "PGLOCALEDIR":
  1602. unsupported()
  1603. }
  1604. }
  1605. return out
  1606. }
  1607. // isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
  1608. func isUTF8(name string) bool {
  1609. // Recognize all sorts of silly things as "UTF-8", like Postgres does
  1610. s := strings.Map(alnumLowerASCII, name)
  1611. return s == "utf8" || s == "unicode"
  1612. }
  1613. func alnumLowerASCII(ch rune) rune {
  1614. if 'A' <= ch && ch <= 'Z' {
  1615. return ch + ('a' - 'A')
  1616. }
  1617. if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
  1618. return ch
  1619. }
  1620. return -1 // discard
  1621. }