浏览代码

Added WebSocket transport

Shelikhoo 9 年之前
父节点
当前提交
fecf9cc5b8

+ 21 - 0
transport/internet/ws/config.go

@@ -0,0 +1,21 @@
+package ws
+
+type Config struct {
+	ConnectionReuse bool
+	Path            string
+	Pto             string
+	Cert            string
+	PrivKey         string
+}
+
+func (this *Config) Apply() {
+	effectiveConfig = this
+}
+
+var (
+	effectiveConfig = &Config{
+		ConnectionReuse: true,
+		Path:            "",
+		Pto:             "",
+	}
+)

+ 25 - 0
transport/internet/ws/config_json.go

@@ -0,0 +1,25 @@
+package ws
+
+import (
+	"encoding/json"
+)
+
+func (this *Config) UnmarshalJSON(data []byte) error {
+	type JsonConfig struct {
+		ConnectionReuse bool   `json:"connectionReuse"`
+		Path            string `json:"Path"`
+		Pto             string `json:"Pto"`
+	}
+	jsonConfig := &JsonConfig{
+		ConnectionReuse: true,
+		Path:            "",
+		Pto:             "",
+	}
+	if err := json.Unmarshal(data, jsonConfig); err != nil {
+		return err
+	}
+	this.ConnectionReuse = jsonConfig.ConnectionReuse
+	this.Path = jsonConfig.Path
+	this.Pto = jsonConfig.Pto
+	return nil
+}

+ 110 - 0
transport/internet/ws/connection.go

@@ -0,0 +1,110 @@
+package ws
+
+import (
+	"errors"
+	"io"
+	"net"
+	"reflect"
+	"time"
+)
+
+var (
+	ErrInvalidConn = errors.New("Invalid Connection.")
+)
+
+type ConnectionManager interface {
+	Recycle(string, *wsconn)
+}
+
+type Connection struct {
+	dest     string
+	conn     *wsconn
+	listener ConnectionManager
+	reusable bool
+}
+
+func NewConnection(dest string, conn *wsconn, manager ConnectionManager) *Connection {
+	return &Connection{
+		dest:     dest,
+		conn:     conn,
+		listener: manager,
+		reusable: effectiveConfig.ConnectionReuse,
+	}
+}
+
+func (this *Connection) Read(b []byte) (int, error) {
+	if this == nil || this.conn == nil {
+		return 0, io.EOF
+	}
+
+	return this.conn.Read(b)
+}
+
+func (this *Connection) Write(b []byte) (int, error) {
+	if this == nil || this.conn == nil {
+		return 0, io.ErrClosedPipe
+	}
+	return this.conn.Write(b)
+}
+
+func (this *Connection) Close() error {
+	if this == nil || this.conn == nil {
+		return io.ErrClosedPipe
+	}
+	if this.Reusable() {
+		this.listener.Recycle(this.dest, this.conn)
+		return nil
+	}
+	err := this.conn.Close()
+	this.conn = nil
+	return err
+}
+
+func (this *Connection) LocalAddr() net.Addr {
+	return this.conn.LocalAddr()
+}
+
+func (this *Connection) RemoteAddr() net.Addr {
+	return this.conn.RemoteAddr()
+}
+
+func (this *Connection) SetDeadline(t time.Time) error {
+	return this.conn.SetDeadline(t)
+}
+
+func (this *Connection) SetReadDeadline(t time.Time) error {
+	return this.conn.SetReadDeadline(t)
+}
+
+func (this *Connection) SetWriteDeadline(t time.Time) error {
+	return this.conn.SetWriteDeadline(t)
+}
+
+func (this *Connection) SetReusable(reusable bool) {
+	if !effectiveConfig.ConnectionReuse {
+		return
+	}
+	this.reusable = reusable
+}
+
+func (this *Connection) Reusable() bool {
+	return this.reusable
+}
+
+func (this *Connection) SysFd() (int, error) {
+	return getSysFd(this.conn)
+}
+
+func getSysFd(conn net.Conn) (int, error) {
+	cv := reflect.ValueOf(conn)
+	switch ce := cv.Elem(); ce.Kind() {
+	case reflect.Struct:
+		netfd := ce.FieldByName("conn").FieldByName("fd")
+		switch fe := netfd.Elem(); fe.Kind() {
+		case reflect.Struct:
+			fd := fe.FieldByName("sysfd")
+			return int(fd.Int()), nil
+		}
+	}
+	return 0, ErrInvalidConn
+}

