forward.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. // Copyright 2014 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package agent
  5. import (
  6. "errors"
  7. "io"
  8. "net"
  9. "sync"
  10. "golang.org/x/crypto/ssh"
  11. )
  12. // RequestAgentForwarding sets up agent forwarding for the session.
  13. // ForwardToAgent or ForwardToRemote should be called to route
  14. // the authentication requests.
  15. func RequestAgentForwarding(session *ssh.Session) error {
  16. ok, err := session.SendRequest("auth-agent-req@openssh.com", true, nil)
  17. if err != nil {
  18. return err
  19. }
  20. if !ok {
  21. return errors.New("forwarding request denied")
  22. }
  23. return nil
  24. }
  25. // ForwardToAgent routes authentication requests to the given keyring.
  26. func ForwardToAgent(client *ssh.Client, keyring Agent) error {
  27. channels := client.HandleChannelOpen(channelType)
  28. if channels == nil {
  29. return errors.New("agent: already have handler for " + channelType)
  30. }
  31. go func() {
  32. for ch := range channels {
  33. channel, reqs, err := ch.Accept()
  34. if err != nil {
  35. continue
  36. }
  37. go ssh.DiscardRequests(reqs)
  38. go func() {
  39. ServeAgent(keyring, channel)
  40. channel.Close()
  41. }()
  42. }
  43. }()
  44. return nil
  45. }
  46. const channelType = "auth-agent@openssh.com"
  47. // ForwardToRemote routes authentication requests to the ssh-agent
  48. // process serving on the given unix socket.
  49. func ForwardToRemote(client *ssh.Client, addr string) error {
  50. channels := client.HandleChannelOpen(channelType)
  51. if channels == nil {
  52. return errors.New("agent: already have handler for " + channelType)
  53. }
  54. conn, err := net.Dial("unix", addr)
  55. if err != nil {
  56. return err
  57. }
  58. conn.Close()
  59. go func() {
  60. for ch := range channels {
  61. channel, reqs, err := ch.Accept()
  62. if err != nil {
  63. continue
  64. }
  65. go ssh.DiscardRequests(reqs)
  66. go forwardUnixSocket(channel, addr)
  67. }
  68. }()
  69. return nil
  70. }
  71. func forwardUnixSocket(channel ssh.Channel, addr string) {
  72. conn, err := net.Dial("unix", addr)
  73. if err != nil {
  74. return
  75. }
  76. var wg sync.WaitGroup
  77. wg.Add(2)
  78. go func() {
  79. io.Copy(conn, channel)
  80. conn.(*net.UnixConn).CloseWrite()
  81. wg.Done()
  82. }()
  83. go func() {
  84. io.Copy(channel, conn)
  85. channel.CloseWrite()
  86. wg.Done()
  87. }()
  88. wg.Wait()
  89. conn.Close()
  90. channel.Close()
  91. }