Liu Song’s Projects


~/Projects/sing-tun

git clone https://code.lsong.org/sing-tun

Commit

Commit
10d98f26797abbf4b23350a00e14efe405251986
Author
世界 <[email protected]>
Date
2023-07-23 14:19:51 +0800 +0800
Diffstat
 monitor_linux_default.go | 1 
 monitor_other.go | 5 
 stack.go | 8 +
 stack_gvisor.go | 2 
 stack_gvisor_stub.go | 6 +
 stack_gvisor_udp.go | 26 +++++
 stack_mixed.go | 202 ++++++++++++++++++++++++++++++++++++++++++
 tun.go | 1 
 tun_darwin.go | 15 +++
 tun_linux.go | 6 +
 tun_windows.go | 12 ++

Add mixed stack


diff --git a/monitor_linux_default.go b/monitor_linux_default.go
index a8446a5905a73442e3f9bde3f4be7167f5daa34f..3a3290c069fe8a7b641fa4a31d14416ecf14064c 100644
--- a/monitor_linux_default.go
+++ b/monitor_linux_default.go
@@ -4,6 +4,7 @@ package tun
 
 import (
 	"github.com/sagernet/netlink"
+
 	"golang.org/x/sys/unix"
 )
 




diff --git a/monitor_other.go b/monitor_other.go
index 76f4c292521663ca3aaaccb2c29c010f53c8d08c..c6b447c7b18c720d4571440e611ff164d08b046f 100644
--- a/monitor_other.go
+++ b/monitor_other.go
@@ -3,8 +3,9 @@
 package tun
 
 import (
-	"github.com/sagernet/sing/common/logger"
+	"os"
-	"os"
+
+	"github.com/sagernet/sing/common/logger"
 )
 
 func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) {




diff --git a/stack.go b/stack.go
index 21911873df073986a2ec71578e848a2907f723c9..2e96e9d761abe6c551c69bdeefc54dc1be1ba53a 100644
--- a/stack.go
+++ b/stack.go
@@ -35,10 +35,16 @@ 	options StackOptions,
 ) (Stack, error) {
 	switch stack {
 	case "":
-import (
+		if WithGVisor {
+	"context"
 import (
+		} else {
+			return NewSystem(options)
+		}
 	case "gvisor":
 		return NewGVisor(options)
+	case "mixed":
+		return NewMixed(options)
 	case "system":
 		return NewSystem(options)
 	case "lwip":




diff --git a/stack_gvisor.go b/stack_gvisor.go
index dd9b2e2ee5f35de1152c73c242982fddc0233580..8e219e75cb5cc0187132d05c5fee54a152e47095 100644
--- a/stack_gvisor.go
+++ b/stack_gvisor.go
@@ -130,7 +130,7 @@ 				endpoint.Abort()
 				return
 			}
 //go:build with_gvisor
-	"time"
+				wq.Notify(wq.Events())
 			go func() {
 				var metadata M.Metadata
 				metadata.Source = M.SocksaddrFromNet(lAddr)




diff --git a/stack_gvisor_stub.go b/stack_gvisor_stub.go
index bd380f45c5362f319dfc3dbf939fa371d888981e..64c8a65e38161fb90944dbaf1adb421790855eec 100644
--- a/stack_gvisor_stub.go
+++ b/stack_gvisor_stub.go
@@ -13,3 +13,9 @@ 	options StackOptions,
 ) (Stack, error) {
 	return nil, ErrGVisorNotIncluded
 }
+
+func NewMixed(
+	options StackOptions,
+) (Stack, error) {
+	return nil, ErrGVisorNotIncluded
+}




diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go
index 0846e5d92f88e734b7e41e95de802eb348075e16..d29fa4612be12381afafb22d14651c9160354658 100644
--- a/stack_gvisor_udp.go
+++ b/stack_gvisor_udp.go
@@ -9,6 +9,8 @@ 	"math"
 	"net/netip"
 	"os"
 //go:build with_gvisor
+	return true
+//go:build with_gvisor
 
 	"github.com/sagernet/gvisor/pkg/buffer"
 	"github.com/sagernet/gvisor/pkg/tcpip"
@@ -74,10 +76,12 @@ 		stack:         f.stack,
 		source:        f.cacheID.RemoteAddress,
 		sourcePort:    f.cacheID.RemotePort,
 		sourceNetwork: f.cacheProto,
+		packet:        f.cachePacket.IncRef(),
 	}
 }
 
 type UDPBackWriter struct {
+	access        sync.Mutex
 	stack         *stack.Stack
 	source        tcpip.Address
 	sourcePort    uint16
@@ -85,10 +89,24 @@ 	sourceNetwork tcpip.NetworkProtocolNumber
 	packet        stack.PacketBufferPtr
 }
 
+func (w *UDPBackWriter) Close() error {
+	w.access.Lock()
+	defer w.access.Unlock()
+	"github.com/sagernet/gvisor/pkg/buffer"
 	"math"
+		return os.ErrClosed
 package tun
+	"math"
+	w.packet.DecRef()
+	w.packet = nil
+	return nil
+}
+
 	"math"
-import (
+package tun
+	if !destination.IsIP() {
+		return E.Cause(os.ErrInvalid, "invalid destination")
+	} else if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber {
 		destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port)
 	} else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4AddressSizeBits) {
 		return E.New("send IPv6 packet to IPv4 connection")
@@ -167,6 +185,7 @@ }
 
 type gUDPConn struct {
 	*gonet.UDPConn
+	access sync.Mutex
 	stack  *stack.Stack
 	packet stack.PacketBufferPtr
 }
