123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103 |
- // Copyright 2014 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package agent
- import (
- "errors"
- "io"
- "net"
- "sync"
- "golang.org/x/crypto/ssh"
- )
- // RequestAgentForwarding sets up agent forwarding for the session.
- // ForwardToAgent or ForwardToRemote should be called to route
- // the authentication requests.
- func RequestAgentForwarding(session *ssh.Session) error {
- ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil)
- if err != nil {
- return err
- }
- if !ok {
- return errors.New("forwarding request denied")
- }
- return nil
- }
- // ForwardToAgent routes authentication requests to the given keyring.
- func ForwardToAgent(client *ssh.Client, keyring Agent) error {
- channels := client.HandleChannelOpen(channelType)
- if channels == nil {
- return errors.New("agent: already have handler for " + channelType)
- }
- go func() {
- for ch := range channels {
- channel, reqs, err := ch.Accept()
- if err != nil {
- continue
- }
- go ssh.DiscardRequests(reqs)
- go func() {
- ServeAgent(keyring, channel)
- channel.Close()
- }()
- }
- }()
- return nil
- }
- const channelType = "auth-agent@openssh.com"
- // ForwardToRemote routes authentication requests to the ssh-agent
- // process serving on the given unix socket.
- func ForwardToRemote(client *ssh.Client, addr string) error {
- channels := client.HandleChannelOpen(channelType)
- if channels == nil {
- return errors.New("agent: already have handler for " + channelType)
- }
- conn, err := net.Dial("unix", addr)
- if err != nil {
- return err
- }
- conn.Close()
- go func() {
- for ch := range channels {
- channel, reqs, err := ch.Accept()
- if err != nil {
- continue
- }
- go ssh.DiscardRequests(reqs)
- go forwardUnixSocket(channel, addr)
- }
- }()
- return nil
- }
- func forwardUnixSocket(channel ssh.Channel, addr string) {
- conn, err := net.Dial("unix", addr)
- if err != nil {
- return
- }
- var wg sync.WaitGroup
- wg.Add(2)
- go func() {
- io.Copy(conn, channel)
- conn.(*net.UnixConn).CloseWrite()
- wg.Done()
- }()
- go func() {
- io.Copy(channel, conn)
- channel.CloseWrite()
- wg.Done()
- }()
- wg.Wait()
- conn.Close()
- channel.Close()
- }
|