123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- package socks5
- import (
- "fmt"
- "io"
- )
- const (
- NoAuth = uint8(0)
- noAcceptable = uint8(255)
- UserPassAuth = uint8(2)
- userAuthVersion = uint8(1)
- authSuccess = uint8(0)
- authFailure = uint8(1)
- )
- var (
- UserAuthFailed = fmt.Errorf("User authentication failed")
- NoSupportedAuth = fmt.Errorf("No supported authentication mechanism")
- )
- // A Request encapsulates authentication state provided
- // during negotiation
- type AuthContext struct {
- // Provided auth method
- Method uint8
- // Payload provided during negotiation.
- // Keys depend on the used auth method.
- // For UserPassauth contains Username
- Payload map[string]string
- }
- type Authenticator interface {
- Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error)
- GetCode() uint8
- }
- // NoAuthAuthenticator is used to handle the "No Authentication" mode
- type NoAuthAuthenticator struct{}
- func (a NoAuthAuthenticator) GetCode() uint8 {
- return NoAuth
- }
- func (a NoAuthAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
- _, err := writer.Write([]byte{socks5Version, NoAuth})
- return &AuthContext{NoAuth, nil}, err
- }
- // UserPassAuthenticator is used to handle username/password based
- // authentication
- type UserPassAuthenticator struct {
- Credentials CredentialStore
- }
- func (a UserPassAuthenticator) GetCode() uint8 {
- return UserPassAuth
- }
- func (a UserPassAuthenticator) Authenticate(reader io.Reader, writer io.Writer) (*AuthContext, error) {
- // Tell the client to use user/pass auth
- if _, err := writer.Write([]byte{socks5Version, UserPassAuth}); err != nil {
- return nil, err
- }
- // Get the version and username length
- header := []byte{0, 0}
- if _, err := io.ReadAtLeast(reader, header, 2); err != nil {
- return nil, err
- }
- // Ensure we are compatible
- if header[0] != userAuthVersion {
- return nil, fmt.Errorf("Unsupported auth version: %v", header[0])
- }
- // Get the user name
- userLen := int(header[1])
- user := make([]byte, userLen)
- if _, err := io.ReadAtLeast(reader, user, userLen); err != nil {
- return nil, err
- }
- // Get the password length
- if _, err := reader.Read(header[:1]); err != nil {
- return nil, err
- }
- // Get the password
- passLen := int(header[0])
- pass := make([]byte, passLen)
- if _, err := io.ReadAtLeast(reader, pass, passLen); err != nil {
- return nil, err
- }
- // Verify the password
- if a.Credentials.Valid(string(user), string(pass)) {
- if _, err := writer.Write([]byte{userAuthVersion, authSuccess}); err != nil {
- return nil, err
- }
- } else {
- if _, err := writer.Write([]byte{userAuthVersion, authFailure}); err != nil {
- return nil, err
- }
- return nil, UserAuthFailed
- }
- // Done
- return &AuthContext{UserPassAuth, map[string]string{"Username": string(user)}}, nil
- }
- // authenticate is used to handle connection authentication
- func (s *Server) authenticate(conn io.Writer, bufConn io.Reader) (*AuthContext, error) {
- // Get the methods
- methods, err := readMethods(bufConn)
- if err != nil {
- return nil, fmt.Errorf("Failed to get auth methods: %v", err)
- }
- // Select a usable method
- for _, method := range methods {
- cator, found := s.authMethods[method]
- if found {
- return cator.Authenticate(bufConn, conn)
- }
- }
- // No usable method found
- return nil, noAcceptableAuth(conn)
- }
- // noAcceptableAuth is used to handle when we have no eligible
- // authentication mechanism
- func noAcceptableAuth(conn io.Writer) error {
- conn.Write([]byte{socks5Version, noAcceptable})
- return NoSupportedAuth
- }
- // readMethods is used to read the number of methods
- // and proceeding auth methods
- func readMethods(r io.Reader) ([]byte, error) {
- header := []byte{0}
- if _, err := r.Read(header); err != nil {
- return nil, err
- }
- numMethods := int(header[0])
- methods := make([]byte, numMethods)
- _, err := io.ReadAtLeast(r, methods, numMethods)
- return methods, err
- }
|