You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

307 lines
6.1 KiB

  1. package pq
  2. import (
  3. "database/sql/driver"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "sync"
  8. )
  9. var (
  10. errCopyInClosed = errors.New("pq: copyin statement has already been closed")
  11. errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY")
  12. errCopyToNotSupported = errors.New("pq: COPY TO is not supported")
  13. errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
  14. errCopyInProgress = errors.New("pq: COPY in progress")
  15. )
  16. // CopyIn creates a COPY FROM statement which can be prepared with
  17. // Tx.Prepare(). The target table should be visible in search_path.
  18. func CopyIn(table string, columns ...string) string {
  19. stmt := "COPY " + QuoteIdentifier(table) + " ("
  20. for i, col := range columns {
  21. if i != 0 {
  22. stmt += ", "
  23. }
  24. stmt += QuoteIdentifier(col)
  25. }
  26. stmt += ") FROM STDIN"
  27. return stmt
  28. }
  29. // CopyInSchema creates a COPY FROM statement which can be prepared with
  30. // Tx.Prepare().
  31. func CopyInSchema(schema, table string, columns ...string) string {
  32. stmt := "COPY " + QuoteIdentifier(schema) + "." + QuoteIdentifier(table) + " ("
  33. for i, col := range columns {
  34. if i != 0 {
  35. stmt += ", "
  36. }
  37. stmt += QuoteIdentifier(col)
  38. }
  39. stmt += ") FROM STDIN"
  40. return stmt
  41. }
  42. type copyin struct {
  43. cn *conn
  44. buffer []byte
  45. rowData chan []byte
  46. done chan bool
  47. driver.Result
  48. closed bool
  49. sync.Mutex // guards err
  50. err error
  51. }
  52. const ciBufferSize = 64 * 1024
  53. // flush buffer before the buffer is filled up and needs reallocation
  54. const ciBufferFlushSize = 63 * 1024
  55. func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
  56. if !cn.isInTransaction() {
  57. return nil, errCopyNotSupportedOutsideTxn
  58. }
  59. ci := &copyin{
  60. cn: cn,
  61. buffer: make([]byte, 0, ciBufferSize),
  62. rowData: make(chan []byte),
  63. done: make(chan bool, 1),
  64. }
  65. // add CopyData identifier + 4 bytes for message length
  66. ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
  67. b := cn.writeBuf('Q')
  68. b.string(q)
  69. cn.send(b)
  70. awaitCopyInResponse:
  71. for {
  72. t, r := cn.recv1()
  73. switch t {
  74. case 'G':
  75. if r.byte() != 0 {
  76. err = errBinaryCopyNotSupported
  77. break awaitCopyInResponse
  78. }
  79. go ci.resploop()
  80. return ci, nil
  81. case 'H':
  82. err = errCopyToNotSupported
  83. break awaitCopyInResponse
  84. case 'E':
  85. err = parseError(r)
  86. case 'Z':
  87. if err == nil {
  88. ci.setBad()
  89. errorf("unexpected ReadyForQuery in response to COPY")
  90. }
  91. cn.processReadyForQuery(r)
  92. return nil, err
  93. default:
  94. ci.setBad()
  95. errorf("unknown response for copy query: %q", t)
  96. }
  97. }
  98. // something went wrong, abort COPY before we return
  99. b = cn.writeBuf('f')
  100. b.string(err.Error())
  101. cn.send(b)
  102. for {
  103. t, r := cn.recv1()
  104. switch t {
  105. case 'c', 'C', 'E':
  106. case 'Z':
  107. // correctly aborted, we're done
  108. cn.processReadyForQuery(r)
  109. return nil, err
  110. default:
  111. ci.setBad()
  112. errorf("unknown response for CopyFail: %q", t)
  113. }
  114. }
  115. }
  116. func (ci *copyin) flush(buf []byte) {
  117. // set message length (without message identifier)
  118. binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
  119. _, err := ci.cn.c.Write(buf)
  120. if err != nil {
  121. panic(err)
  122. }
  123. }
  124. func (ci *copyin) resploop() {
  125. for {
  126. var r readBuf
  127. t, err := ci.cn.recvMessage(&r)
  128. if err != nil {
  129. ci.setBad()
  130. ci.setError(err)
  131. ci.done <- true
  132. return
  133. }
  134. switch t {
  135. case 'C':
  136. // complete
  137. res, _ := ci.cn.parseComplete(r.string())
  138. ci.setResult(res)
  139. case 'N':
  140. if n := ci.cn.noticeHandler; n != nil {
  141. n(parseError(&r))
  142. }
  143. case 'Z':
  144. ci.cn.processReadyForQuery(&r)
  145. ci.done <- true
  146. return
  147. case 'E':
  148. err := parseError(&r)
  149. ci.setError(err)
  150. default:
  151. ci.setBad()
  152. ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
  153. ci.done <- true
  154. return
  155. }
  156. }
  157. }
  158. func (ci *copyin) setBad() {
  159. ci.Lock()
  160. ci.cn.bad = true
  161. ci.Unlock()
  162. }
  163. func (ci *copyin) isBad() bool {
  164. ci.Lock()
  165. b := ci.cn.bad
  166. ci.Unlock()
  167. return b
  168. }
  169. func (ci *copyin) isErrorSet() bool {
  170. ci.Lock()
  171. isSet := (ci.err != nil)
  172. ci.Unlock()
  173. return isSet
  174. }
  175. // setError() sets ci.err if one has not been set already. Caller must not be
  176. // holding ci.Mutex.
  177. func (ci *copyin) setError(err error) {
  178. ci.Lock()
  179. if ci.err == nil {
  180. ci.err = err
  181. }
  182. ci.Unlock()
  183. }
  184. func (ci *copyin) setResult(result driver.Result) {
  185. ci.Lock()
  186. ci.Result = result
  187. ci.Unlock()
  188. }
  189. func (ci *copyin) getResult() driver.Result {
  190. ci.Lock()
  191. result := ci.Result
  192. ci.Unlock()
  193. if result == nil {
  194. return driver.RowsAffected(0)
  195. }
  196. return result
  197. }
  198. func (ci *copyin) NumInput() int {
  199. return -1
  200. }
  201. func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
  202. return nil, ErrNotSupported
  203. }
  204. // Exec inserts values into the COPY stream. The insert is asynchronous
  205. // and Exec can return errors from previous Exec calls to the same
  206. // COPY stmt.
  207. //
  208. // You need to call Exec(nil) to sync the COPY stream and to get any
  209. // errors from pending data, since Stmt.Close() doesn't return errors
  210. // to the user.
  211. func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
  212. if ci.closed {
  213. return nil, errCopyInClosed
  214. }
  215. if ci.isBad() {
  216. return nil, driver.ErrBadConn
  217. }
  218. defer ci.cn.errRecover(&err)
  219. if ci.isErrorSet() {
  220. return nil, ci.err
  221. }
  222. if len(v) == 0 {
  223. if err := ci.Close(); err != nil {
  224. return driver.RowsAffected(0), err
  225. }
  226. return ci.getResult(), nil
  227. }
  228. numValues := len(v)
  229. for i, value := range v {
  230. ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
  231. if i < numValues-1 {
  232. ci.buffer = append(ci.buffer, '\t')
  233. }
  234. }
  235. ci.buffer = append(ci.buffer, '\n')
  236. if len(ci.buffer) > ciBufferFlushSize {
  237. ci.flush(ci.buffer)
  238. // reset buffer, keep bytes for message identifier and length
  239. ci.buffer = ci.buffer[:5]
  240. }
  241. return driver.RowsAffected(0), nil
  242. }
  243. func (ci *copyin) Close() (err error) {
  244. if ci.closed { // Don't do anything, we're already closed
  245. return nil
  246. }
  247. ci.closed = true
  248. if ci.isBad() {
  249. return driver.ErrBadConn
  250. }
  251. defer ci.cn.errRecover(&err)
  252. if len(ci.buffer) > 0 {
  253. ci.flush(ci.buffer)
  254. }
  255. // Avoid touching the scratch buffer as resploop could be using it.
  256. err = ci.cn.sendSimpleMessage('c')
  257. if err != nil {
  258. return err
  259. }
  260. <-ci.done
  261. ci.cn.inCopy = false
  262. if ci.isErrorSet() {
  263. err = ci.err
  264. return err
  265. }
  266. return nil
  267. }