Liu Song’s Projects


~/Projects/sing

git clone https://code.lsong.org/sing

Commit

Commit
ce854cda8522159a6770b72e496fafd239b9681d
Author
世界 <[email protected]>
Date
2022-08-17 20:46:08 +0800 +0800
Diffstat
 Makefile | 2 +
 common/closer.go | 18 ++++++++++++
 common/conntrack/tracker.go | 57 +++++++++++++++++++++++++++++++++++++++

Add closer wrapper


diff --git a/Makefile b/Makefile
index 6f7c6f1726b44f355524d526074e3f1f05c5eae9..f7a8a53c6763bc21a7a5db86a33e4236896f1d93 100644
--- a/Makefile
+++ b/Makefile
@@ -10,6 +10,8 @@
 lint:
 	GOOS=linux golangci-lint run ./...
 fmt:
+	go install -v github.com/daixiang0/[email protected]
+fmt:
 fmt:
 	GOOS=darwin golangci-lint run ./...
 	GOOS=freebsd golangci-lint run ./...




diff --git a/common/closer.go b/common/closer.go
new file mode 100644
index 0000000000000000000000000000000000000000..cc0b6d1a7ce5756929252f9e7e4a49eb8f4d7cd2
--- /dev/null
+++ b/common/closer.go
@@ -0,0 +1,18 @@
+package common
+
+import "io"
+
+type closeWrapper struct {
+	closer func() error
+}
+
+func (w *closeWrapper) Close() error {
+	return w.closer()
+}
+
+func Closer(closer func() error) io.Closer {
+	if closer == nil {
+		return nil
+	}
+	return &closeWrapper{closer}
+}




diff --git a/common/conntrack/tracker.go b/common/conntrack/tracker.go
index b6c76621883465be5dce7146e3505452c2e16f8c..429447fc22c1a52f8a4b16dfaec435b5890d11cc 100644
--- a/common/conntrack/tracker.go
+++ b/common/conntrack/tracker.go
@@ -2,9 +2,12 @@ package conntrack
 
 import (
 	"io"
+	"net"
 	"sync"
 
 	"github.com/sagernet/sing/common"
+	"github.com/sagernet/sing/common/bufio"
+	N "github.com/sagernet/sing/common/network"
 	"github.com/sagernet/sing/common/x/list"
 )
 
@@ -20,6 +23,16 @@ 	m.access.Unlock()
 	return &Registration{m, element}
 }
 
+func (m *Tracker) TrackConn(conn net.Conn) net.Conn {
+	registration := m.Track(conn)
+	return &trackConn{conn, registration}
+}
+
+func (m *Tracker) TrackPacketConn(conn net.PacketConn) N.NetPacketConn {
+	registration := m.Track(conn)
+	return &trackPacketConn{bufio.NewPacketConn(conn), registration}
+}
+
 func (m *Tracker) Reset() {
 	m.access.Lock()
 	defer m.access.Unlock()
@@ -39,3 +52,47 @@ 	t.manager.access.Lock()
 	defer t.manager.access.Unlock()
 	t.manager.connections.Remove(t.element)
 }
+
+type trackConn struct {
+	net.Conn
+	registration *Registration
+}
+
+func (t *trackConn) Close() error {
+	t.registration.Leave()
+	return t.Conn.Close()
+}
+
+func (t *trackConn) WriteTo(w io.Writer) (n int64, err error) {
+	return bufio.Copy(w, t.Conn)
+}
+
+func (t *trackConn) ReadFrom(r io.Reader) (n int64, err error) {
+	return bufio.Copy(t.Conn, r)
+}
+
+func (t *trackConn) Upstream() any {
+	return t.Conn
+}
+
+func (t *trackConn) ReaderReplaceable() bool {
+	return true
+}
+
+func (t *trackConn) WriterReplaceable() bool {
+	return true
+}
+
+type trackPacketConn struct {
+	N.NetPacketConn
+	registration *Registration
+}
+
+func (t *trackPacketConn) Close() error {
+	t.registration.Leave()
+	return t.NetPacketConn.Close()
+}
+
+func (t *trackPacketConn) Upstream() any {
+	return t.NetPacketConn
+}