request_test.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. package socks5
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "io"
  6. "log"
  7. "net"
  8. "os"
  9. "strings"
  10. "testing"
  11. )
  12. type MockConn struct {
  13. buf bytes.Buffer
  14. }
  15. func (m *MockConn) Write(b []byte) (int, error) {
  16. return m.buf.Write(b)
  17. }
  18. func (m *MockConn) RemoteAddr() net.Addr {
  19. return &net.TCPAddr{IP: []byte{127, 0, 0, 1}, Port: 65432}
  20. }
  21. func TestRequest_Connect(t *testing.T) {
  22. // Create a local listener
  23. l, err := net.Listen("tcp", "127.0.0.1:0")
  24. if err != nil {
  25. t.Fatalf("err: %v", err)
  26. }
  27. go func() {
  28. conn, err := l.Accept()
  29. if err != nil {
  30. t.Fatalf("err: %v", err)
  31. }
  32. defer conn.Close()
  33. buf := make([]byte, 4)
  34. if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
  35. t.Fatalf("err: %v", err)
  36. }
  37. if !bytes.Equal(buf, []byte("ping")) {
  38. t.Fatalf("bad: %v", buf)
  39. }
  40. conn.Write([]byte("pong"))
  41. }()
  42. lAddr := l.Addr().(*net.TCPAddr)
  43. // Make server
  44. s := &Server{config: &Config{
  45. Rules: PermitAll(),
  46. Resolver: DNSResolver{},
  47. Logger: log.New(os.Stdout, "", log.LstdFlags),
  48. }}
  49. // Create the connect request
  50. buf := bytes.NewBuffer(nil)
  51. buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
  52. port := []byte{0, 0}
  53. binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
  54. buf.Write(port)
  55. // Send a ping
  56. buf.Write([]byte("ping"))
  57. // Handle the request
  58. resp := &MockConn{}
  59. req, err := NewRequest(buf)
  60. if err != nil {
  61. t.Fatalf("err: %v", err)
  62. }
  63. if err := s.handleRequest(req, resp); err != nil {
  64. t.Fatalf("err: %v", err)
  65. }
  66. // Verify response
  67. out := resp.buf.Bytes()
  68. expected := []byte{
  69. 5,
  70. 0,
  71. 0,
  72. 1,
  73. 127, 0, 0, 1,
  74. 0, 0,
  75. 'p', 'o', 'n', 'g',
  76. }
  77. // Ignore the port for both
  78. out[8] = 0
  79. out[9] = 0
  80. if !bytes.Equal(out, expected) {
  81. t.Fatalf("bad: %v %v", out, expected)
  82. }
  83. }
  84. func TestRequest_Connect_RuleFail(t *testing.T) {
  85. // Create a local listener
  86. l, err := net.Listen("tcp", "127.0.0.1:0")
  87. if err != nil {
  88. t.Fatalf("err: %v", err)
  89. }
  90. go func() {
  91. conn, err := l.Accept()
  92. if err != nil {
  93. t.Fatalf("err: %v", err)
  94. }
  95. defer conn.Close()
  96. buf := make([]byte, 4)
  97. if _, err := io.ReadAtLeast(conn, buf, 4); err != nil {
  98. t.Fatalf("err: %v", err)
  99. }
  100. if !bytes.Equal(buf, []byte("ping")) {
  101. t.Fatalf("bad: %v", buf)
  102. }
  103. conn.Write([]byte("pong"))
  104. }()
  105. lAddr := l.Addr().(*net.TCPAddr)
  106. // Make server
  107. s := &Server{config: &Config{
  108. Rules: PermitNone(),
  109. Resolver: DNSResolver{},
  110. Logger: log.New(os.Stdout, "", log.LstdFlags),
  111. }}
  112. // Create the connect request
  113. buf := bytes.NewBuffer(nil)
  114. buf.Write([]byte{5, 1, 0, 1, 127, 0, 0, 1})
  115. port := []byte{0, 0}
  116. binary.BigEndian.PutUint16(port, uint16(lAddr.Port))
  117. buf.Write(port)
  118. // Send a ping
  119. buf.Write([]byte("ping"))
  120. // Handle the request
  121. resp := &MockConn{}
  122. req, err := NewRequest(buf)
  123. if err != nil {
  124. t.Fatalf("err: %v", err)
  125. }
  126. if err := s.handleRequest(req, resp); !strings.Contains(err.Error(), "blocked by rules") {
  127. t.Fatalf("err: %v", err)
  128. }
  129. // Verify response
  130. out := resp.buf.Bytes()
  131. expected := []byte{
  132. 5,
  133. 2,
  134. 0,
  135. 1,
  136. 0, 0, 0, 0,
  137. 0, 0,
  138. }
  139. if !bytes.Equal(out, expected) {
  140. t.Fatalf("bad: %v %v", out, expected)
  141. }
  142. }