123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364 |
- package socks5
- import (
- "fmt"
- "io"
- "net"
- "strconv"
- "strings"
- "golang.org/x/net/context"
- )
- const (
- ConnectCommand = uint8(1)
- BindCommand = uint8(2)
- AssociateCommand = uint8(3)
- ipv4Address = uint8(1)
- fqdnAddress = uint8(3)
- ipv6Address = uint8(4)
- )
- const (
- successReply uint8 = iota
- serverFailure
- ruleFailure
- networkUnreachable
- hostUnreachable
- connectionRefused
- ttlExpired
- commandNotSupported
- addrTypeNotSupported
- )
- var (
- unrecognizedAddrType = fmt.Errorf("Unrecognized address type")
- )
- // AddressRewriter is used to rewrite a destination transparently
- type AddressRewriter interface {
- Rewrite(ctx context.Context, request *Request) (context.Context, *AddrSpec)
- }
- // AddrSpec is used to return the target AddrSpec
- // which may be specified as IPv4, IPv6, or a FQDN
- type AddrSpec struct {
- FQDN string
- IP net.IP
- Port int
- }
- func (a *AddrSpec) String() string {
- if a.FQDN != "" {
- return fmt.Sprintf("%s (%s):%d", a.FQDN, a.IP, a.Port)
- }
- return fmt.Sprintf("%s:%d", a.IP, a.Port)
- }
- // Address returns a string suitable to dial; prefer returning IP-based
- // address, fallback to FQDN
- func (a AddrSpec) Address() string {
- if 0 != len(a.IP) {
- return net.JoinHostPort(a.IP.String(), strconv.Itoa(a.Port))
- }
- return net.JoinHostPort(a.FQDN, strconv.Itoa(a.Port))
- }
- // A Request represents request received by a server
- type Request struct {
- // Protocol version
- Version uint8
- // Requested command
- Command uint8
- // AuthContext provided during negotiation
- AuthContext *AuthContext
- // AddrSpec of the the network that sent the request
- RemoteAddr *AddrSpec
- // AddrSpec of the desired destination
- DestAddr *AddrSpec
- // AddrSpec of the actual destination (might be affected by rewrite)
- realDestAddr *AddrSpec
- bufConn io.Reader
- }
- type conn interface {
- Write([]byte) (int, error)
- RemoteAddr() net.Addr
- }
- // NewRequest creates a new Request from the tcp connection
- func NewRequest(bufConn io.Reader) (*Request, error) {
- // Read the version byte
- header := []byte{0, 0, 0}
- if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
- return nil, fmt.Errorf("Failed to get command version: %v", err)
- }
- // Ensure we are compatible
- if header[0] != socks5Version {
- return nil, fmt.Errorf("Unsupported command version: %v", header[0])
- }
- // Read in the destination address
- dest, err := readAddrSpec(bufConn)
- if err != nil {
- return nil, err
- }
- request := &Request{
- Version: socks5Version,
- Command: header[1],
- DestAddr: dest,
- bufConn: bufConn,
- }
- return request, nil
- }
- // handleRequest is used for request processing after authentication
- func (s *Server) handleRequest(req *Request, conn conn) error {
- ctx := context.Background()
- // Resolve the address if we have a FQDN
- dest := req.DestAddr
- if dest.FQDN != "" {
- ctx_, addr, err := s.config.Resolver.Resolve(ctx, dest.FQDN)
- if err != nil {
- if err := sendReply(conn, hostUnreachable, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return fmt.Errorf("Failed to resolve destination '%v': %v", dest.FQDN, err)
- }
- ctx = ctx_
- dest.IP = addr
- }
- // Apply any address rewrites
- req.realDestAddr = req.DestAddr
- if s.config.Rewriter != nil {
- ctx, req.realDestAddr = s.config.Rewriter.Rewrite(ctx, req)
- }
- // Switch on the command
- switch req.Command {
- case ConnectCommand:
- return s.handleConnect(ctx, conn, req)
- case BindCommand:
- return s.handleBind(ctx, conn, req)
- case AssociateCommand:
- return s.handleAssociate(ctx, conn, req)
- default:
- if err := sendReply(conn, commandNotSupported, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return fmt.Errorf("Unsupported command: %v", req.Command)
- }
- }
- // handleConnect is used to handle a connect command
- func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error {
- // Check if this is allowed
- if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
- if err := sendReply(conn, ruleFailure, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return fmt.Errorf("Connect to %v blocked by rules", req.DestAddr)
- } else {
- ctx = ctx_
- }
- // Attempt to connect
- dial := s.config.Dial
- if dial == nil {
- dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
- return net.Dial(net_, addr)
- }
- }
- target, err := dial(ctx, "tcp", req.realDestAddr.Address())
- if err != nil {
- msg := err.Error()
- resp := hostUnreachable
- if strings.Contains(msg, "refused") {
- resp = connectionRefused
- } else if strings.Contains(msg, "network is unreachable") {
- resp = networkUnreachable
- }
- if err := sendReply(conn, resp, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return fmt.Errorf("Connect to %v failed: %v", req.DestAddr, err)
- }
- defer target.Close()
- // Send success
- local := target.LocalAddr().(*net.TCPAddr)
- bind := AddrSpec{IP: local.IP, Port: local.Port}
- if err := sendReply(conn, successReply, &bind); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- // Start proxying
- errCh := make(chan error, 2)
- go proxy(target, req.bufConn, errCh)
- go proxy(conn, target, errCh)
- // Wait
- for i := 0; i < 2; i++ {
- e := <-errCh
- if e != nil {
- // return from this function closes target (and conn).
- return e
- }
- }
- return nil
- }
- // handleBind is used to handle a connect command
- func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error {
- // Check if this is allowed
- if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
- if err := sendReply(conn, ruleFailure, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return fmt.Errorf("Bind to %v blocked by rules", req.DestAddr)
- } else {
- ctx = ctx_
- }
- // TODO: Support bind
- if err := sendReply(conn, commandNotSupported, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return nil
- }
- // handleAssociate is used to handle a connect command
- func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error {
- // Check if this is allowed
- if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
- if err := sendReply(conn, ruleFailure, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
- } else {
- ctx = ctx_
- }
- // TODO: Support associate
- if err := sendReply(conn, commandNotSupported, nil); err != nil {
- return fmt.Errorf("Failed to send reply: %v", err)
- }
- return nil
- }
- // readAddrSpec is used to read AddrSpec.
- // Expects an address type byte, follwed by the address and port
- func readAddrSpec(r io.Reader) (*AddrSpec, error) {
- d := &AddrSpec{}
- // Get the address type
- addrType := []byte{0}
- if _, err := r.Read(addrType); err != nil {
- return nil, err
- }
- // Handle on a per type basis
- switch addrType[0] {
- case ipv4Address:
- addr := make([]byte, 4)
- if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
- return nil, err
- }
- d.IP = net.IP(addr)
- case ipv6Address:
- addr := make([]byte, 16)
- if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
- return nil, err
- }
- d.IP = net.IP(addr)
- case fqdnAddress:
- if _, err := r.Read(addrType); err != nil {
- return nil, err
- }
- addrLen := int(addrType[0])
- fqdn := make([]byte, addrLen)
- if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil {
- return nil, err
- }
- d.FQDN = string(fqdn)
- default:
- return nil, unrecognizedAddrType
- }
- // Read the port
- port := []byte{0, 0}
- if _, err := io.ReadAtLeast(r, port, 2); err != nil {
- return nil, err
- }
- d.Port = (int(port[0]) << 8) | int(port[1])
- return d, nil
- }
- // sendReply is used to send a reply message
- func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error {
- // Format the address
- var addrType uint8
- var addrBody []byte
- var addrPort uint16
- switch {
- case addr == nil:
- addrType = ipv4Address
- addrBody = []byte{0, 0, 0, 0}
- addrPort = 0
- case addr.FQDN != "":
- addrType = fqdnAddress
- addrBody = append([]byte{byte(len(addr.FQDN))}, addr.FQDN...)
- addrPort = uint16(addr.Port)
- case addr.IP.To4() != nil:
- addrType = ipv4Address
- addrBody = []byte(addr.IP.To4())
- addrPort = uint16(addr.Port)
- case addr.IP.To16() != nil:
- addrType = ipv6Address
- addrBody = []byte(addr.IP.To16())
- addrPort = uint16(addr.Port)
- default:
- return fmt.Errorf("Failed to format address: %v", addr)
- }
- // Format the message
- msg := make([]byte, 6+len(addrBody))
- msg[0] = socks5Version
- msg[1] = resp
- msg[2] = 0 // Reserved
- msg[3] = addrType
- copy(msg[4:], addrBody)
- msg[4+len(addrBody)] = byte(addrPort >> 8)
- msg[4+len(addrBody)+1] = byte(addrPort & 0xff)
- // Send the message
- _, err := w.Write(msg)
- return err
- }
- type closeWriter interface {
- CloseWrite() error
- }
- // proxy is used to suffle data from src to destination, and sends errors
- // down a dedicated channel
- func proxy(dst io.Writer, src io.Reader, errCh chan error) {
- _, err := io.Copy(dst, src)
- if tcpConn, ok := dst.(closeWriter); ok {
- tcpConn.CloseWrite()
- }
- errCh <- err
- }
|