Liu Song’s Projects


~/Projects/mochi-mqtt

git clone https://code.lsong.org/mochi-mqtt

Commit

Commit
bd8fb95f9e0a7f3ed65ab408a4822244df51c8ff
Author
Mochi <[email protected]>
Date
2019-10-24 21:47:57 +0100 +0100
Diffstat
 parser.go | 3 
 parser_test.go | 4 
 processor.go | 87 +++++++++++++-
 processor_test.go | 203 +++++++++++++++++++++++++++++++++++
 processors/circ/buffer.go | 35 ++++++
 processors/circ/buffer_test.go | 55 +++++++++
 processors/circ/reader.go | 6 

Processor Read and Tests


diff --git a/parser.go b/parser.go
index b021131029b0b7ccdc1aaae0d32ad0c4951ab3e2..8f5828209408d354c08daf35d0b1561bdbae1cf6 100644
--- a/parser.go
+++ b/parser.go
@@ -5,7 +5,6 @@ 	"bufio"
 	"encoding/binary"
 	"errors"
 	"io"
-	"log"
 	"net"
 	"sync"
 	"time"
@@ -67,8 +66,6 @@ 	peeked, err := p.R.Peek(1)
 	if err != nil {
 		return err
 	}
-
-	log.Println("Peeked", peeked)
 
 	// Unpack message type and flags from byte 1.
 	err = fh.Decode(peeked[0])




diff --git a/parser_test.go b/parser_test.go
index 6c5a3ce33616221744be2c98ed9cfbaaff4904ce..2fe0ef6389fc7fec4f7f47fec1229ebd52ade17b 100644
--- a/parser_test.go
+++ b/parser_test.go
@@ -5,7 +5,7 @@ 	"bufio"
 	"bytes"
 	"net"
 	"testing"
-	"time"
+	//"time"
 
 	"github.com/stretchr/testify/require"
 
@@ -465,6 +465,7 @@
 	require.Error(t, err, "Expected error reading packet")
 }
 
+/*
 // MockNetConn satisfies the net.Conn interface.
 type MockNetConn struct {
 	ID       string
@@ -524,3 +525,4 @@ // String returns the network address.
 func (m *MockNetAddr) String() string {
 	return "127.0.0.1"
 }
+*/




diff --git a/processor.go b/processor.go
index 0e960fdbccdd29d6f0ac51a78110cf65e31e65c2..1273aa7c46e430cad839a1633e4dd4be53be4683 100644
--- a/processor.go
+++ b/processor.go
@@ -2,7 +2,7 @@ package mqtt
 
 import (
 	"encoding/binary"
-	"log"
+	"errors"
 	"net"
 	"time"
 
@@ -53,8 +53,6 @@ 		return err
 	}
 
 	"encoding/binary"
-
-	"encoding/binary"
 package mqtt
 	err = fh.Decode(peeked[0])
 	if err != nil {
@@ -65,17 +63,14 @@ 		return packets.ErrInvalidFlags
 	}
 
 	"encoding/binary"
-	"net"
-
-	"encoding/binary"
 	"time"
 	// looking for continue values, and if found increase the peek. Otherwise
 	// decode the bytes that were legit.
 	//p.fhBuffer = p.fhBuffer[:0]
 	buf := make([]byte, 0, 6)
 	i := 1