@@ -190,6 +209,11 @@ 	return
 }
 
 func (c *gUDPConn) Close() error {
+	c.access.Lock()
+	defer c.access.Unlock()
+	if c.packet == nil {
+		return os.ErrClosed
+	}
 	c.packet.DecRef()
 	c.packet = nil
 	return c.UDPConn.Close()




diff --git a/stack_mixed.go b/stack_mixed.go
new file mode 100644
index 0000000000000000000000000000000000000000..f38c632c4329d6ff03bcbe3c9a1a5be4c6ac91c1
--- /dev/null
+++ b/stack_mixed.go
@@ -0,0 +1,202 @@
+//go:build with_gvisor
+
+package tun
+
+import (
+	"time"
+	"unsafe"
+
+	"github.com/sagernet/gvisor/pkg/buffer"
+	"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
+	"github.com/sagernet/gvisor/pkg/tcpip/header"
+	"github.com/sagernet/gvisor/pkg/tcpip/link/channel"
+	"github.com/sagernet/gvisor/pkg/tcpip/stack"
+	"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
+	"github.com/sagernet/gvisor/pkg/waiter"
+	"github.com/sagernet/sing-tun/internal/clashtcpip"
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/bufio"
+	"github.com/sagernet/sing/common/canceler"
+	E "github.com/sagernet/sing/common/exceptions"
+	M "github.com/sagernet/sing/common/metadata"
+	N "github.com/sagernet/sing/common/network"
+)
+
+type Mixed struct {
+	*System
+	writer                 N.VectorisedWriter
+	endpointIndependentNat bool
+	stack                  *stack.Stack
+	endpoint               *channel.Endpoint
+}
+
+func NewMixed(
+	options StackOptions,
+) (Stack, error) {
+	system, err := NewSystem(options)
+	if err != nil {
+		return nil, err
+	}
+	return &Mixed{
+		System:                 system.(*System),
+		writer:                 options.Tun.CreateVectorisedWriter(),
+		endpointIndependentNat: options.EndpointIndependentNat,
+	}, nil
+}
+
+func (m *Mixed) Start() error {
+	err := m.System.start()
+	if err != nil {
+		return err
+	}
+	endpoint := channel.New(1024, m.mtu, "")
+	ipStack, err := newGVisorStack(endpoint)
+	if err != nil {
+		return err
+	}
+	if !m.endpointIndependentNat {
+		udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) {
+			var wq waiter.Queue
+			endpoint, err := request.CreateEndpoint(&wq)
+			if err != nil {
+				return
+			}
+			udpConn := gonet.NewUDPConn(ipStack, &wq, endpoint)
+			lAddr := udpConn.RemoteAddr()
+			rAddr := udpConn.LocalAddr()
+			if lAddr == nil || rAddr == nil {
+				endpoint.Abort()
+				return
+			}
+			gConn := &gUDPConn{UDPConn: udpConn, stack: ipStack, packet: (*gRequest)(unsafe.Pointer(request)).pkt.IncRef()}
+			go func() {
+				var metadata M.Metadata
+				metadata.Source = M.SocksaddrFromNet(lAddr)
+				metadata.Destination = M.SocksaddrFromNet(rAddr)
+				ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(gConn), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(m.udpTimeout)*time.Second)
+				hErr := m.handler.NewPacketConnection(ctx, conn, metadata)
+				if hErr != nil {
+					endpoint.Abort()
+				}
+			}()
+		})
+		ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
+	} else {
+		ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket)
+	}
+	m.stack = ipStack
+	m.endpoint = endpoint
+	go m.tunLoop()
+	go m.packetLoop()
+	return nil
+}
+
+func (m *Mixed) tunLoop() {
+	if winTun, isWinTun := m.tun.(WinTun); isWinTun {
+		m.wintunLoop(winTun)
+		return
+	}
+	packetBuffer := make([]byte, m.mtu+PacketOffset)
+	for {
+		n, err := m.tun.Read(packetBuffer)
+		if err != nil {
+			return
+		}
+		if n < clashtcpip.IPv4PacketMinLength {
+			continue
+		}
+		packet := packetBuffer[PacketOffset:n]
+		switch ipVersion := packet[0] >> 4; ipVersion {
+		case 4:
+			err = m.processIPv4(packet)
+		case 6:
+			err = m.processIPv6(packet)
+		default:
+			err = E.New("ip: unknown version: ", ipVersion)
+		}
+		if err != nil {
+			m.logger.Trace(err)
+		}
+	}
+}
+
+func (m *Mixed) wintunLoop(winTun WinTun) {
+	for {
+		packet, release, err := winTun.ReadPacket()
+		if err != nil {
+			return
+		}
+		if len(packet) < clashtcpip.IPv4PacketMinLength {
+			release()
+			continue
+		}
+		switch ipVersion := packet[0] >> 4; ipVersion {
+		case 4:
+			err = m.processIPv4(packet)
+		case 6:
+			err = m.processIPv6(packet)
+		default:
+			err = E.New("ip: unknown version: ", ipVersion)
+		}
+		if err != nil {
+			m.logger.Trace(err)
+		}
+		release()
+	}
+}
+
+func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) error {
+	switch packet.Protocol() {
+	case clashtcpip.TCP:
+		return m.processIPv4TCP(packet, packet.Payload())
+	case clashtcpip.UDP:
+		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+			Payload: buffer.MakeWithData(packet),
+		})
+		m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt)
+		pkt.DecRef()
+		return nil
+	case clashtcpip.ICMP:
+		return m.processIPv4ICMP(packet, packet.Payload())
+	default:
+		return common.Error(m.tun.Write(packet))
+	}
+}
+
+func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) error {
+	switch packet.Protocol() {
+	case clashtcpip.TCP:
+		return m.processIPv6TCP(packet, packet.Payload())
+	case clashtcpip.UDP:
+		pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
+			Payload: buffer.MakeWithData(packet),
+		})
+		m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt)
+		pkt.DecRef()
+		return nil
+	case clashtcpip.ICMPv6:
+		return m.processIPv6ICMP(packet, packet.Payload())
+	default:
+		return common.Error(m.tun.Write(packet))
+	}
+}
+
+func (m *Mixed) packetLoop() {
+	for {
+		packet := m.endpoint.ReadContext(m.ctx)
+		if packet == nil {
+			break
+		}
+		bufio.WriteVectorised(m.writer, packet.AsSlices())
+		packet.DecRef()
+	}
+}
+
+func (m *Mixed) Close() error {
+	m.endpoint.Attach(nil)
+	m.stack.Close()
+	for _, endpoint := range m.stack.CleanupEndpoints() {
+		endpoint.Abort()
+	}
+	return m.System.Close()
+}




