Browse Source

eliminate partial writes

Darien Raymond 7 years ago
parent
commit
ebea255c74

+ 11 - 0
common/buf/io.go

@@ -47,6 +47,17 @@ func ReadAtLeastFrom(reader io.Reader, size int) Supplier {
 	}
 }
 
+func WriteAllBytes(writer io.Writer, payload []byte) error {
+	for len(payload) > 0 {
+		n, err := writer.Write(payload)
+		if err != nil {
+			return err
+		}
+		payload = payload[n:]
+	}
+	return nil
+}
+
 // NewReader creates a new Reader.
 // The Reader instance doesn't take the ownership of reader.
 func NewReader(reader io.Reader) Reader {

+ 1 - 4
common/buf/writer.go

@@ -179,10 +179,7 @@ func (w *seqWriter) WriteMultiBuffer(mb MultiBuffer) error {
 	defer mb.Release()
 
 	for _, b := range mb {
-		if b.IsEmpty() {
-			continue
-		}
-		if _, err := w.writer.Write(b.Bytes()); err != nil {
+		if err := WriteAllBytes(w.writer, b.Bytes()); err != nil {
 			return err
 		}
 	}

+ 1 - 1
proxy/shadowsocks/ota.go

@@ -118,7 +118,7 @@ func (w *ChunkWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 		payloadLen, _ := mb.Read(w.buffer[2+AuthSize:])
 		serial.Uint16ToBytes(uint16(payloadLen), w.buffer[:0])
 		w.auth.Authenticate(w.buffer[2+AuthSize : 2+AuthSize+payloadLen])(w.buffer[2:])
-		if _, err := w.writer.Write(w.buffer[:2+AuthSize+payloadLen]); err != nil {
+		if err := buf.WriteAllBytes(w.writer, w.buffer[:2+AuthSize+payloadLen]); err != nil {
 			return err
 		}
 		if mb.IsEmpty() {

+ 2 - 2
proxy/shadowsocks/protocol.go

@@ -132,7 +132,7 @@ func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Wri
 	if account.Cipher.IVSize() > 0 {
 		iv = make([]byte, account.Cipher.IVSize())
 		common.Must2(rand.Read(iv))
-		if _, err = writer.Write(iv); err != nil {
+		if err := buf.WriteAllBytes(writer, iv); err != nil {
 			return nil, newError("failed to write IV")
 		}
 	}
@@ -199,7 +199,7 @@ func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Wr
 	if account.Cipher.IVSize() > 0 {
 		iv = make([]byte, account.Cipher.IVSize())
 		common.Must2(rand.Read(iv))
-		if _, err = writer.Write(iv); err != nil {
+		if err := buf.WriteAllBytes(writer, iv); err != nil {
 			return nil, newError("failed to write IV.").Base(err)
 		}
 	}

+ 5 - 8
proxy/socks/protocol.go

@@ -234,8 +234,7 @@ func hasAuthMethod(expectedAuth byte, authCandidates []byte) bool {
 }
 
 func writeSocks5AuthenticationResponse(writer io.Writer, version byte, auth byte) error {
-	_, err := writer.Write([]byte{version, auth})
-	return err
+	return buf.WriteAllBytes(writer, []byte{version, auth})
 }
 
 func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error {
@@ -247,8 +246,7 @@ func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, po
 		return err
 	}
 
-	_, err := writer.Write(buffer.Bytes())
-	return err
+	return buf.WriteAllBytes(writer, buffer.Bytes())
 }
 
 func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, port net.Port) error {
@@ -258,8 +256,7 @@ func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, po
 	common.Must2(buffer.AppendBytes(0x00, errCode))
 	common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value())))
 	common.Must2(buffer.Write(address.IP()))
-	_, err := writer.Write(buffer.Bytes())
-	return err
+	return buf.WriteAllBytes(writer, buffer.Bytes())
 }
 
 func DecodeUDPPacket(packet *buf.Buffer) (*protocol.RequestHeader, error) {
@@ -365,7 +362,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
 		common.Must2(b.Write([]byte(account.Password)))
 	}
 
-	if _, err := writer.Write(b.Bytes()); err != nil {
+	if err := buf.WriteAllBytes(writer, b.Bytes()); err != nil {
 		return nil, err
 	}
 
@@ -400,7 +397,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
 		return nil, err
 	}
 
-	if _, err := writer.Write(b.Bytes()); err != nil {
+	if err := buf.WriteAllBytes(writer, b.Bytes()); err != nil {
 		return nil, err
 	}
 

+ 1 - 1
transport/internet/headers/http/http.go

@@ -103,7 +103,7 @@ func (w *HeaderWriter) Write(writer io.Writer) error {
 	if w.header == nil {
 		return nil
 	}
-	_, err := writer.Write(w.header.Bytes())
+	err := buf.WriteAllBytes(writer, w.header.Bytes())
 	w.header.Release()
 	w.header = nil
 	return err