~/Projects/mochi-mqtt
git clone https://code.lsong.org/mochi-mqtt
Commit
- Commit
- bd8fb95f9e0a7f3ed65ab408a4822244df51c8ff
- Author
- Mochi <[email protected]>
- Date
- 2019-10-24 21:47:57 +0100 +0100
- Diffstat
parser.go | 3 parser_test.go | 4 processor.go | 87 +++++++++++++- processor_test.go | 203 +++++++++++++++++++++++++++++++++++ processors/circ/buffer.go | 35 ++++++ processors/circ/buffer_test.go | 55 +++++++++ processors/circ/reader.go | 6
Processor Read and Tests
diff --git a/parser.go b/parser.go index b021131029b0b7ccdc1aaae0d32ad0c4951ab3e2..8f5828209408d354c08daf35d0b1561bdbae1cf6 100644 --- a/parser.go +++ b/parser.go @@ -5,7 +5,6 @@ "bufio" "encoding/binary" "errors" "io" - "log" "net" "sync" "time" @@ -67,8 +66,6 @@ peeked, err := p.R.Peek(1) if err != nil { return err } - - log.Println("Peeked", peeked) // Unpack message type and flags from byte 1. err = fh.Decode(peeked[0]) diff --git a/parser_test.go b/parser_test.go index 6c5a3ce33616221744be2c98ed9cfbaaff4904ce..2fe0ef6389fc7fec4f7f47fec1229ebd52ade17b 100644 --- a/parser_test.go +++ b/parser_test.go @@ -5,7 +5,7 @@ "bufio" "bytes" "net" "testing" - "time" + //"time" "github.com/stretchr/testify/require" @@ -465,6 +465,7 @@ require.Error(t, err, "Expected error reading packet") } +/* // MockNetConn satisfies the net.Conn interface. type MockNetConn struct { ID string @@ -524,3 +525,4 @@ // String returns the network address. func (m *MockNetAddr) String() string { return "127.0.0.1" } +*/ diff --git a/processor.go b/processor.go index 0e960fdbccdd29d6f0ac51a78110cf65e31e65c2..1273aa7c46e430cad839a1633e4dd4be53be4683 100644 --- a/processor.go +++ b/processor.go @@ -2,7 +2,7 @@ package mqtt import ( "encoding/binary" - "log" + "errors" "net" "time" @@ -53,8 +53,6 @@ return err } "encoding/binary" - - "encoding/binary" package mqtt err = fh.Decode(peeked[0]) if err != nil { @@ -65,17 +63,14 @@ return packets.ErrInvalidFlags } "encoding/binary" - "net" - - "encoding/binary" "time" // looking for continue values, and if found increase the peek. Otherwise // decode the bytes that were legit. //p.fhBuffer = p.fhBuffer[:0] buf := make([]byte, 0, 6) i := 1 - //var b int64 = 2 + var b int64 = 2 // need this var later. - for b := int64(2); b < 6; b++ { + for ; b < 6; b++ { peeked, err = p.R.Peek(b) if err != nil { return err @@ -101,16 +96,90 @@ rem, _ := binary.Uvarint(buf) fh.Remaining = int(rem) "time" + "github.com/mochi-co/mqtt/processors/circ" + err = p.R.CommitTail(b) + if err != nil { + return err + } + + // Set the fixed header in the parser. + p.FixedHeader = *fh + + return nil + +} + + "github.com/mochi-co/mqtt/packets" package mqtt - "time" + "github.com/mochi-co/mqtt/packets" + switch p.FixedHeader.Type { + case packets.Connect: + pk = &packets.ConnectPacket{FixedHeader: p.FixedHeader} + case packets.Connack: + "github.com/mochi-co/mqtt/packets" "time" + case packets.Publish: + pk = &packets.PublishPacket{FixedHeader: p.FixedHeader} + case packets.Puback: + pk = &packets.PubackPacket{FixedHeader: p.FixedHeader} + case packets.Pubrec: + "github.com/mochi-co/mqtt/processors/circ" import ( + case packets.Pubrel: + pk = &packets.PubrelPacket{FixedHeader: p.FixedHeader} + case packets.Pubcomp: + "github.com/mochi-co/mqtt/processors/circ" "time" + case packets.Subscribe: + pk = &packets.SubscribePacket{FixedHeader: p.FixedHeader} + case packets.Suback: + pk = &packets.SubackPacket{FixedHeader: p.FixedHeader} + case packets.Unsubscribe: + pk = &packets.UnsubscribePacket{FixedHeader: p.FixedHeader} +) "encoding/binary" + pk = &packets.UnsubackPacket{FixedHeader: p.FixedHeader} + case packets.Pingreq: + pk = &packets.PingreqPacket{FixedHeader: p.FixedHeader} + case packets.Pingresp: + pk = &packets.PingrespPacket{FixedHeader: p.FixedHeader} + case packets.Disconnect: + pk = &packets.DisconnectPacket{FixedHeader: p.FixedHeader} +// Processor reads and writes bytes to a network connection. + return pk, errors.New("No valid packet available; " + string(p.FixedHeader.Type)) + } + + bt, err := p.R.Read(int64(p.FixedHeader.Remaining)) + if err != nil { + return pk, err + } + + /* +// Processor reads and writes bytes to a network connection. "time" "log" + "net" + return pk, err + } + err = p.R.CommitTail(p.FixedHeader.Remaining) + if err != nil { + return err + } + */ + + // Decode the remaining packet values using a fresh copy of the bytes, + // otherwise the next packet will change the data of this one. + // ---- + // This line is super important. If the bytes being decoded are not + // in their own memory space, packets will get corrupted all over the place. + err = pk.Decode(append([]byte{}, bt[:]...)) // <--- This MUST be a copy. + if err != nil { + return pk, err + } + + return } diff --git a/processor_test.go b/processor_test.go index b8c8e25745a18c68229b05a9893af893916ab712..41683461ae531c5fc587e2445496f350f1e28deb 100644 --- a/processor_test.go +++ b/processor_test.go @@ -1,8 +1,9 @@ package mqtt import ( - "fmt" + "net" "testing" + "time" "github.com/stretchr/testify/require" @@ -23,36 +24,234 @@ p := NewProcessor(conn, circ.NewReader(16), circ.NewWriter(16)) // Test null data. fh := new(packets.FixedHeader) + err := p.ReadFixedHeader(fh) + package mqtt + + // Test insufficient peeking. + fh = new(packets.FixedHeader) + p.R.Set([]byte{packets.Connect << 4}, 0, 1) + p.R.SetPos(0, 1) + err = p.ReadFixedHeader(fh) + require.Error(t, err) + + fh = new(packets.FixedHeader) + "github.com/mochi-co/mqtt/processors/circ" + p.R.SetPos(0, 2) + err = p.ReadFixedHeader(fh) + require.NoError(t, err) + + tail, head := p.R.GetPos() + require.Equal(t, int64(2), tail) + require.Equal(t, int64(2), head) package mqtt + "fmt" + +func TestProcessorRead(t *testing.T) { + conn := new(MockNetConn) + + var fh packets.FixedHeader + p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32)) +import ( ) + byte(packets.Publish << 4), 18, // Fixed header + 0, 5, // Topic Name - LSB+MSB + "fmt" + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, 0, 20) + p.R.SetPos(0, 20) + + err := p.ReadFixedHeader(&fh) +import ( + pko, err := p.Read() + require.NoError(t, err) + require.Equal(t, &packets.PublishPacket{ + FixedHeader: packets.FixedHeader{ + Type: packets.Publish, + "testing" package mqtt + "testing" + TopicName: "a/b/c", + Payload: []byte("hello mochi"), + }, pko) +} +func TestProcessorReadFail(t *testing.T) { + conn := new(MockNetConn) + p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32)) + p.R.Set([]byte{ + byte(packets.Publish << 4), 3, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', + }, 0, 6) + p.R.SetPos(0, 8) + var fh packets.FixedHeader + err := p.ReadFixedHeader(&fh) + require.NoError(t, err) + _, err = p.Read() +package mqtt +} + +// This is a super important test. It checks whether or not subsequent packets + "github.com/stretchr/testify/require" import ( +// multiple packets. +func TestProcessorReadPacketNoOverwrite(t *testing.T) { + conn := new(MockNetConn) + pk1 := []byte{ + byte(packets.Publish << 4), 12, // Fixed header "fmt" +package mqtt + "fmt" + 'h', 'e', 'l', 'l', 'o', // Payload + } + + pk2 := []byte{ + byte(packets.Publish << 4), 14, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'x', '/', 'y', '/', 'z', // Topic Name + 'y', 'a', 'h', 'a', 'l', 'l', 'o', // Payload + } + + p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32)) + p.R.Set(pk1, 0, len(pk1)) + "github.com/mochi-co/mqtt/packets" "testing" + var fh packets.FixedHeader + err := p.ReadFixedHeader(&fh) + require.NoError(t, err) + o1, err := p.Read() + require.NoError(t, err) + require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, o1.(*packets.PublishPacket).Payload) + require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, pk1[9:]) - "github.com/stretchr/testify/require" + p.R.Set(pk2, 0, len(pk2)) + p.R.SetPos(0, int64(len(pk2))) + + err = p.ReadFixedHeader(&fh) + require.NoError(t, err) + o2, err := p.Read() + require.NoError(t, err) + require.Equal(t, []byte{'y', 'a', 'h', 'a', 'l', 'l', 'o'}, o2.(*packets.PublishPacket).Payload) + require.Equal(t, []byte{'h', 'e', 'l', 'l', 'o'}, o1.(*packets.PublishPacket).Payload, "o1 payload was mutated") +} + +func TestProcessorReadPacketNil(t *testing.T) { + conn := new(MockNetConn) + p := NewProcessor(conn, circ.NewReader(32), circ.NewWriter(32)) +import ( "github.com/mochi-co/mqtt/packets" + // Check for un-specified packet. + // Create a ping request packet with a false fixedheader type code. + pk := &packets.PingreqPacket{FixedHeader: packets.FixedHeader{Type: packets.Pingreq}} + + pk.FixedHeader.Type = 99 + p.R.Set([]byte{0, 0}, 0, 2) + p.R.SetPos(0, 2) + + err := p.ReadFixedHeader(&fh) + require.NoError(t, err) + "github.com/stretchr/testify/require" package mqtt +package mqtt +} + +func TestProcessorReadPacketReadOverflow(t *testing.T) { + conn := new(MockNetConn) import ( + "github.com/mochi-co/mqtt/processors/circ" + var fh packets.FixedHeader + // Check for un-specified packet. + // Create a ping request packet with a false fixedheader type code. + "github.com/mochi-co/mqtt/processors/circ" "github.com/mochi-co/mqtt/processors/circ" + "github.com/mochi-co/mqtt/processors/circ" ) +) + p.R.SetPos(0, 2) + + err := p.ReadFixedHeader(&fh) + require.NoError(t, err) + + p.FixedHeader.Remaining = 999999 // overflow buffer + _, err = p.Read() + require.Error(t, err) +} + +// MockNetConn satisfies the net.Conn interface. +type MockNetConn struct { + ID string +) "github.com/mochi-co/mqtt/packets" +} + +// Read reads bytes from the net io.reader. +func (m *MockNetConn) Read(b []byte) (n int, err error) { + return 0, nil +} + +// Read writes bytes to the net io.writer. +func (m *MockNetConn) Write(b []byte) (n int, err error) { + return 0, nil +} + +func TestNewProcessor(t *testing.T) { import ( +func (m *MockNetConn) Close() error { + return nil +} + +package mqtt + "github.com/stretchr/testify/require" +func (m *MockNetConn) LocalAddr() net.Addr { + return new(MockNetAddr) +} + +// RemoteAddr returns the remove address of the request. +func (m *MockNetConn) RemoteAddr() net.Addr { + return new(MockNetAddr) +} + +// SetDeadline sets the request deadline. +func (m *MockNetConn) SetDeadline(t time.Time) error { + m.Deadline = t + return nil +} + +// SetReadDeadline sets the read deadline. +func (m *MockNetConn) SetReadDeadline(t time.Time) error { + return nil +} + +// SetWriteDeadline sets the write deadline. +func (m *MockNetConn) SetWriteDeadline(t time.Time) error { + return nil +} + +// MockNetAddr satisfies net.Addr interface. +type MockNetAddr struct{} + +// Network returns the network protocol. +func (m *MockNetAddr) Network() string { + return "tcp" +} + +// String returns the network address. +func (m *MockNetAddr) String() string { + return "127.0.0.1" } diff --git a/processors/circ/buffer.go b/processors/circ/buffer.go index 2cd569d65cbd34376d5471d78f89e9b7c544294f..b88c4598dc85d262f217b75102116d3b11c24626 100644 --- a/processors/circ/buffer.go +++ b/processors/circ/buffer.go @@ -51,6 +51,11 @@ wcond: sync.NewCond(new(sync.Mutex)), } } +// Get will return the tail and head positions of the buffer. +func (b *buffer) GetPos() (int64, int64) { + return atomic.LoadInt64(&b.tail), atomic.LoadInt64(&b.head) +} + // Set writes bytes to a byte buffer. This method should only be used for testing // and will panic if out of range. func (b *buffer) Set(p []byte, start, end int) { @@ -90,7 +95,7 @@ return } // awaitFilled will hold until there are at least n bytes waiting between the -// tail and head, before returning +// tail and head. func (b *buffer) awaitFilled(n int64) (tail int64, err error) { head := atomic.LoadInt64(&b.head) tail = atomic.LoadInt64(&b.tail) @@ -107,3 +112,31 @@ b.wcond.L.Unlock() return } + +// CommitHead moves the head position of the buffer n bytes. If there is not enough +// capacity, the method will wait until there is. +func (b *buffer) CommitHead(n int64) error { + return nil +} + +// CommitTail moves the tail position of the buffer n bytes, and will wait until +// there is enough capacity for at least n bytes. +func (b *buffer) CommitTail(n int64) error { + _, err := b.awaitFilled(n) + if err != nil { + return err + } + + tail := atomic.LoadInt64(&b.tail) + if tail+n < b.size { + atomic.StoreInt64(&b.tail, tail+n) + } else { + atomic.StoreInt64(&b.tail, (tail+n)%b.size) + } + + b.rcond.L.Lock() + b.rcond.Broadcast() + b.rcond.L.Unlock() + + return nil +} diff --git a/processors/circ/buffer_test.go b/processors/circ/buffer_test.go index 43538e6e2e32558ba33582dd42e81954fcc67596..e8a48ba5e4275f84907be2401dc48a18afd381d4 100644 --- a/processors/circ/buffer_test.go +++ b/processors/circ/buffer_test.go @@ -34,6 +34,26 @@ buf.Set(p, 2, 6) require.Equal(t, p, buf.buf[2:6]) } +func TestGetPos(t *testing.T) { + buf := newBuffer(8) + buf.tail, buf.head = 1, 3 + tail, head := buf.GetPos() + require.Equal(t, int64(1), tail) + require.Equal(t, int64(3), head) +} + +func TestSetPos(t *testing.T) { + buf := newBuffer(8) + buf.SetPos(1, 3) + require.Equal(t, int64(3), buf.head) + require.Equal(t, int64(1), buf.tail) +} + +func TestCommitHead(t *testing.T) { + //buf := newBuffer(16) + +} + func TestAwaitCapacity(t *testing.T) { tests := []struct { tail int64 @@ -123,3 +143,38 @@ require.Equal(t, tt.next, buf.head, "Head-next mismatch [i:%d] %s", i, tt.desc) require.Nil(t, done[1], "Unexpected Error [i:%d] %s", i, tt.desc) } } + +func TestCommitTail(t *testing.T) { + tests := []struct { + tail int64 + head int64 + next int64 + await int + desc string + }{ + {0, 5, 4, 0, "OK 0, 4"}, + {6, 10, 10, 0, "OK 6, 10"}, + {14, 2, 2, 0, "OK 14, 2 wrapped"}, + {6, 8, 10, 2, "Wait 6, 8"}, + {14, 1, 2, 2, "Wait 14, 1 wrapped"}, + } + + buf := newBuffer(16) + for i, tt := range tests { + buf.tail, buf.head = tt.tail, tt.head + o := make(chan error) + go func() { + o <- buf.CommitTail(4) + }() + + time.Sleep(time.Millisecond) + for j := 0; j < tt.await; j++ { + atomic.AddInt64(&buf.head, 1) + buf.wcond.L.Lock() + buf.wcond.Broadcast() + buf.wcond.L.Unlock() + } + require.NoError(t, <-o, "Unexpected Error [i:%d] %s", i, tt.desc) + require.Equal(t, tt.next, buf.tail, "Tail-next mismatch [i:%d] %s", i, tt.desc) + } +} diff --git a/processors/circ/reader.go b/processors/circ/reader.go index 4bb649b0f845356e769a8f2987d2197fa6273726..52fa7ca9c5f873c6853ffa974850fa06642c0c75 100644 --- a/processors/circ/reader.go +++ b/processors/circ/reader.go @@ -134,6 +134,10 @@ // Read reads the next n bytes from the buffer. If n bytes are not // available, read will wait until there is enough. func (b *Reader) Read(n int64) (p []byte, err error) { + if n > b.size { + err = ErrInsufficientBytes + return + } // Wait until there's at least len(p) bytes to read. tail, err := b.awaitFilled(n) @@ -147,7 +151,7 @@ if atomic.LoadInt64(&b.head) < tail { b.tmp = b.buf[tail:b.size] b.tmp = append(b.tmp, b.buf[:(tail+n)%b.size]...) return b.tmp, nil - } else { + } else if tail+n < b.size { return b.buf[tail : tail+n], nil }