+ 112 - 0
transport/internet/ws/connection_cache.go

@@ -0,0 +1,112 @@
+package ws
+
+import (
+	"net"
+	"sync"
+	"time"
+
+	"github.com/v2ray/v2ray-core/common/signal"
+)
+
+type AwaitingConnection struct {
+	conn   *wsconn
+	expire time.Time
+}
+
+func (this *AwaitingConnection) Expired() bool {
+	return this.expire.Before(time.Now())
+}
+
+type ConnectionCache struct {
+	sync.Mutex
+	cache       map[string][]*AwaitingConnection
+	cleanupOnce signal.Once
+}
+
+func NewConnectionCache() *ConnectionCache {
+	return &ConnectionCache{
+		cache: make(map[string][]*AwaitingConnection),
+	}
+}
+
+func (this *ConnectionCache) Cleanup() {
+	defer this.cleanupOnce.Reset()
+
+	for len(this.cache) > 0 {
+		time.Sleep(time.Second * 4)
+		this.Lock()
+		for key, value := range this.cache {
+			size := len(value)
+			changed := false
+			for i := 0; i < size; {
+				if value[i].Expired() {
+					value[i].conn.Close()
+					value[i] = value[size-1]
+					size--
+					changed = true
+				} else {
+					i++
+				}
+			}
+			if changed {
+				for i := size; i < len(value); i++ {
+					value[i] = nil
+				}
+				value = value[:size]
+				this.cache[key] = value
+			}
+		}
+		this.Unlock()
+	}
+}
+
+func (this *ConnectionCache) Recycle(dest string, conn *wsconn) {
+	this.Lock()
+	defer this.Unlock()
+
+	aconn := &AwaitingConnection{
+		conn:   conn,
+		expire: time.Now().Add(time.Second * 4),
+	}
+
+	var list []*AwaitingConnection
+	if v, found := this.cache[dest]; found {
+		v = append(v, aconn)
+		list = v
+	} else {
+		list = []*AwaitingConnection{aconn}
+	}
+	this.cache[dest] = list
+
+	go this.cleanupOnce.Do(this.Cleanup)
+}
+
+func FindFirstValid(list []*AwaitingConnection) int {
+	for idx, conn := range list {
+		if !conn.Expired() && !conn.conn.connClosing {
+			return idx
+		}
+		go conn.conn.Close()
+	}
+	return -1
+}
+
+func (this *ConnectionCache) Get(dest string) net.Conn {
+	this.Lock()
+	defer this.Unlock()
+
+	list, found := this.cache[dest]
+	if !found {
+		return nil
+	}
+
+	firstValid := FindFirstValid(list)
+	if firstValid == -1 {
+		delete(this.cache, dest)
+		return nil
+	}
+	res := list[firstValid].conn
+	list = list[firstValid+1:]
+	this.cache[dest] = list
+	return res
+}

+ 19 - 0
transport/internet/ws/connection_test.go

@@ -0,0 +1,19 @@
+package ws_test
+
+import (
+	"net"
+	"testing"
+
+	"github.com/v2ray/v2ray-core/testing/assert"
+	. "github.com/v2ray/v2ray-core/transport/internet/tcp"
+)
+
+func TestRawConnection(t *testing.T) {
+	assert := assert.On(t)
+
+	rawConn := RawConnection{net.TCPConn{}}
+	assert.Bool(rawConn.Reusable()).IsFalse()
+
+	rawConn.SetReusable(true)
+	assert.Bool(rawConn.Reusable()).IsFalse()
+}

+ 119 - 0
transport/internet/ws/dialer.go

