|
|
@@ -170,38 +170,40 @@ func (v *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
|
|
|
// 1 bytes reserved
|
|
|
request.Command = protocol.RequestCommand(buffer[37])
|
|
|
|
|
|
- request.Port = net.PortFromBytes(buffer[38:40])
|
|
|
-
|
|
|
- switch buffer[40] {
|
|
|
- case AddrTypeIPv4:
|
|
|
- _, err = io.ReadFull(decryptor, buffer[41:45]) // 4 bytes
|
|
|
- bufferLen += 4
|
|
|
- if err != nil {
|
|
|
- return nil, newError("failed to read IPv4 address").Base(err)
|
|
|
- }
|
|
|
- request.Address = net.IPAddress(buffer[41:45])
|
|
|
- case AddrTypeIPv6:
|
|
|
- _, err = io.ReadFull(decryptor, buffer[41:57]) // 16 bytes
|
|
|
- bufferLen += 16
|
|
|
- if err != nil {
|
|
|
- return nil, newError("failed to read IPv6 address").Base(err)
|
|
|
- }
|
|
|
- request.Address = net.IPAddress(buffer[41:57])
|
|
|
- case AddrTypeDomain:
|
|
|
- _, err = io.ReadFull(decryptor, buffer[41:42])
|
|
|
- if err != nil {
|
|
|
- return nil, newError("failed to read domain address").Base(err)
|
|
|
- }
|
|
|
- domainLength := int(buffer[41])
|
|
|
- if domainLength == 0 {
|
|
|
- return nil, newError("zero length domain").Base(err)
|
|
|
- }
|
|
|
- _, err = io.ReadFull(decryptor, buffer[42:42+domainLength])
|
|
|
- if err != nil {
|
|
|
- return nil, newError("failed to read domain address").Base(err)
|
|
|
+ if request.Command != protocol.RequestCommandMux {
|
|
|
+ request.Port = net.PortFromBytes(buffer[38:40])
|
|
|
+
|
|
|
+ switch buffer[40] {
|
|
|
+ case AddrTypeIPv4:
|
|
|
+ _, err = io.ReadFull(decryptor, buffer[41:45]) // 4 bytes
|
|
|
+ bufferLen += 4
|
|
|
+ if err != nil {
|
|
|
+ return nil, newError("failed to read IPv4 address").Base(err)
|
|
|
+ }
|
|
|
+ request.Address = net.IPAddress(buffer[41:45])
|
|
|
+ case AddrTypeIPv6:
|
|
|
+ _, err = io.ReadFull(decryptor, buffer[41:57]) // 16 bytes
|
|
|
+ bufferLen += 16
|
|
|
+ if err != nil {
|
|
|
+ return nil, newError("failed to read IPv6 address").Base(err)
|
|
|
+ }
|
|
|
+ request.Address = net.IPAddress(buffer[41:57])
|
|
|
+ case AddrTypeDomain:
|
|
|
+ _, err = io.ReadFull(decryptor, buffer[41:42])
|
|
|
+ if err != nil {
|
|
|
+ return nil, newError("failed to read domain address").Base(err)
|
|
|
+ }
|
|
|
+ domainLength := int(buffer[41])
|
|
|
+ if domainLength == 0 {
|
|
|
+ return nil, newError("zero length domain").Base(err)
|
|
|
+ }
|
|
|
+ _, err = io.ReadFull(decryptor, buffer[42:42+domainLength])
|
|
|
+ if err != nil {
|
|
|
+ return nil, newError("failed to read domain address").Base(err)
|
|
|
+ }
|
|
|
+ bufferLen += 1 + domainLength
|
|
|
+ request.Address = net.DomainAddress(string(buffer[42 : 42+domainLength]))
|
|
|
}
|
|
|
- bufferLen += 1 + domainLength
|
|
|
- request.Address = net.DomainAddress(string(buffer[42 : 42+domainLength]))
|
|
|
}
|
|
|
|
|
|
if padingLen > 0 {
|