You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

616 lines
14 KiB

  1. package mssql
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/binary"
  6. "fmt"
  7. "math"
  8. "reflect"
  9. "strconv"
  10. "strings"
  11. "time"
  12. )
  13. type Bulk struct {
  14. cn *Conn
  15. metadata []columnStruct
  16. bulkColumns []columnStruct
  17. columnsName []string
  18. tablename string
  19. numRows int
  20. headerSent bool
  21. Options BulkOptions
  22. Debug bool
  23. }
  24. type BulkOptions struct {
  25. CheckConstraints bool
  26. FireTriggers bool
  27. KeepNulls bool
  28. KilobytesPerBatch int
  29. RowsPerBatch int
  30. Order []string
  31. Tablock bool
  32. }
  33. type DataValue interface{}
  34. func (cn *Conn) CreateBulk(table string, columns []string) (_ *Bulk) {
  35. b := Bulk{cn: cn, tablename: table, headerSent: false, columnsName: columns}
  36. b.Debug = false
  37. return &b
  38. }
  39. func (b *Bulk) sendBulkCommand() (err error) {
  40. //get table columns info
  41. err = b.getMetadata()
  42. if err != nil {
  43. return err
  44. }
  45. //match the columns
  46. for _, colname := range b.columnsName {
  47. var bulkCol *columnStruct
  48. for _, m := range b.metadata {
  49. if m.ColName == colname {
  50. bulkCol = &m
  51. break
  52. }
  53. }
  54. if bulkCol != nil {
  55. if bulkCol.ti.TypeId == typeUdt {
  56. //send udt as binary
  57. bulkCol.ti.TypeId = typeBigVarBin
  58. }
  59. b.bulkColumns = append(b.bulkColumns, *bulkCol)
  60. b.dlogf("Adding column %s %s %#x", colname, bulkCol.ColName, bulkCol.ti.TypeId)
  61. } else {
  62. return fmt.Errorf("Column %s does not exist in destination table %s", colname, b.tablename)
  63. }
  64. }
  65. //create the bulk command
  66. //columns definitions
  67. var col_defs bytes.Buffer
  68. for i, col := range b.bulkColumns {
  69. if i != 0 {
  70. col_defs.WriteString(", ")
  71. }
  72. col_defs.WriteString("[" + col.ColName + "] " + makeDecl(col.ti))
  73. }
  74. //options
  75. var with_opts []string
  76. if b.Options.CheckConstraints {
  77. with_opts = append(with_opts, "CHECK_CONSTRAINTS")
  78. }
  79. if b.Options.FireTriggers {
  80. with_opts = append(with_opts, "FIRE_TRIGGERS")
  81. }
  82. if b.Options.KeepNulls {
  83. with_opts = append(with_opts, "KEEP_NULLS")
  84. }
  85. if b.Options.KilobytesPerBatch > 0 {
  86. with_opts = append(with_opts, fmt.Sprintf("KILOBYTES_PER_BATCH = %d", b.Options.KilobytesPerBatch))
  87. }
  88. if b.Options.RowsPerBatch > 0 {
  89. with_opts = append(with_opts, fmt.Sprintf("ROWS_PER_BATCH = %d", b.Options.RowsPerBatch))
  90. }
  91. if len(b.Options.Order) > 0 {
  92. with_opts = append(with_opts, fmt.Sprintf("ORDER(%s)", strings.Join(b.Options.Order, ",")))
  93. }
  94. if b.Options.Tablock {
  95. with_opts = append(with_opts, "TABLOCK")
  96. }
  97. var with_part string
  98. if len(with_opts) > 0 {
  99. with_part = fmt.Sprintf("WITH (%s)", strings.Join(with_opts, ","))
  100. }
  101. query := fmt.Sprintf("INSERT BULK %s (%s) %s", b.tablename, col_defs.String(), with_part)
  102. stmt, err := b.cn.Prepare(query)
  103. if err != nil {
  104. return fmt.Errorf("Prepare failed: %s", err.Error())
  105. }
  106. b.dlogf(query)
  107. _, err = stmt.Exec(nil)
  108. if err != nil {
  109. return err
  110. }
  111. b.headerSent = true
  112. var buf = b.cn.sess.buf
  113. buf.BeginPacket(packBulkLoadBCP)
  114. // send the columns metadata
  115. columnMetadata := b.createColMetadata()
  116. _, err = buf.Write(columnMetadata)
  117. return
  118. }
  119. // AddRow immediately writes the row to the destination table.
  120. // The arguments are the row values in the order they were specified.
  121. func (b *Bulk) AddRow(row []interface{}) (err error) {
  122. if !b.headerSent {
  123. err = b.sendBulkCommand()
  124. if err != nil {
  125. return
  126. }
  127. }
  128. if len(row) != len(b.bulkColumns) {
  129. return fmt.Errorf("Row does not have the same number of columns than the destination table %d %d",
  130. len(row), len(b.bulkColumns))
  131. }
  132. bytes, err := b.makeRowData(row)
  133. if err != nil {
  134. return
  135. }
  136. _, err = b.cn.sess.buf.Write(bytes)
  137. if err != nil {
  138. return
  139. }
  140. b.numRows = b.numRows + 1
  141. return
  142. }
  143. func (b *Bulk) makeRowData(row []interface{}) ([]byte, error) {
  144. buf := new(bytes.Buffer)
  145. buf.WriteByte(byte(tokenRow))
  146. var logcol bytes.Buffer
  147. for i, col := range b.bulkColumns {
  148. if b.Debug {
  149. logcol.WriteString(fmt.Sprintf(" col[%d]='%v' ", i, row[i]))
  150. }
  151. param, err := b.makeParam(row[i], col)
  152. if err != nil {
  153. return nil, fmt.Errorf("bulkcopy: %s", err.Error())
  154. }
  155. if col.ti.Writer == nil {
  156. return nil, fmt.Errorf("no writer for column: %s, TypeId: %#x",
  157. col.ColName, col.ti.TypeId)
  158. }
  159. err = col.ti.Writer(buf, param.ti, param.buffer)
  160. if err != nil {
  161. return nil, fmt.Errorf("bulkcopy: %s", err.Error())
  162. }
  163. }
  164. b.dlogf("row[%d] %s\n", b.numRows, logcol.String())
  165. return buf.Bytes(), nil
  166. }
  167. func (b *Bulk) Done() (rowcount int64, err error) {
  168. if b.headerSent == false {
  169. //no rows had been sent
  170. return 0, nil
  171. }
  172. var buf = b.cn.sess.buf
  173. buf.WriteByte(byte(tokenDone))
  174. binary.Write(buf, binary.LittleEndian, uint16(doneFinal))
  175. binary.Write(buf, binary.LittleEndian, uint16(0)) // curcmd
  176. if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
  177. binary.Write(buf, binary.LittleEndian, uint64(0)) //rowcount 0
  178. } else {
  179. binary.Write(buf, binary.LittleEndian, uint32(0)) //rowcount 0
  180. }
  181. buf.FinishPacket()
  182. tokchan := make(chan tokenStruct, 5)
  183. go processResponse(context.Background(), b.cn.sess, tokchan, nil)
  184. var rowCount int64
  185. for token := range tokchan {
  186. switch token := token.(type) {
  187. case doneStruct:
  188. if token.Status&doneCount != 0 {
  189. rowCount = int64(token.RowCount)
  190. }
  191. if token.isError() {
  192. return 0, token.getError()
  193. }
  194. case error:
  195. return 0, b.cn.checkBadConn(token)
  196. }
  197. }
  198. return rowCount, nil
  199. }
  200. func (b *Bulk) createColMetadata() []byte {
  201. buf := new(bytes.Buffer)
  202. buf.WriteByte(byte(tokenColMetadata)) // token
  203. binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count
  204. for i, col := range b.bulkColumns {
  205. if b.cn.sess.loginAck.TDSVersion >= verTDS72 {
  206. binary.Write(buf, binary.LittleEndian, uint32(col.UserType)) // usertype, always 0?
  207. } else {
  208. binary.Write(buf, binary.LittleEndian, uint16(col.UserType))
  209. }
  210. binary.Write(buf, binary.LittleEndian, uint16(col.Flags))
  211. writeTypeInfo(buf, &b.bulkColumns[i].ti)
  212. if col.ti.TypeId == typeNText ||
  213. col.ti.TypeId == typeText ||
  214. col.ti.TypeId == typeImage {
  215. tablename_ucs2 := str2ucs2(b.tablename)
  216. binary.Write(buf, binary.LittleEndian, uint16(len(tablename_ucs2)/2))
  217. buf.Write(tablename_ucs2)
  218. }
  219. colname_ucs2 := str2ucs2(col.ColName)
  220. buf.WriteByte(uint8(len(colname_ucs2) / 2))
  221. buf.Write(colname_ucs2)
  222. }
  223. return buf.Bytes()
  224. }
  225. func (b *Bulk) getMetadata() (err error) {
  226. stmt, err := b.cn.Prepare("SET FMTONLY ON")
  227. if err != nil {
  228. return
  229. }
  230. _, err = stmt.Exec(nil)
  231. if err != nil {
  232. return
  233. }
  234. //get columns info
  235. stmt, err = b.cn.Prepare(fmt.Sprintf("select * from %s SET FMTONLY OFF", b.tablename))
  236. if err != nil {
  237. return
  238. }
  239. stmt2 := stmt.(*Stmt)
  240. cols, err := stmt2.QueryMeta()
  241. if err != nil {
  242. return fmt.Errorf("get columns info failed: %v", err.Error())
  243. }
  244. b.metadata = cols
  245. if b.Debug {
  246. for _, col := range b.metadata {
  247. b.dlogf("col: %s typeId: %#x size: %d scale: %d prec: %d flags: %d lcid: %#x\n",
  248. col.ColName, col.ti.TypeId, col.ti.Size, col.ti.Scale, col.ti.Prec,
  249. col.Flags, col.ti.Collation.LcidAndFlags)
  250. }
  251. }
  252. return nil
  253. }
  254. // QueryMeta is almost the same as mssql.Stmt.Query, but returns all the columns info.
  255. func (s *Stmt) QueryMeta() (cols []columnStruct, err error) {
  256. if err = s.sendQuery(nil); err != nil {
  257. return
  258. }
  259. tokchan := make(chan tokenStruct, 5)
  260. go processResponse(context.Background(), s.c.sess, tokchan, s.c.outs)
  261. s.c.clearOuts()
  262. loop:
  263. for tok := range tokchan {
  264. switch token := tok.(type) {
  265. case doneStruct:
  266. break loop
  267. case []columnStruct:
  268. cols = token
  269. break loop
  270. case error:
  271. return nil, s.c.checkBadConn(token)
  272. }
  273. }
  274. return cols, nil
  275. }
  276. func (b *Bulk) makeParam(val DataValue, col columnStruct) (res Param, err error) {
  277. res.ti.Size = col.ti.Size
  278. res.ti.TypeId = col.ti.TypeId
  279. if val == nil {
  280. res.ti.Size = 0
  281. return
  282. }
  283. switch col.ti.TypeId {
  284. case typeInt1, typeInt2, typeInt4, typeInt8, typeIntN:
  285. var intvalue int64
  286. switch val := val.(type) {
  287. case int:
  288. intvalue = int64(val)
  289. case int32:
  290. intvalue = int64(val)
  291. case int64:
  292. intvalue = val
  293. default:
  294. err = fmt.Errorf("mssql: invalid type for int column")
  295. return
  296. }
  297. res.buffer = make([]byte, res.ti.Size)
  298. if col.ti.Size == 1 {
  299. res.buffer[0] = byte(intvalue)
  300. } else if col.ti.Size == 2 {
  301. binary.LittleEndian.PutUint16(res.buffer, uint16(intvalue))
  302. } else if col.ti.Size == 4 {
  303. binary.LittleEndian.PutUint32(res.buffer, uint32(intvalue))
  304. } else if col.ti.Size == 8 {
  305. binary.LittleEndian.PutUint64(res.buffer, uint64(intvalue))
  306. }
  307. case typeFlt4, typeFlt8, typeFltN:
  308. var floatvalue float64
  309. switch val := val.(type) {
  310. case float32:
  311. floatvalue = float64(val)
  312. case float64:
  313. floatvalue = val
  314. case int:
  315. floatvalue = float64(val)
  316. case int64:
  317. floatvalue = float64(val)
  318. default:
  319. err = fmt.Errorf("mssql: invalid type for float column: %s", val)
  320. return
  321. }
  322. if col.ti.Size == 4 {
  323. res.buffer = make([]byte, 4)
  324. binary.LittleEndian.PutUint32(res.buffer, math.Float32bits(float32(floatvalue)))
  325. } else if col.ti.Size == 8 {
  326. res.buffer = make([]byte, 8)
  327. binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(floatvalue))
  328. }
  329. case typeNVarChar, typeNText, typeNChar:
  330. switch val := val.(type) {
  331. case string:
  332. res.buffer = str2ucs2(val)
  333. case []byte:
  334. res.buffer = val
  335. default:
  336. err = fmt.Errorf("mssql: invalid type for nvarchar column: %s", val)
  337. return
  338. }
  339. res.ti.Size = len(res.buffer)
  340. case typeVarChar, typeBigVarChar, typeText, typeChar, typeBigChar:
  341. switch val := val.(type) {
  342. case string:
  343. res.buffer = []byte(val)
  344. case []byte:
  345. res.buffer = val
  346. default:
  347. err = fmt.Errorf("mssql: invalid type for varchar column: %s", val)
  348. return
  349. }
  350. res.ti.Size = len(res.buffer)
  351. case typeBit, typeBitN:
  352. if reflect.TypeOf(val).Kind() != reflect.Bool {
  353. err = fmt.Errorf("mssql: invalid type for bit column: %s", val)
  354. return
  355. }
  356. res.ti.TypeId = typeBitN
  357. res.ti.Size = 1
  358. res.buffer = make([]byte, 1)
  359. if val.(bool) {
  360. res.buffer[0] = 1
  361. }
  362. case typeDateTime2N, typeDateTimeOffsetN:
  363. switch val := val.(type) {
  364. case time.Time:
  365. days, ns := dateTime2(val)
  366. ns /= int64(math.Pow10(int(col.ti.Scale)*-1) * 1000000000)
  367. var data = make([]byte, 5)
  368. data[0] = byte(ns)
  369. data[1] = byte(ns >> 8)
  370. data[2] = byte(ns >> 16)
  371. data[3] = byte(ns >> 24)
  372. data[4] = byte(ns >> 32)
  373. if col.ti.Scale <= 2 {
  374. res.ti.Size = 6
  375. } else if col.ti.Scale <= 4 {
  376. res.ti.Size = 7
  377. } else {
  378. res.ti.Size = 8
  379. }
  380. var buf []byte
  381. buf = make([]byte, res.ti.Size)
  382. copy(buf, data[0:res.ti.Size-3])
  383. buf[res.ti.Size-3] = byte(days)
  384. buf[res.ti.Size-2] = byte(days >> 8)
  385. buf[res.ti.Size-1] = byte(days >> 16)
  386. if col.ti.TypeId == typeDateTimeOffsetN {
  387. _, offset := val.Zone()
  388. var offsetMinute = uint16(offset / 60)
  389. buf = append(buf, byte(offsetMinute))
  390. buf = append(buf, byte(offsetMinute>>8))
  391. res.ti.Size = res.ti.Size + 2
  392. }
  393. res.buffer = buf
  394. default:
  395. err = fmt.Errorf("mssql: invalid type for datetime2 column: %s", val)
  396. return
  397. }
  398. case typeDateN:
  399. switch val := val.(type) {
  400. case time.Time:
  401. days, _ := dateTime2(val)
  402. res.ti.Size = 3
  403. res.buffer = make([]byte, 3)
  404. res.buffer[0] = byte(days)
  405. res.buffer[1] = byte(days >> 8)
  406. res.buffer[2] = byte(days >> 16)
  407. default:
  408. err = fmt.Errorf("mssql: invalid type for date column: %s", val)
  409. return
  410. }
  411. case typeDateTime, typeDateTimeN, typeDateTim4:
  412. switch val := val.(type) {
  413. case time.Time:
  414. if col.ti.Size == 4 {
  415. res.ti.Size = 4
  416. res.buffer = make([]byte, 4)
  417. ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
  418. dur := val.Sub(ref)
  419. days := dur / (24 * time.Hour)
  420. if days < 0 {
  421. err = fmt.Errorf("mssql: Date %s is out of range", val)
  422. return
  423. }
  424. mins := val.Hour()*60 + val.Minute()
  425. binary.LittleEndian.PutUint16(res.buffer[0:2], uint16(days))
  426. binary.LittleEndian.PutUint16(res.buffer[2:4], uint16(mins))
  427. } else if col.ti.Size == 8 {
  428. res.ti.Size = 8
  429. res.buffer = make([]byte, 8)
  430. days := divFloor(val.Unix(), 24*60*60)
  431. //25567 - number of days since Jan 1 1900 UTC to Jan 1 1970
  432. days = days + 25567
  433. tm := (val.Hour()*60*60+val.Minute()*60+val.Second())*300 + int(val.Nanosecond()/10000000*3)
  434. binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days))
  435. binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm))
  436. } else {
  437. err = fmt.Errorf("mssql: invalid size of column")
  438. }
  439. default:
  440. err = fmt.Errorf("mssql: invalid type for datetime column: %s", val)
  441. }
  442. // case typeMoney, typeMoney4, typeMoneyN:
  443. case typeDecimal, typeDecimalN, typeNumeric, typeNumericN:
  444. var value float64
  445. switch v := val.(type) {
  446. case int:
  447. value = float64(v)
  448. case int8:
  449. value = float64(v)
  450. case int16:
  451. value = float64(v)
  452. case int32:
  453. value = float64(v)
  454. case int64:
  455. value = float64(v)
  456. case float32:
  457. value = float64(v)
  458. case float64:
  459. value = v
  460. case string:
  461. if value, err = strconv.ParseFloat(v, 64); err != nil {
  462. return res, fmt.Errorf("bulk: unable to convert string to float: %v", err)
  463. }
  464. default:
  465. return res, fmt.Errorf("unknown value for decimal: %#v", v)
  466. }
  467. perc := col.ti.Prec
  468. scale := col.ti.Scale
  469. var dec Decimal
  470. dec, err = Float64ToDecimalScale(value, scale)
  471. if err != nil {
  472. return res, err
  473. }
  474. dec.prec = perc
  475. var length byte
  476. switch {
  477. case perc <= 9:
  478. length = 4
  479. case perc <= 19:
  480. length = 8
  481. case perc <= 28:
  482. length = 12
  483. default:
  484. length = 16
  485. }
  486. buf := make([]byte, length+1)
  487. // first byte length written by typeInfo.writer
  488. res.ti.Size = int(length) + 1
  489. // second byte sign
  490. if value < 0 {
  491. buf[0] = 0
  492. } else {
  493. buf[0] = 1
  494. }
  495. ub := dec.UnscaledBytes()
  496. l := len(ub)
  497. if l > int(length) {
  498. err = fmt.Errorf("decimal out of range: %s", dec)
  499. return res, err
  500. }
  501. // reverse the bytes
  502. for i, j := 1, l-1; j >= 0; i, j = i+1, j-1 {
  503. buf[i] = ub[j]
  504. }
  505. res.buffer = buf
  506. case typeBigVarBin:
  507. switch val := val.(type) {
  508. case []byte:
  509. res.ti.Size = len(val)
  510. res.buffer = val
  511. default:
  512. err = fmt.Errorf("mssql: invalid type for Binary column: %s", val)
  513. return
  514. }
  515. case typeGuid:
  516. switch val := val.(type) {
  517. case []byte:
  518. res.ti.Size = len(val)
  519. res.buffer = val
  520. default:
  521. err = fmt.Errorf("mssql: invalid type for Guid column: %s", val)
  522. return
  523. }
  524. default:
  525. err = fmt.Errorf("mssql: type %x not implemented", col.ti.TypeId)
  526. }
  527. return
  528. }
  529. func (b *Bulk) dlogf(format string, v ...interface{}) {
  530. if b.Debug {
  531. b.cn.sess.log.Printf(format, v...)
  532. }
  533. }