@@ -0,0 +1,119 @@
+package ws
+
+import (
+	"fmt"
+	"net"
+
+	"github.com/gorilla/websocket"
+	"github.com/v2ray/v2ray-core/common/log"
+	v2net "github.com/v2ray/v2ray-core/common/net"
+	"github.com/v2ray/v2ray-core/transport/internet"
+)
+
+var (
+	globalCache = NewConnectionCache()
+)
+
+func Dial(src v2net.Address, dest v2net.Destination) (internet.Connection, error) {
+	log.Info("Dailing WS to ", dest)
+	if src == nil {
+		src = v2net.AnyIP
+	}
+	id := src.String() + "-" + dest.NetAddr()
+	var conn *wsconn
+	if dest.IsTCP() && effectiveConfig.ConnectionReuse {
+		connt := globalCache.Get(id)
+		if connt != nil {
+			conn = connt.(*wsconn)
+		}
+	}
+	if conn == nil {
+		var err error
+		conn, err = wsDial(src, dest)
+		if err != nil {
+			log.Warning("WS Dial failed:" + err.Error())
+			return nil, err
+		}
+	}
+	return NewConnection(id, conn, globalCache), nil
+}
+
+func init() {
+	internet.WSDialer = Dial
+}
+
+func wsDial(src v2net.Address, dest v2net.Destination) (*wsconn, error) {
+	//internet.DialToDest(src, dest)
+	commonDial := func(network, addr string) (net.Conn, error) {
+		return internet.DialToDest(src, dest)
+	}
+
+	dialer := websocket.Dialer{NetDial: commonDial, ReadBufferSize: 65536, WriteBufferSize: 65536}
+
+	effpto := func(dst v2net.Destination) string {
+
+		if effectiveConfig.Pto != "" {
+			return effectiveConfig.Pto
+		}
+
+		switch dst.Port().Value() {
+		/*
+				Since the value is not given explicitly,
+				We are guessing it now.
+
+				HTTP Port:
+						80
+				    8080
+				    8880
+				    2052
+			      2082
+			      2086
+			      2095
+
+				HTTPS Port:
+						443
+				    2053
+				    2083
+			      2087
+			      2096
+			      8443
+
+				if the port you are using is not well-known,
+				specify it to avoid this process.
+
+				We will re		return "CRASH"turn "unknown" if we can't guess it, cause Dial to fail.
+		*/
+		case 80:
+		case 8080:
+		case 8880:
+		case 2052:
+		case 2082:
+		case 2086:
+		case 2095:
+			return "ws"
+		case 443:
+		case 2053:
+		case 2083:
+		case 2087:
+		case 2096:
+		case 8443:
+			return "wss"
+		default:
+			return "unknown"
+		}
+		panic("Runtime unstable. Please report this bug to developers.")
+	}(dest)
+
+	uri := func(dst v2net.Destination, pto string, path string) string {
+		return fmt.Sprintf("%v://%v:%v/%v", pto, dst.NetAddr(), dst.Port(), path)
+	}(dest, effpto, effectiveConfig.Path)
+	conn, _, err := dialer.Dial(uri, nil)
+	if err != nil {
+		return nil, err
+	}
+	return func() internet.Connection {
+		connv2ray := &wsconn{wsc: conn, connClosing: false}
+		connv2ray.setup()
+		return connv2ray
+	}().(*wsconn), nil
+}

+ 162 - 0
transport/internet/ws/hub.go

