You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

560 lines
14 KiB

  1. // Copyright 2013 The ql Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSES/QL-LICENSE file.
  4. // Copyright 2015 PingCAP, Inc.
  5. //
  6. // Licensed under the Apache License, Version 2.0 (the "License");
  7. // you may not use this file except in compliance with the License.
  8. // You may obtain a copy of the License at
  9. //
  10. // http://www.apache.org/licenses/LICENSE-2.0
  11. //
  12. // Unless required by applicable law or agreed to in writing, software
  13. // distributed under the License is distributed on an "AS IS" BASIS,
  14. // See the License for the specific language governing permissions and
  15. // limitations under the License.
  16. // database/sql/driver
  17. package tidb
  18. import (
  19. "database/sql"
  20. "database/sql/driver"
  21. "io"
  22. "net/url"
  23. "path/filepath"
  24. "strings"
  25. "sync"
  26. "github.com/juju/errors"
  27. "github.com/pingcap/tidb/ast"
  28. "github.com/pingcap/tidb/model"
  29. "github.com/pingcap/tidb/sessionctx"
  30. "github.com/pingcap/tidb/terror"
  31. "github.com/pingcap/tidb/util/types"
  32. )
  33. const (
  34. // DriverName is name of TiDB driver.
  35. DriverName = "tidb"
  36. )
  37. var (
  38. _ driver.Conn = (*driverConn)(nil)
  39. _ driver.Execer = (*driverConn)(nil)
  40. _ driver.Queryer = (*driverConn)(nil)
  41. _ driver.Tx = (*driverConn)(nil)
  42. _ driver.Result = (*driverResult)(nil)
  43. _ driver.Rows = (*driverRows)(nil)
  44. _ driver.Stmt = (*driverStmt)(nil)
  45. _ driver.Driver = (*sqlDriver)(nil)
  46. txBeginSQL = "BEGIN;"
  47. txCommitSQL = "COMMIT;"
  48. txRollbackSQL = "ROLLBACK;"
  49. errNoResult = errors.New("query statement does not produce a result set (no top level SELECT)")
  50. )
  51. type errList []error
  52. type driverParams struct {
  53. storePath string
  54. dbName string
  55. // when set to true `mysql.Time` isn't encoded as string but passed as `time.Time`
  56. // this option is named for compatibility the same as in the mysql driver
  57. // while we actually do not have additional parsing to do
  58. parseTime bool
  59. }
  60. func (e *errList) append(err error) {
  61. if err != nil {
  62. *e = append(*e, err)
  63. }
  64. }
  65. func (e errList) error() error {
  66. if len(e) == 0 {
  67. return nil
  68. }
  69. return e
  70. }
  71. func (e errList) Error() string {
  72. a := make([]string, len(e))
  73. for i, v := range e {
  74. a[i] = v.Error()
  75. }
  76. return strings.Join(a, "\n")
  77. }
  78. func params(args []driver.Value) []interface{} {
  79. r := make([]interface{}, len(args))
  80. for i, v := range args {
  81. r[i] = interface{}(v)
  82. }
  83. return r
  84. }
  85. var (
  86. tidbDriver = &sqlDriver{}
  87. driverOnce sync.Once
  88. )
  89. // RegisterDriver registers TiDB driver.
  90. // The name argument can be optionally prefixed by "engine://". In that case the
  91. // prefix is recognized as a storage engine name.
  92. //
  93. // The name argument can be optionally prefixed by "memory://". In that case
  94. // the prefix is stripped before interpreting it as a name of a memory-only,
  95. // volatile DB.
  96. //
  97. // [0]: http://golang.org/pkg/database/sql/driver/
  98. func RegisterDriver() {
  99. driverOnce.Do(func() { sql.Register(DriverName, tidbDriver) })
  100. }
  101. // sqlDriver implements the interface required by database/sql/driver.
  102. type sqlDriver struct {
  103. mu sync.Mutex
  104. }
  105. func (d *sqlDriver) lock() {
  106. d.mu.Lock()
  107. }
  108. func (d *sqlDriver) unlock() {
  109. d.mu.Unlock()
  110. }
  111. // parseDriverDSN cuts off DB name from dsn. It returns error if the dsn is not
  112. // valid.
  113. func parseDriverDSN(dsn string) (params *driverParams, err error) {
  114. u, err := url.Parse(dsn)
  115. if err != nil {
  116. return nil, errors.Trace(err)
  117. }
  118. path := filepath.Join(u.Host, u.Path)
  119. dbName := filepath.Clean(filepath.Base(path))
  120. if dbName == "" || dbName == "." || dbName == string(filepath.Separator) {
  121. return nil, errors.Errorf("invalid DB name %q", dbName)
  122. }
  123. // cut off dbName
  124. path = filepath.Clean(filepath.Dir(path))
  125. if path == "" || path == "." || path == string(filepath.Separator) {
  126. return nil, errors.Errorf("invalid dsn %q", dsn)
  127. }
  128. u.Path, u.Host = path, ""
  129. params = &driverParams{
  130. storePath: u.String(),
  131. dbName: dbName,
  132. }
  133. // parse additional driver params
  134. query := u.Query()
  135. if parseTime := query.Get("parseTime"); parseTime == "true" {
  136. params.parseTime = true
  137. }
  138. return params, nil
  139. }
  140. // Open returns a new connection to the database.
  141. //
  142. // The dsn must be a URL format 'engine://path/dbname?params'.
  143. // Engine is the storage name registered with RegisterStore.
  144. // Path is the storage specific format.
  145. // Params is key-value pairs split by '&', optional params are storage specific.
  146. // Examples:
  147. // goleveldb://relative/path/test
  148. // boltdb:///absolute/path/test
  149. // hbase://zk1,zk2,zk3/hbasetbl/test?tso=zk
  150. //
  151. // Open may return a cached connection (one previously closed), but doing so is
  152. // unnecessary; the sql package maintains a pool of idle connections for
  153. // efficient re-use.
  154. //
  155. // The behavior of the mysql driver regarding time parsing can also be imitated
  156. // by passing ?parseTime
  157. //
  158. // The returned connection is only used by one goroutine at a time.
  159. func (d *sqlDriver) Open(dsn string) (driver.Conn, error) {
  160. params, err := parseDriverDSN(dsn)
  161. if err != nil {
  162. return nil, errors.Trace(err)
  163. }
  164. store, err := NewStore(params.storePath)
  165. if err != nil {
  166. return nil, errors.Trace(err)
  167. }
  168. sess, err := CreateSession(store)
  169. if err != nil {
  170. return nil, errors.Trace(err)
  171. }
  172. s := sess.(*session)
  173. d.lock()
  174. defer d.unlock()
  175. DBName := model.NewCIStr(params.dbName)
  176. domain := sessionctx.GetDomain(s)
  177. cs := &ast.CharsetOpt{
  178. Chs: "utf8",
  179. Col: "utf8_bin",
  180. }
  181. if !domain.InfoSchema().SchemaExists(DBName) {
  182. err = domain.DDL().CreateSchema(s, DBName, cs)
  183. if err != nil {
  184. return nil, errors.Trace(err)
  185. }
  186. }
  187. driver := &sqlDriver{}
  188. return newDriverConn(s, driver, DBName.O, params)
  189. }
  190. // driverConn is a connection to a database. It is not used concurrently by
  191. // multiple goroutines.
  192. //
  193. // Conn is assumed to be stateful.
  194. type driverConn struct {
  195. s Session
  196. driver *sqlDriver
  197. stmts map[string]driver.Stmt
  198. params *driverParams
  199. }
  200. func newDriverConn(sess *session, d *sqlDriver, schema string, params *driverParams) (driver.Conn, error) {
  201. r := &driverConn{
  202. driver: d,
  203. stmts: map[string]driver.Stmt{},
  204. s: sess,
  205. params: params,
  206. }
  207. _, err := r.s.Execute("use " + schema)
  208. if err != nil {
  209. return nil, errors.Trace(err)
  210. }
  211. return r, nil
  212. }
  213. // Prepare returns a prepared statement, bound to this connection.
  214. func (c *driverConn) Prepare(query string) (driver.Stmt, error) {
  215. stmtID, paramCount, fields, err := c.s.PrepareStmt(query)
  216. if err != nil {
  217. return nil, err
  218. }
  219. s := &driverStmt{
  220. conn: c,
  221. query: query,
  222. stmtID: stmtID,
  223. paramCount: paramCount,
  224. isQuery: fields != nil,
  225. }
  226. c.stmts[query] = s
  227. return s, nil
  228. }
  229. // Close invalidates and potentially stops any current prepared statements and
  230. // transactions, marking this connection as no longer in use.
  231. //
  232. // Because the sql package maintains a free pool of connections and only calls
  233. // Close when there's a surplus of idle connections, it shouldn't be necessary
  234. // for drivers to do their own connection caching.
  235. func (c *driverConn) Close() error {
  236. var err errList
  237. for _, s := range c.stmts {
  238. stmt := s.(*driverStmt)
  239. err.append(stmt.conn.s.DropPreparedStmt(stmt.stmtID))
  240. }
  241. c.driver.lock()
  242. defer c.driver.unlock()
  243. return err.error()
  244. }
  245. // Begin starts and returns a new transaction.
  246. func (c *driverConn) Begin() (driver.Tx, error) {
  247. if c.s == nil {
  248. return nil, errors.Errorf("Need init first")
  249. }
  250. if _, err := c.s.Execute(txBeginSQL); err != nil {
  251. return nil, errors.Trace(err)
  252. }
  253. return c, nil
  254. }
  255. func (c *driverConn) Commit() error {
  256. if c.s == nil {
  257. return terror.CommitNotInTransaction
  258. }
  259. _, err := c.s.Execute(txCommitSQL)
  260. if err != nil {
  261. return errors.Trace(err)
  262. }
  263. err = c.s.FinishTxn(false)
  264. return errors.Trace(err)
  265. }
  266. func (c *driverConn) Rollback() error {
  267. if c.s == nil {
  268. return terror.RollbackNotInTransaction
  269. }
  270. if _, err := c.s.Execute(txRollbackSQL); err != nil {
  271. return errors.Trace(err)
  272. }
  273. return nil
  274. }
  275. // Execer is an optional interface that may be implemented by a Conn.
  276. //
  277. // If a Conn does not implement Execer, the sql package's DB.Exec will first
  278. // prepare a query, execute the statement, and then close the statement.
  279. //
  280. // Exec may return driver.ErrSkip.
  281. func (c *driverConn) Exec(query string, args []driver.Value) (driver.Result, error) {
  282. return c.driverExec(query, args)
  283. }
  284. func (c *driverConn) getStmt(query string) (stmt driver.Stmt, err error) {
  285. stmt, ok := c.stmts[query]
  286. if !ok {
  287. stmt, err = c.Prepare(query)
  288. if err != nil {
  289. return nil, errors.Trace(err)
  290. }
  291. }
  292. return
  293. }
  294. func (c *driverConn) driverExec(query string, args []driver.Value) (driver.Result, error) {
  295. if len(args) == 0 {
  296. if _, err := c.s.Execute(query); err != nil {
  297. return nil, errors.Trace(err)
  298. }
  299. r := &driverResult{}
  300. r.lastInsertID, r.rowsAffected = int64(c.s.LastInsertID()), int64(c.s.AffectedRows())
  301. return r, nil
  302. }
  303. stmt, err := c.getStmt(query)
  304. if err != nil {
  305. return nil, errors.Trace(err)
  306. }
  307. return stmt.Exec(args)
  308. }
  309. // Queryer is an optional interface that may be implemented by a Conn.
  310. //
  311. // If a Conn does not implement Queryer, the sql package's DB.Query will first
  312. // prepare a query, execute the statement, and then close the statement.
  313. //
  314. // Query may return driver.ErrSkip.
  315. func (c *driverConn) Query(query string, args []driver.Value) (driver.Rows, error) {
  316. return c.driverQuery(query, args)
  317. }
  318. func (c *driverConn) driverQuery(query string, args []driver.Value) (driver.Rows, error) {
  319. if len(args) == 0 {
  320. rss, err := c.s.Execute(query)
  321. if err != nil {
  322. return nil, errors.Trace(err)
  323. }
  324. if len(rss) == 0 {
  325. return nil, errors.Trace(errNoResult)
  326. }
  327. return &driverRows{params: c.params, rs: rss[0]}, nil
  328. }
  329. stmt, err := c.getStmt(query)
  330. if err != nil {
  331. return nil, errors.Trace(err)
  332. }
  333. return stmt.Query(args)
  334. }
  335. // driverResult is the result of a query execution.
  336. type driverResult struct {
  337. lastInsertID int64
  338. rowsAffected int64
  339. }
  340. // LastInsertID returns the database's auto-generated ID after, for example, an
  341. // INSERT into a table with primary key.
  342. func (r *driverResult) LastInsertId() (int64, error) { // -golint
  343. return r.lastInsertID, nil
  344. }
  345. // RowsAffected returns the number of rows affected by the query.
  346. func (r *driverResult) RowsAffected() (int64, error) {
  347. return r.rowsAffected, nil
  348. }
  349. // driverRows is an iterator over an executed query's results.
  350. type driverRows struct {
  351. rs ast.RecordSet
  352. params *driverParams
  353. }
  354. // Columns returns the names of the columns. The number of columns of the
  355. // result is inferred from the length of the slice. If a particular column
  356. // name isn't known, an empty string should be returned for that entry.
  357. func (r *driverRows) Columns() []string {
  358. if r.rs == nil {
  359. return []string{}
  360. }
  361. fs, _ := r.rs.Fields()
  362. names := make([]string, len(fs))
  363. for i, f := range fs {
  364. names[i] = f.ColumnAsName.O
  365. }
  366. return names
  367. }
  368. // Close closes the rows iterator.
  369. func (r *driverRows) Close() error {
  370. if r.rs != nil {
  371. return r.rs.Close()
  372. }
  373. return nil
  374. }
  375. // Next is called to populate the next row of data into the provided slice. The
  376. // provided slice will be the same size as the Columns() are wide.
  377. //
  378. // The dest slice may be populated only with a driver Value type, but excluding
  379. // string. All string values must be converted to []byte.
  380. //
  381. // Next should return io.EOF when there are no more rows.
  382. func (r *driverRows) Next(dest []driver.Value) error {
  383. if r.rs == nil {
  384. return io.EOF
  385. }
  386. row, err := r.rs.Next()
  387. if err != nil {
  388. return errors.Trace(err)
  389. }
  390. if row == nil {
  391. return io.EOF
  392. }
  393. if len(row.Data) != len(dest) {
  394. return errors.Errorf("field count mismatch: got %d, need %d", len(row.Data), len(dest))
  395. }
  396. for i, xi := range row.Data {
  397. switch xi.Kind() {
  398. case types.KindNull:
  399. dest[i] = nil
  400. case types.KindInt64:
  401. dest[i] = xi.GetInt64()
  402. case types.KindUint64:
  403. dest[i] = xi.GetUint64()
  404. case types.KindFloat32:
  405. dest[i] = xi.GetFloat32()
  406. case types.KindFloat64:
  407. dest[i] = xi.GetFloat64()
  408. case types.KindString:
  409. dest[i] = xi.GetString()
  410. case types.KindBytes:
  411. dest[i] = xi.GetBytes()
  412. case types.KindMysqlBit:
  413. dest[i] = xi.GetMysqlBit().ToString()
  414. case types.KindMysqlDecimal:
  415. dest[i] = xi.GetMysqlDecimal().String()
  416. case types.KindMysqlDuration:
  417. dest[i] = xi.GetMysqlDuration().String()
  418. case types.KindMysqlEnum:
  419. dest[i] = xi.GetMysqlEnum().String()
  420. case types.KindMysqlHex:
  421. dest[i] = xi.GetMysqlHex().ToString()
  422. case types.KindMysqlSet:
  423. dest[i] = xi.GetMysqlSet().String()
  424. case types.KindMysqlTime:
  425. t := xi.GetMysqlTime()
  426. if !r.params.parseTime {
  427. dest[i] = t.String()
  428. } else {
  429. dest[i] = t.Time
  430. }
  431. default:
  432. return errors.Errorf("unable to handle type %T", xi.GetValue())
  433. }
  434. }
  435. return nil
  436. }
  437. // driverStmt is a prepared statement. It is bound to a driverConn and not used
  438. // by multiple goroutines concurrently.
  439. type driverStmt struct {
  440. conn *driverConn
  441. query string
  442. stmtID uint32
  443. paramCount int
  444. isQuery bool
  445. }
  446. // Close closes the statement.
  447. //
  448. // As of Go 1.1, a Stmt will not be closed if it's in use by any queries.
  449. func (s *driverStmt) Close() error {
  450. s.conn.s.DropPreparedStmt(s.stmtID)
  451. delete(s.conn.stmts, s.query)
  452. return nil
  453. }
  454. // NumInput returns the number of placeholder parameters.
  455. //
  456. // If NumInput returns >= 0, the sql package will sanity check argument counts
  457. // from callers and return errors to the caller before the statement's Exec or
  458. // Query methods are called.
  459. //
  460. // NumInput may also return -1, if the driver doesn't know its number of
  461. // placeholders. In that case, the sql package will not sanity check Exec or
  462. // Query argument counts.
  463. func (s *driverStmt) NumInput() int {
  464. return s.paramCount
  465. }
  466. // Exec executes a query that doesn't return rows, such as an INSERT or UPDATE.
  467. func (s *driverStmt) Exec(args []driver.Value) (driver.Result, error) {
  468. c := s.conn
  469. _, err := c.s.ExecutePreparedStmt(s.stmtID, params(args)...)
  470. if err != nil {
  471. return nil, errors.Trace(err)
  472. }
  473. r := &driverResult{}
  474. if s != nil {
  475. r.lastInsertID, r.rowsAffected = int64(c.s.LastInsertID()), int64(c.s.AffectedRows())
  476. }
  477. return r, nil
  478. }
  479. // Exec executes a query that may return rows, such as a SELECT.
  480. func (s *driverStmt) Query(args []driver.Value) (driver.Rows, error) {
  481. c := s.conn
  482. rs, err := c.s.ExecutePreparedStmt(s.stmtID, params(args)...)
  483. if err != nil {
  484. return nil, errors.Trace(err)
  485. }
  486. if rs == nil {
  487. if s.isQuery {
  488. return nil, errors.Trace(errNoResult)
  489. }
  490. // The statement is not a query.
  491. return &driverRows{}, nil
  492. }
  493. return &driverRows{params: s.conn.params, rs: rs}, nil
  494. }
  495. func init() {
  496. RegisterDriver()
  497. }