~/Projects/sing
git clone https://code.lsong.org/sing
Commit
- Commit
- ce854cda8522159a6770b72e496fafd239b9681d
- Author
- 世界 <[email protected]>
- Date
- 2022-08-17 20:46:08 +0800 +0800
- Diffstat
Makefile | 2 + common/closer.go | 18 ++++++++++++ common/conntrack/tracker.go | 57 +++++++++++++++++++++++++++++++++++++++
Add closer wrapper
diff --git a/Makefile b/Makefile index 6f7c6f1726b44f355524d526074e3f1f05c5eae9..f7a8a53c6763bc21a7a5db86a33e4236896f1d93 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,8 @@ lint: GOOS=linux golangci-lint run ./... fmt: + go install -v github.com/daixiang0/[email protected] +fmt: fmt: GOOS=darwin golangci-lint run ./... GOOS=freebsd golangci-lint run ./... diff --git a/common/closer.go b/common/closer.go new file mode 100644 index 0000000000000000000000000000000000000000..cc0b6d1a7ce5756929252f9e7e4a49eb8f4d7cd2 --- /dev/null +++ b/common/closer.go @@ -0,0 +1,18 @@ +package common + +import "io" + +type closeWrapper struct { + closer func() error +} + +func (w *closeWrapper) Close() error { + return w.closer() +} + +func Closer(closer func() error) io.Closer { + if closer == nil { + return nil + } + return &closeWrapper{closer} +} diff --git a/common/conntrack/tracker.go b/common/conntrack/tracker.go index b6c76621883465be5dce7146e3505452c2e16f8c..429447fc22c1a52f8a4b16dfaec435b5890d11cc 100644 --- a/common/conntrack/tracker.go +++ b/common/conntrack/tracker.go @@ -2,9 +2,12 @@ package conntrack import ( "io" + "net" "sync" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/bufio" + N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" ) @@ -20,6 +23,16 @@ m.access.Unlock() return &Registration{m, element} } +func (m *Tracker) TrackConn(conn net.Conn) net.Conn { + registration := m.Track(conn) + return &trackConn{conn, registration} +} + +func (m *Tracker) TrackPacketConn(conn net.PacketConn) N.NetPacketConn { + registration := m.Track(conn) + return &trackPacketConn{bufio.NewPacketConn(conn), registration} +} + func (m *Tracker) Reset() { m.access.Lock() defer m.access.Unlock() @@ -39,3 +52,47 @@ t.manager.access.Lock() defer t.manager.access.Unlock() t.manager.connections.Remove(t.element) } + +type trackConn struct { + net.Conn + registration *Registration +} + +func (t *trackConn) Close() error { + t.registration.Leave() + return t.Conn.Close() +} + +func (t *trackConn) WriteTo(w io.Writer) (n int64, err error) { + return bufio.Copy(w, t.Conn) +} + +func (t *trackConn) ReadFrom(r io.Reader) (n int64, err error) { + return bufio.Copy(t.Conn, r) +} + +func (t *trackConn) Upstream() any { + return t.Conn +} + +func (t *trackConn) ReaderReplaceable() bool { + return true +} + +func (t *trackConn) WriterReplaceable() bool { + return true +} + +type trackPacketConn struct { + N.NetPacketConn + registration *Registration +} + +func (t *trackPacketConn) Close() error { + t.registration.Leave() + return t.NetPacketConn.Close() +} + +func (t *trackPacketConn) Upstream() any { + return t.NetPacketConn +}