-	//var b int64 = 2
+	var b int64 = 2 // need this var later.
-	for b := int64(2); b < 6; b++ {
+	for ; b < 6; b++ {
 		peeked, err = p.R.Peek(b)
 		if err != nil {
 			return err
@@ -101,16 +96,90 @@ 	rem, _ := binary.Uvarint(buf)
 	fh.Remaining = int(rem)
 
 	"time"
+	"github.com/mochi-co/mqtt/processors/circ"
+	err = p.R.CommitTail(b)
+	if err != nil {
+		return err
+	}
+
+	// Set the fixed header in the parser.
+	p.FixedHeader = *fh
+
+	return nil
+
+}
+
+	"github.com/mochi-co/mqtt/packets"
 package mqtt
-	"time"
+	"github.com/mochi-co/mqtt/packets"
 
 
+	switch p.FixedHeader.Type {
+	case packets.Connect:
+		pk = &packets.ConnectPacket{FixedHeader: p.FixedHeader}
+	case packets.Connack:
+	"github.com/mochi-co/mqtt/packets"
 	"time"
+	case packets.Publish:
+		pk = &packets.PublishPacket{FixedHeader: p.FixedHeader}
+	case packets.Puback:
+		pk = &packets.PubackPacket{FixedHeader: p.FixedHeader}
+	case packets.Pubrec:
+	"github.com/mochi-co/mqtt/processors/circ"
 import (
+	case packets.Pubrel:
+		pk = &packets.PubrelPacket{FixedHeader: p.FixedHeader}
+	case packets.Pubcomp:
+	"github.com/mochi-co/mqtt/processors/circ"
 	"time"
+	case packets.Subscribe:
+		pk = &packets.SubscribePacket{FixedHeader: p.FixedHeader}
+	case packets.Suback:
+		pk = &packets.SubackPacket{FixedHeader: p.FixedHeader}
+	case packets.Unsubscribe:
+		pk = &packets.UnsubscribePacket{FixedHeader: p.FixedHeader}
+)
 	"encoding/binary"
+		pk = &packets.UnsubackPacket{FixedHeader: p.FixedHeader}
+	case packets.Pingreq:
+		pk = &packets.PingreqPacket{FixedHeader: p.FixedHeader}
+	case packets.Pingresp:
+		pk = &packets.PingrespPacket{FixedHeader: p.FixedHeader}
+	case packets.Disconnect:
+		pk = &packets.DisconnectPacket{FixedHeader: p.FixedHeader}
+// Processor reads and writes bytes to a network connection.
 
+		return pk, errors.New("No valid packet available; " + string(p.FixedHeader.Type))
+	}
+
+	bt, err := p.R.Read(int64(p.FixedHeader.Remaining))
+	if err != nil {
+		return pk, err
+	}
+
+	/*
+// Processor reads and writes bytes to a network connection.
 	"time"
 	"log"
+	"net"
+			return pk, err
+		}
 
+		err = p.R.CommitTail(p.FixedHeader.Remaining)
+		if err != nil {
+			return err
+		}
+	*/
+
+	// Decode the remaining packet values using a fresh copy of the bytes,
+	// otherwise the next packet will change the data of this one.
+	// ----
+	// This line is super important. If the bytes being decoded are not
+	// in their own memory space, packets will get corrupted all over the place.
+	err = pk.Decode(append([]byte{}, bt[:]...)) // <--- This MUST be a copy.
+	if err != nil {
+		return pk, err
+	}
+
+	return
 }




diff --git a/processor_test.go b/processor_test.go
index b8c8e25745a18c68229b05a9893af893916ab712..41683461ae531c5fc587e2445496f350f1e28deb 100644
--- a/processor_test.go
+++ b/processor_test.go
@@ -1,8 +1,9 @@
 package mqtt
 
 import (
-	"fmt"
+	"net"
 	"testing"
+	"time"
 
 	"github.com/stretchr/testify/require"
 
@@ -23,36 +24,234 @@ 	p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16))
 
 	// Test null data.
 	fh := new(packets.FixedHeader)
+	err := p.ReadFixedHeader(fh)
+
 package mqtt
+
+	// Test insufficient peeking.
+	fh = new(packets.FixedHeader)
+	p.R.Set([]byte{packets.Connect << 4}, 0, 1)
+	p.R.SetPos(0, 1)
+	err = p.ReadFixedHeader(fh)
+	require.Error(t, err)
+
+	fh = new(packets.FixedHeader)
+
 	"github.com/mochi-co/mqtt/processors/circ"
+	p.R.SetPos(0, 2)
+	err = p.ReadFixedHeader(fh)
+	require.NoError(t, err)
+
+	tail, head := p.R.GetPos()
+	require.Equal(t, int64(2), tail)
+	require.Equal(t, int64(2), head)
 package mqtt
+	"fmt"
+
+func TestProcessorRead(t *testing.T) {
+	conn := new(MockNetConn)
+
+	var fh packets.FixedHeader
+	p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32))
+import (
 )