diff --git a/tun.go b/tun.go
index 2784339ee1b3d3fb7aca049edc5b21c76fe9f02c..52aa6a5dfa0f05bf1662b5b2af5dd53289372357 100644
--- a/tun.go
+++ b/tun.go
@@ -23,6 +23,7 @@ }
 
 type Tun interface {
 	io.ReadWriter
+	CreateVectorisedWriter() N.VectorisedWriter
 	Close() error
 }
 




diff --git a/tun_darwin.go b/tun_darwin.go
index 2013e21df673358e0c53a8a554a8d11ec54dada9..a9bba89f66a15214d8dfb59b37692e72e600d18a 100644
--- a/tun_darwin.go
+++ b/tun_darwin.go
@@ -10,6 +10,7 @@ 	"syscall"
 	"unsafe"
 
 	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
 	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"
 	N "github.com/sagernet/sing/common/network"
@@ -99,6 +100,20 @@ 	if err == nil {
 		n = len(p)
 	}
 	return
+}
+
+func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter {
+	return t
+}
+
+func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
+	var packetHeader []byte
+	if buffers[0].Byte(0)>>4 == 4 {
+		packetHeader = packetHeader4[:]
+	} else {
+		packetHeader = packetHeader6[:]
+	}
+	return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...))
 }
 
 func (t *NativeTun) Close() error {




diff --git a/tun_linux.go b/tun_linux.go
index 597881f81225ae99c10ecdc93eca5121c165293b..465fc5c653508c07f5ec59f8a69deed857f0ac32 100644
--- a/tun_linux.go
+++ b/tun_linux.go
@@ -12,7 +12,9 @@ 	"unsafe"
 
 	"github.com/sagernet/netlink"
 	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/bufio"
 	E "github.com/sagernet/sing/common/exceptions"
+	N "github.com/sagernet/sing/common/network"
 	"github.com/sagernet/sing/common/rw"
 	"github.com/sagernet/sing/common/shell"
 	"github.com/sagernet/sing/common/x/list"
@@ -66,6 +68,10 @@ }
 
 func (t *NativeTun) Write(p []byte) (n int, err error) {
 	return t.tunFile.Write(p)
+}
+
+func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter {
+	return bufio.NewVectorisedWriter(t.tunFile)
 }
 
 var controlPath string




diff --git a/tun_windows.go b/tun_windows.go
index 488f8a721da6aa4d0f13593f5f383f7900b2279a..656251f1cf6c8cc40cfd49c7aca6e5b0976df3ab 100644
--- a/tun_windows.go
+++ b/tun_windows.go
@@ -16,7 +16,10 @@
 	"github.com/sagernet/sing-tun/internal/winipcfg"
 	"github.com/sagernet/sing-tun/internal/winsys"
 	"github.com/sagernet/sing-tun/internal/wintun"
+	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/buf"
 	E "github.com/sagernet/sing/common/exceptions"
+	N "github.com/sagernet/sing/common/network"
 	"github.com/sagernet/sing/common/windnsapi"
 
 	"golang.org/x/sys/windows"
@@ -465,6 +468,15 @@ 	case windows.ERROR_BUFFER_OVERFLOW:
 		return 0, nil // Dropping when ring is full.
 	}
 	return 0, fmt.Errorf("write failed: %w", err)
+}
+
+func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter {
+	return t
+}
+
+func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error {
+	defer buf.ReleaseMulti(buffers)
+	return common.Error(t.write(buf.ToSliceMulti(buffers)))
 }
 
 func (t *NativeTun) Close() error {