浏览代码

bug fixes

Darien Raymond 8 年之前
父节点
当前提交
6d7aaa6535

+ 3 - 7
app/proxyman/mux/mux.go

@@ -63,10 +63,6 @@ func (m *ClientManager) onClientFinish() {
 	m.access.Lock()
 	defer m.access.Unlock()
 
-	if len(m.clients) < 10 {
-		return
-	}
-
 	activeClients := make([]*Client, 0, len(m.clients))
 
 	for _, client := range m.clients {
@@ -158,8 +154,8 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
 }
 
 func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool {
-	numSession := m.sessionManager.Size()
-	if numSession >= int(m.concurrency) || numSession >= maxTotal {
+	sm := m.sessionManager
+	if sm.Size() >= int(m.concurrency) || sm.Count() >= maxTotal {
 		return false
 	}
 
@@ -169,7 +165,7 @@ func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool
 	default:
 	}
 
-	s := m.sessionManager.Allocate()
+	s := sm.Allocate()
 	if s == nil {
 		return false
 	}

+ 12 - 5
app/proxyman/mux/session.go

@@ -8,8 +8,8 @@ import (
 
 type SessionManager struct {
 	sync.RWMutex
-	count    uint16
 	sessions map[uint16]*Session
+	count    uint16
 	closed   bool
 }
 
@@ -27,6 +27,13 @@ func (m *SessionManager) Size() int {
 	return len(m.sessions)
 }
 
+func (m *SessionManager) Count() int {
+	m.RLock()
+	defer m.RUnlock()
+
+	return int(m.count)
+}
+
 func (m *SessionManager) Allocate() *Session {
 	m.Lock()
 	defer m.Unlock()
@@ -71,8 +78,8 @@ func (m *SessionManager) Get(id uint16) (*Session, bool) {
 }
 
 func (m *SessionManager) CloseIfNoSession() bool {
-	m.RLock()
-	defer m.RUnlock()
+	m.Lock()
+	defer m.Unlock()
 
 	if m.closed {
 		return true
@@ -87,8 +94,8 @@ func (m *SessionManager) CloseIfNoSession() bool {
 }
 
 func (m *SessionManager) Close() {
-	m.RLock()
-	defer m.RUnlock()
+	m.Lock()
+	defer m.Unlock()
 
 	if m.closed {
 		return

+ 2 - 4
app/proxyman/mux/writer.go

@@ -66,14 +66,12 @@ func (w *Writer) writeData(mb buf.MultiBuffer) error {
 		return err
 	}
 	runtime.KeepAlive(meta)
-
-	mb2 := buf.NewMultiBuffer()
-	mb2.Append(frame)
-
 	if err := frame.AppendSupplier(serial.WriteUint16(uint16(mb.Len()))); err != nil {
 		return err
 	}
 
+	mb2 := buf.NewMultiBuffer()
+	mb2.Append(frame)
 	mb2.AppendMulti(mb)
 	return w.writer.Write(mb2)
 }

+ 6 - 0
common/buf/io.go

@@ -96,6 +96,12 @@ func ToBytesReader(stream Reader) io.Reader {
 
 // NewWriter creates a new Writer.
 func NewWriter(writer io.Writer) Writer {
+	if mw, ok := writer.(MultiBufferWriter); ok {
+		return &writerAdapter{
+			writer: mw,
+		}
+	}
+
 	return &BufferToBytesWriter{
 		writer: writer,
 	}

+ 3 - 11
common/buf/multi_buffer.go

@@ -1,9 +1,6 @@
 package buf
 
-import (
-	"io"
-	"net"
-)
+import "net"
 
 type MultiBufferWriter interface {
 	WriteMultiBuffer(MultiBuffer) (int, error)
@@ -32,17 +29,11 @@ func (b *MultiBuffer) AppendMulti(mb MultiBuffer) {
 }
 
 func (mb *MultiBuffer) Read(b []byte) (int, error) {
-	if len(*mb) == 0 {
-		return 0, io.EOF
-	}
 	endIndex := len(*mb)
 	totalBytes := 0
 	for i, bb := range *mb {
-		nBytes, err := bb.Read(b)
+		nBytes, _ := bb.Read(b)
 		totalBytes += nBytes
-		if err != nil {
-			return totalBytes, err
-		}
 		b = b[nBytes:]
 		if bb.IsEmpty() {
 			bb.Release()
@@ -96,6 +87,7 @@ func (mb *MultiBuffer) SliceBySize(size int) MultiBuffer {
 			endIndex = i
 			break
 		}
+		sliceSize += b.Len()
 		slice.Append(b)
 	}
 	*mb = (*mb)[endIndex:]

+ 4 - 4
common/buf/reader.go

@@ -42,12 +42,12 @@ type bufferToBytesReader struct {
 
 func (r *bufferToBytesReader) Read(b []byte) (int, error) {
 	if r.leftOver != nil {
-		nBytes, err := r.leftOver.Read(b)
+		nBytes, _ := r.leftOver.Read(b)
 		if r.leftOver.IsEmpty() {
 			r.leftOver.Release()
 			r.leftOver = nil
 		}
-		return nBytes, err
+		return nBytes, nil
 	}
 
 	mb, err := r.stream.Read()
@@ -55,11 +55,11 @@ func (r *bufferToBytesReader) Read(b []byte) (int, error) {
 		return 0, err
 	}
 
-	nBytes, err := mb.Read(b)
+	nBytes, _ := mb.Read(b)
 	if !mb.IsEmpty() {
 		r.leftOver = mb
 	}
-	return nBytes, err
+	return nBytes, nil
 }
 
 func (r *bufferToBytesReader) ReadMultiBuffer() (MultiBuffer, error) {

+ 9 - 5
common/buf/writer.go

@@ -9,11 +9,6 @@ type BufferToBytesWriter struct {
 
 // Write implements Writer.Write(). Write() takes ownership of the given buffer.
 func (w *BufferToBytesWriter) Write(mb MultiBuffer) error {
-	if mw, ok := w.writer.(MultiBufferWriter); ok {
-		_, err := mw.WriteMultiBuffer(mb)
-		return err
-	}
-
 	defer mb.Release()
 
 	bs := mb.ToNetBuffers()
@@ -21,6 +16,15 @@ func (w *BufferToBytesWriter) Write(mb MultiBuffer) error {
 	return err
 }
 
+type writerAdapter struct {
+	writer MultiBufferWriter
+}
+
+func (w *writerAdapter) Write(mb MultiBuffer) error {
+	_, err := w.writer.WriteMultiBuffer(mb)
+	return err
+}
+
 type bytesToBufferWriter struct {
 	writer Writer
 }

+ 1 - 4
common/crypto/auth.go

@@ -250,10 +250,7 @@ func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) (int, error)
 	const StartIndex = 17 * 1024
 	var totalBytes int
 	for {
-		payloadLen, err := mb.Read(w.buffer[StartIndex:])
-		if err != nil {
-			return 0, err
-		}
+		payloadLen, _ := mb.Read(w.buffer[StartIndex:])
 		nBytes, err := w.Write(w.buffer[StartIndex : StartIndex+payloadLen])
 		totalBytes += nBytes
 		if err != nil {

+ 1 - 4
proxy/shadowsocks/ota.go

@@ -121,10 +121,7 @@ func (w *ChunkWriter) Write(mb buf.MultiBuffer) error {
 	defer mb.Release()
 
 	for {
-		payloadLen, err := mb.Read(w.buffer[2+AuthSize:])
-		if err != nil {
-			return err
-		}
+		payloadLen, _ := mb.Read(w.buffer[2+AuthSize:])
 		serial.Uint16ToBytes(uint16(payloadLen), w.buffer[:0])
 		w.auth.Authenticate(w.buffer[2+AuthSize : 2+AuthSize+payloadLen])(w.buffer[2:])
 		if _, err := w.writer.Write(w.buffer[:2+AuthSize+payloadLen]); err != nil {

+ 4 - 4
transport/internet/headers/http/http.go

@@ -133,13 +133,13 @@ func (c *HttpConn) Read(b []byte) (int, error) {
 		c.oneTimeReader = nil
 	}
 
-	if c.readBuffer.Len() > 0 {
-		nBytes, err := c.readBuffer.Read(b)
-		if nBytes == c.readBuffer.Len() {
+	if !c.readBuffer.IsEmpty() {
+		nBytes, _ := c.readBuffer.Read(b)
+		if c.readBuffer.IsEmpty() {
 			c.readBuffer.Release()
 			c.readBuffer = nil
 		}
-		return nBytes, err
+		return nBytes, nil
 	}
 
 	return c.Conn.Read(b)

+ 1 - 4
transport/internet/websocket/connection.go

@@ -62,10 +62,7 @@ func (c *connection) WriteMultiBuffer(mb buf.MultiBuffer) (int, error) {
 	}
 	totalBytes := 0
 	for !mb.IsEmpty() {
-		nBytes, err := mb.Read(c.writeBuffer)
-		if err != nil {
-			return totalBytes, err
-		}
+		nBytes, _ := mb.Read(c.writeBuffer)
 		totalBytes += nBytes
 		if _, err := c.Write(c.writeBuffer[:nBytes]); err != nil {
 			return totalBytes, err