@ -11,6 +11,7 @@ import (
"log"
"net"
"sync"
"sync/atomic"
"time"
"gopkg.in/asn1-ber.v1"
@ -82,20 +83,18 @@ const (
type Conn struct {
conn net . Conn
isTLS bool
isClosing bool
closeErr error
closing uint32
closeErr atomicValue
isStartingTLS bool
Debug debugging
chanConfirm chan bool
chanConfirm chan struct { }
messageContexts map [ int64 ] * messageContext
chanMessage chan * messagePacket
chanMessageID chan int64
wgSender sync . WaitGroup
wgClose sync . WaitGroup
once sync . Once
outstandingRequests uint
messageMutex sync . Mutex
requestTimeout time . Duration
requestTimeout int64
}
var _ Client = & Conn { }
@ -142,7 +141,7 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
func NewConn ( conn net . Conn , isTLS bool ) * Conn {
return & Conn {
conn : conn ,
chanConfirm : make ( chan bool ) ,
chanConfirm : make ( chan struct { } ) ,
chanMessageID : make ( chan int64 ) ,
chanMessage : make ( chan * messagePacket , 10 ) ,
messageContexts : map [ int64 ] * messageContext { } ,
@ -158,12 +157,22 @@ func (l *Conn) Start() {
l . wgClose . Add ( 1 )
}
// isClosing returns whether or not we're currently closing.
func ( l * Conn ) isClosing ( ) bool {
return atomic . LoadUint32 ( & l . closing ) == 1
}
// setClosing sets the closing value to true
func ( l * Conn ) setClosing ( ) bool {
return atomic . CompareAndSwapUint32 ( & l . closing , 0 , 1 )
}
// Close closes the connection.
func ( l * Conn ) Close ( ) {
l . once . Do ( func ( ) {
l . isClosing = true
l . wgSender . Wait ( )
l . messageMutex . Lock ( )
defer l . messageMutex . Unlock ( )
if l . setClosing ( ) {
l . Debug . Printf ( "Sending quit message and waiting for confirmation" )
l . chanMessage <- & messagePacket { Op : MessageQuit }
<- l . chanConfirm
@ -171,27 +180,25 @@ func (l *Conn) Close() {
l . Debug . Printf ( "Closing network connection" )
if err := l . conn . Close ( ) ; err != nil {
log . Print ( err )
log . Println ( err )
}
l . wgClose . Done ( )
} )
}
l . wgClose . Wait ( )
}
// SetTimeout sets the time after a request is sent that a MessageTimeout triggers
func ( l * Conn ) SetTimeout ( timeout time . Duration ) {
if timeout > 0 {
l . requestTimeout = timeout
atomic . StoreInt64 ( & l. requestTimeout , int64 class="p">( timeout ) )
}
}
// Returns the next available messageID
func ( l * Conn ) nextMessageID ( ) int64 {
if l . chanMessageID != nil {
if messageID , ok := <- l . chanMessageID ; ok {
return messageID
}
if messageID , ok := <- l . chanMessageID ; ok {
return messageID
}
return 0
}
@ -258,7 +265,7 @@ func (l *Conn) sendMessage(packet *ber.Packet) (*messageContext, error) {
}
func ( l * Conn ) sendMessageWithFlags ( packet * ber . Packet , flags sendMessageFlags ) ( * messageContext , error ) {
if l . isClosing {
if l . isClosing ( ) {
return nil , NewError ( ErrorNetwork , errors . New ( "ldap: connection closed" ) )
}
l . messageMutex . Lock ( )
@ -297,7 +304,7 @@ func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags)
func ( l * Conn ) finishMessage ( msgCtx * messageContext ) {
close ( msgCtx . done )
if l . isClosing {
if l . isClosing ( ) {
return
}
@ -316,12 +323,12 @@ func (l *Conn) finishMessage(msgCtx *messageContext) {
}
func ( l * Conn ) sendProcessMessage ( message * messagePacket ) bool {
if l . isClosing {
l . messageMutex . Lock ( )
defer l . messageMutex . Unlock ( )
if l . isClosing ( ) {
return false
}
l . wgSender . Add ( 1 )
l . chanMessage <- message
l . wgSender . Done ( )
return true
}
@ -333,15 +340,14 @@ func (l *Conn) processMessages() {
for messageID , msgCtx := range l . messageContexts {
// If we are closing due to an error, inform anyone who
// is waiting about the error.
if l . isClosing && l . closeErr != nil {
msgCtx . sendResponse ( & PacketResponse { Error : l . closeErr } )
if l . isClosing ( ) && l . closeErr . Load ( ) != nil {
msgCtx . sendResponse ( & PacketResponse { Error : l . closeErr . Load ( ) . ( error ) })
}
l . Debug . Printf ( "Closing channel for MessageID %d" , messageID )
close ( msgCtx . responses )
delete ( l . messageContexts , messageID )
}
close ( l . chanMessageID )
l . chanConfirm <- true
close ( l . chanConfirm )
} ( )
@ -350,11 +356,7 @@ func (l *Conn) processMessages() {
select {
case l . chanMessageID <- messageID :
messageID ++
case message , ok := <- l . chanMessage :
if ! ok {
l . Debug . Printf ( "Shutting down - message channel is closed" )
return
}
case message := <- l . chanMessage :
switch message . Op {
case MessageQuit :
l . Debug . Printf ( "Shutting down - quit message received" )
@ -377,14 +379,15 @@ func (l *Conn) processMessages() {
l . messageContexts [ message . MessageID ] = message . Context
// Add timeout if defined
if l . requestTimeout > 0 {
requestTimeout := time . Duration ( atomic . LoadInt64 ( & l . requestTimeout ) )
if requestTimeout > 0 {
go func ( ) {
defer func ( ) {
if err := recover ( ) ; err != nil {
log . Printf ( "ldap: recovered panic in RequestTimeout: %v" , err )
}
} ( )
time . Sleep ( l . requestTimeout)
time . Sleep ( requestTimeout )
timeoutMessage := & messagePacket {
Op : MessageTimeout ,
MessageID : message . MessageID ,
@ -397,7 +400,7 @@ func (l *Conn) processMessages() {
if msgCtx , ok := l . messageContexts [ message . MessageID ] ; ok {
msgCtx . sendResponse ( & PacketResponse { message . Packet , nil } )
} else {
log . Printf ( "Received unexpected message %d, %v" , message . MessageID , l . isClosing )
log . Printf ( "Received unexpected message %d, %v" , message . MessageID , l . isClosing ( ) )
ber . PrintPacket ( message . Packet )
}
case MessageTimeout :
@ -439,8 +442,8 @@ func (l *Conn) reader() {
packet , err := ber . ReadPacket ( l . conn )
if err != nil {
// A read error is expected here if we are closing the connection...
if ! l . isClosing {
l . closeErr = fmt . Errorf ( "unable to read LDAP response packet: %s" , err )
if ! l . isClosing ( ) {
l . closeErr . Store class="p">( fmt . Errorf ( "unable to read LDAP response packet: %s" , err ) )
l . Debug . Printf ( "reader error: %s" , err . Error ( ) )
}
return