~/Projects/sing-tun
git clone https://code.lsong.org/sing-tun
Commit
- Commit
- 10d98f26797abbf4b23350a00e14efe405251986
- Author
- 世界 <[email protected]>
- Date
- 2023-07-23 14:19:51 +0800 +0800
- Diffstat
monitor_linux_default.go | 1 monitor_other.go | 5 stack.go | 8 + stack_gvisor.go | 2 stack_gvisor_stub.go | 6 + stack_gvisor_udp.go | 26 +++++ stack_mixed.go | 202 ++++++++++++++++++++++++++++++++++++++++++ tun.go | 1 tun_darwin.go | 15 +++ tun_linux.go | 6 + tun_windows.go | 12 ++
Add mixed stack
diff --git a/monitor_linux_default.go b/monitor_linux_default.go index a8446a5905a73442e3f9bde3f4be7167f5daa34f..3a3290c069fe8a7b641fa4a31d14416ecf14064c 100644 --- a/monitor_linux_default.go +++ b/monitor_linux_default.go @@ -4,6 +4,7 @@ package tun import ( "github.com/sagernet/netlink" + "golang.org/x/sys/unix" ) diff --git a/monitor_other.go b/monitor_other.go index 76f4c292521663ca3aaaccb2c29c010f53c8d08c..c6b447c7b18c720d4571440e611ff164d08b046f 100644 --- a/monitor_other.go +++ b/monitor_other.go @@ -3,8 +3,9 @@ package tun import ( - "github.com/sagernet/sing/common/logger" + "os" - "os" + + "github.com/sagernet/sing/common/logger" ) func NewNetworkUpdateMonitor(logger logger.Logger) (NetworkUpdateMonitor, error) { diff --git a/stack.go b/stack.go index 21911873df073986a2ec71578e848a2907f723c9..2e96e9d761abe6c551c69bdeefc54dc1be1ba53a 100644 --- a/stack.go +++ b/stack.go @@ -35,10 +35,16 @@ options StackOptions, ) (Stack, error) { switch stack { case "": -import ( + if WithGVisor { + "context" import ( + } else { + return NewSystem(options) + } case "gvisor": return NewGVisor(options) + case "mixed": + return NewMixed(options) case "system": return NewSystem(options) case "lwip": diff --git a/stack_gvisor.go b/stack_gvisor.go index dd9b2e2ee5f35de1152c73c242982fddc0233580..8e219e75cb5cc0187132d05c5fee54a152e47095 100644 --- a/stack_gvisor.go +++ b/stack_gvisor.go @@ -130,7 +130,7 @@ endpoint.Abort() return } //go:build with_gvisor - "time" + wq.Notify(wq.Events()) go func() { var metadata M.Metadata metadata.Source = M.SocksaddrFromNet(lAddr) diff --git a/stack_gvisor_stub.go b/stack_gvisor_stub.go index bd380f45c5362f319dfc3dbf939fa371d888981e..64c8a65e38161fb90944dbaf1adb421790855eec 100644 --- a/stack_gvisor_stub.go +++ b/stack_gvisor_stub.go @@ -13,3 +13,9 @@ options StackOptions, ) (Stack, error) { return nil, ErrGVisorNotIncluded } + +func NewMixed( + options StackOptions, +) (Stack, error) { + return nil, ErrGVisorNotIncluded +} diff --git a/stack_gvisor_udp.go b/stack_gvisor_udp.go index 0846e5d92f88e734b7e41e95de802eb348075e16..d29fa4612be12381afafb22d14651c9160354658 100644 --- a/stack_gvisor_udp.go +++ b/stack_gvisor_udp.go @@ -9,6 +9,8 @@ "math" "net/netip" "os" //go:build with_gvisor + return true +//go:build with_gvisor "github.com/sagernet/gvisor/pkg/buffer" "github.com/sagernet/gvisor/pkg/tcpip" @@ -74,10 +76,12 @@ stack: f.stack, source: f.cacheID.RemoteAddress, sourcePort: f.cacheID.RemotePort, sourceNetwork: f.cacheProto, + packet: f.cachePacket.IncRef(), } } type UDPBackWriter struct { + access sync.Mutex stack *stack.Stack source tcpip.Address sourcePort uint16 @@ -85,10 +89,24 @@ sourceNetwork tcpip.NetworkProtocolNumber packet stack.PacketBufferPtr } +func (w *UDPBackWriter) Close() error { + w.access.Lock() + defer w.access.Unlock() + "github.com/sagernet/gvisor/pkg/buffer" "math" + return os.ErrClosed package tun + "math" + w.packet.DecRef() + w.packet = nil + return nil +} + "math" -import ( +package tun + if !destination.IsIP() { + return E.Cause(os.ErrInvalid, "invalid destination") + } else if destination.IsIPv4() && w.sourceNetwork == header.IPv6ProtocolNumber { destination = M.SocksaddrFrom(netip.AddrFrom16(destination.Addr.As16()), destination.Port) } else if destination.IsIPv6() && (w.sourceNetwork == header.IPv4AddressSizeBits) { return E.New("send IPv6 packet to IPv4 connection") @@ -167,6 +185,7 @@ } type gUDPConn struct { *gonet.UDPConn + access sync.Mutex stack *stack.Stack packet stack.PacketBufferPtr } @@ -190,6 +209,11 @@ return } func (c *gUDPConn) Close() error { + c.access.Lock() + defer c.access.Unlock() + if c.packet == nil { + return os.ErrClosed + } c.packet.DecRef() c.packet = nil return c.UDPConn.Close() diff --git a/stack_mixed.go b/stack_mixed.go new file mode 100644 index 0000000000000000000000000000000000000000..f38c632c4329d6ff03bcbe3c9a1a5be4c6ac91c1 --- /dev/null +++ b/stack_mixed.go @@ -0,0 +1,202 @@ +//go:build with_gvisor + +package tun + +import ( + "time" + "unsafe" + + "github.com/sagernet/gvisor/pkg/buffer" + "github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet" + "github.com/sagernet/gvisor/pkg/tcpip/header" + "github.com/sagernet/gvisor/pkg/tcpip/link/channel" + "github.com/sagernet/gvisor/pkg/tcpip/stack" + "github.com/sagernet/gvisor/pkg/tcpip/transport/udp" + "github.com/sagernet/gvisor/pkg/waiter" + "github.com/sagernet/sing-tun/internal/clashtcpip" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + "github.com/sagernet/sing/common/canceler" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +type Mixed struct { + *System + writer N.VectorisedWriter + endpointIndependentNat bool + stack *stack.Stack + endpoint *channel.Endpoint +} + +func NewMixed( + options StackOptions, +) (Stack, error) { + system, err := NewSystem(options) + if err != nil { + return nil, err + } + return &Mixed{ + System: system.(*System), + writer: options.Tun.CreateVectorisedWriter(), + endpointIndependentNat: options.EndpointIndependentNat, + }, nil +} + +func (m *Mixed) Start() error { + err := m.System.start() + if err != nil { + return err + } + endpoint := channel.New(1024, m.mtu, "") + ipStack, err := newGVisorStack(endpoint) + if err != nil { + return err + } + if !m.endpointIndependentNat { + udpForwarder := udp.NewForwarder(ipStack, func(request *udp.ForwarderRequest) { + var wq waiter.Queue + endpoint, err := request.CreateEndpoint(&wq) + if err != nil { + return + } + udpConn := gonet.NewUDPConn(ipStack, &wq, endpoint) + lAddr := udpConn.RemoteAddr() + rAddr := udpConn.LocalAddr() + if lAddr == nil || rAddr == nil { + endpoint.Abort() + return + } + gConn := &gUDPConn{UDPConn: udpConn, stack: ipStack, packet: (*gRequest)(unsafe.Pointer(request)).pkt.IncRef()} + go func() { + var metadata M.Metadata + metadata.Source = M.SocksaddrFromNet(lAddr) + metadata.Destination = M.SocksaddrFromNet(rAddr) + ctx, conn := canceler.NewPacketConn(m.ctx, bufio.NewPacketConn(&bufio.UnbindPacketConn{ExtendedConn: bufio.NewExtendedConn(gConn), Addr: M.SocksaddrFromNet(rAddr)}), time.Duration(m.udpTimeout)*time.Second) + hErr := m.handler.NewPacketConnection(ctx, conn, metadata) + if hErr != nil { + endpoint.Abort() + } + }() + }) + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) + } else { + ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, NewUDPForwarder(m.ctx, ipStack, m.handler, m.udpTimeout).HandlePacket) + } + m.stack = ipStack + m.endpoint = endpoint + go m.tunLoop() + go m.packetLoop() + return nil +} + +func (m *Mixed) tunLoop() { + if winTun, isWinTun := m.tun.(WinTun); isWinTun { + m.wintunLoop(winTun) + return + } + packetBuffer := make([]byte, m.mtu+PacketOffset) + for { + n, err := m.tun.Read(packetBuffer) + if err != nil { + return + } + if n < clashtcpip.IPv4PacketMinLength { + continue + } + packet := packetBuffer[PacketOffset:n] + switch ipVersion := packet[0] >> 4; ipVersion { + case 4: + err = m.processIPv4(packet) + case 6: + err = m.processIPv6(packet) + default: + err = E.New("ip: unknown version: ", ipVersion) + } + if err != nil { + m.logger.Trace(err) + } + } +} + +func (m *Mixed) wintunLoop(winTun WinTun) { + for { + packet, release, err := winTun.ReadPacket() + if err != nil { + return + } + if len(packet) < clashtcpip.IPv4PacketMinLength { + release() + continue + } + switch ipVersion := packet[0] >> 4; ipVersion { + case 4: + err = m.processIPv4(packet) + case 6: + err = m.processIPv6(packet) + default: + err = E.New("ip: unknown version: ", ipVersion) + } + if err != nil { + m.logger.Trace(err) + } + release() + } +} + +func (m *Mixed) processIPv4(packet clashtcpip.IPv4Packet) error { + switch packet.Protocol() { + case clashtcpip.TCP: + return m.processIPv4TCP(packet, packet.Payload()) + case clashtcpip.UDP: + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + }) + m.endpoint.InjectInbound(header.IPv4ProtocolNumber, pkt) + pkt.DecRef() + return nil + case clashtcpip.ICMP: + return m.processIPv4ICMP(packet, packet.Payload()) + default: + return common.Error(m.tun.Write(packet)) + } +} + +func (m *Mixed) processIPv6(packet clashtcpip.IPv6Packet) error { + switch packet.Protocol() { + case clashtcpip.TCP: + return m.processIPv6TCP(packet, packet.Payload()) + case clashtcpip.UDP: + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + }) + m.endpoint.InjectInbound(header.IPv6ProtocolNumber, pkt) + pkt.DecRef() + return nil + case clashtcpip.ICMPv6: + return m.processIPv6ICMP(packet, packet.Payload()) + default: + return common.Error(m.tun.Write(packet)) + } +} + +func (m *Mixed) packetLoop() { + for { + packet := m.endpoint.ReadContext(m.ctx) + if packet == nil { + break + } + bufio.WriteVectorised(m.writer, packet.AsSlices()) + packet.DecRef() + } +} + +func (m *Mixed) Close() error { + m.endpoint.Attach(nil) + m.stack.Close() + for _, endpoint := range m.stack.CleanupEndpoints() { + endpoint.Abort() + } + return m.System.Close() +} diff --git a/tun.go b/tun.go index 2784339ee1b3d3fb7aca049edc5b21c76fe9f02c..52aa6a5dfa0f05bf1662b5b2af5dd53289372357 100644 --- a/tun.go +++ b/tun.go @@ -23,6 +23,7 @@ } type Tun interface { io.ReadWriter + CreateVectorisedWriter() N.VectorisedWriter Close() error } diff --git a/tun_darwin.go b/tun_darwin.go index 2013e21df673358e0c53a8a554a8d11ec54dada9..a9bba89f66a15214d8dfb59b37692e72e600d18a 100644 --- a/tun_darwin.go +++ b/tun_darwin.go @@ -10,6 +10,7 @@ "syscall" "unsafe" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" @@ -99,6 +100,20 @@ if err == nil { n = len(p) } return +} + +func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter { + return t +} + +func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { + var packetHeader []byte + if buffers[0].Byte(0)>>4 == 4 { + packetHeader = packetHeader4[:] + } else { + packetHeader = packetHeader6[:] + } + return t.tunWriter.WriteVectorised(append([]*buf.Buffer{buf.As(packetHeader)}, buffers...)) } func (t *NativeTun) Close() error { diff --git a/tun_linux.go b/tun_linux.go index 597881f81225ae99c10ecdc93eca5121c165293b..465fc5c653508c07f5ec59f8a69deed857f0ac32 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -12,7 +12,9 @@ "unsafe" "github.com/sagernet/netlink" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/shell" "github.com/sagernet/sing/common/x/list" @@ -66,6 +68,10 @@ } func (t *NativeTun) Write(p []byte) (n int, err error) { return t.tunFile.Write(p) +} + +func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter { + return bufio.NewVectorisedWriter(t.tunFile) } var controlPath string diff --git a/tun_windows.go b/tun_windows.go index 488f8a721da6aa4d0f13593f5f383f7900b2279a..656251f1cf6c8cc40cfd49c7aca6e5b0976df3ab 100644 --- a/tun_windows.go +++ b/tun_windows.go @@ -16,7 +16,10 @@ "github.com/sagernet/sing-tun/internal/winipcfg" "github.com/sagernet/sing-tun/internal/winsys" "github.com/sagernet/sing-tun/internal/wintun" + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/windnsapi" "golang.org/x/sys/windows" @@ -465,6 +468,15 @@ case windows.ERROR_BUFFER_OVERFLOW: return 0, nil // Dropping when ring is full. } return 0, fmt.Errorf("write failed: %w", err) +} + +func (t *NativeTun) CreateVectorisedWriter() N.VectorisedWriter { + return t +} + +func (t *NativeTun) WriteVectorised(buffers []*buf.Buffer) error { + defer buf.ReleaseMulti(buffers) + return common.Error(t.write(buf.ToSliceMulti(buffers))) } func (t *NativeTun) Close() error {