You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

377 lines
8.1 KiB

  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "database/sql/driver"
  11. "net"
  12. "strconv"
  13. "strings"
  14. "time"
  15. )
  16. type mysqlConn struct {
  17. buf buffer
  18. netConn net.Conn
  19. affectedRows uint64
  20. insertId uint64
  21. cfg *Config
  22. maxAllowedPacket int
  23. maxWriteSize int
  24. writeTimeout time.Duration
  25. flags clientFlag
  26. status statusFlag
  27. sequence uint8
  28. parseTime bool
  29. strict bool
  30. }
  31. // Handles parameters set in DSN after the connection is established
  32. func (mc *mysqlConn) handleParams() (err error) {
  33. for param, val := range mc.cfg.Params {
  34. switch param {
  35. // Charset
  36. case "charset":
  37. charsets := strings.Split(val, ",")
  38. for i := range charsets {
  39. // ignore errors here - a charset may not exist
  40. err = mc.exec("SET NAMES " + charsets[i])
  41. if err == nil {
  42. break
  43. }
  44. }
  45. if err != nil {
  46. return
  47. }
  48. // System Vars
  49. default:
  50. err = mc.exec("SET " + param + "=" + val + "")
  51. if err != nil {
  52. return
  53. }
  54. }
  55. }
  56. return
  57. }
  58. func (mc *mysqlConn) Begin() (driver.Tx, error) {
  59. if mc.netConn == nil {
  60. errLog.Print(ErrInvalidConn)
  61. return nil, driver.ErrBadConn
  62. }
  63. err := mc.exec("START TRANSACTION")
  64. if err == nil {
  65. return &mysqlTx{mc}, err
  66. }
  67. return nil, err
  68. }
  69. func (mc *mysqlConn) Close() (err error) {
  70. // Makes Close idempotent
  71. if mc.netConn != nil {
  72. err = mc.writeCommandPacket(comQuit)
  73. }
  74. mc.cleanup()
  75. return
  76. }
  77. // Closes the network connection and unsets internal variables. Do not call this
  78. // function after successfully authentication, call Close instead. This function
  79. // is called before auth or on auth failure because MySQL will have already
  80. // closed the network connection.
  81. func (mc *mysqlConn) cleanup() {
  82. // Makes cleanup idempotent
  83. if mc.netConn != nil {
  84. if err := mc.netConn.Close(); err != nil {
  85. errLog.Print(err)
  86. }
  87. mc.netConn = nil
  88. }
  89. mc.cfg = nil
  90. mc.buf.nc = nil
  91. }
  92. func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
  93. if mc.netConn == nil {
  94. errLog.Print(ErrInvalidConn)
  95. return nil, driver.ErrBadConn
  96. }
  97. // Send command
  98. err := mc.writeCommandPacketStr(comStmtPrepare, query)
  99. if err != nil {
  100. return nil, err
  101. }
  102. stmt := &mysqlStmt{
  103. mc: mc,
  104. }
  105. // Read Result
  106. columnCount, err := stmt.readPrepareResultPacket()
  107. if err == nil {
  108. if stmt.paramCount > 0 {
  109. if err = mc.readUntilEOF(); err != nil {
  110. return nil, err
  111. }
  112. }
  113. if columnCount > 0 {
  114. err = mc.readUntilEOF()
  115. }
  116. }
  117. return stmt, err
  118. }
  119. func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
  120. // Number of ? should be same to len(args)
  121. if strings.Count(query, "?") != len(args) {
  122. return "", driver.ErrSkip
  123. }
  124. buf := mc.buf.takeCompleteBuffer()
  125. if buf == nil {
  126. // can not take the buffer. Something must be wrong with the connection
  127. errLog.Print(ErrBusyBuffer)
  128. return "", driver.ErrBadConn
  129. }
  130. buf = buf[:0]
  131. argPos := 0
  132. for i := 0; i < len(query); i++ {
  133. q := strings.IndexByte(query[i:], '?')
  134. if q == -1 {
  135. buf = append(buf, query[i:]...)
  136. break
  137. }
  138. buf = append(buf, query[i:i+q]...)
  139. i += q
  140. arg := args[argPos]
  141. argPos++
  142. if arg == nil {
  143. buf = append(buf, "NULL"...)
  144. continue
  145. }
  146. switch v := arg.(type) {
  147. case int64:
  148. buf = strconv.AppendInt(buf, v, 10)
  149. case float64:
  150. buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
  151. case bool:
  152. if v {
  153. buf = append(buf, '1')
  154. } else {
  155. buf = append(buf, '0')
  156. }
  157. case time.Time:
  158. if v.IsZero() {
  159. buf = append(buf, "'0000-00-00'"...)
  160. } else {
  161. v := v.In(mc.cfg.Loc)
  162. v = v.Add(time.Nanosecond * 500) // To round under microsecond
  163. year := v.Year()
  164. year100 := year / 100
  165. year1 := year % 100
  166. month := v.Month()
  167. day := v.Day()
  168. hour := v.Hour()
  169. minute := v.Minute()
  170. second := v.Second()
  171. micro := v.Nanosecond() / 1000
  172. buf = append(buf, []byte{
  173. '\'',
  174. digits10[year100], digits01[year100],
  175. digits10[year1], digits01[year1],
  176. '-',
  177. digits10[month], digits01[month],
  178. '-',
  179. digits10[day], digits01[day],
  180. ' ',
  181. digits10[hour], digits01[hour],
  182. ':',
  183. digits10[minute], digits01[minute],
  184. ':',
  185. digits10[second], digits01[second],
  186. }...)
  187. if micro != 0 {
  188. micro10000 := micro / 10000
  189. micro100 := micro / 100 % 100
  190. micro1 := micro % 100
  191. buf = append(buf, []byte{
  192. '.',
  193. digits10[micro10000], digits01[micro10000],
  194. digits10[micro100], digits01[micro100],
  195. digits10[micro1], digits01[micro1],
  196. }...)
  197. }
  198. buf = append(buf, '\'')
  199. }
  200. case []byte:
  201. if v == nil {
  202. buf = append(buf, "NULL"...)
  203. } else {
  204. buf = append(buf, "_binary'"...)
  205. if mc.status&statusNoBackslashEscapes == 0 {
  206. buf = escapeBytesBackslash(buf, v)
  207. } else {
  208. buf = escapeBytesQuotes(buf, v)
  209. }
  210. buf = append(buf, '\'')
  211. }
  212. case string:
  213. buf = append(buf, '\'')
  214. if mc.status&statusNoBackslashEscapes == 0 {
  215. buf = escapeStringBackslash(buf, v)
  216. } else {
  217. buf = escapeStringQuotes(buf, v)
  218. }
  219. buf = append(buf, '\'')
  220. default:
  221. return "", driver.ErrSkip
  222. }
  223. if len(buf)+4 > mc.maxAllowedPacket {
  224. return "", driver.ErrSkip
  225. }
  226. }
  227. if argPos != len(args) {
  228. return "", driver.ErrSkip
  229. }
  230. return string(buf), nil
  231. }
  232. func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
  233. if mc.netConn == nil {
  234. errLog.Print(ErrInvalidConn)
  235. return nil, driver.ErrBadConn
  236. }
  237. if len(args) != 0 {
  238. if !mc.cfg.InterpolateParams {
  239. return nil, driver.ErrSkip
  240. }
  241. // try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
  242. prepared, err := mc.interpolateParams(query, args)
  243. if err != nil {
  244. return nil, err
  245. }
  246. query = prepared
  247. args = nil
  248. }
  249. mc.affectedRows = 0
  250. mc.insertId = 0
  251. err := mc.exec(query)
  252. if err == nil {
  253. return &mysqlResult{
  254. affectedRows: int64(mc.affectedRows),
  255. insertId: int64(mc.insertId),
  256. }, err
  257. }
  258. return nil, err
  259. }
  260. // Internal function to execute commands
  261. func (mc *mysqlConn) exec(query string) error {
  262. // Send command
  263. err := mc.writeCommandPacketStr(comQuery, query)
  264. if err != nil {
  265. return err
  266. }
  267. // Read Result
  268. resLen, err := mc.readResultSetHeaderPacket()
  269. if err == nil && resLen > 0 {
  270. if err = mc.readUntilEOF(); err != nil {
  271. return err
  272. }
  273. err = mc.readUntilEOF()
  274. }
  275. return err
  276. }
  277. func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
  278. if mc.netConn == nil {
  279. errLog.Print(ErrInvalidConn)
  280. return nil, driver.ErrBadConn
  281. }
  282. if len(args) != 0 {
  283. if !mc.cfg.InterpolateParams {
  284. return nil, driver.ErrSkip
  285. }
  286. // try client-side prepare to reduce roundtrip
  287. prepared, err := mc.interpolateParams(query, args)
  288. if err != nil {
  289. return nil, err
  290. }
  291. query = prepared
  292. args = nil
  293. }
  294. // Send command
  295. err := mc.writeCommandPacketStr(comQuery, query)
  296. if err == nil {
  297. // Read Result
  298. var resLen int
  299. resLen, err = mc.readResultSetHeaderPacket()
  300. if err == nil {
  301. rows := new(textRows)
  302. rows.mc = mc
  303. if resLen == 0 {
  304. // no columns, no more data
  305. return emptyRows{}, nil
  306. }
  307. // Columns
  308. rows.columns, err = mc.readColumns(resLen)
  309. return rows, err
  310. }
  311. }
  312. return nil, err
  313. }
  314. // Gets the value of the given MySQL System Variable
  315. // The returned byte slice is only valid until the next read
  316. func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
  317. // Send command
  318. if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
  319. return nil, err
  320. }
  321. // Read Result
  322. resLen, err := mc.readResultSetHeaderPacket()
  323. if err == nil {
  324. rows := new(textRows)
  325. rows.mc = mc
  326. rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
  327. if resLen > 0 {
  328. // Columns
  329. if err := mc.readUntilEOF(); err != nil {
  330. return nil, err
  331. }
  332. }
  333. dest := make([]driver.Value, resLen)
  334. if err = rows.readRow(dest); err == nil {
  335. return dest[0].([]byte), mc.readUntilEOF()
  336. }
  337. }
  338. return nil, err
  339. }