request.go 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. package socks5
  2. import (
  3. "fmt"
  4. "io"
  5. "net"
  6. "strconv"
  7. "strings"
  8. "golang.org/x/net/context"
  9. )
  10. const (
  11. ConnectCommand = uint8(1)
  12. BindCommand = uint8(2)
  13. AssociateCommand = uint8(3)
  14. ipv4Address = uint8(1)
  15. fqdnAddress = uint8(3)
  16. ipv6Address = uint8(4)
  17. )
  18. const (
  19. successReply uint8 = iota
  20. serverFailure
  21. ruleFailure
  22. networkUnreachable
  23. hostUnreachable
  24. connectionRefused
  25. ttlExpired
  26. commandNotSupported
  27. addrTypeNotSupported
  28. )
  29. var (
  30. unrecognizedAddrType = fmt.Errorf("Unrecognized address type")
  31. )
  32. // AddressRewriter is used to rewrite a destination transparently
  33. type AddressRewriter interface {
  34. Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec)
  35. }
  36. // AddrSpec is used to return the target AddrSpec
  37. // which may be specified as IPv4, IPv6, or a FQDN
  38. type AddrSpec struct {
  39. FQDN string
  40. IP net.IP
  41. Port int
  42. }
  43. func (a *AddrSpec) String() string {
  44. if a.FQDN != "" {
  45. return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
  46. }
  47. return fmt.Sprintf("%s:%d", a.IP, a.Port)
  48. }
  49. // Address returns a string suitable to dial; prefer returning IP-based
  50. // address, fallback to FQDN
  51. func (a AddrSpec) Address() string {
  52. if 0 != len(a.IP) {
  53. return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port))
  54. }
  55. return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
  56. }
  57. // A Request represents request received by a server
  58. type Request struct {
  59. // Protocol version
  60. Version uint8
  61. // Requested command
  62. Command uint8
  63. // AuthContext provided during negotiation
  64. AuthContext *AuthContext
  65. // AddrSpec of the the network that sent the request
  66. RemoteAddr *AddrSpec
  67. // AddrSpec of the desired destination
  68. DestAddr *AddrSpec
  69. // AddrSpec of the actual destination (might be affected by rewrite)
  70. realDestAddr *AddrSpec
  71. bufConn io.Reader
  72. }
  73. type conn interface {
  74. Write([]byte) (int, error)
  75. RemoteAddr() net.Addr
  76. }
  77. // NewRequest creates a new Request from the tcp connection
  78. func NewRequest(bufConn io.Reader) (*Request, error) {
  79. // Read the version byte
  80. header := []byte{0, 0, 0}
  81. if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
  82. return nil, fmt.Errorf("Failed to get command version: %v", err)
  83. }
  84. // Ensure we are compatible
  85. if header[0] != socks5Version {
  86. return nil, fmt.Errorf("Unsupported command version: %v", header[0])
  87. }
  88. // Read in the destination address
  89. dest, err := readAddrSpec(bufConn)
  90. if err != nil {
  91. return nil, err
  92. }
  93. request := &Request{
  94. Version: socks5Version,
  95. Command: header[1],
  96. DestAddr: dest,
  97. bufConn: bufConn,
  98. }
  99. return request, nil
  100. }
  101. // handleRequest is used for request processing after authentication
  102. func (s *Server) handleRequest(req *Request, conn conn) error {
  103. ctx := context.Background()
  104. // Resolve the address if we have a FQDN
  105. dest := req.DestAddr
  106. if dest.FQDN != "" {
  107. ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN)
  108. if err != nil {
  109. if err := sendReply(conn, hostUnreachable, nil); err != nil {
  110. return fmt.Errorf("Failed to send reply: %v", err)
  111. }
  112. return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err)
  113. }
  114. ctx = ctx_
  115. dest.IP = addr
  116. }
  117. // Apply any address rewrites
  118. req.realDestAddr = req.DestAddr
  119. if s.config.Rewriter != nil {
  120. ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req)
  121. }
  122. // Switch on the command
  123. switch req.Command {
  124. case ConnectCommand:
  125. return s.handleConnect(ctx, conn, req)
  126. case BindCommand:
  127. return s.handleBind(ctx, conn, req)
  128. case AssociateCommand:
  129. return s.handleAssociate(ctx, conn, req)
  130. default:
  131. if err := sendReply(conn, commandNotSupported, nil); err != nil {
  132. return fmt.Errorf("Failed to send reply: %v", err)
  133. }
  134. return fmt.Errorf("Unsupported command: %v", req.Command)
  135. }
  136. }
  137. // handleConnect is used to handle a connect command
  138. func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error {
  139. // Check if this is allowed
  140. if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
  141. if err := sendReply(conn, ruleFailure, nil); err != nil {
  142. return fmt.Errorf("Failed to send reply: %v", err)
  143. }
  144. return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr)
  145. } else {
  146. ctx = ctx_
  147. }
  148. // Attempt to connect
  149. dial := s.config.Dial
  150. if dial == nil {
  151. dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
  152. return net.Dial(net_, addr)
  153. }
  154. }
  155. target, err := dial(ctx, "tcp", req.realDestAddr.Address())
  156. if err != nil {
  157. msg := err.Error()
  158. resp := hostUnreachable
  159. if strings.Contains(msg, "refused") {
  160. resp = connectionRefused
  161. } else if strings.Contains(msg, "network is unreachable") {
  162. resp = networkUnreachable
  163. }
  164. if err := sendReply(conn, resp, nil); err != nil {
  165. return fmt.Errorf("Failed to send reply: %v", err)
  166. }
  167. return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err)
  168. }
  169. defer target.Close()
  170. // Send success
  171. local := target.LocalAddr().(*net.TCPAddr)
  172. bind := AddrSpec{IP: local.IP, Port: local.Port}
  173. if err := sendReply(conn, successReply, &bind); err != nil {
  174. return fmt.Errorf("Failed to send reply: %v", err)
  175. }
  176. // Start proxying
  177. errCh := make(chan error, 2)
  178. go proxy(target, req.bufConn, errCh)
  179. go proxy(conn, target, errCh)
  180. // Wait
  181. for i := 0; i < 2; i++ {
  182. e := <-errCh
  183. if e != nil {
  184. // return from this function closes target (and conn).
  185. return e
  186. }
  187. }
  188. return nil
  189. }
  190. // handleBind is used to handle a connect command
  191. func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error {
  192. // Check if this is allowed
  193. if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
  194. if err := sendReply(conn, ruleFailure, nil); err != nil {
  195. return fmt.Errorf("Failed to send reply: %v", err)
  196. }
  197. return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr)
  198. } else {
  199. ctx = ctx_
  200. }
  201. // TODO: Support bind
  202. if err := sendReply(conn, commandNotSupported, nil); err != nil {
  203. return fmt.Errorf("Failed to send reply: %v", err)
  204. }
  205. return nil
  206. }
  207. // handleAssociate is used to handle a connect command
  208. func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error {
  209. // Check if this is allowed
  210. if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
  211. if err := sendReply(conn, ruleFailure, nil); err != nil {
  212. return fmt.Errorf("Failed to send reply: %v", err)
  213. }
  214. return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
  215. } else {
  216. ctx = ctx_
  217. }
  218. // TODO: Support associate
  219. if err := sendReply(conn, commandNotSupported, nil); err != nil {
  220. return fmt.Errorf("Failed to send reply: %v", err)
  221. }
  222. return nil
  223. }
  224. // readAddrSpec is used to read AddrSpec.
  225. // Expects an address type byte, follwed by the address and port
  226. func readAddrSpec(r io.Reader) (*AddrSpec, error) {
  227. d := &AddrSpec{}
  228. // Get the address type
  229. addrType := []byte{0}
  230. if _, err := r.Read(addrType); err != nil {
  231. return nil, err
  232. }
  233. // Handle on a per type basis
  234. switch addrType[0] {
  235. case ipv4Address:
  236. addr := make([]byte, 4)
  237. if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
  238. return nil, err
  239. }
  240. d.IP = net.IP(addr)
  241. case ipv6Address:
  242. addr := make([]byte, 16)
  243. if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
  244. return nil, err
  245. }
  246. d.IP = net.IP(addr)
  247. case fqdnAddress:
  248. if _, err := r.Read(addrType); err != nil {
  249. return nil, err
  250. }
  251. addrLen := int(addrType[0])
  252. fqdn := make([]byte, addrLen)
  253. if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil {
  254. return nil, err
  255. }
  256. d.FQDN = string(fqdn)
  257. default:
  258. return nil, unrecognizedAddrType
  259. }
  260. // Read the port
  261. port := []byte{0, 0}
  262. if _, err := io.ReadAtLeast(r, port, 2); err != nil {
  263. return nil, err
  264. }
  265. d.Port = (int(port[0]) << 8) | int(port[1])
  266. return d, nil
  267. }
  268. // sendReply is used to send a reply message
  269. func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error {
  270. // Format the address
  271. var addrType uint8
  272. var addrBody []byte
  273. var addrPort uint16
  274. switch {
  275. case addr == nil:
  276. addrType = ipv4Address
  277. addrBody = []byte{0, 0, 0, 0}
  278. addrPort = 0
  279. case addr.FQDN != "":
  280. addrType = fqdnAddress
  281. addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...)
  282. addrPort = uint16(addr.Port)
  283. case addr.IP.To4() != nil:
  284. addrType = ipv4Address
  285. addrBody = []byte(addr.IP.To4())
  286. addrPort = uint16(addr.Port)
  287. case addr.IP.To16() != nil:
  288. addrType = ipv6Address
  289. addrBody = []byte(addr.IP.To16())
  290. addrPort = uint16(addr.Port)
  291. default:
  292. return fmt.Errorf("Failed to format address: %v", addr)
  293. }
  294. // Format the message
  295. msg := make([]byte, 6+len(addrBody))
  296. msg[0] = socks5Version
  297. msg[1] = resp
  298. msg[2] = 0 // Reserved
  299. msg[3] = addrType
  300. copy(msg[4:], addrBody)
  301. msg[4+len(addrBody)] = byte(addrPort >> 8)
  302. msg[4+len(addrBody)+1] = byte(addrPort & 0xff)
  303. // Send the message
  304. _, err := w.Write(msg)
  305. return err
  306. }
  307. type closeWriter interface {
  308. CloseWrite() error
  309. }
  310. // proxy is used to suffle data from src to destination, and sends errors
  311. // down a dedicated channel
  312. func proxy(dst io.Writer, src io.Reader, errCh chan error) {
  313. _, err := io.Copy(dst, src)
  314. if tcpConn, ok := dst.(closeWriter); ok {
  315. tcpConn.CloseWrite()
  316. }
  317. errCh <- err
  318. }