You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

128 lines
2.5 KiB

  1. // +build go1.8
  2. package pq
  3. import (
  4. "context"
  5. "database/sql"
  6. "database/sql/driver"
  7. "fmt"
  8. "io"
  9. "io/ioutil"
  10. )
  11. // Implement the "QueryerContext" interface
  12. func (cn *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
  13. list := make([]driver.Value, len(args))
  14. for i, nv := range args {
  15. list[i] = nv.Value
  16. }
  17. finish := cn.watchCancel(ctx)
  18. r, err := cn.query(query, list)
  19. if err != nil {
  20. if finish != nil {
  21. finish()
  22. }
  23. return nil, err
  24. }
  25. r.finish = finish
  26. return r, nil
  27. }
  28. // Implement the "ExecerContext" interface
  29. func (cn *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
  30. list := make([]driver.Value, len(args))
  31. for i, nv := range args {
  32. list[i] = nv.Value
  33. }
  34. if finish := cn.watchCancel(ctx); finish != nil {
  35. defer finish()
  36. }
  37. return cn.Exec(query, list)
  38. }
  39. // Implement the "ConnBeginTx" interface
  40. func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
  41. var mode string
  42. switch sql.IsolationLevel(opts.Isolation) {
  43. case sql.LevelDefault:
  44. // Don't touch mode: use the server's default
  45. case sql.LevelReadUncommitted:
  46. mode = " ISOLATION LEVEL READ UNCOMMITTED"
  47. case sql.LevelReadCommitted:
  48. mode = " ISOLATION LEVEL READ COMMITTED"
  49. case sql.LevelRepeatableRead:
  50. mode = " ISOLATION LEVEL REPEATABLE READ"
  51. case sql.LevelSerializable:
  52. mode = " ISOLATION LEVEL SERIALIZABLE"
  53. default:
  54. return nil, fmt.Errorf("pq: isolation level not supported: %d", opts.Isolation)
  55. }
  56. if opts.ReadOnly {
  57. mode += " READ ONLY"
  58. } else {
  59. mode += " READ WRITE"
  60. }
  61. tx, err := cn.begin(mode)
  62. if err != nil {
  63. return nil, err
  64. }
  65. cn.txnFinish = cn.watchCancel(ctx)
  66. return tx, nil
  67. }
  68. func (cn *conn) watchCancel(ctx context.Context) func() {
  69. if done := ctx.Done(); done != nil {
  70. finished := make(chan struct{})
  71. go func() {
  72. select {
  73. case <-done:
  74. _ = cn.cancel()
  75. finished <- struct{}{}
  76. case <-finished:
  77. }
  78. }()
  79. return func() {
  80. select {
  81. case <-finished:
  82. case finished <- struct{}{}:
  83. }
  84. }
  85. }
  86. return nil
  87. }
  88. func (cn *conn) cancel() error {
  89. c, err := dial(cn.dialer, cn.opts)
  90. if err != nil {
  91. return err
  92. }
  93. defer c.Close()
  94. {
  95. can := conn{
  96. c: c,
  97. }
  98. can.ssl(cn.opts)
  99. w := can.writeBuf(0)
  100. w.int32(80877102) // cancel request code
  101. w.int32(cn.processID)
  102. w.int32(cn.secretKey)
  103. if err := can.sendStartupPacket(w); err != nil {
  104. return err
  105. }
  106. }
  107. // Read until EOF to ensure that the server received the cancel.
  108. {
  109. _, err := io.Copy(ioutil.Discard, c)
  110. return err
  111. }
  112. }