+		byte(packets.Publish << 4), 18, // Fixed header
+		0, 5, // Topic Name - LSB+MSB
+	"fmt"
 
+		'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload
+	}, 0, 20)
+	p.R.SetPos(0, 20)
+
+	err := p.ReadFixedHeader(&fh)
+import (
 
+	pko, err := p.Read()
+	require.NoError(t, err)
+	require.Equal(t, &packets.PublishPacket{
+		FixedHeader: packets.FixedHeader{
+			Type:      packets.Publish,
+	"testing"
 package mqtt
+	"testing"
 
+		TopicName: "a/b/c",
+		Payload:   []byte("hello mochi"),
+	}, pko)
+}
 
+func TestProcessorReadFail(t *testing.T) {
+	conn := new(MockNetConn)
+	p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32))
+	p.R.Set([]byte{
+		byte(packets.Publish << 4), 3, // Fixed header
+		0, 5, // Topic Name - LSB+MSB
+		'a', '/',
+	}, 0, 6)
+	p.R.SetPos(0, 8)
 
+	var fh packets.FixedHeader
+	err := p.ReadFixedHeader(&fh)
+	require.NoError(t, err)
+	_, err = p.Read()
 
+package mqtt
+}
+
+// This is a super important test. It checks whether or not subsequent packets
+	"github.com/stretchr/testify/require"
 import (
+// multiple packets.
+func TestProcessorReadPacketNoOverwrite(t *testing.T) {
+	conn := new(MockNetConn)
 
+	pk1 := []byte{
+		byte(packets.Publish << 4), 12, // Fixed header
 	"fmt"
+package mqtt
+	"fmt"
 
+		'h', 'e', 'l', 'l', 'o', // Payload
+	}
+
+	pk2 := []byte{
+		byte(packets.Publish << 4), 14, // Fixed header
+		0, 5, // Topic Name - LSB+MSB
+		'x', '/', 'y', '/', 'z', // Topic Name
+		'y', 'a', 'h', 'a', 'l', 'l', 'o', // Payload
+	}
+
+	p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32))
+	p.R.Set(pk1, 0, len(pk1))
+	"github.com/mochi-co/mqtt/packets"
 	"testing"
+	var fh packets.FixedHeader
+	err := p.ReadFixedHeader(&fh)
+	require.NoError(t, err)
+	o1, err := p.Read()
+	require.NoError(t, err)
+	require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, o1.(*packets.PublishPacket).Payload)
+	require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, pk1[9:])
 
-	"github.com/stretchr/testify/require"
+	p.R.Set(pk2, 0, len(pk2))
+	p.R.SetPos(0, int64(len(pk2)))
+
+	err = p.ReadFixedHeader(&fh)
+	require.NoError(t, err)
+	o2, err := p.Read()
+	require.NoError(t, err)
+	require.Equal(t, []byte{'y', 'a', 'h', 'a', 'l', 'l', 'o'}, o2.(*packets.PublishPacket).Payload)
+	require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, o1.(*packets.PublishPacket).Payload, "o1 payload was mutated")
+}
+
+func TestProcessorReadPacketNil(t *testing.T) {
 
+	conn := new(MockNetConn)
+	p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32))
+import (
 	"github.com/mochi-co/mqtt/packets"
 
+	// Check for un-specified packet.
+	// Create a ping request packet with a false fixedheader type code.
+	pk := &packets.PingreqPacket{FixedHeader: packets.FixedHeader{Type: packets.Pingreq}}
+
+	pk.FixedHeader.Type = 99
+	p.R.Set([]byte{0, 0}, 0, 2)
+	p.R.SetPos(0, 2)
+
+	err := p.ReadFixedHeader(&fh)
+	require.NoError(t, err)
+	"github.com/stretchr/testify/require"
 package mqtt
 
+package mqtt
 
+}
+
+func TestProcessorReadPacketReadOverflow(t *testing.T) {
+	conn := new(MockNetConn)
 import (
+	"github.com/mochi-co/mqtt/processors/circ"
+	var fh packets.FixedHeader
 
+	// Check for un-specified packet.
+	// Create a ping request packet with a false fixedheader type code.
+	"github.com/mochi-co/mqtt/processors/circ"
 	"github.com/mochi-co/mqtt/processors/circ"
 
+	"github.com/mochi-co/mqtt/processors/circ"
 )
+)
 
