|
|
@ -770,22 +770,25 @@ var ( |
|
|
|
postgresQuoter = schemas.Quoter{'"', '"', schemas.AlwaysReserve} |
|
|
|
) |
|
|
|
|
|
|
|
const postgresPublicSchema = "public" |
|
|
|
var ( |
|
|
|
// DefaultPostgresSchema default postgres schema
|
|
|
|
DefaultPostgresSchema = "public" |
|
|
|
) |
|
|
|
|
|
|
|
type postgres struct { |
|
|
|
Base |
|
|
|
} |
|
|
|
|
|
|
|
func (db *postgres) Init(d *core.DB, uri *URI) error { |
|
|
|
func (db *postgres) Init(uri *URI) error { |
|
|
|
db.quoter = postgresQuoter |
|
|
|
err := db.Base.Init(d, db, uri) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
if db.uri.Schema == "" { |
|
|
|
db.uri.Schema = postgresPublicSchema |
|
|
|
return db.Base.Init(db, uri) |
|
|
|
} |
|
|
|
|
|
|
|
func (db *postgres) getSchema() string { |
|
|
|
if db.uri.Schema != "" { |
|
|
|
return db.uri.Schema |
|
|
|
} |
|
|
|
return nil |
|
|
|
return DefaultPostgresSchema |
|
|
|
} |
|
|
|
|
|
|
|
func (db *postgres) needQuote(name string) bool { |
|
|
@ -817,10 +820,6 @@ func (db *postgres) SetQuotePolicy(quotePolicy QuotePolicy) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func (db *postgres) DefaultSchema() string { |
|
|
|
return postgresPublicSchema |
|
|
|
} |
|
|
|
|
|
|
|
func (db *postgres) SQLType(c *schemas.Column) string { |
|
|
|
var res string |
|
|
|
switch t := c.SQLType.Name; t { |
|
|
@ -932,32 +931,32 @@ func (db *postgres) CreateTableSQL(table *schemas.Table, tableName string) ([]st |
|
|
|
} |
|
|
|
|
|
|
|
func (db *postgres) IndexCheckSQL(tableName, idxName string) (string, []interface{}) { |
|
|
|
if len(db.uri.Schema) == 0 { |
|
|
|
if len(db.getSchema()) == 0 { |
|
|
|
args := []interface{}{tableName, idxName} |
|
|
|
return `SELECT indexname FROM pg_indexes WHERE tablename = ? AND indexname = ?`, args |
|
|
|
} |
|
|
|
|
|
|
|
args := []interface{}{db.uri.Schema, tableName, idxName} |
|
|
|
args := []interface{}{db.getSchema(), tableName, idxName} |
|
|
|
return `SELECT indexname FROM pg_indexes ` + |
|
|
|
`WHERE schemaname = ? AND tablename = ? AND indexname = ?`, args |
|
|
|
} |
|
|
|
|
|
|
|
func (db *postgres) IsTableExist(ctx context.Context, tableName string) (bool, error) { |
|
|
|
if len(db.uri.Schema) == 0 { |
|
|
|
return db.HasRecords(ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName) |
|
|
|
func (db *postgres) IsTableExist(queryer core.Queryer, ctx context.Context, tableName string) (bool, error) { |
|
|
|
if len(db.getSchema()) == 0 { |
|
|
|
return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE tablename = $1`, tableName) |
|
|
|
} |
|
|
|
|
|
|
|
return db.HasRecords(ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`, |
|
|
|
db.uri.Schema, tableName) |
|
|
|
return db.HasRecords(queryer, ctx, `SELECT tablename FROM pg_tables WHERE schemaname = $1 AND tablename = $2`, |
|
|
|
db.getSchema(), tableName) |
|
|
|
} |
|
|
|
|
|
|
|
func (db *postgres) ModifyColumnSQL(tableName string, col *schemas.Column) string { |
|
|
|
if len(db.uri.Schema) == 0 || strings.Contains(tableName, ".") { |
|
|
|
if len(db.getSchema()) == 0 || strings.Contains(tableName, ".") { |
|
|
|
return fmt.Sprintf("alter table %s ALTER COLUMN %s TYPE %s", |
|
|
|
tableName, col.Name, db.SQLType(col)) |
|
|
|
} |
|
|
|
return fmt.Sprintf("alter table %s.%s ALTER COLUMN %s TYPE %s", |
|
|
|
db.uri.Schema, tableName, col.Name, db.SQLType(col)) |
|
|
|
db.getSchema(), tableName, col.Name, db.SQLType(col)) |
|
|
|
} |
|
|
|
|
|
|
|
func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string { |
|
|
@ -974,23 +973,23 @@ func (db *postgres) DropIndexSQL(tableName string, index *schemas.Index) string |
|
|
|
idxName = fmt.Sprintf("IDX_%v_%v", tableName, index.Name) |
|
|
|
} |
|
|
|
} |
|
|
|
if db.uri.Schema != "" { |
|
|
|
idxName = db.uri.Schema + "." + idxName |
|
|
|
if db.getSchema() != "" { |
|
|
|
idxName = db.getSchema() + "." + idxName |
|
|
|
} |
|
|
|
return fmt.Sprintf("DROP INDEX %v", db.Quoter().Quote(idxName)) |
|
|
|
} |
|
|
|
|
|
|
|
func (db *postgres) IsColumnExist(ctx context.Context, tableName, colName string) (bool, error) { |
|
|
|
args := []interface{}{db.uri.Schema, tableName, colName} |
|
|
|
func (db *postgres) IsColumnExist(queryer core.Queryer, ctx context.Context, tableName, colName string) (bool, error) { |
|
|
|
args := []interface{}{db.getSchema(), tableName, colName} |
|
|
|
query := "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_schema = $1 AND table_name = $2" + |
|
|
|
" AND column_name = $3" |
|
|
|
if len(db.uri.Schema) == 0 { |
|
|
|
if len(db.getSchema()) == 0 { |
|
|
|
args = []interface{}{tableName, colName} |
|
|
|
query = "SELECT column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name = $1" + |
|
|
|
" AND column_name = $2" |
|
|
|
} |
|
|
|
|
|
|
|
rows, err := db.DB().QueryContext(ctx, query, args...) |
|
|
|
rows, err := queryer.QueryContext(ctx, query, args...) |
|
|
|
if err != nil { |
|
|
|
return false, err |
|
|
|
} |
|
|
@ -999,8 +998,8 @@ func (db *postgres) IsColumnExist(ctx context.Context, tableName, colName string |
|
|
|
return rows.Next(), nil |
|
|
|
} |
|
|
|
|
|
|
|
func (db *postgres) GetColumns(ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { |
|
|
|
args := []interface{}{db.uri.Schema, tableName, db.uri.Schema} |
|
|
|
func (db *postgres) GetColumns(queryer core.Queryer, ctx context.Context, tableName string) ([]string, map[string]*schemas.Column, error) { |
|
|
|
args := []interface{}{tableName} |
|
|
|
s := `SELECT column_name, column_default, is_nullable, data_type, character_maximum_length, |
|
|
|
CASE WHEN p.contype = 'p' THEN true ELSE false END AS primarykey, |
|
|
|
CASE WHEN p.contype = 'u' THEN true ELSE false END AS uniquekey |
|
|
@ -1011,9 +1010,17 @@ FROM pg_attribute f |
|
|
|
LEFT JOIN pg_constraint p ON p.conrelid = c.oid AND f.attnum = ANY (p.conkey) |
|
|
|
LEFT JOIN pg_class AS g ON p.confrelid = g.oid |
|
|
|
LEFT JOIN INFORMATION_SCHEMA.COLUMNS s ON s.column_name=f.attname AND c.relname=s.table_name |
|
|
|
WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_schema = $3 AND f.attnum > 0 ORDER BY f.attnum;` |
|
|
|
WHERE n.nspname= s.table_schema AND c.relkind = 'r'::char AND c.relname = $1%s AND f.attnum > 0 ORDER BY f.attnum;` |
|
|
|
|
|
|
|
rows, err := db.DB().QueryContext(ctx, s, args...) |
|
|
|
schema := db.getSchema() |
|
|
|
if schema != "" { |
|
|
|
s = fmt.Sprintf(s, "AND s.table_schema = $2") |
|
|
|
args = append(args, schema) |
|
|
|
} else { |
|
|
|
s = fmt.Sprintf(s, "") |
|
|
|
} |
|
|
|
|
|
|
|
rows, err := queryer.QueryContext(ctx, s, args...) |
|
|
|
if err != nil { |
|
|
|
return nil, nil, err |
|
|
|
} |
|
|
@ -1132,15 +1139,16 @@ WHERE n.nspname= $1 AND c.relkind = 'r'::char AND c.relname = $2 AND s.table_sch |
|
|
|
return colSeq, cols, nil |
|
|
|
} |
|
|
|
|
|
|
|
func (db *postgres) GetTables(ctx context.Context) ([]*schemas.Table, error) { |
|
|
|
func (db *postgres) GetTables(queryer core.Queryer, ctx context.Context) ([]*schemas.Table, error) { |
|
|
|
args := []interface{}{} |
|
|
|
s := "SELECT tablename FROM pg_tables" |
|
|
|
if len(db.uri.Schema) != 0 { |
|
|
|
args = append(args, db.uri.Schema) |
|
|
|
schema := db.getSchema() |
|
|
|
if schema != "" { |
|
|
|
args = append(args, schema) |
|
|
|
s = s + " WHERE schemaname = $1" |
|
|
|
} |
|
|
|
|
|
|
|
rows, err := db.DB().QueryContext(ctx, s, args...) |
|
|
|
rows, err := queryer.QueryContext(ctx, s, args...) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
@ -1171,15 +1179,15 @@ func getIndexColName(indexdef string) []string { |
|
|
|
return colNames |
|
|
|
} |
|
|
|
|
|
|
|
func (db *postgres) GetIndexes(ctx context.Context, tableName string) (map[string]*schemas.Index, error) { |
|
|
|
func (db *postgres) GetIndexes(queryer core.Queryer, ctx context.Context, tableName string) (map[string]*schemas.Index, error) { |
|
|
|
args := []interface{}{tableName} |
|
|
|
s := fmt.Sprintf("SELECT indexname, indexdef FROM pg_indexes WHERE tablename=$1") |
|
|
|
if len(db.uri.Schema) != 0 { |
|
|
|
args = append(args, db.uri.Schema) |
|
|
|
if len(db.getSchema()) != 0 { |
|
|
|
args = append(args, db.getSchema()) |
|
|
|
s = s + " AND schemaname=$2" |
|
|
|
} |
|
|
|
|
|
|
|
rows, err := db.DB().QueryContext(ctx, s, args...) |
|
|
|
rows, err := queryer.QueryContext(ctx, s, args...) |
|
|
|
if err != nil { |
|
|
|
return nil, err |
|
|
|
} |
|
|
@ -1319,3 +1327,22 @@ func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*URI, error) { |
|
|
|
} |
|
|
|
return pgx.pqDriver.Parse(driverName, dataSourceName) |
|
|
|
} |
|
|
|
|
|
|
|
// QueryDefaultPostgresSchema returns the default postgres schema
|
|
|
|
func QueryDefaultPostgresSchema(ctx context.Context, queryer core.Queryer) (string, error) { |
|
|
|
rows, err := queryer.QueryContext(ctx, "SHOW SEARCH_PATH") |
|
|
|
if err != nil { |
|
|
|
return "", err |
|
|
|
} |
|
|
|
defer rows.Close() |
|
|
|
if rows.Next() { |
|
|
|
var defaultSchema string |
|
|
|
if err = rows.Scan(&defaultSchema); err != nil { |
|
|
|
return "", err |
|
|
|
} |
|
|
|
parts := strings.Split(defaultSchema, ",") |
|
|
|
return strings.TrimSpace(parts[len(parts)-1]), nil |
|
|
|
} |
|
|
|
|
|
|
|
return "", errors.New("No default schema") |
|
|
|
} |