|
|
- // Copyright 2013 The ql Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSES/QL-LICENSE file.
-
- // Copyright 2015 PingCAP, Inc.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // See the License for the specific language governing permissions and
- // limitations under the License.
-
- // database/sql/driver
-
- package tidb
-
- import (
- "database/sql"
- "database/sql/driver"
- "io"
- "net/url"
- "path/filepath"
- "strings"
- "sync"
-
- "github.com/juju/errors"
- "github.com/pingcap/tidb/ast"
- "github.com/pingcap/tidb/model"
- "github.com/pingcap/tidb/sessionctx"
- "github.com/pingcap/tidb/terror"
- "github.com/pingcap/tidb/util/types"
- )
-
- const (
- // DriverName is name of TiDB driver.
- DriverName = "tidb"
- )
-
- var (
- _ driver.Conn = (*driverConn)(nil)
- _ driver.Execer = (*driverConn)(nil)
- _ driver.Queryer = (*driverConn)(nil)
- _ driver.Tx = (*driverConn)(nil)
-
- _ driver.Result = (*driverResult)(nil)
- _ driver.Rows = (*driverRows)(nil)
- _ driver.Stmt = (*driverStmt)(nil)
- _ driver.Driver = (*sqlDriver)(nil)
-
- txBeginSQL = "BEGIN;"
- txCommitSQL = "COMMIT;"
- txRollbackSQL = "ROLLBACK;"
-
- errNoResult = errors.New("query statement does not produce a result set (no top level SELECT)")
- )
-
- type errList []error
-
- type driverParams struct {
- storePath string
- dbName string
- // when set to true `mysql.Time` isn't encoded as string but passed as `time.Time`
- // this option is named for compatibility the same as in the mysql driver
- // while we actually do not have additional parsing to do
- parseTime bool
- }
-
- func (e *errList) append(err error) {
- if err != nil {
- *e = append(*e, err)
- }
- }
-
- func (e errList) error() error {
- if len(e) == 0 {
- return nil
- }
-
- return e
- }
-
- func (e errList) Error() string {
- a := make([]string, len(e))
- for i, v := range e {
- a[i] = v.Error()
- }
- return strings.Join(a, "\n")
- }
-
- func params(args []driver.Value) []interface{} {
- r := make([]interface{}, len(args))
- for i, v := range args {
- r[i] = interface{}(v)
- }
- return r
- }
-
- var (
- tidbDriver = &sqlDriver{}
- driverOnce sync.Once
- )
-
- // RegisterDriver registers TiDB driver.
- // The name argument can be optionally prefixed by "engine://". In that case the
- // prefix is recognized as a storage engine name.
- //
- // The name argument can be optionally prefixed by "memory://". In that case
- // the prefix is stripped before interpreting it as a name of a memory-only,
- // volatile DB.
- //
- // [0]: http://golang.org/pkg/database/sql/driver/
- func RegisterDriver() {
- driverOnce.Do(func() { sql.Register(DriverName, tidbDriver) })
- }
-
- // sqlDriver implements the interface required by database/sql/driver.
- type sqlDriver struct {
- mu sync.Mutex
- }
-
- func (d *sqlDriver) lock() {
- d.mu.Lock()
- }
-
- func (d *sqlDriver) unlock() {
- d.mu.Unlock()
- }
-
- // parseDriverDSN cuts off DB name from dsn. It returns error if the dsn is not
- // valid.
- func parseDriverDSN(dsn string) (params *driverParams, err error) {
- u, err := url.Parse(dsn)
- if err != nil {
- return nil, errors.Trace(err)
- }
- path := filepath.Join(u.Host, u.Path)
- dbName := filepath.Clean(filepath.Base(path))
- if dbName == "" || dbName == "." || dbName == string(filepath.Separator) {
- return nil, errors.Errorf("invalid DB name %q", dbName)
- }
- // cut off dbName
- path = filepath.Clean(filepath.Dir(path))
- if path == "" || path == "." || path == string(filepath.Separator) {
- return nil, errors.Errorf("invalid dsn %q", dsn)
- }
- u.Path, u.Host = path, ""
- params = &driverParams{
- storePath: u.String(),
- dbName: dbName,
- }
- // parse additional driver params
- query := u.Query()
- if parseTime := query.Get("parseTime"); parseTime == "true" {
- params.parseTime = true
- }
-
- return params, nil
- }
-
- // Open returns a new connection to the database.
- //
- // The dsn must be a URL format 'engine://path/dbname?params'.
- // Engine is the storage name registered with RegisterStore.
- // Path is the storage specific format.
- // Params is key-value pairs split by '&', optional params are storage specific.
- // Examples:
- // goleveldb://relative/path/test
- // boltdb:///absolute/path/test
- // hbase://zk1,zk2,zk3/hbasetbl/test?tso=zk
- //
- // Open may return a cached connection (one previously closed), but doing so is
- // unnecessary; the sql package maintains a pool of idle connections for
- // efficient re-use.
- //
- // The behavior of the mysql driver regarding time parsing can also be imitated
- // by passing ?parseTime
- //
- // The returned connection is only used by one goroutine at a time.
- func (d *sqlDriver) Open(dsn string) (driver.Conn, error) {
- params, err := parseDriverDSN(dsn)
- if err != nil {
- return nil, errors.Trace(err)
- }
- store, err := NewStore(params.storePath)
- if err != nil {
- return nil, errors.Trace(err)
- }
-
- sess, err := CreateSession(store)
- if err != nil {
- return nil, errors.Trace(err)
- }
- s := sess.(*session)
-
- d.lock()
- defer d.unlock()
-
- DBName := model.NewCIStr(params.dbName)
- domain := sessionctx.GetDomain(s)
- cs := &ast.CharsetOpt{
- Chs: "utf8",
- Col: "utf8_bin",
- }
- if !domain.InfoSchema().SchemaExists(DBName) {
- err = domain.DDL().CreateSchema(s, DBName, cs)
- if err != nil {
- return nil, errors.Trace(err)
- }
- }
- driver := &sqlDriver{}
- return newDriverConn(s, driver, DBName.O, params)
- }
-
- // driverConn is a connection to a database. It is not used concurrently by
- // multiple goroutines.
- //
- // Conn is assumed to be stateful.
- type driverConn struct {
- s Session
- driver *sqlDriver
- stmts map[string]driver.Stmt
- params *driverParams
- }
-
- func newDriverConn(sess *session, d *sqlDriver, schema string, params *driverParams) (driver.Conn, error) {
- r := &driverConn{
- driver: d,
- stmts: map[string]driver.Stmt{},
- s: sess,
- params: params,
- }
-
- _, err := r.s.Execute("use " + schema)
- if err != nil {
- return nil, errors.Trace(err)
- }
- return r, nil
- }
-
- // Prepare returns a prepared statement, bound to this connection.
- func (c *driverConn) Prepare(query string) (driver.Stmt, error) {
- stmtID, paramCount, fields, err := c.s.PrepareStmt(query)
- if err != nil {
- return nil, err
- }
- s := &driverStmt{
- conn: c,
- query: query,
- stmtID: stmtID,
- paramCount: paramCount,
- isQuery: fields != nil,
- }
- c.stmts[query] = s
- return s, nil
- }
-
- // Close invalidates and potentially stops any current prepared statements and
- // transactions, marking this connection as no longer in use.
- //
- // Because the sql package maintains a free pool of connections and only calls
- // Close when there's a surplus of idle connections, it shouldn't be necessary
- // for drivers to do their own connection caching.
- func (c *driverConn) Close() error {
- var err errList
- for _, s := range c.stmts {
- stmt := s.(*driverStmt)
- err.append(stmt.conn.s.DropPreparedStmt(stmt.stmtID))
- }
-
- c.driver.lock()
- defer c.driver.unlock()
-
- return err.error()
- }
-
- // Begin starts and returns a new transaction.
- func (c *driverConn) Begin() (driver.Tx, error) {
- if c.s == nil {
- return nil, errors.Errorf("Need init first")
- }
-
- if _, err := c.s.Execute(txBeginSQL); err != nil {
- return nil, errors.Trace(err)
- }
-
- return c, nil
- }
-
- func (c *driverConn) Commit() error {
- if c.s == nil {
- return terror.CommitNotInTransaction
- }
- _, err := c.s.Execute(txCommitSQL)
-
- if err != nil {
- return errors.Trace(err)
- }
-
- err = c.s.FinishTxn(false)
- return errors.Trace(err)
- }
-
- func (c *driverConn) Rollback() error {
- if c.s == nil {
- return terror.RollbackNotInTransaction
- }
-
- if _, err := c.s.Execute(txRollbackSQL); err != nil {
- return errors.Trace(err)
- }
-
- return nil
- }
-
- // Execer is an optional interface that may be implemented by a Conn.
- //
- // If a Conn does not implement Execer, the sql package's DB.Exec will first
- // prepare a query, execute the statement, and then close the statement.
- //
- // Exec may return driver.ErrSkip.
- func (c *driverConn) Exec(query string, args []driver.Value) (driver.Result, error) {
- return c.driverExec(query, args)
-
- }
-
- func (c *driverConn) getStmt(query string) (stmt driver.Stmt, err error) {
- stmt, ok := c.stmts[query]
- if !ok {
- stmt, err = c.Prepare(query)
- if err != nil {
- return nil, errors.Trace(err)
- }
- }
- return
- }
-
- func (c *driverConn) driverExec(query string, args []driver.Value) (driver.Result, error) {
- if len(args) == 0 {
- if _, err := c.s.Execute(query); err != nil {
- return nil, errors.Trace(err)
- }
- r := &driverResult{}
- r.lastInsertID, r.rowsAffected = int64(c.s.LastInsertID()), int64(c.s.AffectedRows())
- return r, nil
- }
- stmt, err := c.getStmt(query)
- if err != nil {
- return nil, errors.Trace(err)
- }
- return stmt.Exec(args)
- }
-
- // Queryer is an optional interface that may be implemented by a Conn.
- //
- // If a Conn does not implement Queryer, the sql package's DB.Query will first
- // prepare a query, execute the statement, and then close the statement.
- //
- // Query may return driver.ErrSkip.
- func (c *driverConn) Query(query string, args []driver.Value) (driver.Rows, error) {
- return c.driverQuery(query, args)
- }
-
- func (c *driverConn) driverQuery(query string, args []driver.Value) (driver.Rows, error) {
- if len(args) == 0 {
- rss, err := c.s.Execute(query)
- if err != nil {
- return nil, errors.Trace(err)
- }
- if len(rss) == 0 {
- return nil, errors.Trace(errNoResult)
- }
- return &driverRows{params: c.params, rs: rss[0]}, nil
- }
- stmt, err := c.getStmt(query)
- if err != nil {
- return nil, errors.Trace(err)
- }
- return stmt.Query(args)
- }
-
- // driverResult is the result of a query execution.
- type driverResult struct {
- lastInsertID int64
- rowsAffected int64
- }
-
- // LastInsertID returns the database's auto-generated ID after, for example, an
- // INSERT into a table with primary key.
- func (r *driverResult) LastInsertId() (int64, error) { // -golint
- return r.lastInsertID, nil
- }
-
- // RowsAffected returns the number of rows affected by the query.
- func (r *driverResult) RowsAffected() (int64, error) {
- return r.rowsAffected, nil
- }
-
- // driverRows is an iterator over an executed query's results.
- type driverRows struct {
- rs ast.RecordSet
- params *driverParams
- }
-
- // Columns returns the names of the columns. The number of columns of the
- // result is inferred from the length of the slice. If a particular column
- // name isn't known, an empty string should be returned for that entry.
- func (r *driverRows) Columns() []string {
- if r.rs == nil {
- return []string{}
- }
- fs, _ := r.rs.Fields()
- names := make([]string, len(fs))
- for i, f := range fs {
- names[i] = f.ColumnAsName.O
- }
- return names
- }
-
- // Close closes the rows iterator.
- func (r *driverRows) Close() error {
- if r.rs != nil {
- return r.rs.Close()
- }
- return nil
- }
-
- // Next is called to populate the next row of data into the provided slice. The
- // provided slice will be the same size as the Columns() are wide.
- //
- // The dest slice may be populated only with a driver Value type, but excluding
- // string. All string values must be converted to []byte.
- //
- // Next should return io.EOF when there are no more rows.
- func (r *driverRows) Next(dest []driver.Value) error {
- if r.rs == nil {
- return io.EOF
- }
- row, err := r.rs.Next()
- if err != nil {
- return errors.Trace(err)
- }
- if row == nil {
- return io.EOF
- }
- if len(row.Data) != len(dest) {
- return errors.Errorf("field count mismatch: got %d, need %d", len(row.Data), len(dest))
- }
- for i, xi := range row.Data {
- switch xi.Kind() {
- case types.KindNull:
- dest[i] = nil
- case types.KindInt64:
- dest[i] = xi.GetInt64()
- case types.KindUint64:
- dest[i] = xi.GetUint64()
- case types.KindFloat32:
- dest[i] = xi.GetFloat32()
- case types.KindFloat64:
- dest[i] = xi.GetFloat64()
- case types.KindString:
- dest[i] = xi.GetString()
- case types.KindBytes:
- dest[i] = xi.GetBytes()
- case types.KindMysqlBit:
- dest[i] = xi.GetMysqlBit().ToString()
- case types.KindMysqlDecimal:
- dest[i] = xi.GetMysqlDecimal().String()
- case types.KindMysqlDuration:
- dest[i] = xi.GetMysqlDuration().String()
- case types.KindMysqlEnum:
- dest[i] = xi.GetMysqlEnum().String()
- case types.KindMysqlHex:
- dest[i] = xi.GetMysqlHex().ToString()
- case types.KindMysqlSet:
- dest[i] = xi.GetMysqlSet().String()
- case types.KindMysqlTime:
- t := xi.GetMysqlTime()
- if !r.params.parseTime {
- dest[i] = t.String()
- } else {
- dest[i] = t.Time
- }
- default:
- return errors.Errorf("unable to handle type %T", xi.GetValue())
- }
- }
- return nil
- }
-
- // driverStmt is a prepared statement. It is bound to a driverConn and not used
- // by multiple goroutines concurrently.
- type driverStmt struct {
- conn *driverConn
- query string
- stmtID uint32
- paramCount int
- isQuery bool
- }
-
- // Close closes the statement.
- //
- // As of Go 1.1, a Stmt will not be closed if it's in use by any queries.
- func (s *driverStmt) Close() error {
- s.conn.s.DropPreparedStmt(s.stmtID)
- delete(s.conn.stmts, s.query)
- return nil
- }
-
- // NumInput returns the number of placeholder parameters.
- //
- // If NumInput returns >= 0, the sql package will sanity check argument counts
- // from callers and return errors to the caller before the statement's Exec or
- // Query methods are called.
- //
- // NumInput may also return -1, if the driver doesn't know its number of
- // placeholders. In that case, the sql package will not sanity check Exec or
- // Query argument counts.
- func (s *driverStmt) NumInput() int {
- return s.paramCount
- }
-
- // Exec executes a query that doesn't return rows, such as an INSERT or UPDATE.
- func (s *driverStmt) Exec(args []driver.Value) (driver.Result, error) {
- c := s.conn
- _, err := c.s.ExecutePreparedStmt(s.stmtID, params(args)...)
- if err != nil {
- return nil, errors.Trace(err)
- }
- r := &driverResult{}
- if s != nil {
- r.lastInsertID, r.rowsAffected = int64(c.s.LastInsertID()), int64(c.s.AffectedRows())
- }
- return r, nil
- }
-
- // Exec executes a query that may return rows, such as a SELECT.
- func (s *driverStmt) Query(args []driver.Value) (driver.Rows, error) {
- c := s.conn
- rs, err := c.s.ExecutePreparedStmt(s.stmtID, params(args)...)
- if err != nil {
- return nil, errors.Trace(err)
- }
- if rs == nil {
- if s.isQuery {
- return nil, errors.Trace(errNoResult)
- }
- // The statement is not a query.
- return &driverRows{}, nil
- }
- return &driverRows{params: s.conn.params, rs: rs}, nil
- }
-
- func init() {
- RegisterDriver()
- }
|