Liu Song’s Projects


~/Projects/mqtt-go

git clone https://code.lsong.org/mqtt-go

Commit

Commit
92cd935a161cad82959693e15ebf47cff28d3d7c
Author
JB <28275108+[email protected]>
Date
2022-12-21 11:38:28 +0000 +0000
Diffstat
 listeners/unixsock.go | 98 ++++++++++++++++++++++++++++++++++++++++
 listeners/unixsock_test.go | 96 +++++++++++++++++++++++++++++++++++++++

Merge branch 'master' into master


diff --git a/listeners/unixsock.go b/listeners/unixsock.go
new file mode 100644
index 0000000000000000000000000000000000000000..a16352dfe281117fdf8448baa6850f64241b9452
--- /dev/null
+++ b/listeners/unixsock.go
@@ -0,0 +1,98 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-co
+// SPDX-FileContributor: [email protected]
+
+package listeners
+
+import (
+	"net"
+	"os"
+	"sync"
+	"sync/atomic"
+
+	"github.com/rs/zerolog"
+)
+
+// UnixSock is a listener for establishing client connections on basic UnixSock protocol.
+type UnixSock struct {
+	sync.RWMutex
+	id      string          // the internal id of the listener.
+	address string          // the network address to bind to.
+	listen  net.Listener    // a net.Listener which will listen for new clients.
+	log     *zerolog.Logger // server logger
+	end     uint32          // ensure the close methods are only called once.
+}
+
+// NewUnixSock initialises and returns a new UnixSock listener, listening on an address.
+func NewUnixSock(id, address string) *UnixSock {
+	return &UnixSock{
+		id:      id,
+		address: address,
+	}
+}
+
+// ID returns the id of the listener.
+func (l *UnixSock) ID() string {
+	return l.id
+}
+
+// Address returns the address of the listener.
+func (l *UnixSock) Address() string {
+	return l.address
+}
+
+// Protocol returns the address of the listener.
+func (l *UnixSock) Protocol() string {
+	return "unix"
+}
+
+// Init initializes the listener.
+func (l *UnixSock) Init(log *zerolog.Logger) error {
+	l.log = log
+
+	var err error
+	_ = os.Remove(l.address)
+	l.listen, err = net.Listen("unix", l.address)
+	return err
+}
+
+// Serve starts waiting for new UnixSock connections, and calls the establish
+// connection callback for any received.
+func (l *UnixSock) Serve(establish EstablishFn) {
+	for {
+		if atomic.LoadUint32(&l.end) == 1 {
+			return
+		}
+
+		conn, err := l.listen.Accept()
+		if err != nil {
+			return
+		}
+
+		if atomic.LoadUint32(&l.end) == 0 {
+			go func() {
+				err = establish(l.id, conn)
+				if err != nil {
+					l.log.Warn().Err(err).Send()
+				}
+			}()
+		}
+	}
+}
+
+// Close closes the listener and any client connections.
+func (l *UnixSock) Close(closeClients CloseFn) {
+	l.Lock()
+	defer l.Unlock()
+
+	if atomic.CompareAndSwapUint32(&l.end, 0, 1) {
+		closeClients(l.id)
+	}
+
+	if l.listen != nil {
+		err := l.listen.Close()
+		if err != nil {
+			return
+		}
+	}
+}




diff --git a/listeners/unixsock_test.go b/listeners/unixsock_test.go
new file mode 100644
index 0000000000000000000000000000000000000000..d09f7764db888ef696df913514020e2acb02c2d9
--- /dev/null
+++ b/listeners/unixsock_test.go
@@ -0,0 +1,96 @@
+// SPDX-License-Identifier: MIT
+// SPDX-FileCopyrightText: 2022 mochi-co
+// SPDX-FileContributor: [email protected]
+
+package listeners
+
+import (
+	"errors"
+	"net"
+	"testing"
+	"time"
+
+	"github.com/stretchr/testify/require"
+)
+
+const testUnixAddr = "mochi.sock"
+
+func TestNewUnixSock(t *testing.T) {
+	l := NewUnixSock("t1", testUnixAddr)
+	require.Equal(t, "t1", l.id)
+	require.Equal(t, testUnixAddr, l.address)
+}
+
+func TestUnixSockID(t *testing.T) {
+	l := NewUnixSock("t1", testUnixAddr)
+	require.Equal(t, "t1", l.ID())
+}
+
+func TestUnixSockAddress(t *testing.T) {
+	l := NewUnixSock("t1", testUnixAddr)
+	require.Equal(t, testUnixAddr, l.Address())
+}
+
+func TestUnixSockProtocol(t *testing.T) {
+	l := NewUnixSock("t1", testUnixAddr)
+	require.Equal(t, "unix", l.Protocol())
+}
+
+func TestUnixSockInit(t *testing.T) {
+	l := NewUnixSock("t1", testUnixAddr)
+	err := l.Init(&logger)
+	l.Close(MockCloser)
+	require.NoError(t, err)
+
+	l2 := NewUnixSock("t2", testUnixAddr)
+	err = l2.Init(&logger)
+	l2.Close(MockCloser)
+	require.NoError(t, err)
+}
+
+func TestUnixSockServeAndClose(t *testing.T) {
+	l := NewUnixSock("t1", testUnixAddr)
+	err := l.Init(&logger)
+	require.NoError(t, err)
+
+	o := make(chan bool)
+	go func(o chan bool) {
+		l.Serve(MockEstablisher)
+		o <- true
+	}(o)
+
+	time.Sleep(time.Millisecond)
+
+	var closed bool
+	l.Close(func(id string) {
+		closed = true
+	})
+
+	require.True(t, closed)
+	<-o
+
+	l.Close(MockCloser)      // coverage: close closed
+	l.Serve(MockEstablisher) // coverage: serve closed
+}
+
+func TestUnixSockEstablishThenEnd(t *testing.T) {
+	l := NewUnixSock("t1", testUnixAddr)
+	err := l.Init(&logger)
+	require.NoError(t, err)
+
+	o := make(chan bool)
+	established := make(chan bool)
+	go func() {
+		l.Serve(func(id string, c net.Conn) error {
+			established <- true
+			return errors.New("ending") // return an error to exit immediately
+		})
+		o <- true
+	}()
+
+	time.Sleep(time.Millisecond)
+	net.Dial("unix", l.listen.Addr().String())
+	require.Equal(t, true, <-established)
+	l.Close(MockCloser)
+	<-o
+}