@@ -0,0 +1,162 @@
+package ws
+
+import (
+	"errors"
+	"net"
+	"net/http"
+	"strconv"
+	"sync"
+	"time"
+
+	"github.com/gorilla/websocket"
+	"github.com/v2ray/v2ray-core/common/log"
+	v2net "github.com/v2ray/v2ray-core/common/net"
+	"github.com/v2ray/v2ray-core/transport/internet"
+)
+
+var (
+	ErrClosedListener = errors.New("Listener is closed.")
+)
+
+type ConnectionWithError struct {
+	conn net.Conn
+	err  error
+}
+
+type WSListener struct {
+	sync.Mutex
+	acccepting    bool
+	awaitingConns chan *ConnectionWithError
+}
+
+func ListenWS(address v2net.Address, port v2net.Port) (internet.Listener, error) {
+
+	l := &WSListener{
+		acccepting:    true,
+		awaitingConns: make(chan *ConnectionWithError, 32),
+	}
+
+	err := l.listenws(address, port)
+
+	return l, err
+}
+
+func (wsl *WSListener) listenws(address v2net.Address, port v2net.Port) error {
+
+	http.HandleFunc("/"+effectiveConfig.Path, func(w http.ResponseWriter, r *http.Request) {
+		log.Warning("WS:WSListener->listenws->(HandleFunc,lambda 2)! Accepting websocket")
+		con, err := wsl.converttovws(w, r)
+		if err != nil {
+			log.Warning("WS:WSListener->listenws->(HandleFunc,lambda 2)!" + err.Error())
+			return
+		}
+
+		select {
+		case wsl.awaitingConns <- &ConnectionWithError{
+			conn: con,
+			err:  err,
+		}:
+			log.Warning("WS:WSListener->listenws->(HandleFunc,lambda 2)! transferd websocket")
+		default:
+			if con != nil {
+				con.Close()
+			}
+		}
+		//con.retloc.Wait()
+		return
+
+	})
+
+	errchan := make(chan error)
+
+	go func() {
+		err := http.ListenAndServe(address.String()+":"+strconv.Itoa(int(port.Value())), nil)
+		errchan <- err
+		return
+	}()
+
+	var err error
+	select {
+	case err = <-errchan:
+	case <-time.After(time.Second * 2):
+		//Should this listen fail after 2 sec, it could gone untracked.
+	}
+
+	if err != nil {
+		log.Error("WS:WSListener->listenws->ListenAndServe!" + err.Error())
+	}
+
+	return err
+
+}
+
+func (wsl *WSListener) converttovws(w http.ResponseWriter, r *http.Request) (*wsconn, error) {
+	var upgrader = websocket.Upgrader{
+		ReadBufferSize:  65536,
+		WriteBufferSize: 65536,
+	}
+	conn, err := upgrader.Upgrade(w, r, nil)
+
+	if err != nil {
+		return nil, err
+	}
+
+	wrapedConn := &wsconn{wsc: conn, connClosing: false}
+	wrapedConn.setup()
+	return wrapedConn, nil
+}
+
+func (this *WSListener) Accept() (internet.Connection, error) {
+	for this.acccepting {
+		select {
+		case connErr, open := <-this.awaitingConns:
+			log.Info("WSListener: conn accepted")
+			if !open {
+				return nil, ErrClosedListener
+			}
+			if connErr.err != nil {
+				return nil, connErr.err
+			}
+			return NewConnection("", connErr.conn.(*wsconn), this), nil
+		case <-time.After(time.Second * 2):
+		}
+	}
+	return nil, ErrClosedListener
+}
+
+func (this *WSListener) Recycle(dest string, conn *wsconn) {
+	this.Lock()
+	defer this.Unlock()
+	if !this.acccepting {
+		return
+	}
+	select {
+	case this.awaitingConns <- &ConnectionWithError{conn: conn}:
+	default:
+		conn.Close()
+	}
+}
+
+func (this *WSListener) Addr() net.Addr {
+	return nil
+}
+
+func (this *WSListener) Close() error {
+	this.Lock()
+	defer this.Unlock()
+	this.acccepting = false
+
+	log.Warning("WSListener: Yet to support close listening HTTP service")
+
+	close(this.awaitingConns)
+	for connErr := range this.awaitingConns {
+		if connErr.conn != nil {
+			go connErr.conn.Close()
+		}
+	}
+	return nil
+}
+
+func init() {
+	internet.WSListenFunc = ListenWS
+}

+ 199 - 0
transport/internet/ws/wsconn.go

