~/Projects/mqtt-go
git clone https://code.lsong.org/mqtt-go
Commit
- Commit
- 92cd935a161cad82959693e15ebf47cff28d3d7c
- Author
- JB <28275108+[email protected]>
- Date
- 2022-12-21 11:38:28 +0000 +0000
- Diffstat
listeners/unixsock.go | 98 ++++++++++++++++++++++++++++++++++++++++ listeners/unixsock_test.go | 96 +++++++++++++++++++++++++++++++++++++++
Merge branch 'master' into master
diff --git a/listeners/unixsock.go b/listeners/unixsock.go new file mode 100644 index 0000000000000000000000000000000000000000..a16352dfe281117fdf8448baa6850f64241b9452 --- /dev/null +++ b/listeners/unixsock.go @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileContributor: [email protected] + +package listeners + +import ( + "net" + "os" + "sync" + "sync/atomic" + + "github.com/rs/zerolog" +) + +// UnixSock is a listener for establishing client connections on basic UnixSock protocol. +type UnixSock struct { + sync.RWMutex + id string // the internal id of the listener. + address string // the network address to bind to. + listen net.Listener // a net.Listener which will listen for new clients. + log *zerolog.Logger // server logger + end uint32 // ensure the close methods are only called once. +} + +// NewUnixSock initialises and returns a new UnixSock listener, listening on an address. +func NewUnixSock(id, address string) *UnixSock { + return &UnixSock{ + id: id, + address: address, + } +} + +// ID returns the id of the listener. +func (l *UnixSock) ID() string { + return l.id +} + +// Address returns the address of the listener. +func (l *UnixSock) Address() string { + return l.address +} + +// Protocol returns the address of the listener. +func (l *UnixSock) Protocol() string { + return "unix" +} + +// Init initializes the listener. +func (l *UnixSock) Init(log *zerolog.Logger) error { + l.log = log + + var err error + _ = os.Remove(l.address) + l.listen, err = net.Listen("unix", l.address) + return err +} + +// Serve starts waiting for new UnixSock connections, and calls the establish +// connection callback for any received. +func (l *UnixSock) Serve(establish EstablishFn) { + for { + if atomic.LoadUint32(&l.end) == 1 { + return + } + + conn, err := l.listen.Accept() + if err != nil { + return + } + + if atomic.LoadUint32(&l.end) == 0 { + go func() { + err = establish(l.id, conn) + if err != nil { + l.log.Warn().Err(err).Send() + } + }() + } + } +} + +// Close closes the listener and any client connections. +func (l *UnixSock) Close(closeClients CloseFn) { + l.Lock() + defer l.Unlock() + + if atomic.CompareAndSwapUint32(&l.end, 0, 1) { + closeClients(l.id) + } + + if l.listen != nil { + err := l.listen.Close() + if err != nil { + return + } + } +} diff --git a/listeners/unixsock_test.go b/listeners/unixsock_test.go new file mode 100644 index 0000000000000000000000000000000000000000..d09f7764db888ef696df913514020e2acb02c2d9 --- /dev/null +++ b/listeners/unixsock_test.go @@ -0,0 +1,96 @@ +// SPDX-License-Identifier: MIT +// SPDX-FileCopyrightText: 2022 mochi-co +// SPDX-FileContributor: [email protected] + +package listeners + +import ( + "errors" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +const testUnixAddr = "mochi.sock" + +func TestNewUnixSock(t *testing.T) { + l := NewUnixSock("t1", testUnixAddr) + require.Equal(t, "t1", l.id) + require.Equal(t, testUnixAddr, l.address) +} + +func TestUnixSockID(t *testing.T) { + l := NewUnixSock("t1", testUnixAddr) + require.Equal(t, "t1", l.ID()) +} + +func TestUnixSockAddress(t *testing.T) { + l := NewUnixSock("t1", testUnixAddr) + require.Equal(t, testUnixAddr, l.Address()) +} + +func TestUnixSockProtocol(t *testing.T) { + l := NewUnixSock("t1", testUnixAddr) + require.Equal(t, "unix", l.Protocol()) +} + +func TestUnixSockInit(t *testing.T) { + l := NewUnixSock("t1", testUnixAddr) + err := l.Init(&logger) + l.Close(MockCloser) + require.NoError(t, err) + + l2 := NewUnixSock("t2", testUnixAddr) + err = l2.Init(&logger) + l2.Close(MockCloser) + require.NoError(t, err) +} + +func TestUnixSockServeAndClose(t *testing.T) { + l := NewUnixSock("t1", testUnixAddr) + err := l.Init(&logger) + require.NoError(t, err) + + o := make(chan bool) + go func(o chan bool) { + l.Serve(MockEstablisher) + o <- true + }(o) + + time.Sleep(time.Millisecond) + + var closed bool + l.Close(func(id string) { + closed = true + }) + + require.True(t, closed) + <-o + + l.Close(MockCloser) // coverage: close closed + l.Serve(MockEstablisher) // coverage: serve closed +} + +func TestUnixSockEstablishThenEnd(t *testing.T) { + l := NewUnixSock("t1", testUnixAddr) + err := l.Init(&logger) + require.NoError(t, err) + + o := make(chan bool) + established := make(chan bool) + go func() { + l.Serve(func(id string, c net.Conn) error { + established <- true + return errors.New("ending") // return an error to exit immediately + }) + o <- true + }() + + time.Sleep(time.Millisecond) + net.Dial("unix", l.listen.Addr().String()) + require.Equal(t, true, <-established) + l.Close(MockCloser) + <-o +}