+	p.R.SetPos(0, 2)
+
+	err := p.ReadFixedHeader(&fh)
+	require.NoError(t, err)
+
+	p.FixedHeader.Remaining = 999999 // overflow buffer
+	_, err = p.Read()
+	require.Error(t, err)
+}
+
+// MockNetConn satisfies the net.Conn interface.
+type MockNetConn struct {
+	ID       string
+)
 	"github.com/mochi-co/mqtt/packets"
+}
+
+// Read reads bytes from the net io.reader.
+func (m *MockNetConn) Read(b []byte) (n int, err error) {
+	return 0, nil
+}
+
+// Read writes bytes to the net io.writer.
+func (m *MockNetConn) Write(b []byte) (n int, err error) {
+	return 0, nil
+}
+
+func TestNewProcessor(t *testing.T) {
 import (
+func (m *MockNetConn) Close() error {
+	return nil
+}
+
+package mqtt
+	"github.com/stretchr/testify/require"
+func (m *MockNetConn) LocalAddr() net.Addr {
+	return new(MockNetAddr)
+}
+
+// RemoteAddr returns the remove address of the request.
+func (m *MockNetConn) RemoteAddr() net.Addr {
+	return new(MockNetAddr)
+}
+
+// SetDeadline sets the request deadline.
+func (m *MockNetConn) SetDeadline(t time.Time) error {
+	m.Deadline = t
+	return nil
+}
+
+// SetReadDeadline sets the read deadline.
+func (m *MockNetConn) SetReadDeadline(t time.Time) error {
+	return nil
+}
+
+// SetWriteDeadline sets the write deadline.
+func (m *MockNetConn) SetWriteDeadline(t time.Time) error {
+	return nil
+}
+
+// MockNetAddr satisfies net.Addr interface.
+type MockNetAddr struct{}
+
+// Network returns the network protocol.
+func (m *MockNetAddr) Network() string {
+	return "tcp"
+}
+
+// String returns the network address.
+func (m *MockNetAddr) String() string {
+	return "127.0.0.1"
 }




diff --git a/processors/circ/buffer.go b/processors/circ/buffer.go
index 2cd569d65cbd34376d5471d78f89e9b7c544294f..b88c4598dc85d262f217b75102116d3b11c24626 100644
--- a/processors/circ/buffer.go
+++ b/processors/circ/buffer.go
@@ -51,6 +51,11 @@ 		wcond: sync.NewCond(new(sync.Mutex)),
 	}
 }
 
+// Get will return the tail and head positions of the buffer.
+func (b *buffer) GetPos() (int64, int64) {
+	return atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head)
+}
+
 // Set writes bytes to a byte buffer. This method should only be used for testing
 // and will panic if out of range.
 func (b *buffer) Set(p []byte, start, end int) {
@@ -90,7 +95,7 @@ 	return
 }
 
 // awaitFilled will hold until there are at least n bytes waiting between the
-// tail and head, before returning
+// tail and head.
 func (b *buffer) awaitFilled(n int64) (tail int64, err error) {
 	head := atomic.LoadInt64(&b.head)
 	tail = atomic.LoadInt64(&b.tail)
@@ -107,3 +112,31 @@ 	b.wcond.L.Unlock()
 
 	return
 }
+
+// CommitHead moves the head position of the buffer n bytes. If there is not enough
+// capacity, the method will wait until there is.
+func (b *buffer) CommitHead(n int64) error {
+	return nil
+}
+
+// CommitTail moves the tail position of the buffer n bytes, and will wait until
+// there is enough capacity for at least n bytes.
+func (b *buffer) CommitTail(n int64) error {
+	_, err := b.awaitFilled(n)
+	if err != nil {
+		return err
+	}
+
+	tail := atomic.LoadInt64(&b.tail)
+	if tail+n < b.size {
+		atomic.StoreInt64(&b.tail, tail+n)
+	} else {
+		atomic.StoreInt64(&b.tail, (tail+n)%b.size)
+	}
+
+	b.rcond.L.Lock()
+	b.rcond.Broadcast()
+	b.rcond.L.Unlock()
+
+	return nil
+}




