You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

516 lines
14 KiB

  1. package ldap
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "log"
  7. "net"
  8. "net/url"
  9. "sync"
  10. "sync/atomic"
  11. "time"
  12. "gopkg.in/asn1-ber.v1"
  13. )
  14. const (
  15. // MessageQuit causes the processMessages loop to exit
  16. MessageQuit = 0
  17. // MessageRequest sends a request to the server
  18. MessageRequest = 1
  19. // MessageResponse receives a response from the server
  20. MessageResponse = 2
  21. // MessageFinish indicates the client considers a particular message ID to be finished
  22. MessageFinish = 3
  23. // MessageTimeout indicates the client-specified timeout for a particular message ID has been reached
  24. MessageTimeout = 4
  25. )
  26. const (
  27. // DefaultLdapPort default ldap port for pure TCP connection
  28. DefaultLdapPort = "389"
  29. // DefaultLdapsPort default ldap port for SSL connection
  30. DefaultLdapsPort = "636"
  31. )
  32. // PacketResponse contains the packet or error encountered reading a response
  33. type PacketResponse struct {
  34. // Packet is the packet read from the server
  35. Packet *ber.Packet
  36. // Error is an error encountered while reading
  37. Error error
  38. }
  39. // ReadPacket returns the packet or an error
  40. func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
  41. if (pr == nil) || (pr.Packet == nil && pr.Error == nil) {
  42. return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
  43. }
  44. return pr.Packet, pr.Error
  45. }
  46. type messageContext struct {
  47. id int64
  48. // close(done) should only be called from finishMessage()
  49. done chan struct{}
  50. // close(responses) should only be called from processMessages(), and only sent to from sendResponse()
  51. responses chan *PacketResponse
  52. }
  53. // sendResponse should only be called within the processMessages() loop which
  54. // is also responsible for closing the responses channel.
  55. func (msgCtx *messageContext) sendResponse(packet *PacketResponse) {
  56. select {
  57. case msgCtx.responses <- packet:
  58. // Successfully sent packet to message handler.
  59. case <-msgCtx.done:
  60. // The request handler is done and will not receive more
  61. // packets.
  62. }
  63. }
  64. type messagePacket struct {
  65. Op int
  66. MessageID int64
  67. Packet *ber.Packet
  68. Context *messageContext
  69. }
  70. type sendMessageFlags uint
  71. const (
  72. startTLS sendMessageFlags = 1 << iota
  73. )
  74. // Conn represents an LDAP Connection
  75. type Conn struct {
  76. // requestTimeout is loaded atomically
  77. // so we need to ensure 64-bit alignment on 32-bit platforms.
  78. requestTimeout int64
  79. conn net.Conn
  80. isTLS bool
  81. closing uint32
  82. closeErr atomic.Value
  83. isStartingTLS bool
  84. Debug debugging
  85. chanConfirm chan struct{}
  86. messageContexts map[int64]*messageContext
  87. chanMessage chan *messagePacket
  88. chanMessageID chan int64
  89. wgClose sync.WaitGroup
  90. outstandingRequests uint
  91. messageMutex sync.Mutex
  92. }
  93. var _ Client = &Conn{}
  94. // DefaultTimeout is a package-level variable that sets the timeout value
  95. // used for the Dial and DialTLS methods.
  96. //
  97. // WARNING: since this is a package-level variable, setting this value from
  98. // multiple places will probably result in undesired behaviour.
  99. var DefaultTimeout = 60 * time.Second
  100. // Dial connects to the given address on the given network using net.Dial
  101. // and then returns a new Conn for the connection.
  102. func Dial(network, addr string) (*Conn, error) {
  103. c, err := net.DialTimeout(network, addr, DefaultTimeout)
  104. if err != nil {
  105. return nil, NewError(ErrorNetwork, err)
  106. }
  107. conn := NewConn(c, false)
  108. conn.Start()
  109. return conn, nil
  110. }
  111. // DialTLS connects to the given address on the given network using tls.Dial
  112. // and then returns a new Conn for the connection.
  113. func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
  114. c, err := tls.DialWithDialer(&net.Dialer{Timeout: DefaultTimeout}, network, addr, config)
  115. if err != nil {
  116. return nil, NewError(ErrorNetwork, err)
  117. }
  118. conn := NewConn(c, true)
  119. conn.Start()
  120. return conn, nil
  121. }
  122. // DialURL connects to the given ldap URL vie TCP using tls.Dial or net.Dial if ldaps://
  123. // or ldap:// specified as protocol. On success a new Conn for the connection
  124. // is returned.
  125. func DialURL(addr string) (*Conn, error) {
  126. lurl, err := url.Parse(addr)
  127. if err != nil {
  128. return nil, NewError(ErrorNetwork, err)
  129. }
  130. host, port, err := net.SplitHostPort(lurl.Host)
  131. if err != nil {
  132. // we asume that error is due to missing port
  133. host = lurl.Host
  134. port = ""
  135. }
  136. switch lurl.Scheme {
  137. case "ldap":
  138. if port == "" {
  139. port = DefaultLdapPort
  140. }
  141. return Dial("tcp", net.JoinHostPort(host, port))
  142. case "ldaps":
  143. if port == "" {
  144. port = DefaultLdapsPort
  145. }
  146. tlsConf := &tls.Config{
  147. ServerName: host,
  148. }
  149. return DialTLS("tcp", net.JoinHostPort(host, port), tlsConf)
  150. }
  151. return nil, NewError(ErrorNetwork, fmt.Errorf("Unknown scheme '%s'", lurl.Scheme))
  152. }
  153. // NewConn returns a new Conn using conn for network I/O.
  154. func NewConn(conn net.Conn, isTLS bool) *Conn {
  155. return &Conn{
  156. conn: conn,
  157. chanConfirm: make(chan struct{}),
  158. chanMessageID: make(chan int64),
  159. chanMessage: make(chan *messagePacket, 10),
  160. messageContexts: map[int64]*messageContext{},
  161. requestTimeout: 0,
  162. isTLS: isTLS,
  163. }
  164. }
  165. // Start initializes goroutines to read responses and process messages
  166. func (l *Conn) Start() {
  167. go l.reader()
  168. go l.processMessages()
  169. l.wgClose.Add(1)
  170. }
  171. // IsClosing returns whether or not we're currently closing.
  172. func (l *Conn) IsClosing() bool {
  173. return atomic.LoadUint32(&l.closing) == 1
  174. }
  175. // setClosing sets the closing value to true
  176. func (l *Conn) setClosing() bool {
  177. return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
  178. }
  179. // Close closes the connection.
  180. func (l *Conn) Close() {
  181. l.messageMutex.Lock()
  182. defer l.messageMutex.Unlock()
  183. if l.setClosing() {
  184. l.Debug.Printf("Sending quit message and waiting for confirmation")
  185. l.chanMessage <- &messagePacket{Op: MessageQuit}
  186. <-l.chanConfirm
  187. close(l.chanMessage)
  188. l.Debug.Printf("Closing network connection")
  189. if err := l.conn.Close(); err != nil {
  190. log.Println(err)
  191. }
  192. l.wgClose.Done()
  193. }
  194. l.wgClose.Wait()
  195. }
  196. // SetTimeout sets the time after a request is sent that a MessageTimeout triggers
  197. func (l *Conn) SetTimeout(timeout time.Duration) {
  198. if timeout > 0 {
  199. atomic.StoreInt64(&l.requestTimeout, int64(timeout))
  200. }
  201. }
  202. // Returns the next available messageID
  203. func (l *Conn) nextMessageID() int64 {
  204. if messageID, ok := <-l.chanMessageID; ok {
  205. return messageID
  206. }
  207. return 0
  208. }
  209. // StartTLS sends the command to start a TLS session and then creates a new TLS Client
  210. func (l *Conn) StartTLS(config *tls.Config) error {
  211. if l.isTLS {
  212. return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
  213. }
  214. packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
  215. packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, l.nextMessageID(), "MessageID"))
  216. request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
  217. request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
  218. packet.AppendChild(request)
  219. l.Debug.PrintPacket(packet)
  220. msgCtx, err := l.sendMessageWithFlags(packet, startTLS)
  221. if err != nil {
  222. return err
  223. }
  224. defer l.finishMessage(msgCtx)
  225. l.Debug.Printf("%d: waiting for response", msgCtx.id)
  226. packetResponse, ok := <-msgCtx.responses
  227. if !ok {
  228. return NewError(ErrorNetwork, errors.New("ldap: response channel closed"))
  229. }
  230. packet, err = packetResponse.ReadPacket()
  231. l.Debug.Printf("%d: got response %p", msgCtx.id, packet)
  232. if err != nil {
  233. return err
  234. }
  235. if l.Debug {
  236. if err := addLDAPDescriptions(packet); err != nil {
  237. l.Close()
  238. return err
  239. }
  240. ber.PrintPacket(packet)
  241. }
  242. if err := GetLDAPError(packet); err == nil {
  243. conn := tls.Client(l.conn, config)
  244. if connErr := conn.Handshake(); connErr != nil {
  245. l.Close()
  246. return NewError(ErrorNetwork, fmt.Errorf("TLS handshake failed (%v)", connErr))
  247. }
  248. l.isTLS = true
  249. l.conn = conn
  250. } else {
  251. return err
  252. }
  253. go l.reader()
  254. return nil
  255. }
  256. // TLSConnectionState returns the client's TLS connection state.
  257. // The return values are their zero values if StartTLS did
  258. // not succeed.
  259. func (l *Conn) TLSConnectionState() (state tls.ConnectionState, ok bool) {
  260. tc, ok := l.conn.(*tls.Conn)
  261. if !ok {
  262. return
  263. }
  264. return tc.ConnectionState(), true
  265. }
  266. func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
  267. return l.sendMessageWithFlags(packet, 0)
  268. }
  269. func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (*messageContext, error) {
  270. if l.IsClosing() {
  271. return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
  272. }
  273. l.messageMutex.Lock()
  274. l.Debug.Printf("flags&startTLS = %d", flags&startTLS)
  275. if l.isStartingTLS {
  276. l.messageMutex.Unlock()
  277. return nil, NewError(ErrorNetwork, errors.New("ldap: connection is in startls phase"))
  278. }
  279. if flags&startTLS != 0 {
  280. if l.outstandingRequests != 0 {
  281. l.messageMutex.Unlock()
  282. return nil, NewError(ErrorNetwork, errors.New("ldap: cannot StartTLS with outstanding requests"))
  283. }
  284. l.isStartingTLS = true
  285. }
  286. l.outstandingRequests++
  287. l.messageMutex.Unlock()
  288. responses := make(chan *PacketResponse)
  289. messageID := packet.Children[0].Value.(int64)
  290. message := &messagePacket{
  291. Op: MessageRequest,
  292. MessageID: messageID,
  293. Packet: packet,
  294. Context: &messageContext{
  295. id: messageID,
  296. done: make(chan struct{}),
  297. responses: responses,
  298. },
  299. }
  300. l.sendProcessMessage(message)
  301. return message.Context, nil
  302. }
  303. func (l *Conn) finishMessage(msgCtx *messageContext) {
  304. close(msgCtx.done)
  305. if l.IsClosing() {
  306. return
  307. }
  308. l.messageMutex.Lock()
  309. l.outstandingRequests--
  310. if l.isStartingTLS {
  311. l.isStartingTLS = false
  312. }
  313. l.messageMutex.Unlock()
  314. message := &messagePacket{
  315. Op: MessageFinish,
  316. MessageID: msgCtx.id,
  317. }
  318. l.sendProcessMessage(message)
  319. }
  320. func (l *Conn) sendProcessMessage(message *messagePacket) bool {
  321. l.messageMutex.Lock()
  322. defer l.messageMutex.Unlock()
  323. if l.IsClosing() {
  324. return false
  325. }
  326. l.chanMessage <- message
  327. return true
  328. }
  329. func (l *Conn) processMessages() {
  330. defer func() {
  331. if err := recover(); err != nil {
  332. log.Printf("ldap: recovered panic in processMessages: %v", err)
  333. }
  334. for messageID, msgCtx := range l.messageContexts {
  335. // If we are closing due to an error, inform anyone who
  336. // is waiting about the error.
  337. if l.IsClosing() && l.closeErr.Load() != nil {
  338. msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)})
  339. }
  340. l.Debug.Printf("Closing channel for MessageID %d", messageID)
  341. close(msgCtx.responses)
  342. delete(l.messageContexts, messageID)
  343. }
  344. close(l.chanMessageID)
  345. close(l.chanConfirm)
  346. }()
  347. var messageID int64 = 1
  348. for {
  349. select {
  350. case l.chanMessageID <- messageID:
  351. messageID++
  352. case message := <-l.chanMessage:
  353. switch message.Op {
  354. case MessageQuit:
  355. l.Debug.Printf("Shutting down - quit message received")
  356. return
  357. case MessageRequest:
  358. // Add to message list and write to network
  359. l.Debug.Printf("Sending message %d", message.MessageID)
  360. buf := message.Packet.Bytes()
  361. _, err := l.conn.Write(buf)
  362. if err != nil {
  363. l.Debug.Printf("Error Sending Message: %s", err.Error())
  364. message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)})
  365. close(message.Context.responses)
  366. break
  367. }
  368. // Only add to messageContexts if we were able to
  369. // successfully write the message.
  370. l.messageContexts[message.MessageID] = message.Context
  371. // Add timeout if defined
  372. requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
  373. if requestTimeout > 0 {
  374. go func() {
  375. defer func() {
  376. if err := recover(); err != nil {
  377. log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
  378. }
  379. }()
  380. time.Sleep(requestTimeout)
  381. timeoutMessage := &messagePacket{
  382. Op: MessageTimeout,
  383. MessageID: message.MessageID,
  384. }
  385. l.sendProcessMessage(timeoutMessage)
  386. }()
  387. }
  388. case MessageResponse:
  389. l.Debug.Printf("Receiving message %d", message.MessageID)
  390. if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
  391. msgCtx.sendResponse(&PacketResponse{message.Packet, nil})
  392. } else {
  393. log.Printf("Received unexpected message %d, %v", message.MessageID, l.IsClosing())
  394. ber.PrintPacket(message.Packet)
  395. }
  396. case MessageTimeout:
  397. // Handle the timeout by closing the channel
  398. // All reads will return immediately
  399. if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
  400. l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
  401. msgCtx.sendResponse(&PacketResponse{message.Packet, errors.New("ldap: connection timed out")})
  402. delete(l.messageContexts, message.MessageID)
  403. close(msgCtx.responses)
  404. }
  405. case MessageFinish:
  406. l.Debug.Printf("Finished message %d", message.MessageID)
  407. if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
  408. delete(l.messageContexts, message.MessageID)
  409. close(msgCtx.responses)
  410. }
  411. }
  412. }
  413. }
  414. }
  415. func (l *Conn) reader() {
  416. cleanstop := false
  417. defer func() {
  418. if err := recover(); err != nil {
  419. log.Printf("ldap: recovered panic in reader: %v", err)
  420. }
  421. if !cleanstop {
  422. l.Close()
  423. }
  424. }()
  425. for {
  426. if cleanstop {
  427. l.Debug.Printf("reader clean stopping (without closing the connection)")
  428. return
  429. }
  430. packet, err := ber.ReadPacket(l.conn)
  431. if err != nil {
  432. // A read error is expected here if we are closing the connection...
  433. if !l.IsClosing() {
  434. l.closeErr.Store(fmt.Errorf("unable to read LDAP response packet: %s", err))
  435. l.Debug.Printf("reader error: %s", err.Error())
  436. }
  437. return
  438. }
  439. addLDAPDescriptions(packet)
  440. if len(packet.Children) == 0 {
  441. l.Debug.Printf("Received bad ldap packet")
  442. continue
  443. }
  444. l.messageMutex.Lock()
  445. if l.isStartingTLS {
  446. cleanstop = true
  447. }
  448. l.messageMutex.Unlock()
  449. message := &messagePacket{
  450. Op: MessageResponse,
  451. MessageID: packet.Children[0].Value.(int64),
  452. Packet: packet,
  453. }
  454. if !l.sendProcessMessage(message) {
  455. return
  456. }
  457. }
  458. }