client.go 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359
  1. package wireguard
  2. import (
  3. "context"
  4. "encoding/base64"
  5. "encoding/hex"
  6. core "github.com/v2fly/v2ray-core/v4"
  7. "github.com/v2fly/v2ray-core/v4/common"
  8. "github.com/v2fly/v2ray-core/v4/common/buf"
  9. "github.com/v2fly/v2ray-core/v4/common/net"
  10. "github.com/v2fly/v2ray-core/v4/common/protocol"
  11. "github.com/v2fly/v2ray-core/v4/common/session"
  12. "github.com/v2fly/v2ray-core/v4/common/signal"
  13. "github.com/v2fly/v2ray-core/v4/common/signal/done"
  14. "github.com/v2fly/v2ray-core/v4/common/task"
  15. "github.com/v2fly/v2ray-core/v4/features/dns"
  16. "github.com/v2fly/v2ray-core/v4/features/policy"
  17. "github.com/v2fly/v2ray-core/v4/features/routing"
  18. "github.com/v2fly/v2ray-core/v4/proxy"
  19. "github.com/v2fly/v2ray-core/v4/transport"
  20. "github.com/v2fly/v2ray-core/v4/transport/internet"
  21. "golang.zx2c4.com/wireguard/conn"
  22. "golang.zx2c4.com/wireguard/device"
  23. "golang.zx2c4.com/wireguard/tun"
  24. "strings"
  25. "sync"
  26. )
  27. func init() {
  28. common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
  29. o := new(Outbound)
  30. err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher, policyManager policy.Manager, dnsClient dns.Client) error {
  31. o.ctx = ctx
  32. o.dispatcher = dispatcher
  33. o.dnsClient = dnsClient
  34. o.init = done.New()
  35. return o.Init(config.(*Config), policyManager)
  36. })
  37. return o, err
  38. }))
  39. common.Must(common.RegisterConfig((*SimplifiedConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
  40. o := new(Outbound)
  41. err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher, policyManager policy.Manager, dnsClient dns.Client) error {
  42. sf := config.(*SimplifiedConfig)
  43. cf := &Config{
  44. Server: &protocol.ServerEndpoint{
  45. Address: sf.Address,
  46. Port: sf.Port,
  47. },
  48. Network: sf.Network,
  49. PrivateKey: sf.PrivateKey,
  50. PeerPublicKey: sf.PeerPublicKey,
  51. PreSharedKey: sf.PreSharedKey,
  52. Mtu: sf.Mtu,
  53. UserLevel: sf.UserLevel,
  54. }
  55. o.ctx = ctx
  56. o.dispatcher = dispatcher
  57. o.dnsClient = dnsClient
  58. o.init = done.New()
  59. return o.Init(cf, policyManager)
  60. })
  61. return o, err
  62. }))
  63. }
  64. var _ proxy.Outbound = (*Outbound)(nil)
  65. var _ conn.Bind = (*Outbound)(nil)
  66. type Outbound struct {
  67. sync.Mutex
  68. ctx context.Context
  69. dispatcher routing.Dispatcher
  70. sessionPolicy policy.Session
  71. dnsClient dns.Client
  72. tun tun.Device
  73. dev *device.Device
  74. wire *Net
  75. dialer internet.Dialer
  76. init *done.Instance
  77. destination net.Destination
  78. endpoint *conn.StdNetEndpoint
  79. connection *remoteConnection
  80. }
  81. func (o *Outbound) Init(config *Config, policyManager policy.Manager) error {
  82. o.sessionPolicy = policyManager.ForLevel(config.UserLevel)
  83. spec, err := protocol.NewServerSpecFromPB(config.Server)
  84. if err != nil {
  85. return err
  86. }
  87. o.destination = spec.Destination()
  88. o.endpoint = &conn.StdNetEndpoint{
  89. Port: int(o.destination.Port),
  90. }
  91. if o.destination.Address.Family().IsDomain() {
  92. o.endpoint.IP = []byte{172, 19, 0, 3}
  93. } else {
  94. o.endpoint.IP = o.destination.Address.IP()
  95. }
  96. localAddress := make([]net.IP, len(config.LocalAddress))
  97. if len(localAddress) == 0 {
  98. return newError("empty local address")
  99. }
  100. for index, address := range config.LocalAddress {
  101. localAddress[index] = net.ParseIP(address)
  102. }
  103. var privateKey, peerPublicKey, preSharedKey string
  104. {
  105. decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(config.PrivateKey))
  106. bytes, err := buf.ReadAllToBytes(decoder)
  107. if err != nil {
  108. return newError("failed to decode private key from base64: ", config.PrivateKey).Base(err)
  109. }
  110. privateKey = hex.EncodeToString(bytes)
  111. }
  112. {
  113. decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(config.PeerPublicKey))
  114. bytes, err := buf.ReadAllToBytes(decoder)
  115. if err != nil {
  116. return newError("failed to decode peer public key from base64: ", config.PeerPublicKey).Base(err)
  117. }
  118. peerPublicKey = hex.EncodeToString(bytes)
  119. }
  120. if config.PreSharedKey != "" {
  121. decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(config.PreSharedKey))
  122. bytes, err := buf.ReadAllToBytes(decoder)
  123. if err != nil {
  124. return newError("failed to decode pre share key from base64: ", config.PreSharedKey).Base(err)
  125. }
  126. preSharedKey = hex.EncodeToString(bytes)
  127. }
  128. ipcConf := "private_key=" + privateKey
  129. ipcConf += "\npublic_key=" + peerPublicKey
  130. ipcConf += "\nendpoint=" + o.endpoint.DstToString()
  131. if preSharedKey != "" {
  132. ipcConf += "\npreshared_key=" + preSharedKey
  133. }
  134. var has4, has6 bool
  135. for _, address := range localAddress {
  136. if address.To4() != nil {
  137. has4 = true
  138. } else {
  139. has6 = true
  140. }
  141. }
  142. if has4 {
  143. ipcConf += "\nallowed_ip=0.0.0.0/0"
  144. }
  145. if has6 {
  146. ipcConf += "\nallowed_ip=::/0"
  147. }
  148. mtu := int(config.Mtu)
  149. if mtu == 0 {
  150. mtu = 1450
  151. }
  152. tun, wire, err := CreateNetTUN(localAddress, mtu)
  153. if err != nil {
  154. return newError("failed to create wireguard device").Base(err)
  155. }
  156. dev := device.NewDevice(tun, o, device.NewLogger(device.LogLevelVerbose, ""))
  157. err = dev.IpcSet(ipcConf)
  158. if err != nil {
  159. return newError("failed to set wireguard ipc conf").Base(err)
  160. }
  161. o.tun = tun
  162. o.dev = dev
  163. o.wire = wire
  164. return nil
  165. }
  166. func (o *Outbound) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
  167. if o.dialer == nil {
  168. o.dialer = dialer
  169. }
  170. o.init.Close()
  171. outbound := session.OutboundFromContext(ctx)
  172. if outbound == nil || !outbound.Target.IsValid() {
  173. return newError("target not specified")
  174. }
  175. destination := outbound.Target
  176. if destination.Address.Family().IsDomain() {
  177. if c, ok := o.dnsClient.(dns.ClientWithIPOption); ok {
  178. c.SetFakeDNSOption(false)
  179. }
  180. ips, err := o.dnsClient.LookupIP(destination.Address.Domain())
  181. if err != nil {
  182. return newError("failed to lookup ip addresses for domain ", destination.Address.Domain()).Base(err)
  183. }
  184. destination.Address = net.IPAddress(ips[0])
  185. }
  186. var conn internet.Connection
  187. {
  188. var err error
  189. switch destination.Network {
  190. case net.Network_TCP:
  191. conn, err = o.wire.DialContextTCP(ctx, &net.TCPAddr{
  192. IP: destination.Address.IP(),
  193. Port: int(destination.Port),
  194. })
  195. case net.Network_UDP:
  196. conn, err = o.wire.DialUDP(nil, &net.UDPAddr{
  197. IP: destination.Address.IP(),
  198. Port: int(destination.Port),
  199. })
  200. }
  201. if err != nil {
  202. return err
  203. }
  204. }
  205. defer conn.Close()
  206. ctx, cancel := context.WithCancel(ctx)
  207. timer := signal.CancelAfterInactivity(ctx, cancel, o.sessionPolicy.Timeouts.ConnectionIdle)
  208. ctx = policy.ContextWithBufferPolicy(ctx, o.sessionPolicy.Buffer)
  209. uplink := func() error {
  210. defer timer.SetTimeout(o.sessionPolicy.Timeouts.UplinkOnly)
  211. if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
  212. return newError("failed to transport all TCP response").Base(err)
  213. }
  214. return nil
  215. }
  216. downlink := func() error {
  217. defer timer.SetTimeout(o.sessionPolicy.Timeouts.DownlinkOnly)
  218. if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
  219. return newError("failed to transport all TCP request").Base(err)
  220. }
  221. return nil
  222. }
  223. if err := task.Run(ctx, uplink, downlink); err != nil {
  224. common.Interrupt(link.Reader)
  225. common.Interrupt(link.Writer)
  226. return newError("connection ends").Base(err)
  227. }
  228. return nil
  229. }
  230. type remoteConnection struct {
  231. internet.Connection
  232. done *done.Instance
  233. }
  234. func (r remoteConnection) Close() error {
  235. if !r.done.Done() {
  236. r.done.Close()
  237. }
  238. return r.Connection.Close()
  239. }
  240. func (o *Outbound) connect() (*remoteConnection, error) {
  241. if o.dialer == nil {
  242. <-o.init.Wait()
  243. }
  244. if c := o.connection; c != nil && !c.done.Done() {
  245. return c, nil
  246. }
  247. o.Lock()
  248. defer o.Unlock()
  249. if c := o.connection; c != nil && !c.done.Done() {
  250. return c, nil
  251. }
  252. conn, err := o.dialer.Dial(context.Background(), o.destination)
  253. if err == nil {
  254. o.connection = &remoteConnection{
  255. conn,
  256. done.New(),
  257. }
  258. }
  259. return o.connection, err
  260. }
  261. func (o *Outbound) Open(uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
  262. return []conn.ReceiveFunc{o.Receive}, 0, nil
  263. }
  264. func (o *Outbound) Receive(b []byte) (n int, ep conn.Endpoint, err error) {
  265. var c *remoteConnection
  266. c, err = o.connect()
  267. if err != nil {
  268. return
  269. }
  270. n, err = c.Read(b)
  271. if err != nil {
  272. common.Close(c)
  273. } else {
  274. ep = o.endpoint
  275. }
  276. return
  277. }
  278. func (o *Outbound) Close() error {
  279. o.Lock()
  280. defer o.Unlock()
  281. c := o.connection
  282. if c != nil {
  283. common.Close(c)
  284. }
  285. return nil
  286. }
  287. func (o *Outbound) SetMark(uint32) error {
  288. return nil
  289. }
  290. func (o *Outbound) Send(b []byte, _ conn.Endpoint) (err error) {
  291. var c *remoteConnection
  292. c, err = o.connect()
  293. if err != nil {
  294. return
  295. }
  296. _, err = c.Write(b)
  297. if err != nil {
  298. common.Close(c)
  299. }
  300. return err
  301. }
  302. func (o *Outbound) ParseEndpoint(string) (conn.Endpoint, error) {
  303. return o.endpoint, nil
  304. }