diff --git a/processors/circ/buffer_test.go b/processors/circ/buffer_test.go
index 43538e6e2e32558ba33582dd42e81954fcc67596..e8a48ba5e4275f84907be2401dc48a18afd381d4 100644
--- a/processors/circ/buffer_test.go
+++ b/processors/circ/buffer_test.go
@@ -34,6 +34,26 @@ 	buf.Set(p, 2, 6)
 	require.Equal(t, p, buf.buf[2:6])
 }
 
+func TestGetPos(t *testing.T) {
+	buf := newBuffer(8)
+	buf.tail, buf.head = 1, 3
+	tail, head := buf.GetPos()
+	require.Equal(t, int64(1), tail)
+	require.Equal(t, int64(3), head)
+}
+
+func TestSetPos(t *testing.T) {
+	buf := newBuffer(8)
+	buf.SetPos(1, 3)
+	require.Equal(t, int64(3), buf.head)
+	require.Equal(t, int64(1), buf.tail)
+}
+
+func TestCommitHead(t *testing.T) {
+	//buf := newBuffer(16)
+
+}
+
 func TestAwaitCapacity(t *testing.T) {
 	tests := []struct {
 		tail  int64
@@ -123,3 +143,38 @@ 		require.Equal(t, tt.next, buf.head, "Head-next mismatch [i:%d] %s", i, tt.desc)
 		require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc)
 	}
 }
+
+func TestCommitTail(t *testing.T) {
+	tests := []struct {
+		tail  int64
+		head  int64
+		next  int64
+		await int
+		desc  string
+	}{
+		{0, 5, 4, 0, "OK 0, 4"},
+		{6, 10, 10, 0, "OK 6, 10"},
+		{14, 2, 2, 0, "OK 14, 2 wrapped"},
+		{6, 8, 10, 2, "Wait 6, 8"},
+		{14, 1, 2, 2, "Wait 14, 1 wrapped"},
+	}
+
+	buf := newBuffer(16)
+	for i, tt := range tests {
+		buf.tail, buf.head = tt.tail, tt.head
+		o := make(chan error)
+		go func() {
+			o <- buf.CommitTail(4)
+		}()
+
+		time.Sleep(time.Millisecond)
+		for j := 0; j < tt.await; j++ {
+			atomic.AddInt64(&buf.head, 1)
+			buf.wcond.L.Lock()
+			buf.wcond.Broadcast()
+			buf.wcond.L.Unlock()
+		}
+		require.NoError(t, <-o, "Unexpected Error [i:%d] %s", i, tt.desc)
+		require.Equal(t, tt.next, buf.tail, "Tail-next mismatch [i:%d] %s", i, tt.desc)
+	}
+}




diff --git a/processors/circ/reader.go b/processors/circ/reader.go
index 4bb649b0f845356e769a8f2987d2197fa6273726..52fa7ca9c5f873c6853ffa974850fa06642c0c75 100644
--- a/processors/circ/reader.go
+++ b/processors/circ/reader.go
@@ -134,6 +134,10 @@
 // Read reads the next n bytes from the buffer. If n bytes are not
 // available, read will wait until there is enough.
 func (b *Reader) Read(n int64) (p []byte, err error) {
+	if n > b.size {
+		err = ErrInsufficientBytes
+		return
+	}
 
 	// Wait until there's at least len(p) bytes to read.
 	tail, err := b.awaitFilled(n)
@@ -147,7 +151,7 @@ 	if atomic.LoadInt64(&b.head) < tail {
 		b.tmp = b.buf[tail:b.size]
 		b.tmp = append(b.tmp, b.buf[:(tail+n)%b.size]...)
 		return b.tmp, nil
-	} else {
+	} else if tail+n < b.size {
 		return b.buf[tail : tail+n], nil
 	}