You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1070 lines
25 KiB

  1. package mssql
  2. import (
  3. "crypto/tls"
  4. "crypto/x509"
  5. "encoding/binary"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "io/ioutil"
  10. "net"
  11. "os"
  12. "sort"
  13. "strconv"
  14. "strings"
  15. "time"
  16. "unicode/utf16"
  17. "unicode/utf8"
  18. )
  19. func parseInstances(msg []byte) map[string]map[string]string {
  20. results := map[string]map[string]string{}
  21. if len(msg) > 3 && msg[0] == 5 {
  22. out_s := string(msg[3:])
  23. tokens := strings.Split(out_s, ";")
  24. instdict := map[string]string{}
  25. got_name := false
  26. var name string
  27. for _, token := range tokens {
  28. if got_name {
  29. instdict[name] = token
  30. got_name = false
  31. } else {
  32. name = token
  33. if len(name) == 0 {
  34. if len(instdict) == 0 {
  35. break
  36. }
  37. results[strings.ToUpper(instdict["InstanceName"])] = instdict
  38. instdict = map[string]string{}
  39. continue
  40. }
  41. got_name = true
  42. }
  43. }
  44. }
  45. return results
  46. }
  47. func getInstances(address string) (map[string]map[string]string, error) {
  48. conn, err := net.DialTimeout("udp", address+":1434", 5*time.Second)
  49. if err != nil {
  50. return nil, err
  51. }
  52. defer conn.Close()
  53. conn.SetDeadline(time.Now().Add(5 * time.Second))
  54. _, err = conn.Write([]byte{3})
  55. if err != nil {
  56. return nil, err
  57. }
  58. var resp = make([]byte, 16*1024-1)
  59. read, err := conn.Read(resp)
  60. if err != nil {
  61. return nil, err
  62. }
  63. return parseInstances(resp[:read]), nil
  64. }
  65. // tds versions
  66. const (
  67. verTDS70 = 0x70000000
  68. verTDS71 = 0x71000000
  69. verTDS71rev1 = 0x71000001
  70. verTDS72 = 0x72090002
  71. verTDS73A = 0x730A0003
  72. verTDS73 = verTDS73A
  73. verTDS73B = 0x730B0003
  74. verTDS74 = 0x74000004
  75. )
  76. // packet types
  77. const (
  78. packSQLBatch = 1
  79. packRPCRequest = 3
  80. packReply = 4
  81. packCancel = 6
  82. packBulkLoadBCP = 7
  83. packTransMgrReq = 14
  84. packNormal = 15
  85. packLogin7 = 16
  86. packSSPIMessage = 17
  87. packPrelogin = 18
  88. )
  89. // prelogin fields
  90. // http://msdn.microsoft.com/en-us/library/dd357559.aspx
  91. const (
  92. preloginVERSION = 0
  93. preloginENCRYPTION = 1
  94. preloginINSTOPT = 2
  95. preloginTHREADID = 3
  96. preloginMARS = 4
  97. preloginTRACEID = 5
  98. preloginTERMINATOR = 0xff
  99. )
  100. const (
  101. encryptOff = 0 // Encryption is available but off.
  102. encryptOn = 1 // Encryption is available and on.
  103. encryptNotSup = 2 // Encryption is not available.
  104. encryptReq = 3 // Encryption is required.
  105. )
  106. type tdsSession struct {
  107. buf *tdsBuffer
  108. loginAck loginAckStruct
  109. database string
  110. partner string
  111. columns []columnStruct
  112. tranid uint64
  113. logFlags uint64
  114. log *Logger
  115. routedServer string
  116. routedPort uint16
  117. }
  118. const (
  119. logErrors = 1
  120. logMessages = 2
  121. logRows = 4
  122. logSQL = 8
  123. logParams = 16
  124. logTransaction = 32
  125. )
  126. type columnStruct struct {
  127. UserType uint32
  128. Flags uint16
  129. ColName string
  130. ti typeInfo
  131. }
  132. type KeySlice []uint8
  133. func (p KeySlice) Len() int { return len(p) }
  134. func (p KeySlice) Less(i, j int) bool { return p[i] < p[j] }
  135. func (p KeySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
  136. // http://msdn.microsoft.com/en-us/library/dd357559.aspx
  137. func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
  138. var err error
  139. w.BeginPacket(packPrelogin)
  140. offset := uint16(5*len(fields) + 1)
  141. keys := make(KeySlice, 0, len(fields))
  142. for k, _ := range fields {
  143. keys = append(keys, k)
  144. }
  145. sort.Sort(keys)
  146. // writing header
  147. for _, k := range keys {
  148. err = w.WriteByte(k)
  149. if err != nil {
  150. return err
  151. }
  152. err = binary.Write(w, binary.BigEndian, offset)
  153. if err != nil {
  154. return err
  155. }
  156. v := fields[k]
  157. size := uint16(len(v))
  158. err = binary.Write(w, binary.BigEndian, size)
  159. if err != nil {
  160. return err
  161. }
  162. offset += size
  163. }
  164. err = w.WriteByte(preloginTERMINATOR)
  165. if err != nil {
  166. return err
  167. }
  168. // writing values
  169. for _, k := range keys {
  170. v := fields[k]
  171. written, err := w.Write(v)
  172. if err != nil {
  173. return err
  174. }
  175. if written != len(v) {
  176. return errors.New("Write method didn't write the whole value")
  177. }
  178. }
  179. return w.FinishPacket()
  180. }
  181. func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
  182. packet_type, err := r.BeginRead()
  183. if err != nil {
  184. return nil, err
  185. }
  186. struct_buf, err := ioutil.ReadAll(r)
  187. if err != nil {
  188. return nil, err
  189. }
  190. if packet_type != 4 {
  191. return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE")
  192. }
  193. offset := 0
  194. results := map[uint8][]byte{}
  195. for true {
  196. rec_type := struct_buf[offset]
  197. if rec_type == preloginTERMINATOR {
  198. break
  199. }
  200. rec_offset := binary.BigEndian.Uint16(struct_buf[offset+1:])
  201. rec_len := binary.BigEndian.Uint16(struct_buf[offset+3:])
  202. value := struct_buf[rec_offset : rec_offset+rec_len]
  203. results[rec_type] = value
  204. offset += 5
  205. }
  206. return results, nil
  207. }
  208. // OptionFlags2
  209. // http://msdn.microsoft.com/en-us/library/dd304019.aspx
  210. const (
  211. fLanguageFatal = 1
  212. fODBC = 2
  213. fTransBoundary = 4
  214. fCacheConnect = 8
  215. fIntSecurity = 0x80
  216. )
  217. // TypeFlags
  218. const (
  219. // 4 bits for fSQLType
  220. // 1 bit for fOLEDB
  221. fReadOnlyIntent = 32
  222. )
  223. type login struct {
  224. TDSVersion uint32
  225. PacketSize uint32
  226. ClientProgVer uint32
  227. ClientPID uint32
  228. ConnectionID uint32
  229. OptionFlags1 uint8
  230. OptionFlags2 uint8
  231. TypeFlags uint8
  232. OptionFlags3 uint8
  233. ClientTimeZone int32
  234. ClientLCID uint32
  235. HostName string
  236. UserName string
  237. Password string
  238. AppName string
  239. ServerName string
  240. CtlIntName string
  241. Language string
  242. Database string
  243. ClientID [6]byte
  244. SSPI []byte
  245. AtchDBFile string
  246. ChangePassword string
  247. }
  248. type loginHeader struct {
  249. Length uint32
  250. TDSVersion uint32
  251. PacketSize uint32
  252. ClientProgVer uint32
  253. ClientPID uint32
  254. ConnectionID uint32
  255. OptionFlags1 uint8
  256. OptionFlags2 uint8
  257. TypeFlags uint8
  258. OptionFlags3 uint8
  259. ClientTimeZone int32
  260. ClientLCID uint32
  261. HostNameOffset uint16
  262. HostNameLength uint16
  263. UserNameOffset uint16
  264. UserNameLength uint16
  265. PasswordOffset uint16
  266. PasswordLength uint16
  267. AppNameOffset uint16
  268. AppNameLength uint16
  269. ServerNameOffset uint16
  270. ServerNameLength uint16
  271. ExtensionOffset uint16
  272. ExtensionLenght uint16
  273. CtlIntNameOffset uint16
  274. CtlIntNameLength uint16
  275. LanguageOffset uint16
  276. LanguageLength uint16
  277. DatabaseOffset uint16
  278. DatabaseLength uint16
  279. ClientID [6]byte
  280. SSPIOffset uint16
  281. SSPILength uint16
  282. AtchDBFileOffset uint16
  283. AtchDBFileLength uint16
  284. ChangePasswordOffset uint16
  285. ChangePasswordLength uint16
  286. SSPILongLength uint32
  287. }
  288. // convert Go string to UTF-16 encoded []byte (littleEndian)
  289. // done manually rather than using bytes and binary packages
  290. // for performance reasons
  291. func str2ucs2(s string) []byte {
  292. res := utf16.Encode([]rune(s))
  293. ucs2 := make([]byte, 2*len(res))
  294. for i := 0; i < len(res); i++ {
  295. ucs2[2*i] = byte(res[i])
  296. ucs2[2*i+1] = byte(res[i] >> 8)
  297. }
  298. return ucs2
  299. }
  300. func ucs22str(s []byte) (string, error) {
  301. if len(s)%2 != 0 {
  302. return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
  303. }
  304. buf := make([]uint16, len(s)/2)
  305. for i := 0; i < len(s); i += 2 {
  306. buf[i/2] = binary.LittleEndian.Uint16(s[i:])
  307. }
  308. return string(utf16.Decode(buf)), nil
  309. }
  310. func manglePassword(password string) []byte {
  311. var ucs2password []byte = str2ucs2(password)
  312. for i, ch := range ucs2password {
  313. ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5
  314. }
  315. return ucs2password
  316. }
  317. // http://msdn.microsoft.com/en-us/library/dd304019.aspx
  318. func sendLogin(w *tdsBuffer, login login) error {
  319. w.BeginPacket(packLogin7)
  320. hostname := str2ucs2(login.HostName)
  321. username := str2ucs2(login.UserName)
  322. password := manglePassword(login.Password)
  323. appname := str2ucs2(login.AppName)
  324. servername := str2ucs2(login.ServerName)
  325. ctlintname := str2ucs2(login.CtlIntName)
  326. language := str2ucs2(login.Language)
  327. database := str2ucs2(login.Database)
  328. atchdbfile := str2ucs2(login.AtchDBFile)
  329. changepassword := str2ucs2(login.ChangePassword)
  330. hdr := loginHeader{
  331. TDSVersion: login.TDSVersion,
  332. PacketSize: login.PacketSize,
  333. ClientProgVer: login.ClientProgVer,
  334. ClientPID: login.ClientPID,
  335. ConnectionID: login.ConnectionID,
  336. OptionFlags1: login.OptionFlags1,
  337. OptionFlags2: login.OptionFlags2,
  338. TypeFlags: login.TypeFlags,
  339. OptionFlags3: login.OptionFlags3,
  340. ClientTimeZone: login.ClientTimeZone,
  341. ClientLCID: login.ClientLCID,
  342. HostNameLength: uint16(utf8.RuneCountInString(login.HostName)),
  343. UserNameLength: uint16(utf8.RuneCountInString(login.UserName)),
  344. PasswordLength: uint16(utf8.RuneCountInString(login.Password)),
  345. AppNameLength: uint16(utf8.RuneCountInString(login.AppName)),
  346. ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)),
  347. CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)),
  348. LanguageLength: uint16(utf8.RuneCountInString(login.Language)),
  349. DatabaseLength: uint16(utf8.RuneCountInString(login.Database)),
  350. ClientID: login.ClientID,
  351. SSPILength: uint16(len(login.SSPI)),
  352. AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)),
  353. ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)),
  354. }
  355. offset := uint16(binary.Size(hdr))
  356. hdr.HostNameOffset = offset
  357. offset += uint16(len(hostname))
  358. hdr.UserNameOffset = offset
  359. offset += uint16(len(username))
  360. hdr.PasswordOffset = offset
  361. offset += uint16(len(password))
  362. hdr.AppNameOffset = offset
  363. offset += uint16(len(appname))
  364. hdr.ServerNameOffset = offset
  365. offset += uint16(len(servername))
  366. hdr.CtlIntNameOffset = offset
  367. offset += uint16(len(ctlintname))
  368. hdr.LanguageOffset = offset
  369. offset += uint16(len(language))
  370. hdr.DatabaseOffset = offset
  371. offset += uint16(len(database))
  372. hdr.SSPIOffset = offset
  373. offset += uint16(len(login.SSPI))
  374. hdr.AtchDBFileOffset = offset
  375. offset += uint16(len(atchdbfile))
  376. hdr.ChangePasswordOffset = offset
  377. offset += uint16(len(changepassword))
  378. hdr.Length = uint32(offset)
  379. var err error
  380. err = binary.Write(w, binary.LittleEndian, &hdr)
  381. if err != nil {
  382. return err
  383. }
  384. _, err = w.Write(hostname)
  385. if err != nil {
  386. return err
  387. }
  388. _, err = w.Write(username)
  389. if err != nil {
  390. return err
  391. }
  392. _, err = w.Write(password)
  393. if err != nil {
  394. return err
  395. }
  396. _, err = w.Write(appname)
  397. if err != nil {
  398. return err
  399. }
  400. _, err = w.Write(servername)
  401. if err != nil {
  402. return err
  403. }
  404. _, err = w.Write(ctlintname)
  405. if err != nil {
  406. return err
  407. }
  408. _, err = w.Write(language)
  409. if err != nil {
  410. return err
  411. }
  412. _, err = w.Write(database)
  413. if err != nil {
  414. return err
  415. }
  416. _, err = w.Write(login.SSPI)
  417. if err != nil {
  418. return err
  419. }
  420. _, err = w.Write(atchdbfile)
  421. if err != nil {
  422. return err
  423. }
  424. _, err = w.Write(changepassword)
  425. if err != nil {
  426. return err
  427. }
  428. return w.FinishPacket()
  429. }
  430. func readUcs2(r io.Reader, numchars int) (res string, err error) {
  431. buf := make([]byte, numchars*2)
  432. _, err = io.ReadFull(r, buf)
  433. if err != nil {
  434. return "", err
  435. }
  436. return ucs22str(buf)
  437. }
  438. func readUsVarChar(r io.Reader) (res string, err error) {
  439. var numchars uint16
  440. err = binary.Read(r, binary.LittleEndian, &numchars)
  441. if err != nil {
  442. return "", err
  443. }
  444. return readUcs2(r, int(numchars))
  445. }
  446. func writeUsVarChar(w io.Writer, s string) (err error) {
  447. buf := str2ucs2(s)
  448. var numchars int = len(buf) / 2
  449. if numchars > 0xffff {
  450. panic("invalid size for US_VARCHAR")
  451. }
  452. err = binary.Write(w, binary.LittleEndian, uint16(numchars))
  453. if err != nil {
  454. return
  455. }
  456. _, err = w.Write(buf)
  457. return
  458. }
  459. func readBVarChar(r io.Reader) (res string, err error) {
  460. var numchars uint8
  461. err = binary.Read(r, binary.LittleEndian, &numchars)
  462. if err != nil {
  463. return "", err
  464. }
  465. return readUcs2(r, int(numchars))
  466. }
  467. func writeBVarChar(w io.Writer, s string) (err error) {
  468. buf := str2ucs2(s)
  469. var numchars int = len(buf) / 2
  470. if numchars > 0xff {
  471. panic("invalid size for B_VARCHAR")
  472. }
  473. err = binary.Write(w, binary.LittleEndian, uint8(numchars))
  474. if err != nil {
  475. return
  476. }
  477. _, err = w.Write(buf)
  478. return
  479. }
  480. func readBVarByte(r io.Reader) (res []byte, err error) {
  481. var length uint8
  482. err = binary.Read(r, binary.LittleEndian, &length)
  483. if err != nil {
  484. return
  485. }
  486. res = make([]byte, length)
  487. _, err = io.ReadFull(r, res)
  488. return
  489. }
  490. func readUshort(r io.Reader) (res uint16, err error) {
  491. err = binary.Read(r, binary.LittleEndian, &res)
  492. return
  493. }
  494. func readByte(r io.Reader) (res byte, err error) {
  495. var b [1]byte
  496. _, err = r.Read(b[:])
  497. res = b[0]
  498. return
  499. }
  500. // Packet Data Stream Headers
  501. // http://msdn.microsoft.com/en-us/library/dd304953.aspx
  502. type headerStruct struct {
  503. hdrtype uint16
  504. data []byte
  505. }
  506. const (
  507. dataStmHdrQueryNotif = 1 // query notifications
  508. dataStmHdrTransDescr = 2 // MARS transaction descriptor (required)
  509. dataStmHdrTraceActivity = 3
  510. )
  511. // Query Notifications Header
  512. // http://msdn.microsoft.com/en-us/library/dd304949.aspx
  513. type queryNotifHdr struct {
  514. notifyId string
  515. ssbDeployment string
  516. notifyTimeout uint32
  517. }
  518. func (hdr queryNotifHdr) pack() (res []byte) {
  519. notifyId := str2ucs2(hdr.notifyId)
  520. ssbDeployment := str2ucs2(hdr.ssbDeployment)
  521. res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4)
  522. b := res
  523. binary.LittleEndian.PutUint16(b, uint16(len(notifyId)))
  524. b = b[2:]
  525. copy(b, notifyId)
  526. b = b[len(notifyId):]
  527. binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment)))
  528. b = b[2:]
  529. copy(b, ssbDeployment)
  530. b = b[len(ssbDeployment):]
  531. binary.LittleEndian.PutUint32(b, hdr.notifyTimeout)
  532. return res
  533. }
  534. // MARS Transaction Descriptor Header
  535. // http://msdn.microsoft.com/en-us/library/dd340515.aspx
  536. type transDescrHdr struct {
  537. transDescr uint64 // transaction descriptor returned from ENVCHANGE
  538. outstandingReqCnt uint32 // outstanding request count
  539. }
  540. func (hdr transDescrHdr) pack() (res []byte) {
  541. res = make([]byte, 8+4)
  542. binary.LittleEndian.PutUint64(res, hdr.transDescr)
  543. binary.LittleEndian.PutUint32(res[8:], hdr.outstandingReqCnt)
  544. return res
  545. }
  546. func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
  547. // calculatint total length
  548. var totallen uint32 = 4
  549. for _, hdr := range headers {
  550. totallen += 4 + 2 + uint32(len(hdr.data))
  551. }
  552. // writing
  553. err = binary.Write(w, binary.LittleEndian, totallen)
  554. if err != nil {
  555. return err
  556. }
  557. for _, hdr := range headers {
  558. var headerlen uint32 = 4 + 2 + uint32(len(hdr.data))
  559. err = binary.Write(w, binary.LittleEndian, headerlen)
  560. if err != nil {
  561. return err
  562. }
  563. err = binary.Write(w, binary.LittleEndian, hdr.hdrtype)
  564. if err != nil {
  565. return err
  566. }
  567. _, err = w.Write(hdr.data)
  568. if err != nil {
  569. return err
  570. }
  571. }
  572. return nil
  573. }
  574. func sendSqlBatch72(buf *tdsBuffer,
  575. sqltext string,
  576. headers []headerStruct) (err error) {
  577. buf.BeginPacket(packSQLBatch)
  578. if err = writeAllHeaders(buf, headers); err != nil {
  579. return
  580. }
  581. _, err = buf.Write(str2ucs2(sqltext))
  582. if err != nil {
  583. return
  584. }
  585. return buf.FinishPacket()
  586. }
  587. type connectParams struct {
  588. logFlags uint64
  589. port uint64
  590. host string
  591. instance string
  592. database string
  593. user string
  594. password string
  595. dial_timeout time.Duration
  596. conn_timeout time.Duration
  597. keepAlive time.Duration
  598. encrypt bool
  599. disableEncryption bool
  600. trustServerCertificate bool
  601. certificate string
  602. hostInCertificate string
  603. serverSPN string
  604. workstation string
  605. appname string
  606. typeFlags uint8
  607. failOverPartner string
  608. failOverPort uint64
  609. }
  610. func splitConnectionString(dsn string) (res map[string]string) {
  611. res = map[string]string{}
  612. parts := strings.Split(dsn, ";")
  613. for _, part := range parts {
  614. if len(part) == 0 {
  615. continue
  616. }
  617. lst := strings.SplitN(part, "=", 2)
  618. name := strings.TrimSpace(strings.ToLower(lst[0]))
  619. if len(name) == 0 {
  620. continue
  621. }
  622. var value string = ""
  623. if len(lst) > 1 {
  624. value = strings.TrimSpace(lst[1])
  625. }
  626. res[name] = value
  627. }
  628. return res
  629. }
  630. func parseConnectParams(dsn string) (connectParams, error) {
  631. params := splitConnectionString(dsn)
  632. var p connectParams
  633. strlog, ok := params["log"]
  634. if ok {
  635. var err error
  636. p.logFlags, err = strconv.ParseUint(strlog, 10, 0)
  637. if err != nil {
  638. return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
  639. }
  640. }
  641. server := params["server"]
  642. parts := strings.SplitN(server, "\\", 2)
  643. p.host = parts[0]
  644. if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
  645. p.host = "localhost"
  646. }
  647. if len(parts) > 1 {
  648. p.instance = parts[1]
  649. }
  650. p.database = params["database"]
  651. p.user = params["user id"]
  652. p.password = params["password"]
  653. p.port = 1433
  654. strport, ok := params["port"]
  655. if ok {
  656. var err error
  657. p.port, err = strconv.ParseUint(strport, 0, 16)
  658. if err != nil {
  659. f := "Invalid tcp port '%v': %v"
  660. return p, fmt.Errorf(f, strport, err.Error())
  661. }
  662. }
  663. p.dial_timeout = 5 * time.Second
  664. p.conn_timeout = 30 * time.Second
  665. strconntimeout, ok := params["connection timeout"]
  666. if ok {
  667. timeout, err := strconv.ParseUint(strconntimeout, 0, 16)
  668. if err != nil {
  669. f := "Invalid connection timeout '%v': %v"
  670. return p, fmt.Errorf(f, strconntimeout, err.Error())
  671. }
  672. p.conn_timeout = time.Duration(timeout) * time.Second
  673. }
  674. strdialtimeout, ok := params["dial timeout"]
  675. if ok {
  676. timeout, err := strconv.ParseUint(strdialtimeout, 0, 16)
  677. if err != nil {
  678. f := "Invalid dial timeout '%v': %v"
  679. return p, fmt.Errorf(f, strdialtimeout, err.Error())
  680. }
  681. p.dial_timeout = time.Duration(timeout) * time.Second
  682. }
  683. keepAlive, ok := params["keepalive"]
  684. if ok {
  685. timeout, err := strconv.ParseUint(keepAlive, 0, 16)
  686. if err != nil {
  687. f := "Invalid keepAlive value '%s': %s"
  688. return p, fmt.Errorf(f, keepAlive, err.Error())
  689. }
  690. p.keepAlive = time.Duration(timeout) * time.Second
  691. }
  692. encrypt, ok := params["encrypt"]
  693. if ok {
  694. if strings.ToUpper(encrypt) == "DISABLE" {
  695. p.disableEncryption = true
  696. } else {
  697. var err error
  698. p.encrypt, err = strconv.ParseBool(encrypt)
  699. if err != nil {
  700. f := "Invalid encrypt '%s': %s"
  701. return p, fmt.Errorf(f, encrypt, err.Error())
  702. }
  703. }
  704. } else {
  705. p.trustServerCertificate = true
  706. }
  707. trust, ok := params["trustservercertificate"]
  708. if ok {
  709. var err error
  710. p.trustServerCertificate, err = strconv.ParseBool(trust)
  711. if err != nil {
  712. f := "Invalid trust server certificate '%s': %s"
  713. return p, fmt.Errorf(f, trust, err.Error())
  714. }
  715. }
  716. p.certificate = params["certificate"]
  717. p.hostInCertificate, ok = params["hostnameincertificate"]
  718. if !ok {
  719. p.hostInCertificate = p.host
  720. }
  721. serverSPN, ok := params["serverspn"]
  722. if ok {
  723. p.serverSPN = serverSPN
  724. } else {
  725. p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port)
  726. }
  727. workstation, ok := params["workstation id"]
  728. if ok {
  729. p.workstation = workstation
  730. } else {
  731. workstation, err := os.Hostname()
  732. if err == nil {
  733. p.workstation = workstation
  734. }
  735. }
  736. appname, ok := params["app name"]
  737. if !ok {
  738. appname = "go-mssqldb"
  739. }
  740. p.appname = appname
  741. appintent, ok := params["applicationintent"]
  742. if ok {
  743. if appintent == "ReadOnly" {
  744. p.typeFlags |= fReadOnlyIntent
  745. }
  746. }
  747. failOverPartner, ok := params["failoverpartner"]
  748. if ok {
  749. p.failOverPartner = failOverPartner
  750. }
  751. failOverPort, ok := params["failoverport"]
  752. if ok {
  753. var err error
  754. p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
  755. if err != nil {
  756. f := "Invalid tcp port '%v': %v"
  757. return p, fmt.Errorf(f, failOverPort, err.Error())
  758. }
  759. }
  760. return p, nil
  761. }
  762. type Auth interface {
  763. InitialBytes() ([]byte, error)
  764. NextBytes([]byte) ([]byte, error)
  765. Free()
  766. }
  767. // SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
  768. // list of IP addresses. So if there is more than one, try them all and
  769. // use the first one that allows a connection.
  770. func dialConnection(p connectParams) (conn net.Conn, err error) {
  771. var ips []net.IP
  772. ips, err = net.LookupIP(p.host)
  773. if err != nil {
  774. ip := net.ParseIP(p.host)
  775. if ip == nil {
  776. return nil, err
  777. }
  778. ips = []net.IP{ip}
  779. }
  780. if len(ips) == 1 {
  781. d := createDialer(p)
  782. addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
  783. conn, err = d.Dial("tcp", addr)
  784. } else {
  785. //Try Dials in parallel to avoid waiting for timeouts.
  786. connChan := make(chan net.Conn, len(ips))
  787. errChan := make(chan error, len(ips))
  788. portStr := strconv.Itoa(int(p.port))
  789. for _, ip := range ips {
  790. go func(ip net.IP) {
  791. d := createDialer(p)
  792. addr := net.JoinHostPort(ip.String(), portStr)
  793. conn, err := d.Dial("tcp", addr)
  794. if err == nil {
  795. connChan <- conn
  796. } else {
  797. errChan <- err
  798. }
  799. }(ip)
  800. }
  801. // Wait for either the *first* successful connection, or all the errors
  802. wait_loop:
  803. for i, _ := range ips {
  804. select {
  805. case conn = <-connChan:
  806. // Got a connection to use, close any others
  807. go func(n int) {
  808. for i := 0; i < n; i++ {
  809. select {
  810. case conn := <-connChan:
  811. conn.Close()
  812. case <-errChan:
  813. }
  814. }
  815. }(len(ips) - i - 1)
  816. // Remove any earlier errors we may have collected
  817. err = nil
  818. break wait_loop
  819. case err = <-errChan:
  820. }
  821. }
  822. }
  823. // Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
  824. if conn == nil {
  825. f := "Unable to open tcp connection with host '%v:%v': %v"
  826. return nil, fmt.Errorf(f, p.host, p.port, err.Error())
  827. }
  828. return conn, err
  829. }
  830. func connect(p connectParams) (res *tdsSession, err error) {
  831. res = nil
  832. // if instance is specified use instance resolution service
  833. if p.instance != "" {
  834. p.instance = strings.ToUpper(p.instance)
  835. instances, err := getInstances(p.host)
  836. if err != nil {
  837. f := "Unable to get instances from Sql Server Browser on host %v: %v"
  838. return nil, fmt.Errorf(f, p.host, err.Error())
  839. }
  840. strport, ok := instances[p.instance]["tcp"]
  841. if !ok {
  842. f := "No instance matching '%v' returned from host '%v'"
  843. return nil, fmt.Errorf(f, p.instance, p.host)
  844. }
  845. p.port, err = strconv.ParseUint(strport, 0, 16)
  846. if err != nil {
  847. f := "Invalid tcp port returned from Sql Server Browser '%v': %v"
  848. return nil, fmt.Errorf(f, strport, err.Error())
  849. }
  850. }
  851. initiate_connection:
  852. conn, err := dialConnection(p)
  853. if err != nil {
  854. return nil, err
  855. }
  856. toconn := NewTimeoutConn(conn, p.conn_timeout)
  857. outbuf := newTdsBuffer(4096, toconn)
  858. sess := tdsSession{
  859. buf: outbuf,
  860. logFlags: p.logFlags,
  861. }
  862. instance_buf := []byte(p.instance)
  863. instance_buf = append(instance_buf, 0) // zero terminate instance name
  864. var encrypt byte
  865. if p.disableEncryption {
  866. encrypt = encryptNotSup
  867. } else if p.encrypt {
  868. encrypt = encryptOn
  869. } else {
  870. encrypt = encryptOff
  871. }
  872. fields := map[uint8][]byte{
  873. preloginVERSION: {0, 0, 0, 0, 0, 0},
  874. preloginENCRYPTION: {encrypt},
  875. preloginINSTOPT: instance_buf,
  876. preloginTHREADID: {0, 0, 0, 0},
  877. preloginMARS: {0}, // MARS disabled
  878. }
  879. err = writePrelogin(outbuf, fields)
  880. if err != nil {
  881. return nil, err
  882. }
  883. fields, err = readPrelogin(outbuf)
  884. if err != nil {
  885. return nil, err
  886. }
  887. encryptBytes, ok := fields[preloginENCRYPTION]
  888. if !ok {
  889. return nil, fmt.Errorf("Encrypt negotiation failed")
  890. }
  891. encrypt = encryptBytes[0]
  892. if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {
  893. return nil, fmt.Errorf("Server does not support encryption")
  894. }
  895. if encrypt != encryptNotSup {
  896. var config tls.Config
  897. if p.certificate != "" {
  898. pem, err := ioutil.ReadFile(p.certificate)
  899. if err != nil {
  900. f := "Cannot read certificate '%s': %s"
  901. return nil, fmt.Errorf(f, p.certificate, err.Error())
  902. }
  903. certs := x509.NewCertPool()
  904. certs.AppendCertsFromPEM(pem)
  905. config.RootCAs = certs
  906. }
  907. if p.trustServerCertificate {
  908. config.InsecureSkipVerify = true
  909. }
  910. config.ServerName = p.hostInCertificate
  911. outbuf.transport = conn
  912. toconn.buf = outbuf
  913. tlsConn := tls.Client(toconn, &config)
  914. err = tlsConn.Handshake()
  915. toconn.buf = nil
  916. outbuf.transport = tlsConn
  917. if err != nil {
  918. f := "TLS Handshake failed: %s"
  919. return nil, fmt.Errorf(f, err.Error())
  920. }
  921. if encrypt == encryptOff {
  922. outbuf.afterFirst = func() {
  923. outbuf.transport = toconn
  924. }
  925. }
  926. }
  927. login := login{
  928. TDSVersion: verTDS74,
  929. PacketSize: uint32(len(outbuf.buf)),
  930. Database: p.database,
  931. OptionFlags2: fODBC, // to get unlimited TEXTSIZE
  932. HostName: p.workstation,
  933. ServerName: p.host,
  934. AppName: p.appname,
  935. TypeFlags: p.typeFlags,
  936. }
  937. auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation)
  938. if auth_ok {
  939. login.SSPI, err = auth.InitialBytes()
  940. if err != nil {
  941. return nil, err
  942. }
  943. login.OptionFlags2 |= fIntSecurity
  944. defer auth.Free()
  945. } else {
  946. login.UserName = p.user
  947. login.Password = p.password
  948. }
  949. err = sendLogin(outbuf, login)
  950. if err != nil {
  951. return nil, err
  952. }
  953. // processing login response
  954. var sspi_msg []byte
  955. continue_login:
  956. tokchan := make(chan tokenStruct, 5)
  957. go processResponse(&sess, tokchan)
  958. success := false
  959. for tok := range tokchan {
  960. switch token := tok.(type) {
  961. case sspiMsg:
  962. sspi_msg, err = auth.NextBytes(token)
  963. if err != nil {
  964. return nil, err
  965. }
  966. case loginAckStruct:
  967. success = true
  968. sess.loginAck = token
  969. case error:
  970. return nil, fmt.Errorf("Login error: %s", token.Error())
  971. }
  972. }
  973. if sspi_msg != nil {
  974. outbuf.BeginPacket(packSSPIMessage)
  975. _, err = outbuf.Write(sspi_msg)
  976. if err != nil {
  977. return nil, err
  978. }
  979. err = outbuf.FinishPacket()
  980. if err != nil {
  981. return nil, err
  982. }
  983. sspi_msg = nil
  984. goto continue_login
  985. }
  986. if !success {
  987. return nil, fmt.Errorf("Login failed")
  988. }
  989. if sess.routedServer != "" {
  990. toconn.Close()
  991. p.host = sess.routedServer
  992. p.port = uint64(sess.routedPort)
  993. goto initiate_connection
  994. }
  995. return &sess, nil
  996. }