|
|
- // Copyright 2013 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
-
- package ssh
-
- import (
- "encoding/binary"
- "fmt"
- "io"
- "log"
- "sync"
- "sync/atomic"
- )
-
- // debugMux, if set, causes messages in the connection protocol to be
- // logged.
- const debugMux = false
-
- // chanList is a thread safe channel list.
- type chanList struct {
- // protects concurrent access to chans
- sync.Mutex
-
- // chans are indexed by the local id of the channel, which the
- // other side should send in the PeersId field.
- chans []*channel
-
- // This is a debugging aid: it offsets all IDs by this
- // amount. This helps distinguish otherwise identical
- // server/client muxes
- offset uint32
- }
-
- // Assigns a channel ID to the given channel.
- func (c *chanList) add(ch *channel) uint32 {
- c.Lock()
- defer c.Unlock()
- for i := range c.chans {
- if c.chans[i] == nil {
- c.chans[i] = ch
- return uint32(i) + c.offset
- }
- }
- c.chans = append(c.chans, ch)
- return uint32(len(c.chans)-1) + c.offset
- }
-
- // getChan returns the channel for the given ID.
- func (c *chanList) getChan(id uint32) *channel {
- id -= c.offset
-
- c.Lock()
- defer c.Unlock()
- if id < uint32(len(c.chans)) {
- return c.chans[id]
- }
- return nil
- }
-
- func (c *chanList) remove(id uint32) {
- id -= c.offset
- c.Lock()
- if id < uint32(len(c.chans)) {
- c.chans[id] = nil
- }
- c.Unlock()
- }
-
- // dropAll forgets all channels it knows, returning them in a slice.
- func (c *chanList) dropAll() []*channel {
- c.Lock()
- defer c.Unlock()
- var r []*channel
-
- for _, ch := range c.chans {
- if ch == nil {
- continue
- }
- r = append(r, ch)
- }
- c.chans = nil
- return r
- }
-
- // mux represents the state for the SSH connection protocol, which
- // multiplexes many channels onto a single packet transport.
- type mux struct {
- conn packetConn
- chanList chanList
-
- incomingChannels chan NewChannel
-
- globalSentMu sync.Mutex
- globalResponses chan interface{}
- incomingRequests chan *Request
-
- errCond *sync.Cond
- err error
- }
-
- // When debugging, each new chanList instantiation has a different
- // offset.
- var globalOff uint32
-
- func (m *mux) Wait() error {
- m.errCond.L.Lock()
- defer m.errCond.L.Unlock()
- for m.err == nil {
- m.errCond.Wait()
- }
- return m.err
- }
-
- // newMux returns a mux that runs over the given connection.
- func newMux(p packetConn) *mux {
- m := &mux{
- conn: p,
- incomingChannels: make(chan NewChannel, 16),
- globalResponses: make(chan interface{}, 1),
- incomingRequests: make(chan *Request, 16),
- errCond: newCond(),
- }
- if debugMux {
- m.chanList.offset = atomic.AddUint32(&globalOff, 1)
- }
-
- go m.loop()
- return m
- }
-
- func (m *mux) sendMessage(msg interface{}) error {
- p := Marshal(msg)
- return m.conn.writePacket(p)
- }
-
- func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
- if wantReply {
- m.globalSentMu.Lock()
- defer m.globalSentMu.Unlock()
- }
-
- if err := m.sendMessage(globalRequestMsg{
- Type: name,
- WantReply: wantReply,
- Data: payload,
- }); err != nil {
- return false, nil, err
- }
-
- if !wantReply {
- return false, nil, nil
- }
-
- msg, ok := <-m.globalResponses
- if !ok {
- return false, nil, io.EOF
- }
- switch msg := msg.(type) {
- case *globalRequestFailureMsg:
- return false, msg.Data, nil
- case *globalRequestSuccessMsg:
- return true, msg.Data, nil
- default:
- return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
- }
- }
-
- // ackRequest must be called after processing a global request that
- // has WantReply set.
- func (m *mux) ackRequest(ok bool, data []byte) error {
- if ok {
- return m.sendMessage(globalRequestSuccessMsg{Data: data})
- }
- return m.sendMessage(globalRequestFailureMsg{Data: data})
- }
-
- // TODO(hanwen): Disconnect is a transport layer message. We should
- // probably send and receive Disconnect somewhere in the transport
- // code.
-
- // Disconnect sends a disconnect message.
- func (m *mux) Disconnect(reason uint32, message string) error {
- return m.sendMessage(disconnectMsg{
- Reason: reason,
- Message: message,
- })
- }
-
- func (m *mux) Close() error {
- return m.conn.Close()
- }
-
- // loop runs the connection machine. It will process packets until an
- // error is encountered. To synchronize on loop exit, use mux.Wait.
- func (m *mux) loop() {
- var err error
- for err == nil {
- err = m.onePacket()
- }
-
- for _, ch := range m.chanList.dropAll() {
- ch.close()
- }
-
- close(m.incomingChannels)
- close(m.incomingRequests)
- close(m.globalResponses)
-
- m.conn.Close()
-
- m.errCond.L.Lock()
- m.err = err
- m.errCond.Broadcast()
- m.errCond.L.Unlock()
-
- if debugMux {
- log.Println("loop exit", err)
- }
- }
-
- // onePacket reads and processes one packet.
- func (m *mux) onePacket() error {
- packet, err := m.conn.readPacket()
- if err != nil {
- return err
- }
-
- if debugMux {
- if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
- log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
- } else {
- p, _ := decode(packet)
- log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
- }
- }
-
- switch packet[0] {
- case msgNewKeys:
- // Ignore notification of key change.
- return nil
- case msgDisconnect:
- return m.handleDisconnect(packet)
- case msgChannelOpen:
- return m.handleChannelOpen(packet)
- case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
- return m.handleGlobalPacket(packet)
- }
-
- // assume a channel packet.
- if len(packet) < 5 {
- return parseError(packet[0])
- }
- id := binary.BigEndian.Uint32(packet[1:])
- ch := m.chanList.getChan(id)
- if ch == nil {
- return fmt.Errorf("ssh: invalid channel %d", id)
- }
-
- return ch.handlePacket(packet)
- }
-
- func (m *mux) handleDisconnect(packet []byte) error {
- var d disconnectMsg
- if err := Unmarshal(packet, &d); err != nil {
- return err
- }
-
- if debugMux {
- log.Printf("caught disconnect: %v", d)
- }
- return &d
- }
-
- func (m *mux) handleGlobalPacket(packet []byte) error {
- msg, err := decode(packet)
- if err != nil {
- return err
- }
-
- switch msg := msg.(type) {
- case *globalRequestMsg:
- m.incomingRequests <- &Request{
- Type: msg.Type,
- WantReply: msg.WantReply,
- Payload: msg.Data,
- mux: m,
- }
- case *globalRequestSuccessMsg, *globalRequestFailureMsg:
- m.globalResponses <- msg
- default:
- panic(fmt.Sprintf("not a global message %#v", msg))
- }
-
- return nil
- }
-
- // handleChannelOpen schedules a channel to be Accept()ed.
- func (m *mux) handleChannelOpen(packet []byte) error {
- var msg channelOpenMsg
- if err := Unmarshal(packet, &msg); err != nil {
- return err
- }
-
- if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
- failMsg := channelOpenFailureMsg{
- PeersId: msg.PeersId,
- Reason: ConnectionFailed,
- Message: "invalid request",
- Language: "en_US.UTF-8",
- }
- return m.sendMessage(failMsg)
- }
-
- c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
- c.remoteId = msg.PeersId
- c.maxRemotePayload = msg.MaxPacketSize
- c.remoteWin.add(msg.PeersWindow)
- m.incomingChannels <- c
- return nil
- }
-
- func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
- ch, err := m.openChannel(chanType, extra)
- if err != nil {
- return nil, nil, err
- }
-
- return ch, ch.incomingRequests, nil
- }
-
- func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
- ch := m.newChannel(chanType, channelOutbound, extra)
-
- ch.maxIncomingPayload = channelMaxPacket
-
- open := channelOpenMsg{
- ChanType: chanType,
- PeersWindow: ch.myWindow,
- MaxPacketSize: ch.maxIncomingPayload,
- TypeSpecificData: extra,
- PeersId: ch.localId,
- }
- if err := m.sendMessage(open); err != nil {
- return nil, err
- }
-
- switch msg := (<-ch.msg).(type) {
- case *channelOpenConfirmMsg:
- return ch, nil
- case *channelOpenFailureMsg:
- return nil, &OpenChannelError{msg.Reason, msg.Message}
- default:
- return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
- }
- }
|