|
|
@@ -11,14 +11,17 @@ import (
|
|
|
"github.com/v2fly/v2ray-core/v5/common/buf"
|
|
|
"github.com/v2fly/v2ray-core/v5/common/net"
|
|
|
"github.com/v2fly/v2ray-core/v5/transport/internet"
|
|
|
+
|
|
|
+ "github.com/pion/transport/v2/replaydetector"
|
|
|
)
|
|
|
|
|
|
func NewClientUDPSession(ctx context.Context, conn io.ReadWriteCloser, packetProcessor UDPClientPacketProcessor) *ClientUDPSession {
|
|
|
session := &ClientUDPSession{
|
|
|
- locker: &sync.Mutex{},
|
|
|
+ locker: &sync.RWMutex{},
|
|
|
conn: conn,
|
|
|
packetProcessor: packetProcessor,
|
|
|
sessionMap: make(map[string]*ClientUDPSessionConn),
|
|
|
+ sessionMapAlias: make(map[string]string),
|
|
|
}
|
|
|
session.ctx, session.finish = context.WithCancel(ctx)
|
|
|
|
|
|
@@ -27,16 +30,87 @@ func NewClientUDPSession(ctx context.Context, conn io.ReadWriteCloser, packetPro
|
|
|
}
|
|
|
|
|
|
type ClientUDPSession struct {
|
|
|
- locker *sync.Mutex
|
|
|
+ locker *sync.RWMutex
|
|
|
|
|
|
conn io.ReadWriteCloser
|
|
|
packetProcessor UDPClientPacketProcessor
|
|
|
sessionMap map[string]*ClientUDPSessionConn
|
|
|
|
|
|
+ sessionMapAlias map[string]string
|
|
|
+
|
|
|
ctx context.Context
|
|
|
finish func()
|
|
|
}
|
|
|
|
|
|
+func (c *ClientUDPSession) GetCachedState(sessionID string) UDPClientPacketProcessorCachedState {
|
|
|
+ c.locker.RLock()
|
|
|
+ defer c.locker.RUnlock()
|
|
|
+
|
|
|
+ state, ok := c.sessionMap[sessionID]
|
|
|
+ if !ok {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ return state.cachedProcessorState
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ClientUDPSession) GetCachedServerState(serverSessionID string) UDPClientPacketProcessorCachedState {
|
|
|
+ c.locker.RLock()
|
|
|
+ defer c.locker.RUnlock()
|
|
|
+
|
|
|
+ clientSessionID := c.getCachedStateAlias(serverSessionID)
|
|
|
+ if clientSessionID == "" {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ state, ok := c.sessionMap[clientSessionID]
|
|
|
+ if !ok {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+
|
|
|
+ if serverState, ok := state.trackedServerSessionID[serverSessionID]; !ok {
|
|
|
+ return nil
|
|
|
+ } else {
|
|
|
+ return serverState.cachedRecvProcessorState
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ClientUDPSession) getCachedStateAlias(serverSessionID string) string {
|
|
|
+ state, ok := c.sessionMapAlias[serverSessionID]
|
|
|
+ if !ok {
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+ return state
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ClientUDPSession) PutCachedState(sessionID string, cache UDPClientPacketProcessorCachedState) {
|
|
|
+ c.locker.RLock()
|
|
|
+ defer c.locker.RUnlock()
|
|
|
+
|
|
|
+ state, ok := c.sessionMap[sessionID]
|
|
|
+ if !ok {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ state.cachedProcessorState = cache
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ClientUDPSession) PutCachedServerState(serverSessionID string, cache UDPClientPacketProcessorCachedState) {
|
|
|
+ c.locker.RLock()
|
|
|
+ defer c.locker.RUnlock()
|
|
|
+
|
|
|
+ clientSessionID := c.getCachedStateAlias(serverSessionID)
|
|
|
+ if clientSessionID == "" {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ state, ok := c.sessionMap[clientSessionID]
|
|
|
+ if !ok {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if serverState, ok := state.trackedServerSessionID[serverSessionID]; ok {
|
|
|
+ serverState.cachedRecvProcessorState = cache
|
|
|
+ return
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func (c *ClientUDPSession) Close() error {
|
|
|
c.finish()
|
|
|
return c.conn.Close()
|
|
|
@@ -45,7 +119,7 @@ func (c *ClientUDPSession) Close() error {
|
|
|
func (c *ClientUDPSession) WriteUDPRequest(request *UDPRequest) error {
|
|
|
buffer := buf.New()
|
|
|
defer buffer.Release()
|
|
|
- err := c.packetProcessor.EncodeUDPRequest(request, buffer)
|
|
|
+ err := c.packetProcessor.EncodeUDPRequest(request, buffer, c)
|
|
|
if request.Payload != nil {
|
|
|
request.Payload.Release()
|
|
|
}
|
|
|
@@ -69,7 +143,7 @@ func (c *ClientUDPSession) KeepReading() {
|
|
|
return
|
|
|
}
|
|
|
if n != 0 {
|
|
|
- err := c.packetProcessor.DecodeUDPResp(buffer[:n], udpResp)
|
|
|
+ err := c.packetProcessor.DecodeUDPResp(buffer[:n], udpResp, c)
|
|
|
if err != nil {
|
|
|
newError("unable to decode udp response").Base(err).WriteToLog()
|
|
|
continue
|
|
|
@@ -78,13 +152,14 @@ func (c *ClientUDPSession) KeepReading() {
|
|
|
{
|
|
|
timeDifference := int64(udpResp.TimeStamp) - time.Now().Unix()
|
|
|
if timeDifference < -30 || timeDifference > 30 {
|
|
|
- newError("udp packet timestamp difference too large, packet discarded").WriteToLog()
|
|
|
+ newError("udp packet timestamp difference too large, packet discarded, time diff = ", timeDifference).WriteToLog()
|
|
|
continue
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- c.locker.Lock()
|
|
|
+ c.locker.RLock()
|
|
|
session, ok := c.sessionMap[string(udpResp.ClientSessionID[:])]
|
|
|
+ c.locker.RUnlock()
|
|
|
if ok {
|
|
|
select {
|
|
|
case session.readChan <- udpResp:
|
|
|
@@ -93,7 +168,6 @@ func (c *ClientUDPSession) KeepReading() {
|
|
|
} else {
|
|
|
newError("misbehaving server: unknown client session ID").Base(err).WriteToLog()
|
|
|
}
|
|
|
- c.locker.Unlock()
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
@@ -108,12 +182,13 @@ func (c *ClientUDPSession) NewSessionConn() (internet.AbstractPacketConn, error)
|
|
|
connctx, connfinish := context.WithCancel(c.ctx)
|
|
|
|
|
|
sessionConn := &ClientUDPSessionConn{
|
|
|
- sessionID: string(sessionID),
|
|
|
- readChan: make(chan *UDPResponse, 16),
|
|
|
- parent: c,
|
|
|
- ctx: connctx,
|
|
|
- finish: connfinish,
|
|
|
- nextWritePacketID: 0,
|
|
|
+ sessionID: string(sessionID),
|
|
|
+ readChan: make(chan *UDPResponse, 128),
|
|
|
+ parent: c,
|
|
|
+ ctx: connctx,
|
|
|
+ finish: connfinish,
|
|
|
+ nextWritePacketID: 0,
|
|
|
+ trackedServerSessionID: make(map[string]*ClientUDPSessionServerTracker),
|
|
|
}
|
|
|
c.locker.Lock()
|
|
|
c.sessionMap[sessionConn.sessionID] = sessionConn
|
|
|
@@ -121,19 +196,33 @@ func (c *ClientUDPSession) NewSessionConn() (internet.AbstractPacketConn, error)
|
|
|
return sessionConn, nil
|
|
|
}
|
|
|
|
|
|
+type ClientUDPSessionServerTracker struct {
|
|
|
+ cachedRecvProcessorState UDPClientPacketProcessorCachedState
|
|
|
+ rxReplayDetector replaydetector.ReplayDetector
|
|
|
+ lastSeen time.Time
|
|
|
+}
|
|
|
+
|
|
|
type ClientUDPSessionConn struct {
|
|
|
sessionID string
|
|
|
readChan chan *UDPResponse
|
|
|
parent *ClientUDPSession
|
|
|
|
|
|
- nextWritePacketID uint64
|
|
|
+ nextWritePacketID uint64
|
|
|
+ trackedServerSessionID map[string]*ClientUDPSessionServerTracker
|
|
|
+
|
|
|
+ cachedProcessorState UDPClientPacketProcessorCachedState
|
|
|
|
|
|
ctx context.Context
|
|
|
finish func()
|
|
|
}
|
|
|
|
|
|
func (c *ClientUDPSessionConn) Close() error {
|
|
|
+ c.parent.locker.Lock()
|
|
|
delete(c.parent.sessionMap, c.sessionID)
|
|
|
+ for k := range c.trackedServerSessionID {
|
|
|
+ delete(c.parent.sessionMapAlias, k)
|
|
|
+ }
|
|
|
+ c.parent.locker.Unlock()
|
|
|
c.finish()
|
|
|
return nil
|
|
|
}
|
|
|
@@ -160,13 +249,44 @@ func (c *ClientUDPSessionConn) WriteTo(p []byte, addr gonet.Addr) (n int, err er
|
|
|
}
|
|
|
|
|
|
func (c *ClientUDPSessionConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
|
|
|
- select {
|
|
|
- case <-c.ctx.Done():
|
|
|
- return 0, nil, io.EOF
|
|
|
- case resp := <-c.readChan:
|
|
|
- n = copy(p, resp.Payload.Bytes())
|
|
|
- resp.Payload.Release()
|
|
|
- addr = &net.UDPAddr{IP: resp.Address.IP(), Port: resp.Port}
|
|
|
- }
|
|
|
- return
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-c.ctx.Done():
|
|
|
+ return 0, nil, io.EOF
|
|
|
+ case resp := <-c.readChan:
|
|
|
+ n = copy(p, resp.Payload.Bytes())
|
|
|
+ resp.Payload.Release()
|
|
|
+
|
|
|
+ var trackedState *ClientUDPSessionServerTracker
|
|
|
+ if trackedStateReceived, ok := c.trackedServerSessionID[string(resp.SessionID[:])]; !ok {
|
|
|
+ for key, value := range c.trackedServerSessionID {
|
|
|
+ if time.Since(value.lastSeen) > 65*time.Second {
|
|
|
+ delete(c.trackedServerSessionID, key)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ state := &ClientUDPSessionServerTracker{
|
|
|
+ rxReplayDetector: replaydetector.New(1024, ^uint64(0)),
|
|
|
+ }
|
|
|
+ c.trackedServerSessionID[string(resp.SessionID[:])] = state
|
|
|
+ c.parent.locker.RLock()
|
|
|
+ c.parent.sessionMapAlias[string(resp.SessionID[:])] = string(resp.ClientSessionID[:])
|
|
|
+ c.parent.locker.RUnlock()
|
|
|
+ trackedState = state
|
|
|
+ } else {
|
|
|
+ trackedState = trackedStateReceived
|
|
|
+ }
|
|
|
+
|
|
|
+ if accept, ok := trackedState.rxReplayDetector.Check(resp.PacketID); ok {
|
|
|
+ accept()
|
|
|
+ } else {
|
|
|
+ newError("misbehaving server: replayed packet").Base(err).WriteToLog()
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ trackedState.lastSeen = time.Now()
|
|
|
+
|
|
|
+ addr = &net.UDPAddr{IP: resp.Address.IP(), Port: resp.Port}
|
|
|
+ }
|
|
|
+ return n, addr, nil
|
|
|
+ }
|
|
|
}
|