| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359 |
- package wireguard
- import (
- "context"
- "encoding/base64"
- "encoding/hex"
- core "github.com/v2fly/v2ray-core/v4"
- "github.com/v2fly/v2ray-core/v4/common"
- "github.com/v2fly/v2ray-core/v4/common/buf"
- "github.com/v2fly/v2ray-core/v4/common/net"
- "github.com/v2fly/v2ray-core/v4/common/protocol"
- "github.com/v2fly/v2ray-core/v4/common/session"
- "github.com/v2fly/v2ray-core/v4/common/signal"
- "github.com/v2fly/v2ray-core/v4/common/signal/done"
- "github.com/v2fly/v2ray-core/v4/common/task"
- "github.com/v2fly/v2ray-core/v4/features/dns"
- "github.com/v2fly/v2ray-core/v4/features/policy"
- "github.com/v2fly/v2ray-core/v4/features/routing"
- "github.com/v2fly/v2ray-core/v4/proxy"
- "github.com/v2fly/v2ray-core/v4/transport"
- "github.com/v2fly/v2ray-core/v4/transport/internet"
- "golang.zx2c4.com/wireguard/conn"
- "golang.zx2c4.com/wireguard/device"
- "golang.zx2c4.com/wireguard/tun"
- "strings"
- "sync"
- )
- func init() {
- common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
- o := new(Outbound)
- err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher, policyManager policy.Manager, dnsClient dns.Client) error {
- o.ctx = ctx
- o.dispatcher = dispatcher
- o.dnsClient = dnsClient
- o.init = done.New()
- return o.Init(config.(*Config), policyManager)
- })
- return o, err
- }))
- common.Must(common.RegisterConfig((*SimplifiedConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
- o := new(Outbound)
- err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher, policyManager policy.Manager, dnsClient dns.Client) error {
- sf := config.(*SimplifiedConfig)
- cf := &Config{
- Server: &protocol.ServerEndpoint{
- Address: sf.Address,
- Port: sf.Port,
- },
- Network: sf.Network,
- PrivateKey: sf.PrivateKey,
- PeerPublicKey: sf.PeerPublicKey,
- PreSharedKey: sf.PreSharedKey,
- Mtu: sf.Mtu,
- UserLevel: sf.UserLevel,
- }
- o.ctx = ctx
- o.dispatcher = dispatcher
- o.dnsClient = dnsClient
- o.init = done.New()
- return o.Init(cf, policyManager)
- })
- return o, err
- }))
- }
- var _ proxy.Outbound = (*Outbound)(nil)
- var _ conn.Bind = (*Outbound)(nil)
- type Outbound struct {
- sync.Mutex
- ctx context.Context
- dispatcher routing.Dispatcher
- sessionPolicy policy.Session
- dnsClient dns.Client
- tun tun.Device
- dev *device.Device
- wire *Net
- dialer internet.Dialer
- init *done.Instance
- destination net.Destination
- endpoint *conn.StdNetEndpoint
- connection *remoteConnection
- }
- func (o *Outbound) Init(config *Config, policyManager policy.Manager) error {
- o.sessionPolicy = policyManager.ForLevel(config.UserLevel)
- spec, err := protocol.NewServerSpecFromPB(config.Server)
- if err != nil {
- return err
- }
- o.destination = spec.Destination()
- o.endpoint = &conn.StdNetEndpoint{
- Port: int(o.destination.Port),
- }
- if o.destination.Address.Family().IsDomain() {
- o.endpoint.IP = []byte{172, 19, 0, 3}
- } else {
- o.endpoint.IP = o.destination.Address.IP()
- }
- localAddress := make([]net.IP, len(config.LocalAddress))
- if len(localAddress) == 0 {
- return newError("empty local address")
- }
- for index, address := range config.LocalAddress {
- localAddress[index] = net.ParseIP(address)
- }
- var privateKey, peerPublicKey, preSharedKey string
- {
- decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(config.PrivateKey))
- bytes, err := buf.ReadAllToBytes(decoder)
- if err != nil {
- return newError("failed to decode private key from base64: ", config.PrivateKey).Base(err)
- }
- privateKey = hex.EncodeToString(bytes)
- }
- {
- decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(config.PeerPublicKey))
- bytes, err := buf.ReadAllToBytes(decoder)
- if err != nil {
- return newError("failed to decode peer public key from base64: ", config.PeerPublicKey).Base(err)
- }
- peerPublicKey = hex.EncodeToString(bytes)
- }
- if config.PreSharedKey != "" {
- decoder := base64.NewDecoder(base64.StdEncoding, strings.NewReader(config.PreSharedKey))
- bytes, err := buf.ReadAllToBytes(decoder)
- if err != nil {
- return newError("failed to decode pre share key from base64: ", config.PreSharedKey).Base(err)
- }
- preSharedKey = hex.EncodeToString(bytes)
- }
- ipcConf := "private_key=" + privateKey
- ipcConf += "\npublic_key=" + peerPublicKey
- ipcConf += "\nendpoint=" + o.endpoint.DstToString()
- if preSharedKey != "" {
- ipcConf += "\npreshared_key=" + preSharedKey
- }
- var has4, has6 bool
- for _, address := range localAddress {
- if address.To4() != nil {
- has4 = true
- } else {
- has6 = true
- }
- }
- if has4 {
- ipcConf += "\nallowed_ip=0.0.0.0/0"
- }
- if has6 {
- ipcConf += "\nallowed_ip=::/0"
- }
- mtu := int(config.Mtu)
- if mtu == 0 {
- mtu = 1450
- }
- tun, wire, err := CreateNetTUN(localAddress, mtu)
- if err != nil {
- return newError("failed to create wireguard device").Base(err)
- }
- dev := device.NewDevice(tun, o, device.NewLogger(device.LogLevelVerbose, ""))
- err = dev.IpcSet(ipcConf)
- if err != nil {
- return newError("failed to set wireguard ipc conf").Base(err)
- }
- o.tun = tun
- o.dev = dev
- o.wire = wire
- return nil
- }
- func (o *Outbound) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error {
- if o.dialer == nil {
- o.dialer = dialer
- }
- o.init.Close()
- outbound := session.OutboundFromContext(ctx)
- if outbound == nil || !outbound.Target.IsValid() {
- return newError("target not specified")
- }
- destination := outbound.Target
- if destination.Address.Family().IsDomain() {
- if c, ok := o.dnsClient.(dns.ClientWithIPOption); ok {
- c.SetFakeDNSOption(false)
- }
- ips, err := o.dnsClient.LookupIP(destination.Address.Domain())
- if err != nil {
- return newError("failed to lookup ip addresses for domain ", destination.Address.Domain()).Base(err)
- }
- destination.Address = net.IPAddress(ips[0])
- }
- var conn internet.Connection
- {
- var err error
- switch destination.Network {
- case net.Network_TCP:
- conn, err = o.wire.DialContextTCP(ctx, &net.TCPAddr{
- IP: destination.Address.IP(),
- Port: int(destination.Port),
- })
- case net.Network_UDP:
- conn, err = o.wire.DialUDP(nil, &net.UDPAddr{
- IP: destination.Address.IP(),
- Port: int(destination.Port),
- })
- }
- if err != nil {
- return err
- }
- }
- defer conn.Close()
- ctx, cancel := context.WithCancel(ctx)
- timer := signal.CancelAfterInactivity(ctx, cancel, o.sessionPolicy.Timeouts.ConnectionIdle)
- ctx = policy.ContextWithBufferPolicy(ctx, o.sessionPolicy.Buffer)
- uplink := func() error {
- defer timer.SetTimeout(o.sessionPolicy.Timeouts.UplinkOnly)
- if err := buf.Copy(link.Reader, buf.NewWriter(conn), buf.UpdateActivity(timer)); err != nil {
- return newError("failed to transport all TCP response").Base(err)
- }
- return nil
- }
- downlink := func() error {
- defer timer.SetTimeout(o.sessionPolicy.Timeouts.DownlinkOnly)
- if err := buf.Copy(buf.NewReader(conn), link.Writer, buf.UpdateActivity(timer)); err != nil {
- return newError("failed to transport all TCP request").Base(err)
- }
- return nil
- }
- if err := task.Run(ctx, uplink, downlink); err != nil {
- common.Interrupt(link.Reader)
- common.Interrupt(link.Writer)
- return newError("connection ends").Base(err)
- }
- return nil
- }
- type remoteConnection struct {
- internet.Connection
- done *done.Instance
- }
- func (r remoteConnection) Close() error {
- if !r.done.Done() {
- r.done.Close()
- }
- return r.Connection.Close()
- }
- func (o *Outbound) connect() (*remoteConnection, error) {
- if o.dialer == nil {
- <-o.init.Wait()
- }
- if c := o.connection; c != nil && !c.done.Done() {
- return c, nil
- }
- o.Lock()
- defer o.Unlock()
- if c := o.connection; c != nil && !c.done.Done() {
- return c, nil
- }
- conn, err := o.dialer.Dial(context.Background(), o.destination)
- if err == nil {
- o.connection = &remoteConnection{
- conn,
- done.New(),
- }
- }
- return o.connection, err
- }
- func (o *Outbound) Open(uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
- return []conn.ReceiveFunc{o.Receive}, 0, nil
- }
- func (o *Outbound) Receive(b []byte) (n int, ep conn.Endpoint, err error) {
- var c *remoteConnection
- c, err = o.connect()
- if err != nil {
- return
- }
- n, err = c.Read(b)
- if err != nil {
- common.Close(c)
- } else {
- ep = o.endpoint
- }
- return
- }
- func (o *Outbound) Close() error {
- o.Lock()
- defer o.Unlock()
- c := o.connection
- if c != nil {
- common.Close(c)
- }
- return nil
- }
- func (o *Outbound) SetMark(uint32) error {
- return nil
- }
- func (o *Outbound) Send(b []byte, _ conn.Endpoint) (err error) {
- var c *remoteConnection
- c, err = o.connect()
- if err != nil {
- return
- }
- _, err = c.Write(b)
- if err != nil {
- common.Close(c)
- }
- return err
- }
- func (o *Outbound) ParseEndpoint(string) (conn.Endpoint, error) {
- return o.endpoint, nil
- }
|