|
|
- // Copyright 2011 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
-
- package ldap
-
- import (
- "crypto/tls"
- "errors"
- "log"
- "net"
- "sync"
-
- "github.com/gogits/gogs/modules/asn1-ber"
- )
-
- const (
- MessageQuit = 0
- MessageRequest = 1
- MessageResponse = 2
- MessageFinish = 3
- )
-
- type messagePacket struct {
- Op int
- MessageID uint64
- Packet *ber.Packet
- Channel chan *ber.Packet
- }
-
- // Conn represents an LDAP Connection
- type Conn struct {
- conn net.Conn
- isTLS bool
- isClosing bool
- Debug debugging
- chanConfirm chan bool
- chanResults map[uint64]chan *ber.Packet
- chanMessage chan *messagePacket
- chanMessageID chan uint64
- wgSender sync.WaitGroup
- wgClose sync.WaitGroup
- once sync.Once
- }
-
- // Dial connects to the given address on the given network using net.Dial
- // and then returns a new Conn for the connection.
- func Dial(network, addr string) (*Conn, error) {
- c, err := net.Dial(network, addr)
- if err != nil {
- return nil, NewError(ErrorNetwork, err)
- }
- conn := NewConn(c)
- conn.start()
- return conn, nil
- }
-
- // DialTLS connects to the given address on the given network using tls.Dial
- // and then returns a new Conn for the connection.
- func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
- c, err := tls.Dial(network, addr, config)
- if err != nil {
- return nil, NewError(ErrorNetwork, err)
- }
- conn := NewConn(c)
- conn.isTLS = true
- conn.start()
- return conn, nil
- }
-
- // NewConn returns a new Conn using conn for network I/O.
- func NewConn(conn net.Conn) *Conn {
- return &Conn{
- conn: conn,
- chanConfirm: make(chan bool),
- chanMessageID: make(chan uint64),
- chanMessage: make(chan *messagePacket, 10),
- chanResults: map[uint64]chan *ber.Packet{},
- }
- }
-
- func (l *Conn) start() {
- go l.reader()
- go l.processMessages()
- l.wgClose.Add(1)
- }
-
- // Close closes the connection.
- func (l *Conn) Close() {
- l.once.Do(func() {
- l.isClosing = true
- l.wgSender.Wait()
-
- l.Debug.Printf("Sending quit message and waiting for confirmation")
- l.chanMessage <- &messagePacket{Op: MessageQuit}
- <-l.chanConfirm
- close(l.chanMessage)
-
- l.Debug.Printf("Closing network connection")
- if err := l.conn.Close(); err != nil {
- log.Print(err)
- }
-
- l.conn = nil
- l.wgClose.Done()
- })
- l.wgClose.Wait()
- }
-
- // Returns the next available messageID
- func (l *Conn) nextMessageID() uint64 {
- if l.chanMessageID != nil {
- if messageID, ok := <-l.chanMessageID; ok {
- return messageID
- }
- }
- return 0
- }
-
- // StartTLS sends the command to start a TLS session and then creates a new TLS Client
- func (l *Conn) StartTLS(config *tls.Config) error {
- messageID := l.nextMessageID()
-
- if l.isTLS {
- return NewError(ErrorNetwork, errors.New("ldap: already encrypted"))
- }
-
- packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
- packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
- request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS")
- request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command"))
- packet.AppendChild(request)
- l.Debug.PrintPacket(packet)
-
- _, err := l.conn.Write(packet.Bytes())
- if err != nil {
- return NewError(ErrorNetwork, err)
- }
-
- packet, err = ber.ReadPacket(l.conn)
- if err != nil {
- return NewError(ErrorNetwork, err)
- }
-
- if l.Debug {
- if err := addLDAPDescriptions(packet); err != nil {
- return err
- }
- ber.PrintPacket(packet)
- }
-
- if packet.Children[1].Children[0].Value.(uint64) == 0 {
- conn := tls.Client(l.conn, config)
- l.isTLS = true
- l.conn = conn
- }
-
- return nil
- }
-
- func (l *Conn) sendMessage(packet *ber.Packet) (chan *ber.Packet, error) {
- if l.isClosing {
- return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
- }
- out := make(chan *ber.Packet)
- message := &messagePacket{
- Op: MessageRequest,
- MessageID: packet.Children[0].Value.(uint64),
- Packet: packet,
- Channel: out,
- }
- l.sendProcessMessage(message)
- return out, nil
- }
-
- func (l *Conn) finishMessage(messageID uint64) {
- if l.isClosing {
- return
- }
- message := &messagePacket{
- Op: MessageFinish,
- MessageID: messageID,
- }
- l.sendProcessMessage(message)
- }
-
- func (l *Conn) sendProcessMessage(message *messagePacket) bool {
- if l.isClosing {
- return false
- }
- l.wgSender.Add(1)
- l.chanMessage <- message
- l.wgSender.Done()
- return true
- }
-
- func (l *Conn) processMessages() {
- defer func() {
- for messageID, channel := range l.chanResults {
- l.Debug.Printf("Closing channel for MessageID %d", messageID)
- close(channel)
- delete(l.chanResults, messageID)
- }
- close(l.chanMessageID)
- l.chanConfirm <- true
- close(l.chanConfirm)
- }()
-
- var messageID uint64 = 1
- for {
- select {
- case l.chanMessageID <- messageID:
- messageID++
- case messagePacket, ok := <-l.chanMessage:
- if !ok {
- l.Debug.Printf("Shutting down - message channel is closed")
- return
- }
- switch messagePacket.Op {
- case MessageQuit:
- l.Debug.Printf("Shutting down - quit message received")
- return
- case MessageRequest:
- // Add to message list and write to network
- l.Debug.Printf("Sending message %d", messagePacket.MessageID)
- l.chanResults[messagePacket.MessageID] = messagePacket.Channel
- // go routine
- buf := messagePacket.Packet.Bytes()
-
- _, err := l.conn.Write(buf)
- if err != nil {
- l.Debug.Printf("Error Sending Message: %s", err.Error())
- break
- }
- case MessageResponse:
- l.Debug.Printf("Receiving message %d", messagePacket.MessageID)
- if chanResult, ok := l.chanResults[messagePacket.MessageID]; ok {
- chanResult <- messagePacket.Packet
- } else {
- log.Printf("Received unexpected message %d", messagePacket.MessageID)
- ber.PrintPacket(messagePacket.Packet)
- }
- case MessageFinish:
- // Remove from message list
- l.Debug.Printf("Finished message %d", messagePacket.MessageID)
- close(l.chanResults[messagePacket.MessageID])
- delete(l.chanResults, messagePacket.MessageID)
- }
- }
- }
- }
-
- func (l *Conn) reader() {
- defer func() {
- l.Close()
- }()
-
- for {
- packet, err := ber.ReadPacket(l.conn)
- if err != nil {
- l.Debug.Printf("reader: %s", err.Error())
- return
- }
- addLDAPDescriptions(packet)
- message := &messagePacket{
- Op: MessageResponse,
- MessageID: packet.Children[0].Value.(uint64),
- Packet: packet,
- }
- if !l.sendProcessMessage(message) {
- return
- }
-
- }
- }
|