tcp.go 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package tcp
  2. import (
  3. "fmt"
  4. "io"
  5. "v2ray.com/core/common/buf"
  6. "v2ray.com/core/common/net"
  7. "v2ray.com/core/common/task"
  8. "v2ray.com/core/transport/pipe"
  9. )
  10. type Server struct {
  11. Port net.Port
  12. MsgProcessor func(msg []byte) []byte
  13. ShouldClose bool
  14. SendFirst []byte
  15. Listen net.Address
  16. listener *net.TCPListener
  17. }
  18. func (server *Server) Start() (net.Destination, error) {
  19. listenerAddr := server.Listen
  20. if listenerAddr == nil {
  21. listenerAddr = net.LocalHostIP
  22. }
  23. listener, err := net.ListenTCP("tcp", &net.TCPAddr{
  24. IP: listenerAddr.IP(),
  25. Port: int(server.Port),
  26. Zone: "",
  27. })
  28. if err != nil {
  29. return net.Destination{}, err
  30. }
  31. server.Port = net.Port(listener.Addr().(*net.TCPAddr).Port)
  32. server.listener = listener
  33. go server.acceptConnections(listener)
  34. localAddr := listener.Addr().(*net.TCPAddr)
  35. return net.TCPDestination(net.IPAddress(localAddr.IP), net.Port(localAddr.Port)), nil
  36. }
  37. func (server *Server) acceptConnections(listener *net.TCPListener) {
  38. for {
  39. conn, err := listener.Accept()
  40. if err != nil {
  41. fmt.Printf("Failed accept TCP connection: %v\n", err)
  42. return
  43. }
  44. go server.handleConnection(conn)
  45. }
  46. }
  47. func (server *Server) handleConnection(conn net.Conn) {
  48. if len(server.SendFirst) > 0 {
  49. conn.Write(server.SendFirst)
  50. }
  51. pReader, pWriter := pipe.New(pipe.WithoutSizeLimit())
  52. err := task.Run(task.Parallel(func() error {
  53. for {
  54. b := buf.New()
  55. if err := b.AppendSupplier(buf.ReadFrom(conn)); err != nil {
  56. if err == io.EOF {
  57. return nil
  58. }
  59. return err
  60. }
  61. copy(b.Bytes(), server.MsgProcessor(b.Bytes()))
  62. if err := pWriter.WriteMultiBuffer(buf.NewMultiBufferValue(b)); err != nil {
  63. return err
  64. }
  65. }
  66. }, func() error {
  67. w := buf.NewWriter(conn)
  68. for {
  69. mb, err := pReader.ReadMultiBuffer()
  70. if err != nil {
  71. return err
  72. }
  73. if err := w.WriteMultiBuffer(mb); err != nil {
  74. return err
  75. }
  76. }
  77. }))()
  78. if err != nil {
  79. fmt.Println("failed to transfer data: ", err.Error())
  80. }
  81. conn.Close() // nolint: errcheck
  82. }
  83. func (server *Server) Close() error {
  84. return server.listener.Close()
  85. }