tcp.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package tcp
  2. import (
  3. "fmt"
  4. "io"
  5. "v2ray.com/core/common/buf"
  6. "v2ray.com/core/common/net"
  7. "v2ray.com/core/common/task"
  8. "v2ray.com/core/transport/pipe"
  9. )
  10. type Server struct {
  11. Port net.Port
  12. MsgProcessor func(msg []byte) []byte
  13. ShouldClose bool
  14. SendFirst []byte
  15. Listen net.Address
  16. listener *net.TCPListener
  17. }
  18. func (server *Server) Start() (net.Destination, error) {
  19. listenerAddr := server.Listen
  20. if listenerAddr == nil {
  21. listenerAddr = net.LocalHostIP
  22. }
  23. listener, err := net.ListenTCP("tcp", &net.TCPAddr{
  24. IP: listenerAddr.IP(),
  25. Port: int(server.Port),
  26. Zone: "",
  27. })
  28. if err != nil {
  29. return net.Destination{}, err
  30. }
  31. server.Port = net.Port(listener.Addr().(*net.TCPAddr).Port)
  32. server.listener = listener
  33. go server.acceptConnections(listener)
  34. localAddr := listener.Addr().(*net.TCPAddr)
  35. return net.TCPDestination(net.IPAddress(localAddr.IP), net.Port(localAddr.Port)), nil
  36. }
  37. func (server *Server) acceptConnections(listener *net.TCPListener) {
  38. for {
  39. conn, err := listener.Accept()
  40. if err != nil {
  41. fmt.Printf("Failed accept TCP connection: %v\n", err)
  42. return
  43. }
  44. go server.handleConnection(conn)
  45. }
  46. }
  47. func (server *Server) handleConnection(conn net.Conn) {
  48. if len(server.SendFirst) > 0 {
  49. conn.Write(server.SendFirst)
  50. }
  51. pReader, pWriter := pipe.New(pipe.WithoutSizeLimit())
  52. err := task.Run(task.Parallel(func() error {
  53. defer pWriter.Close() // nolint: errcheck
  54. for {
  55. b := buf.New()
  56. if err := b.AppendSupplier(buf.ReadFrom(conn)); err != nil {
  57. if err == io.EOF {
  58. return nil
  59. }
  60. return err
  61. }
  62. copy(b.Bytes(), server.MsgProcessor(b.Bytes()))
  63. if err := pWriter.WriteMultiBuffer(buf.NewMultiBufferValue(b)); err != nil {
  64. return err
  65. }
  66. }
  67. }, func() error {
  68. defer pReader.CloseError()
  69. w := buf.NewWriter(conn)
  70. for {
  71. mb, err := pReader.ReadMultiBuffer()
  72. if err != nil {
  73. if err == io.EOF {
  74. return nil
  75. }
  76. return err
  77. }
  78. if err := w.WriteMultiBuffer(mb); err != nil {
  79. return err
  80. }
  81. }
  82. }))()
  83. if err != nil {
  84. fmt.Println("failed to transfer data: ", err.Error())
  85. }
  86. conn.Close() // nolint: errcheck
  87. }
  88. func (server *Server) Close() error {
  89. return server.listener.Close()
  90. }