Browse Source

check for nil

v2ray 9 years ago
parent
commit
2b966b039f
1 changed files with 47 additions and 19 deletions
  1. 47 19
      transport/internet/kcp/connection.go

+ 47 - 19
transport/internet/kcp/connection.go

@@ -12,9 +12,10 @@ import (
 )
 
 var (
-	errTimeout        = errors.New("i/o timeout")
-	errBrokenPipe     = errors.New("broken pipe")
-	errClosedListener = errors.New("Listener closed.")
+	errTimeout          = errors.New("i/o timeout")
+	errBrokenPipe       = errors.New("broken pipe")
+	errClosedListener   = errors.New("Listener closed.")
+	errClosedConnection = errors.New("Connection closed.")
 )
 
 const (
@@ -59,7 +60,7 @@ func nowMillisec() int64 {
 
 // UDPSession defines a KCP session implemented by UDP
 type UDPSession struct {
-	sync.Mutex
+	sync.RWMutex
 	state         ConnState
 	kcp           *KCP // the core ARQ
 	kcpAccess     sync.Mutex
@@ -114,24 +115,22 @@ func (this *UDPSession) Elapsed() uint32 {
 
 // Read implements the Conn Read method.
 func (s *UDPSession) Read(b []byte) (int, error) {
-	if s.state == ConnStateReadyToClose || s.state == ConnStateClosed {
+	if s == nil || s.state == ConnStateReadyToClose || s.state == ConnStateClosed {
 		return 0, io.EOF
 	}
 
 	for {
-		s.Lock()
+		s.RLock()
 		if s.state == ConnStateReadyToClose || s.state == ConnStateClosed {
-			s.Unlock()
+			s.RUnlock()
 			return 0, io.EOF
 		}
 
-		if !s.rd.IsZero() {
-			if time.Now().After(s.rd) {
-				s.Unlock()
-				return 0, errTimeout
-			}
+		if !s.rd.IsZero() && s.rd.Before(time.Now()) {
+			s.RUnlock()
+			return 0, errTimeout
 		}
-		s.Unlock()
+		s.RUnlock()
 
 		s.kcpAccess.Lock()
 		nBytes := s.kcp.Recv(b)
@@ -148,18 +147,22 @@ func (s *UDPSession) Read(b []byte) (int, error) {
 
 // Write implements the Conn Write method.
 func (s *UDPSession) Write(b []byte) (int, error) {
-	if s.state == ConnStateReadyToClose ||
+	if s == nil ||
+		s.state == ConnStateReadyToClose ||
 		s.state == ConnStatePeerClosed ||
 		s.state == ConnStateClosed {
 		return 0, io.ErrClosedPipe
 	}
 
 	for {
+		s.RLock()
 		if s.state == ConnStateReadyToClose ||
 			s.state == ConnStatePeerClosed ||
 			s.state == ConnStateClosed {
+			s.RUnlock()
 			return 0, io.ErrClosedPipe
 		}
+		s.RUnlock()
 
 		s.kcpAccess.Lock()
 		if s.kcp.WaitSnd() < int(s.kcp.snd_wnd) {
@@ -182,7 +185,7 @@ func (s *UDPSession) Write(b []byte) (int, error) {
 }
 
 func (this *UDPSession) Terminate() {
-	if this.state == ConnStateClosed {
+	if this == nil || this.state == ConnStateClosed {
 		return
 	}
 	this.Lock()
@@ -197,12 +200,12 @@ func (this *UDPSession) Terminate() {
 
 func (this *UDPSession) NotifyTermination() {
 	for i := 0; i < 16; i++ {
-		this.Lock()
+		this.RLock()
 		if this.state == ConnStateClosed {
-			this.Unlock()
+			this.RUnlock()
 			break
 		}
-		this.Unlock()
+		this.RUnlock()
 		buffer := alloc.NewSmallBuffer().Clear()
 		buffer.AppendBytes(byte(CommandTerminate), byte(OptionClose), byte(0), byte(0), byte(0), byte(0))
 		this.output(buffer)
@@ -215,6 +218,9 @@ func (this *UDPSession) NotifyTermination() {
 
 // Close closes the connection.
 func (s *UDPSession) Close() error {
+	if s == nil || s.state == ConnStateClosed || s.state == ConnStateReadyToClose {
+		return errClosedConnection
+	}
 	log.Debug("KCP|Connection: Closing connection to ", s.remote)
 	s.Lock()
 	defer s.Unlock()
@@ -235,14 +241,25 @@ func (s *UDPSession) Close() error {
 
 // LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it.
 func (s *UDPSession) LocalAddr() net.Addr {
+	if s == nil {
+		return nil
+	}
 	return s.local
 }
 
 // RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it.
-func (s *UDPSession) RemoteAddr() net.Addr { return s.remote }
+func (s *UDPSession) RemoteAddr() net.Addr {
+	if s == nil {
+		return nil
+	}
+	return s.remote
+}
 
 // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
 func (s *UDPSession) SetDeadline(t time.Time) error {
+	if s == nil || s.state != ConnStateActive {
+		return errClosedConnection
+	}
 	s.Lock()
 	defer s.Unlock()
 	s.rd = t
@@ -252,6 +269,9 @@ func (s *UDPSession) SetDeadline(t time.Time) error {
 
 // SetReadDeadline implements the Conn SetReadDeadline method.
 func (s *UDPSession) SetReadDeadline(t time.Time) error {
+	if s == nil || s.state != ConnStateActive {
+		return errClosedConnection
+	}
 	s.Lock()
 	defer s.Unlock()
 	s.rd = t
@@ -260,6 +280,9 @@ func (s *UDPSession) SetReadDeadline(t time.Time) error {
 
 // SetWriteDeadline implements the Conn SetWriteDeadline method.
 func (s *UDPSession) SetWriteDeadline(t time.Time) error {
+	if s == nil || s.state != ConnStateActive {
+		return errClosedConnection
+	}
 	s.Lock()
 	defer s.Unlock()
 	s.wd = t
@@ -268,7 +291,12 @@ func (s *UDPSession) SetWriteDeadline(t time.Time) error {
 
 func (s *UDPSession) output(payload *alloc.Buffer) {
 	defer payload.Release()
+	if s == nil {
+		return
+	}
 
+	s.RLock()
+	defer s.RUnlock()
 	if s.state == ConnStatePeerClosed || s.state == ConnStateClosed {
 		return
 	}