Darien Raymond il y a 8 ans
Parent
commit
326a54baea
1 fichiers modifiés avec 6 ajouts et 12 suppressions
  1. 6 12
      proxy/shadowsocks/protocol.go

+ 6 - 12
proxy/shadowsocks/protocol.go

@@ -32,8 +32,7 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
 	defer buffer.Release()
 
 	ivLen := account.Cipher.IVSize()
-	err = buffer.AppendSupplier(buf.ReadFullFrom(reader, ivLen))
-	if err != nil {
+	if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, ivLen)); err != nil {
 		return nil, nil, newError("failed to read IV").Base(err)
 	}
 
@@ -52,15 +51,13 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
 		Command: protocol.RequestCommandTCP,
 	}
 
-	buffer.Clear()
-	err = buffer.AppendSupplier(buf.ReadFullFrom(reader, 1))
-	if err != nil {
+	if err := buffer.Reset(buf.ReadFullFrom(reader, 1)); err != nil {
 		return nil, nil, newError("failed to read address type").Base(err)
 	}
 
 	addrType := (buffer.Byte(0) & 0x0F)
 	if (buffer.Byte(0) & 0x10) == 0x10 {
-		request.Option |= RequestOptionOneTimeAuth
+		request.Option.Set(RequestOptionOneTimeAuth)
 	}
 
 	if request.Option.Has(RequestOptionOneTimeAuth) && account.OneTimeAuth == Account_Disabled {
@@ -73,20 +70,17 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
 
 	switch addrType {
 	case AddrTypeIPv4:
-		err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 4))
-		if err != nil {
+		if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil {
 			return nil, nil, newError("failed to read IPv4 address").Base(err)
 		}
 		request.Address = v2net.IPAddress(buffer.BytesFrom(-4))
 	case AddrTypeIPv6:
-		err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 16))
-		if err != nil {
+		if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil {
 			return nil, nil, newError("failed to read IPv6 address").Base(err)
 		}
 		request.Address = v2net.IPAddress(buffer.BytesFrom(-16))
 	case AddrTypeDomain:
-		err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 1))
-		if err != nil {
+		if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
 			return nil, nil, newError("failed to read domain lenth.").Base(err)
 		}
 		domainLength := int(buffer.BytesFrom(-1)[0])