@@ -0,0 +1,199 @@
+package ws
+
+import (
+	"bufio"
+	"io"
+	"net"
+	"sync"
+	"time"
+
+	"github.com/v2ray/v2ray-core/common/log"
+
+	"github.com/gorilla/websocket"
+)
+
+type wsconn struct {
+	wsc         *websocket.Conn
+	readBuffer  *bufio.Reader
+	connClosing bool
+	reusable    bool
+	retloc      *sync.Cond
+	rlock       *sync.Mutex
+	wlock       *sync.Mutex
+}
+
+func (ws *wsconn) Read(b []byte) (n int, err error) {
+
+	//defer ws.rlock.Unlock()
+	//ws.checkifRWAfterClosing()
+	if ws.connClosing {
+
+		return 0, io.EOF
+	}
+	getNewBuffer := func() error {
+		_, r, err := ws.wsc.NextReader()
+		if err != nil {
+			log.Warning("WS transport: ws connection NewFrameReader return " + err.Error())
+			ws.connClosing = true
+			ws.Close()
+			return err
+		}
+		ws.readBuffer = bufio.NewReader(r)
+		return nil
+	}
+
+	readNext := func(b []byte) (n int, err error) {
+		if ws.readBuffer == nil {
+			err = getNewBuffer()
+			if err != nil {
+				//ws.Close()
+				return 0, err
+			}
+		}
+
+		n, err = ws.readBuffer.Read(b)
+
+		if err == nil {
+			return n, err
+		}
+
+		if err == io.EOF {
+			ws.readBuffer = nil
+			if n == 0 {
+				return ws.Read(b)
+			}
+			return n, nil
+		}
+		//ws.Close()
+		return n, err
+
+	}
+	n, err = readNext(b)
+
+	return n, err
+
+}
+
+func (ws *wsconn) Write(b []byte) (n int, err error) {
+
+	//defer
+	//ws.checkifRWAfterClosing()
+	if ws.connClosing {
+
+		return 0, io.EOF
+	}
+	writeWs := func(b []byte) (n int, err error) {
+		wr, err := ws.wsc.NextWriter(websocket.BinaryMessage)
+		if err != nil {
+			log.Warning("WS transport: ws connection NewFrameReader return " + err.Error())
+			ws.connClosing = true
+			ws.Close()
+			return 0, err
+		}
+		n, err = wr.Write(b)
+		if err != nil {
+			//ws.Close()
+			return 0, err
+		}
+		err = wr.Close()
+		if err != nil {
+			//ws.Close()
+			return 0, err
+		}
+		return n, err
+	}
+	n, err = writeWs(b)
+	return n, err
+}
+func (ws *wsconn) Close() error {
+	ws.connClosing = true
+	err := ws.wsc.Close()
+	ws.retloc.Broadcast()
+	return err
+}
+func (ws *wsconn) LocalAddr() net.Addr {
+	return ws.wsc.LocalAddr()
+}
+func (ws *wsconn) RemoteAddr() net.Addr {
+	return ws.wsc.RemoteAddr()
+}
+func (ws *wsconn) SetDeadline(t time.Time) error {
+	return func() error {
+		errr := ws.SetReadDeadline(t)
+		errw := ws.SetWriteDeadline(t)
+		if errr == nil || errw == nil {
+			return nil
+		}
+		if errr != nil {
+			return errr
+		}
+
+		return errw
+	}()
+}
+func (ws *wsconn) SetReadDeadline(t time.Time) error {
+	return ws.wsc.SetReadDeadline(t)
+}
+func (ws *wsconn) SetWriteDeadline(t time.Time) error {
+	return ws.wsc.SetWriteDeadline(t)
+}
+
+func (ws *wsconn) checkifRWAfterClosing() {
+	if ws.connClosing {
+		log.Error("WS transport: Read or Write After Conn have been marked closing, this can be dangerous.")
+		//panic("WS transport: Read or Write After Conn have been marked closing. Please report this crash to developer.")
+	}
+}
+
+func (ws *wsconn) setup() {
+	ws.connClosing = false
+
+	ws.rlock = &sync.Mutex{}
+	ws.wlock = &sync.Mutex{}
+
+	initConnectedCond := func() {
+		rsl := &sync.Mutex{}
+		ws.retloc = sync.NewCond(rsl)
+	}
+
+	initConnectedCond()
+	//ws.pingPong()
+}
+
+func (ws *wsconn) Reusable() bool {
+	return ws.reusable && !ws.connClosing
+}
+
+func (ws *wsconn) SetReusable(reusable bool) {
+	if !effectiveConfig.ConnectionReuse {
+		return
+	}
+	ws.reusable = reusable
+}
+
+func (ws *wsconn) pingPong() {
+	pongRcv := make(chan int, 0)
+	ws.wsc.SetPongHandler(func(data string) error {
+		pongRcv <- 0
+		return nil
+	})
+
+	go func() {
+		for !ws.connClosing {
+			ws.wsc.WriteMessage(websocket.PingMessage, nil)
+			tick := time.NewTicker(time.Second * 3)
+
+			select {
+			case <-pongRcv:
+				break
+			case <-tick.C:
+				ws.Close()
+			}
+			<-tick.C
+			tick.Stop()
+		}
+
+		return
+	}()
+
+}