You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

394 lines
7.9 KiB

  1. // Copyright 2016 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package builder
  5. import (
  6. sql2 "database/sql"
  7. "fmt"
  8. "sort"
  9. )
  10. type optype byte
  11. const (
  12. condType optype = iota // only conditions
  13. selectType // select
  14. insertType // insert
  15. updateType // update
  16. deleteType // delete
  17. unionType // union
  18. )
  19. const (
  20. POSTGRES = "postgres"
  21. SQLITE = "sqlite3"
  22. MYSQL = "mysql"
  23. MSSQL = "mssql"
  24. ORACLE = "oracle"
  25. )
  26. type join struct {
  27. joinType string
  28. joinTable string
  29. joinCond Cond
  30. }
  31. type union struct {
  32. unionType string
  33. builder *Builder
  34. }
  35. type limit struct {
  36. limitN int
  37. offset int
  38. }
  39. // Builder describes a SQL statement
  40. type Builder struct {
  41. optype
  42. dialect string
  43. isNested bool
  44. into string
  45. from string
  46. subQuery *Builder
  47. cond Cond
  48. selects []string
  49. joins []join
  50. unions []union
  51. limitation *limit
  52. insertCols []string
  53. insertVals []interface{}
  54. updates []Eq
  55. orderBy string
  56. groupBy string
  57. having string
  58. }
  59. // Dialect sets the db dialect of Builder.
  60. func Dialect(dialect string) *Builder {
  61. builder := &Builder{cond: NewCond(), dialect: dialect}
  62. return builder
  63. }
  64. // MySQL is shortcut of Dialect(MySQL)
  65. func MySQL() *Builder {
  66. return Dialect(MYSQL)
  67. }
  68. // MsSQL is shortcut of Dialect(MsSQL)
  69. func MsSQL() *Builder {
  70. return Dialect(MSSQL)
  71. }
  72. // Oracle is shortcut of Dialect(Oracle)
  73. func Oracle() *Builder {
  74. return Dialect(ORACLE)
  75. }
  76. // Postgres is shortcut of Dialect(Postgres)
  77. func Postgres() *Builder {
  78. return Dialect(POSTGRES)
  79. }
  80. // SQLite is shortcut of Dialect(SQLITE)
  81. func SQLite() *Builder {
  82. return Dialect(SQLITE)
  83. }
  84. // Where sets where SQL
  85. func (b *Builder) Where(cond Cond) *Builder {
  86. if b.cond.IsValid() {
  87. b.cond = b.cond.And(cond)
  88. } else {
  89. b.cond = cond
  90. }
  91. return b
  92. }
  93. // From sets from subject(can be a table name in string or a builder pointer) and its alias
  94. func (b *Builder) From(subject interface{}, alias ...string) *Builder {
  95. switch subject.(type) {
  96. case *Builder:
  97. b.subQuery = subject.(*Builder)
  98. if len(alias) > 0 {
  99. b.from = alias[0]
  100. } else {
  101. b.isNested = true
  102. }
  103. case string:
  104. b.from = subject.(string)
  105. if len(alias) > 0 {
  106. b.from = b.from + " " + alias[0]
  107. }
  108. }
  109. return b
  110. }
  111. // TableName returns the table name
  112. func (b *Builder) TableName() string {
  113. if b.optype == insertType {
  114. return b.into
  115. }
  116. return b.from
  117. }
  118. // Into sets insert table name
  119. func (b *Builder) Into(tableName string) *Builder {
  120. b.into = tableName
  121. return b
  122. }
  123. // Join sets join table and conditions
  124. func (b *Builder) Join(joinType, joinTable string, joinCond interface{}) *Builder {
  125. switch joinCond.(type) {
  126. case Cond:
  127. b.joins = append(b.joins, join{joinType, joinTable, joinCond.(Cond)})
  128. case string:
  129. b.joins = append(b.joins, join{joinType, joinTable, Expr(joinCond.(string))})
  130. }
  131. return b
  132. }
  133. // Union sets union conditions
  134. func (b *Builder) Union(unionTp string, unionCond *Builder) *Builder {
  135. var builder *Builder
  136. if b.optype != unionType {
  137. builder = &Builder{cond: NewCond()}
  138. builder.optype = unionType
  139. builder.dialect = b.dialect
  140. builder.selects = b.selects
  141. currentUnions := b.unions
  142. // erase sub unions (actually append to new Builder.unions)
  143. b.unions = nil
  144. for e := range currentUnions {
  145. currentUnions[e].builder.dialect = b.dialect
  146. }
  147. builder.unions = append(append(builder.unions, union{"", b}), currentUnions...)
  148. } else {
  149. builder = b
  150. }
  151. if unionCond != nil {
  152. if unionCond.dialect == "" && builder.dialect != "" {
  153. unionCond.dialect = builder.dialect
  154. }
  155. builder.unions = append(builder.unions, union{unionTp, unionCond})
  156. }
  157. return builder
  158. }
  159. // Limit sets limitN condition
  160. func (b *Builder) Limit(limitN int, offset ...int) *Builder {
  161. b.limitation = &limit{limitN: limitN}
  162. if len(offset) > 0 {
  163. b.limitation.offset = offset[0]
  164. }
  165. return b
  166. }
  167. // InnerJoin sets inner join
  168. func (b *Builder) InnerJoin(joinTable string, joinCond interface{}) *Builder {
  169. return b.Join("INNER", joinTable, joinCond)
  170. }
  171. // LeftJoin sets left join SQL
  172. func (b *Builder) LeftJoin(joinTable string, joinCond interface{}) *Builder {
  173. return b.Join("LEFT", joinTable, joinCond)
  174. }
  175. // RightJoin sets right join SQL
  176. func (b *Builder) RightJoin(joinTable string, joinCond interface{}) *Builder {
  177. return b.Join("RIGHT", joinTable, joinCond)
  178. }
  179. // CrossJoin sets cross join SQL
  180. func (b *Builder) CrossJoin(joinTable string, joinCond interface{}) *Builder {
  181. return b.Join("CROSS", joinTable, joinCond)
  182. }
  183. // FullJoin sets full join SQL
  184. func (b *Builder) FullJoin(joinTable string, joinCond interface{}) *Builder {
  185. return b.Join("FULL", joinTable, joinCond)
  186. }
  187. // Select sets select SQL
  188. func (b *Builder) Select(cols ...string) *Builder {
  189. b.selects = cols
  190. if b.optype == condType {
  191. b.optype = selectType
  192. }
  193. return b
  194. }
  195. // And sets AND condition
  196. func (b *Builder) And(cond Cond) *Builder {
  197. b.cond = And(b.cond, cond)
  198. return b
  199. }
  200. // Or sets OR condition
  201. func (b *Builder) Or(cond Cond) *Builder {
  202. b.cond = Or(b.cond, cond)
  203. return b
  204. }
  205. type insertColsSorter struct {
  206. cols []string
  207. vals []interface{}
  208. }
  209. func (s insertColsSorter) Len() int {
  210. return len(s.cols)
  211. }
  212. func (s insertColsSorter) Swap(i, j int) {
  213. s.cols[i], s.cols[j] = s.cols[j], s.cols[i]
  214. s.vals[i], s.vals[j] = s.vals[j], s.vals[i]
  215. }
  216. func (s insertColsSorter) Less(i, j int) bool {
  217. return s.cols[i] < s.cols[j]
  218. }
  219. // Insert sets insert SQL
  220. func (b *Builder) Insert(eq ...interface{}) *Builder {
  221. if len(eq) > 0 {
  222. var paramType = -1
  223. for _, e := range eq {
  224. switch t := e.(type) {
  225. case Eq:
  226. if paramType == -1 {
  227. paramType = 0
  228. }
  229. if paramType != 0 {
  230. break
  231. }
  232. for k, v := range t {
  233. b.insertCols = append(b.insertCols, k)
  234. b.insertVals = append(b.insertVals, v)
  235. }
  236. case string:
  237. if paramType == -1 {
  238. paramType = 1
  239. }
  240. if paramType != 1 {
  241. break
  242. }
  243. b.insertCols = append(b.insertCols, t)
  244. }
  245. }
  246. }
  247. if len(b.insertCols) == len(b.insertVals) {
  248. sort.Sort(insertColsSorter{
  249. cols: b.insertCols,
  250. vals: b.insertVals,
  251. })
  252. }
  253. b.optype = insertType
  254. return b
  255. }
  256. // Update sets update SQL
  257. func (b *Builder) Update(updates ...Eq) *Builder {
  258. b.updates = make([]Eq, 0, len(updates))
  259. for _, update := range updates {
  260. if update.IsValid() {
  261. b.updates = append(b.updates, update)
  262. }
  263. }
  264. b.optype = updateType
  265. return b
  266. }
  267. // Delete sets delete SQL
  268. func (b *Builder) Delete(conds ...Cond) *Builder {
  269. b.cond = b.cond.And(conds...)
  270. b.optype = deleteType
  271. return b
  272. }
  273. // WriteTo implements Writer interface
  274. func (b *Builder) WriteTo(w Writer) error {
  275. switch b.optype {
  276. /*case condType:
  277. return b.cond.WriteTo(w)*/
  278. case selectType:
  279. return b.selectWriteTo(w)
  280. case insertType:
  281. return b.insertWriteTo(w)
  282. case updateType:
  283. return b.updateWriteTo(w)
  284. case deleteType:
  285. return b.deleteWriteTo(w)
  286. case unionType:
  287. return b.unionWriteTo(w)
  288. }
  289. return ErrNotSupportType
  290. }
  291. // ToSQL convert a builder to SQL and args
  292. func (b *Builder) ToSQL() (string, []interface{}, error) {
  293. w := NewWriter()
  294. if err := b.WriteTo(w); err != nil {
  295. return "", nil, err
  296. }
  297. // in case of sql.NamedArg in args
  298. for e := range w.args {
  299. if namedArg, ok := w.args[e].(sql2.NamedArg); ok {
  300. w.args[e] = namedArg.Value
  301. }
  302. }
  303. var sql = w.writer.String()
  304. var err error
  305. switch b.dialect {
  306. case ORACLE, MSSQL:
  307. // This is for compatibility with different sql drivers
  308. for e := range w.args {
  309. w.args[e] = sql2.Named(fmt.Sprintf("p%d", e+1), w.args[e])
  310. }
  311. var prefix string
  312. if b.dialect == ORACLE {
  313. prefix = ":p"
  314. } else {
  315. prefix = "@p"
  316. }
  317. if sql, err = ConvertPlaceholder(sql, prefix); err != nil {
  318. return "", nil, err
  319. }
  320. case POSTGRES:
  321. if sql, err = ConvertPlaceholder(sql, "$"); err != nil {
  322. return "", nil, err
  323. }
  324. }
  325. return sql, w.args, nil
  326. }
  327. // ToBoundSQL
  328. func (b *Builder) ToBoundSQL() (string, error) {
  329. w := NewWriter()
  330. if err := b.WriteTo(w); err != nil {
  331. return "", err
  332. }
  333. return ConvertToBoundSQL(w.writer.String(), w.args)
  334. }