Parcourir la source

simplify buf.BufferedReader

Darien Raymond il y a 7 ans
Parent
commit
148a7d064d

+ 2 - 2
app/proxyman/mux/mux.go

@@ -258,7 +258,7 @@ func (m *Client) fetchOutput() {
 		common.Must(m.done.Close())
 	}()
 
-	reader := buf.NewBufferedReader(m.link.Reader)
+	reader := &buf.BufferedReader{Reader: m.link.Reader}
 
 	for {
 		meta, err := ReadMetadata(reader)
@@ -456,7 +456,7 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedRead
 
 func (w *ServerWorker) run(ctx context.Context) {
 	input := w.link.Reader
-	reader := buf.NewBufferedReader(input)
+	reader := &buf.BufferedReader{Reader: input}
 
 	defer w.sessionManager.Close()
 

+ 1 - 1
app/proxyman/mux/mux_test.go

@@ -59,7 +59,7 @@ func TestReaderWriter(t *testing.T) {
 	assert(writePayload(writer2, 'y'), IsNil)
 	writer2.Close()
 
-	bytesReader := buf.NewBufferedReader(pReader)
+	bytesReader := &buf.BufferedReader{Reader: pReader}
 
 	meta, err := ReadMetadata(bytesReader)
 	assert(err, IsNil)

+ 34 - 49
common/buf/reader.go

@@ -75,32 +75,17 @@ func (r *BytesToBufferReader) ReadMultiBuffer() (MultiBuffer, error) {
 
 // BufferedReader is a Reader that keeps its internal buffer.
 type BufferedReader struct {
-	stream   Reader
-	leftOver MultiBuffer
-	buffered bool
-}
-
-// NewBufferedReader returns a new BufferedReader.
-func NewBufferedReader(reader Reader) *BufferedReader {
-	return &BufferedReader{
-		stream:   reader,
-		buffered: true,
-	}
-}
-
-// SetBuffered sets whether to keep the interal buffer.
-func (r *BufferedReader) SetBuffered(f bool) {
-	r.buffered = f
-}
-
-// IsBuffered returns true if internal buffer is used.
-func (r *BufferedReader) IsBuffered() bool {
-	return r.buffered
+	// Reader is the underlying reader to be read from
+	Reader Reader
+	// Buffer is the internal buffer to be read from first
+	Buffer MultiBuffer
+	// Direct indicates whether or not to use the internal buffer
+	Direct bool
 }
 
 // BufferedBytes returns the number of bytes that is cached in this reader.
 func (r *BufferedReader) BufferedBytes() int32 {
-	return r.leftOver.Len()
+	return r.Buffer.Len()
 }
 
 // ReadByte implements io.ByteReader.
@@ -112,26 +97,26 @@ func (r *BufferedReader) ReadByte() (byte, error) {
 
 // Read implements io.Reader. It reads from internal buffer first (if available) and then reads from the underlying reader.
 func (r *BufferedReader) Read(b []byte) (int, error) {
-	if r.leftOver != nil {
-		nBytes, _ := r.leftOver.Read(b)
-		if r.leftOver.IsEmpty() {
-			r.leftOver.Release()
-			r.leftOver = nil
+	if r.Buffer != nil {
+		nBytes, _ := r.Buffer.Read(b)
+		if r.Buffer.IsEmpty() {
+			r.Buffer.Release()
+			r.Buffer = nil
 		}
 		return nBytes, nil
 	}
 
-	if !r.buffered {
-		if reader, ok := r.stream.(io.Reader); ok {
+	if r.Direct {
+		if reader, ok := r.Reader.(io.Reader); ok {
 			return reader.Read(b)
 		}
 	}
 
-	mb, err := r.stream.ReadMultiBuffer()
+	mb, err := r.Reader.ReadMultiBuffer()
 	if mb != nil {
 		nBytes, _ := mb.Read(b)
 		if !mb.IsEmpty() {
-			r.leftOver = mb
+			r.Buffer = mb
 		}
 		return nBytes, err
 	}
@@ -140,28 +125,28 @@ func (r *BufferedReader) Read(b []byte) (int, error) {
 
 // ReadMultiBuffer implements Reader.
 func (r *BufferedReader) ReadMultiBuffer() (MultiBuffer, error) {
-	if r.leftOver != nil {
-		mb := r.leftOver
-		r.leftOver = nil
+	if r.Buffer != nil {
+		mb := r.Buffer
+		r.Buffer = nil
 		return mb, nil
 	}
 
-	return r.stream.ReadMultiBuffer()
+	return r.Reader.ReadMultiBuffer()
 }
 
 // ReadAtMost returns a MultiBuffer with at most size.
 func (r *BufferedReader) ReadAtMost(size int32) (MultiBuffer, error) {
-	if r.leftOver == nil {
-		mb, err := r.stream.ReadMultiBuffer()
+	if r.Buffer == nil {
+		mb, err := r.Reader.ReadMultiBuffer()
 		if mb.IsEmpty() && err != nil {
 			return nil, err
 		}
-		r.leftOver = mb
+		r.Buffer = mb
 	}
 
-	mb := r.leftOver.SliceBySize(size)
-	if r.leftOver.IsEmpty() {
-		r.leftOver = nil
+	mb := r.Buffer.SliceBySize(size)
+	if r.Buffer.IsEmpty() {
+		r.Buffer = nil
 	}
 	return mb, nil
 }
@@ -169,16 +154,16 @@ func (r *BufferedReader) ReadAtMost(size int32) (MultiBuffer, error) {
 func (r *BufferedReader) writeToInternal(writer io.Writer) (int64, error) {
 	mbWriter := NewWriter(writer)
 	totalBytes := int64(0)
-	if r.leftOver != nil {
-		totalBytes += int64(r.leftOver.Len())
-		if err := mbWriter.WriteMultiBuffer(r.leftOver); err != nil {
+	if r.Buffer != nil {
+		totalBytes += int64(r.Buffer.Len())
+		if err := mbWriter.WriteMultiBuffer(r.Buffer); err != nil {
 			return 0, err
 		}
-		r.leftOver = nil
+		r.Buffer = nil
 	}
 
 	for {
-		mb, err := r.stream.ReadMultiBuffer()
+		mb, err := r.Reader.ReadMultiBuffer()
 		if mb != nil {
 			totalBytes += int64(mb.Len())
 			if werr := mbWriter.WriteMultiBuffer(mb); werr != nil {
@@ -202,8 +187,8 @@ func (r *BufferedReader) WriteTo(writer io.Writer) (int64, error) {
 
 // Close implements io.Closer.
 func (r *BufferedReader) Close() error {
-	if !r.leftOver.IsEmpty() {
-		r.leftOver.Release()
+	if !r.Buffer.IsEmpty() {
+		r.Buffer.Release()
 	}
-	return common.Close(r.stream)
+	return common.Close(r.Reader)
 }

+ 2 - 2
common/buf/reader_test.go

@@ -39,7 +39,7 @@ func TestBytesReaderWriteTo(t *testing.T) {
 	assert := With(t)
 
 	pReader, pWriter := pipe.New()
-	reader := NewBufferedReader(pReader)
+	reader := &BufferedReader{Reader: pReader}
 	b1 := New()
 	b1.AppendBytes('a', 'b', 'c')
 	b2 := New()
@@ -66,7 +66,7 @@ func TestBytesReaderMultiBuffer(t *testing.T) {
 	assert := With(t)
 
 	pReader, pWriter := pipe.New()
-	reader := NewBufferedReader(pReader)
+	reader := &BufferedReader{Reader: pReader}
 	b1 := New()
 	b1.AppendBytes('a', 'b', 'c')
 	b2 := New()

+ 1 - 1
common/buf/writer_test.go

@@ -67,7 +67,7 @@ func TestDiscardBytesMultiBuffer(t *testing.T) {
 	common.Must2(buffer.ReadFrom(io.LimitReader(rand.Reader, size)))
 
 	r := NewReader(buffer)
-	nBytes, err := io.Copy(DiscardBytes, NewBufferedReader(r))
+	nBytes, err := io.Copy(DiscardBytes, &BufferedReader{Reader: r})
 	assert(nBytes, Equals, int64(size))
 	assert(err, IsNil)
 }

+ 1 - 1
common/crypto/auth.go

@@ -91,7 +91,7 @@ type AuthenticationReader struct {
 func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, transferType protocol.TransferType) *AuthenticationReader {
 	return &AuthenticationReader{
 		auth:         auth,
-		reader:       buf.NewBufferedReader(buf.NewReader(reader)),
+		reader:       &buf.BufferedReader{Reader: buf.NewReader(reader)},
 		sizeParser:   sizeParser,
 		transferType: transferType,
 		size:         -1,

+ 1 - 1
common/crypto/chunk.go

@@ -68,7 +68,7 @@ type ChunkStreamReader struct {
 func NewChunkStreamReader(sizeDecoder ChunkSizeDecoder, reader io.Reader) *ChunkStreamReader {
 	return &ChunkStreamReader{
 		sizeDecoder: sizeDecoder,
-		reader:      buf.NewBufferedReader(buf.NewReader(reader)),
+		reader:      &buf.BufferedReader{Reader: buf.NewReader(reader)},
 		buffer:      make([]byte, sizeDecoder.SizeBytes()),
 	}
 }

+ 2 - 2
common/net/connection.go

@@ -38,13 +38,13 @@ func ConnectionInputMulti(writer buf.Writer) ConnectionOption {
 
 func ConnectionOutput(reader io.Reader) ConnectionOption {
 	return func(c *connection) {
-		c.reader = buf.NewBufferedReader(buf.NewReader(reader))
+		c.reader = &buf.BufferedReader{Reader: buf.NewReader(reader)}
 	}
 }
 
 func ConnectionOutputMulti(reader buf.Reader) ConnectionOption {
 	return func(c *connection) {
-		c.reader = buf.NewBufferedReader(reader)
+		c.reader = &buf.BufferedReader{Reader: reader}
 	}
 }
 

+ 1 - 1
proxy/http/server.go

@@ -268,7 +268,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
 	}
 
 	responseDone := func() error {
-		responseReader := bufio.NewReaderSize(buf.NewBufferedReader(link.Reader), buf.Size)
+		responseReader := bufio.NewReaderSize(&buf.BufferedReader{Reader: link.Reader}, buf.Size)
 		response, err := http.ReadResponse(responseReader, request)
 		if err == nil {
 			http_proto.RemoveHopByHopHeaders(response.Header)

+ 2 - 2
proxy/shadowsocks/protocol.go

@@ -52,7 +52,7 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
 	if err != nil {
 		return nil, nil, newError("failed to initialize decoding stream").Base(err).AtError()
 	}
-	br := buf.NewBufferedReader(r)
+	br := &buf.BufferedReader{Reader: r}
 	reader = nil
 
 	authenticator := NewAuthenticator(HeaderKeyGenerator(account.Key, iv))
@@ -109,7 +109,7 @@ func ReadTCPSession(user *protocol.User, reader io.Reader) (*protocol.RequestHea
 		return nil, nil, newError("invalid remote address.")
 	}
 
-	br.SetBuffered(false)
+	br.Direct = true
 
 	var chunkReader buf.Reader
 	if request.Option.Has(RequestOptionOneTimeAuth) {

+ 3 - 3
proxy/shadowsocks/server.go

@@ -140,8 +140,8 @@ func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
 func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, dispatcher core.Dispatcher) error {
 	sessionPolicy := s.v.PolicyManager().ForLevel(s.user.Level)
 	conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake))
-	bufferedReader := buf.NewBufferedReader(buf.NewReader(conn))
-	request, bodyReader, err := ReadTCPSession(s.user, bufferedReader)
+	bufferedReader := buf.BufferedReader{Reader: buf.NewReader(conn)}
+	request, bodyReader, err := ReadTCPSession(s.user, &bufferedReader)
 	if err != nil {
 		log.Record(&log.AccessMessage{
 			From:   conn.RemoteAddr(),
@@ -153,7 +153,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
 	}
 	conn.SetReadDeadline(time.Time{})
 
-	bufferedReader.SetBuffered(false)
+	bufferedReader.Direct = true
 
 	dest := request.Destination()
 	log.Record(&log.AccessMessage{

+ 1 - 1
proxy/socks/server.go

@@ -70,7 +70,7 @@ func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispa
 		newError("failed to set deadline").Base(err).WithContext(ctx).WriteToLog()
 	}
 
-	reader := buf.NewBufferedReader(buf.NewReader(conn))
+	reader := &buf.BufferedReader{Reader: buf.NewReader(conn)}
 
 	inboundDest, ok := proxy.InboundEntryPointFromContext(ctx)
 	if !ok {

+ 1 - 1
proxy/vmess/inbound/inbound.go

@@ -224,7 +224,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
 		return newError("unable to set read deadline").Base(err).AtWarning()
 	}
 
-	reader := buf.NewBufferedReader(buf.NewReader(connection))
+	reader := &buf.BufferedReader{Reader: buf.NewReader(connection)}
 
 	session := encoding.NewServerSession(h.clients, h.sessionHistory)
 	request, err := session.DecodeRequestHeader(reader)

+ 2 - 2
proxy/vmess/outbound/outbound.go

@@ -146,14 +146,14 @@ func (v *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia
 	responseDone := func() error {
 		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
 
-		reader := buf.NewBufferedReader(buf.NewReader(conn))
+		reader := &buf.BufferedReader{Reader: buf.NewReader(conn)}
 		header, err := session.DecodeResponseHeader(reader)
 		if err != nil {
 			return newError("failed to read header").Base(err)
 		}
 		v.handleCommand(rec.Destination(), header.Command)
 
-		reader.SetBuffered(false)
+		reader.Direct = true
 		bodyReader := session.DecodeResponseBody(request, reader)
 
 		return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))

+ 2 - 2
transport/internet/http/dialer.go

@@ -84,11 +84,11 @@ func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error
 	}
 
 	preader, pwriter := pipe.New(pipe.WithSizeLimit(20 * 1024))
-	breader := buf.NewBufferedReader(preader)
+	breader := &buf.BufferedReader{Reader: preader}
 	request := &http.Request{
 		Method: "PUT",
 		Host:   httpSettings.getRandomHost(),
-		Body:   buf.NewBufferedReader(preader),
+		Body:   breader,
 		URL: &url.URL{
 			Scheme: "https",
 			Host:   dest.NetAddr(),