You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

282 lines
5.7 KiB

  1. package pq
  2. import (
  3. "database/sql/driver"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "sync"
  8. )
  9. var (
  10. errCopyInClosed = errors.New("pq: copyin statement has already been closed")
  11. errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY")
  12. errCopyToNotSupported = errors.New("pq: COPY TO is not supported")
  13. errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
  14. errCopyInProgress = errors.New("pq: COPY in progress")
  15. )
  16. // CopyIn creates a COPY FROM statement which can be prepared with
  17. // Tx.Prepare(). The target table should be visible in search_path.
  18. func CopyIn(table string, columns ...string) string {
  19. stmt := "COPY " + QuoteIdentifier(table) + " ("
  20. for i, col := range columns {
  21. if i != 0 {
  22. stmt += ", "
  23. }
  24. stmt += QuoteIdentifier(col)
  25. }
  26. stmt += ") FROM STDIN"
  27. return stmt
  28. }
  29. // CopyInSchema creates a COPY FROM statement which can be prepared with
  30. // Tx.Prepare().
  31. func CopyInSchema(schema, table string, columns ...string) string {
  32. stmt := "COPY " + QuoteIdentifier(schema) + "." + QuoteIdentifier(table) + " ("
  33. for i, col := range columns {
  34. if i != 0 {
  35. stmt += ", "
  36. }
  37. stmt += QuoteIdentifier(col)
  38. }
  39. stmt += ") FROM STDIN"
  40. return stmt
  41. }
  42. type copyin struct {
  43. cn *conn
  44. buffer []byte
  45. rowData chan []byte
  46. done chan bool
  47. closed bool
  48. sync.Mutex // guards err
  49. err error
  50. }
  51. const ciBufferSize = 64 * 1024
  52. // flush buffer before the buffer is filled up and needs reallocation
  53. const ciBufferFlushSize = 63 * 1024
  54. func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
  55. if !cn.isInTransaction() {
  56. return nil, errCopyNotSupportedOutsideTxn
  57. }
  58. ci := &copyin{
  59. cn: cn,
  60. buffer: make([]byte, 0, ciBufferSize),
  61. rowData: make(chan []byte),
  62. done: make(chan bool, 1),
  63. }
  64. // add CopyData identifier + 4 bytes for message length
  65. ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
  66. b := cn.writeBuf('Q')
  67. b.string(q)
  68. cn.send(b)
  69. awaitCopyInResponse:
  70. for {
  71. t, r := cn.recv1()
  72. switch t {
  73. case 'G':
  74. if r.byte() != 0 {
  75. err = errBinaryCopyNotSupported
  76. break awaitCopyInResponse
  77. }
  78. go ci.resploop()
  79. return ci, nil
  80. case 'H':
  81. err = errCopyToNotSupported
  82. break awaitCopyInResponse
  83. case 'E':
  84. err = parseError(r)
  85. case 'Z':
  86. if err == nil {
  87. ci.setBad()
  88. errorf("unexpected ReadyForQuery in response to COPY")
  89. }
  90. cn.processReadyForQuery(r)
  91. return nil, err
  92. default:
  93. ci.setBad()
  94. errorf("unknown response for copy query: %q", t)
  95. }
  96. }
  97. // something went wrong, abort COPY before we return
  98. b = cn.writeBuf('f')
  99. b.string(err.Error())
  100. cn.send(b)
  101. for {
  102. t, r := cn.recv1()
  103. switch t {
  104. case 'c', 'C', 'E':
  105. case 'Z':
  106. // correctly aborted, we're done
  107. cn.processReadyForQuery(r)
  108. return nil, err
  109. default:
  110. ci.setBad()
  111. errorf("unknown response for CopyFail: %q", t)
  112. }
  113. }
  114. }
  115. func (ci *copyin) flush(buf []byte) {
  116. // set message length (without message identifier)
  117. binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
  118. _, err := ci.cn.c.Write(buf)
  119. if err != nil {
  120. panic(err)
  121. }
  122. }
  123. func (ci *copyin) resploop() {
  124. for {
  125. var r readBuf
  126. t, err := ci.cn.recvMessage(&r)
  127. if err != nil {
  128. ci.setBad()
  129. ci.setError(err)
  130. ci.done <- true
  131. return
  132. }
  133. switch t {
  134. case 'C':
  135. // complete
  136. case 'N':
  137. // NoticeResponse
  138. case 'Z':
  139. ci.cn.processReadyForQuery(&r)
  140. ci.done <- true
  141. return
  142. case 'E':
  143. err := parseError(&r)
  144. ci.setError(err)
  145. default:
  146. ci.setBad()
  147. ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
  148. ci.done <- true
  149. return
  150. }
  151. }
  152. }
  153. func (ci *copyin) setBad() {
  154. ci.Lock()
  155. ci.cn.bad = true
  156. ci.Unlock()
  157. }
  158. func (ci *copyin) isBad() bool {
  159. ci.Lock()
  160. b := ci.cn.bad
  161. ci.Unlock()
  162. return b
  163. }
  164. func (ci *copyin) isErrorSet() bool {
  165. ci.Lock()
  166. isSet := (ci.err != nil)
  167. ci.Unlock()
  168. return isSet
  169. }
  170. // setError() sets ci.err if one has not been set already. Caller must not be
  171. // holding ci.Mutex.
  172. func (ci *copyin) setError(err error) {
  173. ci.Lock()
  174. if ci.err == nil {
  175. ci.err = err
  176. }
  177. ci.Unlock()
  178. }
  179. func (ci *copyin) NumInput() int {
  180. return -1
  181. }
  182. func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
  183. return nil, ErrNotSupported
  184. }
  185. // Exec inserts values into the COPY stream. The insert is asynchronous
  186. // and Exec can return errors from previous Exec calls to the same
  187. // COPY stmt.
  188. //
  189. // You need to call Exec(nil) to sync the COPY stream and to get any
  190. // errors from pending data, since Stmt.Close() doesn't return errors
  191. // to the user.
  192. func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
  193. if ci.closed {
  194. return nil, errCopyInClosed
  195. }
  196. if ci.isBad() {
  197. return nil, driver.ErrBadConn
  198. }
  199. defer ci.cn.errRecover(&err)
  200. if ci.isErrorSet() {
  201. return nil, ci.err
  202. }
  203. if len(v) == 0 {
  204. return nil, ci.Close()
  205. }
  206. numValues := len(v)
  207. for i, value := range v {
  208. ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
  209. if i < numValues-1 {
  210. ci.buffer = append(ci.buffer, '\t')
  211. }
  212. }
  213. ci.buffer = append(ci.buffer, '\n')
  214. if len(ci.buffer) > ciBufferFlushSize {
  215. ci.flush(ci.buffer)
  216. // reset buffer, keep bytes for message identifier and length
  217. ci.buffer = ci.buffer[:5]
  218. }
  219. return driver.RowsAffected(0), nil
  220. }
  221. func (ci *copyin) Close() (err error) {
  222. if ci.closed { // Don't do anything, we're already closed
  223. return nil
  224. }
  225. ci.closed = true
  226. if ci.isBad() {
  227. return driver.ErrBadConn
  228. }
  229. defer ci.cn.errRecover(&err)
  230. if len(ci.buffer) > 0 {
  231. ci.flush(ci.buffer)
  232. }
  233. // Avoid touching the scratch buffer as resploop could be using it.
  234. err = ci.cn.sendSimpleMessage('c')
  235. if err != nil {
  236. return err
  237. }
  238. <-ci.done
  239. ci.cn.inCopy = false
  240. if ci.isErrorSet() {
  241. err = ci.err
  242. return err
  243. }
  244. return nil
  245. }