~/Projects/mochi-mqtt
git clone https://code.lsong.org/mochi-mqtt
Commit
- Commit
- 7c7b8d58fee321b8cb3e83aecc50a9c179c16eed
- Author
- mochi <[email protected]>
- Date
- 2022-01-05 18:10:24 +0000 +0000
- Diffstat
server/internal/clients/clients.go | 2 +- server/internal/clients/clients_test.go | 4 ++-- server/internal/topics/trie.go | 2 +- server/internal/topics/trie_test.go | 2 +- | 0 | 0 | 0 | 0 | 0 | 0 | 0 server/server_test.go | 2 +-
Return packets to internal Now that we can alias types, there's no compelling reason to expose the packets library
diff --git a/server/internal/clients/clients.go b/server/internal/clients/clients.go index 931542c7333388bfd8b2c539c18423eb89832e04..24d0272c2c824daab5d2636d995172494e0a24fc 100644 --- a/server/internal/clients/clients.go +++ b/server/internal/clients/clients.go @@ -13,9 +13,9 @@ "github.com/rs/xid" "github.com/mochi-co/mqtt/server/internal/circ" + "github.com/mochi-co/mqtt/server/internal/packets" "github.com/mochi-co/mqtt/server/internal/topics" "github.com/mochi-co/mqtt/server/listeners/auth" - "github.com/mochi-co/mqtt/server/packets" "github.com/mochi-co/mqtt/server/system" ) diff --git a/server/internal/clients/clients_test.go b/server/internal/clients/clients_test.go index 52c73665fd8e62369e2de25dbd75b3a1dc31dd62..f239b19af780353763d08e629d8d82d230d2031d 100644 --- a/server/internal/clients/clients_test.go +++ b/server/internal/clients/clients_test.go @@ -10,10 +10,10 @@ "testing" "time" "github.com/mochi-co/mqtt/server/internal/circ" -package clients +func TestClientsLen(t *testing.T) { package clients -import ( + "github.com/mochi-co/mqtt/server/system" "github.com/stretchr/testify/require" ) diff --git a/server/internal/packets/codec.go b/server/internal/packets/codec.go new file mode 100644 index 0000000000000000000000000000000000000000..7ab4cdc8ea185232038225af99e4e45a2780a987 --- /dev/null +++ b/server/internal/packets/codec.go @@ -0,0 +1,114 @@ +package packets + +import ( + "encoding/binary" + "unicode/utf8" + "unsafe" +) + +// bytesToString provides a zero-alloc, no-copy byte to string conversion. +// via https://github.com/golang/go/issues/25484#issuecomment-391415660 +func bytesToString(bs []byte) string { + return *(*string)(unsafe.Pointer(&bs)) +} + +// decodeUint16 extracts the value of two bytes from a byte array. +func decodeUint16(buf []byte, offset int) (uint16, int, error) { + if len(buf) < offset+2 { + return 0, 0, ErrOffsetUintOutOfRange + } + + return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil +} + +// decodeString extracts a string from a byte array, beginning at an offset. +func decodeString(buf []byte, offset int) (string, int, error) { + b, n, err := decodeBytes(buf, offset) + if err != nil { + return "", 0, err + } + + return bytesToString(b), n, nil +} + +// decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads. +func decodeBytes(buf []byte, offset int) ([]byte, int, error) { + length, next, err := decodeUint16(buf, offset) + if err != nil { + return make([]byte, 0, 0), 0, err + } + + if next+int(length) > len(buf) { + return make([]byte, 0, 0), 0, ErrOffsetStrOutOfRange + } + + if !validUTF8(buf[next : next+int(length)]) { + return make([]byte, 0, 0), 0, ErrOffsetStrInvalidUTF8 + } + + return buf[next : next+int(length)], next + int(length), nil +} + +// decodeByte extracts the value of a byte from a byte array. +func decodeByte(buf []byte, offset int) (byte, int, error) { + if len(buf) <= offset { + return 0, 0, ErrOffsetByteOutOfRange + } + return buf[offset], offset + 1, nil +} + +// decodeByteBool extracts the value of a byte from a byte array and returns a bool. +func decodeByteBool(buf []byte, offset int) (bool, int, error) { + if len(buf) <= offset { + return false, 0, ErrOffsetBoolOutOfRange + } + return 1&buf[offset] > 0, offset + 1, nil +} + +// encodeBool returns a byte instead of a bool. +func encodeBool(b bool) byte { + if b { + return 1 + } + return 0 +} + +// encodeBytes encodes a byte array to a byte array. Used primarily for message payloads. +func encodeBytes(val []byte) []byte { + // In many circumstances the number of bytes being encoded is small. + // Setting the cap to a low amount allows us to account for those without + // triggering allocation growth on append unless we need to. + buf := make([]byte, 2, 32) + binary.BigEndian.PutUint16(buf, uint16(len(val))) + return append(buf, val...) +} + +// encodeUint16 encodes a uint16 value to a byte array. +func encodeUint16(val uint16) []byte { + buf := make([]byte, 2) + binary.BigEndian.PutUint16(buf, val) + return buf +} + +// encodeString encodes a string to a byte array. +func encodeString(val string) []byte { + // Like encodeBytes, we set the cap to a small number to avoid + // triggering allocation growth on append unless we absolutely need to. + buf := make([]byte, 2, 32) + binary.BigEndian.PutUint16(buf, uint16(len(val))) + return append(buf, []byte(val)...) +} + +// validUTF8 checks if the byte array contains valid UTF-8 characters, specifically +// conforming to the MQTT specification requirements. +func validUTF8(b []byte) bool { + // [MQTT-1.4.0-1] The character data in a UTF-8 encoded string MUST be well-formed UTF-8... + if !utf8.Valid(b) { + return false + } + + // [MQTT-1.4.0-2] A UTF-8 encoded string MUST NOT include an encoding of the null character U+0000... + // ... + return true + +} diff --git a/server/internal/packets/codec_test.go b/server/internal/packets/codec_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2cd0438cc2928b630462dedda5af95504d1442df --- /dev/null +++ b/server/internal/packets/codec_test.go @@ -0,0 +1,383 @@ +package packets + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBytesToString(t *testing.T) { + b := []byte{'a', 'b', 'c'} + require.Equal(t, "abc", bytesToString(b)) +} + +func BenchmarkBytesToString(b *testing.B) { + for n := 0; n < b.N; n++ { + bytesToString([]byte{'a', 'b', 'c'}) + } +} + +func TestDecodeString(t *testing.T) { + expect := []struct { + rawBytes []byte + result []string + offset int + shouldFail bool + }{ + { + offset: 0, + rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, + result: []string{"a/b/c/d", "a"}, + }, + { + offset: 14, + rawBytes: []byte{ + byte(Connect << 4), 17, // Fixed header + 0, 6, // Protocol Name - MSB+LSB + 'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name + 3, // Protocol Version + 0, // Packet Flags + 0, 30, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'h', 'e', 'y', // Client ID "zen"}, + }, + result: []string{"hey"}, + }, + + { + offset: 2, + rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97}, + result: []string{"1/2/3/4/a/b/c/d/e/^/@/!", "a"}, + }, + { + offset: 0, + rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38}, + result: []string{"x/y/z", "!@#$%^&"}, + }, + { + offset: 0, + rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'}, + result: []string{"a/b/c/d", "z"}, + shouldFail: true, + }, + { + offset: 5, + rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'}, + result: []string{"a/b/c/d", "x"}, + shouldFail: true, + }, + { + offset: 9, + rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'}, + result: []string{"a/b/c/d", "y"}, + shouldFail: true, + }, + { + offset: 17, + rawBytes: []byte{ + byte(Connect << 4), 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 6, // Will Topic - MSB+LSB + 'l', + }, + result: []string{"lwt"}, + shouldFail: true, + }, + } + + for i, wanted := range expect { + result, _, err := decodeString(wanted.rawBytes, wanted.offset) + if wanted.shouldFail { + require.Error(t, err, "Expected error decoding string [i:%d]", i) + continue + } + + require.NoError(t, err, "Error decoding string [i:%d]", i) + require.Equal(t, wanted.result[0], result, "Incorrect decoded value [i:%d]", i) + } +} + +func BenchmarkDecodeString(b *testing.B) { + in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97} + for n := 0; n < b.N; n++ { + decodeString(in, 0) + } +} + +func TestDecodeBytes(t *testing.T) { + expect := []struct { + rawBytes []byte + result []uint8 + next int + offset int + shouldFail bool + }{ + { + rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // ... truncated connect packet (clean session) + result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), + next: 6, + offset: 0, + }, + { + rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // ... truncated connect packet, only checking start + result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), + next: 6, + offset: 0, + }, + { + rawBytes: []byte{0, 4, 77, 81}, + result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), + offset: 0, + shouldFail: true, + }, + { + rawBytes: []byte{0, 4, 77, 81}, + result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), + offset: 8, + shouldFail: true, + }, + { + rawBytes: []byte{0, 4, 77, 81}, + result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), + offset: 0, + shouldFail: true, + }, + } + + for i, wanted := range expect { + result, _, err := decodeBytes(wanted.rawBytes, wanted.offset) + if wanted.shouldFail { + require.Error(t, err, "Expected error decoding bytes [i:%d]", i) + continue + } + + require.NoError(t, err, "Error decoding bytes [i:%d]", i) + require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) + } +} + +func BenchmarkDecodeBytes(b *testing.B) { + in := []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52} + for n := 0; n < b.N; n++ { + decodeBytes(in, 0) + } +} + +func TestDecodeByte(t *testing.T) { + expect := []struct { + rawBytes []byte + result uint8 + offset int + shouldFail bool + }{ + { + rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes + result: uint8(0x00), + offset: 0, + }, + { + rawBytes: []byte{0, 4, 77, 81, 84, 84}, + result: uint8(0x04), + offset: 1, + }, + { + rawBytes: []byte{0, 4, 77, 81, 84, 84}, + result: uint8(0x4d), + offset: 2, + }, + { + rawBytes: []byte{0, 4, 77, 81, 84, 84}, + result: uint8(0x51), + offset: 3, + }, + { + rawBytes: []byte{0, 4, 77, 80, 82, 84}, + result: uint8(0x00), + offset: 8, + shouldFail: true, + }, + } + + for i, wanted := range expect { + result, offset, err := decodeByte(wanted.rawBytes, wanted.offset) + if wanted.shouldFail { + require.Error(t, err, "Expected error decoding byte [i:%d]", i) + continue + } + + require.NoError(t, err, "Error decoding byte [i:%d]", i) + require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) + require.Equal(t, i+1, offset, "Incorrect offset value [i:%d]", i) + } +} + +func BenchmarkDecodeByte(b *testing.B) { + in := []byte{0, 4, 77, 81, 84, 84} + for n := 0; n < b.N; n++ { + decodeByte(in, 0) + } +} + +func TestDecodeUint16(t *testing.T) { + expect := []struct { + rawBytes []byte + result uint16 + offset int + shouldFail bool + }{ + { + rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, + result: uint16(0x07), + offset: 0, + }, + { + rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, + result: uint16(0x761), + offset: 1, + }, + { + rawBytes: []byte{0, 7, 255, 47}, + result: uint16(0x761), + offset: 8, + shouldFail: true, + }, + } + + for i, wanted := range expect { + result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset) + if wanted.shouldFail { + require.Error(t, err, "Expected error decoding uint16 [i:%d]", i) + continue + } + + require.NoError(t, err, "Error decoding uint16 [i:%d]", i) + require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) + require.Equal(t, i+2, offset, "Incorrect offset value [i:%d]", i) + } +} + +func BenchmarkDecodeUint16(b *testing.B) { + in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97} + for n := 0; n < b.N; n++ { + decodeUint16(in, 0) + } +} + +func TestDecodeByteBool(t *testing.T) { + expect := []struct { + rawBytes []byte + result bool + offset int + shouldFail bool + }{ + { + rawBytes: []byte{0x00, 0x00}, + result: false, + }, + { + rawBytes: []byte{0x01, 0x00}, + result: true, + }, + { + rawBytes: []byte{0x01, 0x00}, + offset: 5, + shouldFail: true, + }, + } + + for i, wanted := range expect { + result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset) + if wanted.shouldFail { + require.Error(t, err, "Expected error decoding byte bool [i:%d]", i) + continue + } + + require.NoError(t, err, "Error decoding byte bool [i:%d]", i) + require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) + require.Equal(t, 1, offset, "Incorrect offset value [i:%d]", i) + } +} + +func BenchmarkDecodeByteBool(b *testing.B) { + in := []byte{0x00, 0x00} + for n := 0; n < b.N; n++ { + decodeByteBool(in, 0) + } +} + +func TestEncodeBool(t *testing.T) { + result := encodeBool(true) + require.Equal(t, byte(1), result, "Incorrect encoded value; not true") + + result = encodeBool(false) + require.Equal(t, byte(0), result, "Incorrect encoded value; not false") + + // Check failure. + result = encodeBool(false) + require.NotEqual(t, byte(1), result, "Expected failure, incorrect encoded value") +} + +func BenchmarkEncodeBool(b *testing.B) { + for n := 0; n < b.N; n++ { + encodeBool(true) + } +} + +func TestEncodeBytes(t *testing.T) { + result := encodeBytes([]byte("testing")) + require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result, "Incorrect encoded value") + + result = encodeBytes([]byte("testing")) + require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result, "Expected failure, incorrect encoded value") +} + +func BenchmarkEncodeBytes(b *testing.B) { + bb := []byte("testing") + for n := 0; n < b.N; n++ { + encodeBytes(bb) + } +} + +func TestEncodeUint16(t *testing.T) { + result := encodeUint16(0) + require.Equal(t, []byte{0x00, 0x00}, result, "Incorrect encoded value, 0") + + result = encodeUint16(32767) + require.Equal(t, []byte{0x7f, 0xff}, result, "Incorrect encoded value, 32767") + + result = encodeUint16(65535) + require.Equal(t, []byte{0xff, 0xff}, result, "Incorrect encoded value, 65535") +} + +func BenchmarkEncodeUint16(b *testing.B) { + for n := 0; n < b.N; n++ { + encodeUint16(32767) + } +} + +func TestEncodeString(t *testing.T) { + result := encodeString("testing") + require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result, "Incorrect encoded value, testing") + + result = encodeString("") + require.Equal(t, []uint8{0x00, 0x00}, result, "Incorrect encoded value, null") + + result = encodeString("a") + require.Equal(t, []uint8{0x00, 0x01, 0x61}, result, "Incorrect encoded value, a") + + result = encodeString("b") + require.NotEqual(t, []uint8{0x00, 0x00}, result, "Expected failure, incorrect encoded value, b") + +} + +func BenchmarkEncodeString(b *testing.B) { + for n := 0; n < b.N; n++ { + encodeString("benchmarking") + } +} diff --git a/server/internal/packets/fixedheader.go b/server/internal/packets/fixedheader.go new file mode 100644 index 0000000000000000000000000000000000000000..a159143bffeb1b91ddaf38bb1b66dd8722056a6f --- /dev/null +++ b/server/internal/packets/fixedheader.go @@ -0,0 +1,59 @@ +package packets + +import ( + "bytes" +) + +// FixedHeader contains the values of the fixed header portion of the MQTT packet. +type FixedHeader struct { + Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1). + Dup bool // indicates if the packet was already sent at an earlier time. + Qos byte // indicates the quality of service expected. + Retain bool // whether the message should be retained. + Remaining int // the number of remaining bytes in the payload. +} + +// Encode encodes the FixedHeader and returns a bytes buffer. +func (fh *FixedHeader) Encode(buf *bytes.Buffer) { + buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain)) + encodeLength(buf, fh.Remaining) +} + +// decode extracts the specification bits from the header byte. +func (fh *FixedHeader) Decode(headerByte byte) error { + fh.Type = headerByte >> 4 // Get the message type from the first 4 bytes. + + switch fh.Type { + case Publish: + fh.Dup = (headerByte>>3)&0x01 > 0 // Extract flags. Check if message is duplicate. + fh.Qos = (headerByte >> 1) & 0x03 // Extract QoS flag. + fh.Retain = headerByte&0x01 > 0 // Extract retain flag. + case Pubrel: + fh.Qos = (headerByte >> 1) & 0x03 + case Subscribe: + fh.Qos = (headerByte >> 1) & 0x03 + case Unsubscribe: + fh.Qos = (headerByte >> 1) & 0x03 + default: + if (headerByte>>3)&0x01 > 0 || (headerByte>>1)&0x03 > 0 || headerByte&0x01 > 0 { + return ErrInvalidFlags + } + } + + return nil +} + +// encodeLength writes length bits for the header. +func encodeLength(buf *bytes.Buffer, length int) { + for { + digit := byte(length % 128) + length /= 128 + if length > 0 { + digit |= 0x80 + } + buf.WriteByte(digit) + if length == 0 { + break + } + } +} diff --git a/server/internal/packets/fixedheader_test.go b/server/internal/packets/fixedheader_test.go new file mode 100644 index 0000000000000000000000000000000000000000..7ce4e3314e1fd161dc8f85487044f8effb023b92 --- /dev/null +++ b/server/internal/packets/fixedheader_test.go @@ -0,0 +1,220 @@ +package packets + +import ( + "bytes" + "math" + "testing" + + "github.com/stretchr/testify/require" +) + +type fixedHeaderTable struct { + rawBytes []byte + header FixedHeader + packetError bool + flagError bool +} + +var fixedHeaderExpected = []fixedHeaderTable{ + { + rawBytes: []byte{Connect << 4, 0x00}, + header: FixedHeader{Connect, false, 0, false, 0}, // Type byte, Dup bool, Qos byte, Retain bool, Remaining int + }, + { + rawBytes: []byte{Connack << 4, 0x00}, + header: FixedHeader{Connack, false, 0, false, 0}, + }, + { + rawBytes: []byte{Publish << 4, 0x00}, + header: FixedHeader{Publish, false, 0, false, 0}, + }, + { + rawBytes: []byte{Publish<<4 | 1<<1, 0x00}, + header: FixedHeader{Publish, false, 1, false, 0}, + }, + { + rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00}, + header: FixedHeader{Publish, false, 1, true, 0}, + }, + { + rawBytes: []byte{Publish<<4 | 2<<1, 0x00}, + header: FixedHeader{Publish, false, 2, false, 0}, + }, + { + rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00}, + header: FixedHeader{Publish, false, 2, true, 0}, + }, + { + rawBytes: []byte{Publish<<4 | 1<<3, 0x00}, + header: FixedHeader{Publish, true, 0, false, 0}, + }, + { + rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00}, + header: FixedHeader{Publish, true, 0, true, 0}, + }, + { + rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00}, + header: FixedHeader{Publish, true, 1, true, 0}, + }, + { + rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00}, + header: FixedHeader{Publish, true, 2, true, 0}, + }, + { + rawBytes: []byte{Puback << 4, 0x00}, + header: FixedHeader{Puback, false, 0, false, 0}, + }, + { + rawBytes: []byte{Pubrec << 4, 0x00}, + header: FixedHeader{Pubrec, false, 0, false, 0}, + }, + { + rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00}, + header: FixedHeader{Pubrel, false, 1, false, 0}, + }, + { + rawBytes: []byte{Pubcomp << 4, 0x00}, + header: FixedHeader{Pubcomp, false, 0, false, 0}, + }, + { + rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00}, + header: FixedHeader{Subscribe, false, 1, false, 0}, + }, + { + rawBytes: []byte{Suback << 4, 0x00}, + header: FixedHeader{Suback, false, 0, false, 0}, + }, + { + rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00}, + header: FixedHeader{Unsubscribe, false, 1, false, 0}, + }, + { + rawBytes: []byte{Unsuback << 4, 0x00}, + header: FixedHeader{Unsuback, false, 0, false, 0}, + }, + { + rawBytes: []byte{Pingreq << 4, 0x00}, + header: FixedHeader{Pingreq, false, 0, false, 0}, + }, + { + rawBytes: []byte{Pingresp << 4, 0x00}, + header: FixedHeader{Pingresp, false, 0, false, 0}, + }, + { + rawBytes: []byte{Disconnect << 4, 0x00}, + header: FixedHeader{Disconnect, false, 0, false, 0}, + }, + + // remaining length + { + rawBytes: []byte{Publish << 4, 0x0a}, + header: FixedHeader{Publish, false, 0, false, 10}, + }, + { + rawBytes: []byte{Publish << 4, 0x80, 0x04}, + header: FixedHeader{Publish, false, 0, false, 512}, + }, + { + rawBytes: []byte{Publish << 4, 0xd2, 0x07}, + header: FixedHeader{Publish, false, 0, false, 978}, + }, + { + rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01}, + header: FixedHeader{Publish, false, 0, false, 20102}, + }, + { + rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}, + header: FixedHeader{Publish, false, 0, false, 333333333}, + packetError: true, + }, + + // Invalid flags for packet + { + rawBytes: []byte{Connect<<4 | 1<<3, 0x00}, + header: FixedHeader{Connect, true, 0, false, 0}, + flagError: true, + }, + { + rawBytes: []byte{Connect<<4 | 1<<1, 0x00}, + header: FixedHeader{Connect, false, 1, false, 0}, + flagError: true, + }, + { + rawBytes: []byte{Connect<<4 | 1, 0x00}, + header: FixedHeader{Connect, false, 0, true, 0}, + flagError: true, + }, +} + +func TestFixedHeaderEncode(t *testing.T) { + for i, wanted := range fixedHeaderExpected { + buf := new(bytes.Buffer) + wanted.header.Encode(buf) + if wanted.flagError == false { + require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()), "Mismatched fixedheader length [i:%d] %v", i, wanted.rawBytes) + require.EqualValues(t, wanted.rawBytes, buf.Bytes(), "Mismatched byte values [i:%d] %v", i, wanted.rawBytes) + } + } +} + +func BenchmarkFixedHeaderEncode(b *testing.B) { + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + fixedHeaderExpected[0].header.Encode(buf) + } +} + +func TestFixedHeaderDecode(t *testing.T) { + for i, wanted := range fixedHeaderExpected { + fh := new(FixedHeader) + err := fh.Decode(wanted.rawBytes[0]) + if wanted.flagError { + require.Error(t, err, "Expected error reading fixedheader [i:%d] %v", i, wanted.rawBytes) + } else { + require.NoError(t, err, "Error reading fixedheader [i:%d] %v", i, wanted.rawBytes) + require.Equal(t, wanted.header.Type, fh.Type, "Mismatched fixedheader type [i:%d] %v", i, wanted.rawBytes) + require.Equal(t, wanted.header.Dup, fh.Dup, "Mismatched fixedheader dup [i:%d] %v", i, wanted.rawBytes) + require.Equal(t, wanted.header.Qos, fh.Qos, "Mismatched fixedheader qos [i:%d] %v", i, wanted.rawBytes) + require.Equal(t, wanted.header.Retain, fh.Retain, "Mismatched fixedheader retain [i:%d] %v", i, wanted.rawBytes) + } + } +} + +func BenchmarkFixedHeaderDecode(b *testing.B) { + fh := new(FixedHeader) + for n := 0; n < b.N; n++ { + err := fh.Decode(fixedHeaderExpected[0].rawBytes[0]) + if err != nil { + panic(err) + } + } +} + +func TestEncodeLength(t *testing.T) { + tt := []struct { + have int + want []byte + }{ + { + 120, + []byte{0x78}, + }, + { + math.MaxInt64, + []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, + }, + } + + for i, wanted := range tt { + buf := new(bytes.Buffer) + encodeLength(buf, wanted.have) + require.Equal(t, wanted.want, buf.Bytes(), "Returned bytes should match length [i:%d] %s", i, wanted.have) + } +} + +func BenchmarkEncodeLength(b *testing.B) { + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + encodeLength(buf, 120) + } +} diff --git a/server/internal/packets/packets.go b/server/internal/packets/packets.go new file mode 100644 index 0000000000000000000000000000000000000000..7211d8638e4a0199c7cbbde184fa466ca2d27e3e --- /dev/null +++ b/server/internal/packets/packets.go @@ -0,0 +1,673 @@ +package packets + +import ( + "bytes" + "errors" +) + +// All of the valid packet types and their packet identifier. +const ( + Reserved byte = iota + Connect // 1 + Connack // 2 + Publish // 3 + Puback // 4 + Pubrec // 5 + Pubrel // 6 + Pubcomp // 7 + Subscribe // 8 + Suback // 9 + Unsubscribe // 10 + Unsuback // 11 + Pingreq // 12 + Pingresp // 13 + Disconnect // 14 + + Accepted byte = 0x00 + Failed byte = 0xFF + CodeConnectBadProtocolVersion byte = 0x01 + CodeConnectBadClientID byte = 0x02 + CodeConnectServerUnavailable byte = 0x03 + CodeConnectBadAuthValues byte = 0x04 + CodeConnectNotAuthorised byte = 0x05 + CodeConnectNetworkError byte = 0xFE + CodeConnectProtocolViolation byte = 0xFF + ErrSubAckNetworkError byte = 0x80 +) + +var ( + // CONNECT + ErrMalformedProtocolName = errors.New("malformed packet: protocol name") + ErrMalformedProtocolVersion = errors.New("malformed packet: protocol version") + ErrMalformedFlags = errors.New("malformed packet: flags") + ErrMalformedKeepalive = errors.New("malformed packet: keepalive") + ErrMalformedClientID = errors.New("malformed packet: client id") + ErrMalformedWillTopic = errors.New("malformed packet: will topic") + ErrMalformedWillMessage = errors.New("malformed packet: will message") + ErrMalformedUsername = errors.New("malformed packet: username") + ErrMalformedPassword = errors.New("malformed packet: password") + + // CONNACK + ErrMalformedSessionPresent = errors.New("malformed packet: session present") + ErrMalformedReturnCode = errors.New("malformed packet: return code") + + // PUBLISH + ErrMalformedTopic = errors.New("malformed packet: topic name") + ErrMalformedPacketID = errors.New("malformed packet: packet id") + + // SUBSCRIBE + ErrMalformedQoS = errors.New("malformed packet: qos") + + // PACKETS + ErrProtocolViolation = errors.New("protocol violation") + ErrOffsetStrOutOfRange = errors.New("offset string out of range") + ErrOffsetBytesOutOfRange = errors.New("offset bytes out of range") + ErrOffsetByteOutOfRange = errors.New("offset byte out of range") + ErrOffsetBoolOutOfRange = errors.New("offset bool out of range") + ErrOffsetUintOutOfRange = errors.New("offset uint out of range") + ErrOffsetStrInvalidUTF8 = errors.New("offset string invalid utf8") + ErrInvalidFlags = errors.New("invalid flags set for packet") + ErrOversizedLengthIndicator = errors.New("protocol violation: oversized length indicator") + ErrMissingPacketID = errors.New("missing packet id") + ErrSurplusPacketID = errors.New("surplus packet id") +) + +// Packet is an MQTT packet. Instead of providing a packet interface and variant +// packet structs, this is a single concrete packet type to cover all packet +// types, which allows us to take advantage of various compiler optimizations. +type Packet struct { + FixedHeader FixedHeader + + PacketID uint16 + + // Connect + ProtocolName []byte + ProtocolVersion byte + CleanSession bool + WillFlag bool + WillQos byte + WillRetain bool + UsernameFlag bool + PasswordFlag bool + ReservedBit byte + Keepalive uint16 + ClientIdentifier string + WillTopic string + WillMessage []byte + Username []byte + Password []byte + + // Connack + SessionPresent bool + ReturnCode byte + + // Publish + TopicName string + Payload []byte + + // Subscribe, Unsubscribe + Topics []string + Qoss []byte + + ReturnCodes []byte // Suback +} + +// ConnectEncode encodes a connect packet. +func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error { + + protoName := encodeBytes(pk.ProtocolName) + protoVersion := pk.ProtocolVersion + flag := encodeBool(pk.CleanSession)<<1 | encodeBool(pk.WillFlag)<<2 | pk.WillQos<<3 | encodeBool(pk.WillRetain)<<5 | encodeBool(pk.PasswordFlag)<<6 | encodeBool(pk.UsernameFlag)<<7 + keepalive := encodeUint16(pk.Keepalive) + clientID := encodeString(pk.ClientIdentifier) + + var willTopic, willFlag, usernameFlag, passwordFlag []byte + + // If will flag is set, add topic and message. + if pk.WillFlag { + willTopic = encodeString(pk.WillTopic) + willFlag = encodeBytes(pk.WillMessage) + } + + // If username flag is set, add username. + if pk.UsernameFlag { + usernameFlag = encodeBytes(pk.Username) + } + + // If password flag is set, add password. + if pk.PasswordFlag { + passwordFlag = encodeBytes(pk.Password) + } + + // Get a length for the connect header. This is not super pretty, but it works. + pk.FixedHeader.Remaining = + len(protoName) + 1 + 1 + len(keepalive) + len(clientID) + + len(willTopic) + len(willFlag) + + len(usernameFlag) + len(passwordFlag) + + pk.FixedHeader.Encode(buf) + + // Eschew magic for readability. + buf.Write(protoName) + buf.WriteByte(protoVersion) + buf.WriteByte(flag) + buf.Write(keepalive) + buf.Write(clientID) + buf.Write(willTopic) + buf.Write(willFlag) + buf.Write(usernameFlag) + buf.Write(passwordFlag) + + return nil +} + +// ConnectDecode decodes a connect packet. +func (pk *Packet) ConnectDecode(buf []byte) error { + var offset int + var err error + + // Unpack protocol name and version. + pk.ProtocolName, offset, err = decodeBytes(buf, 0) + if err != nil { + return ErrMalformedProtocolName + } + + pk.ProtocolVersion, offset, err = decodeByte(buf, offset) + if err != nil { + return ErrMalformedProtocolVersion + } + // Unpack flags byte. + flags, offset, err := decodeByte(buf, offset) + if err != nil { + return ErrMalformedFlags + } + pk.ReservedBit = 1 & flags + pk.CleanSession = 1&(flags>>1) > 0 + pk.WillFlag = 1&(flags>>2) > 0 + pk.WillQos = 3 & (flags >> 3) // this one is not a bool + pk.WillRetain = 1&(flags>>5) > 0 + pk.PasswordFlag = 1&(flags>>6) > 0 + pk.UsernameFlag = 1&(flags>>7) > 0 + + // Get keepalive interval. + pk.Keepalive, offset, err = decodeUint16(buf, offset) + if err != nil { + return ErrMalformedKeepalive + } + + // Get client ID. + pk.ClientIdentifier, offset, err = decodeString(buf, offset) + if err != nil { + return ErrMalformedClientID + } + + // Get Last Will and Testament topic and message if applicable. + if pk.WillFlag { + pk.WillTopic, offset, err = decodeString(buf, offset) + if err != nil { + return ErrMalformedWillTopic + } + + pk.WillMessage, offset, err = decodeBytes(buf, offset) + if err != nil { + return ErrMalformedWillMessage + } + } + + // Get username and password if applicable. + if pk.UsernameFlag { + pk.Username, offset, err = decodeBytes(buf, offset) + if err != nil { + return ErrMalformedUsername + } + } + + if pk.PasswordFlag { + pk.Password, offset, err = decodeBytes(buf, offset) + if err != nil { + return ErrMalformedPassword + } + } + + return nil + +} + +// ConnectValidate ensures the connect packet is compliant. +func (pk *Packet) ConnectValidate() (b byte, err error) { + + // End if protocol name is bad. + if bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) != 0 && + bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) != 0 { + return CodeConnectProtocolViolation, ErrProtocolViolation + } + + // End if protocol version is bad. + if (bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) == 0 && pk.ProtocolVersion != 3) || + (bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) == 0 && pk.ProtocolVersion != 4) { + return CodeConnectBadProtocolVersion, ErrProtocolViolation + } + + // End if reserved bit is not 0. + if pk.ReservedBit != 0 { + return CodeConnectProtocolViolation, ErrProtocolViolation + } + + // End if ClientID is too long. + if len(pk.ClientIdentifier) > 65535 { + return CodeConnectProtocolViolation, ErrProtocolViolation + } + + // End if password flag is set without a username. + if pk.PasswordFlag && !pk.UsernameFlag { + return CodeConnectProtocolViolation, ErrProtocolViolation + } + + // End if Username or Password is too long. + if len(pk.Username) > 65535 || len(pk.Password) > 65535 { + return CodeConnectProtocolViolation, ErrProtocolViolation + } + + // End if client id isn't set and clean session is false. + if !pk.CleanSession && len(pk.ClientIdentifier) == 0 { + return CodeConnectBadClientID, ErrProtocolViolation + } + + return Accepted, nil +} + +// ConnackEncode encodes a Connack packet. +func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error { + pk.FixedHeader.Remaining = 2 + pk.FixedHeader.Encode(buf) + buf.WriteByte(encodeBool(pk.SessionPresent)) + buf.WriteByte(pk.ReturnCode) + return nil +} + +// ConnackDecode decodes a Connack packet. +func (pk *Packet) ConnackDecode(buf []byte) error { + var offset int + var err error + + pk.SessionPresent, offset, err = decodeByteBool(buf, 0) + if err != nil { + return ErrMalformedSessionPresent + } + + pk.ReturnCode, offset, err = decodeByte(buf, offset) + if err != nil { + return ErrMalformedReturnCode + } + + return nil +} + +// DisconnectEncode encodes a Disconnect packet. +func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error { + pk.FixedHeader.Encode(buf) + return nil +} + +// PingreqEncode encodes a Pingreq packet. +func (pk *Packet) PingreqEncode(buf *bytes.Buffer) error { + pk.FixedHeader.Encode(buf) + return nil +} + +// PingrespEncode encodes a Pingresp packet. +func (pk *Packet) PingrespEncode(buf *bytes.Buffer) error { + pk.FixedHeader.Encode(buf) + return nil +} + +// PubackEncode encodes a Puback packet. +func (pk *Packet) PubackEncode(buf *bytes.Buffer) error { + pk.FixedHeader.Remaining = 2 + pk.FixedHeader.Encode(buf) + buf.Write(encodeUint16(pk.PacketID)) + return nil +} + +// PubackDecode decodes a Puback packet. +func (pk *Packet) PubackDecode(buf []byte) error { + var err error + pk.PacketID, _, err = decodeUint16(buf, 0) + if err != nil { + return ErrMalformedPacketID + } + return nil +} + +// PubcompEncode encodes a Pubcomp packet. +func (pk *Packet) PubcompEncode(buf *bytes.Buffer) error { + pk.FixedHeader.Remaining = 2 + pk.FixedHeader.Encode(buf) + buf.Write(encodeUint16(pk.PacketID)) + return nil +} + +// PubcompDecode decodes a Pubcomp packet. +func (pk *Packet) PubcompDecode(buf []byte) error { + var err error + pk.PacketID, _, err = decodeUint16(buf, 0) + if err != nil { + return ErrMalformedPacketID + } + return nil +} + +// PublishEncode encodes a Publish packet. +func (pk *Packet) PublishEncode(buf *bytes.Buffer) error { + topicName := encodeString(pk.TopicName) + var packetID []byte + + // Add PacketID if QOS is set. + // [MQTT-2.3.1-5] A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0. + if pk.FixedHeader.Qos > 0 { + + // [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. + if pk.PacketID == 0 { + return ErrMissingPacketID + } + + packetID = encodeUint16(pk.PacketID) + } + + pk.FixedHeader.Remaining = len(topicName) + len(packetID) + len(pk.Payload) + pk.FixedHeader.Encode(buf) + buf.Write(topicName) + buf.Write(packetID) + buf.Write(pk.Payload) + + return nil +} + +// PublishDecode extracts the data values from the packet. +func (pk *Packet) PublishDecode(buf []byte) error { + var offset int + var err error + + pk.TopicName, offset, err = decodeString(buf, 0) + if err != nil { + return ErrMalformedTopic + } + + // If QOS decode Packet ID. + if pk.FixedHeader.Qos > 0 { + pk.PacketID, offset, err = decodeUint16(buf, offset) + if err != nil { + return ErrMalformedPacketID + } + } + + pk.Payload = buf[offset:] + + return nil +} + +// PublishCopy creates a new instance of Publish packet bearing the +// same payload and destination topic, but with an empty header for +// inheriting new QoS flags, etc. +func (pk *Packet) PublishCopy() Packet { + return Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Retain: pk.FixedHeader.Retain, + }, + TopicName: pk.TopicName, + Payload: pk.Payload, + } +} + +// PublishValidate validates a publish packet. +func (pk *Packet) PublishValidate() (byte, error) { + + // @SPEC [MQTT-2.3.1-1] + // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. + if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 { + return Failed, ErrMissingPacketID + } + + // @SPEC [MQTT-2.3.1-5] + // A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0. + if pk.FixedHeader.Qos == 0 && pk.PacketID > 0 { + return Failed, ErrSurplusPacketID + } + + return Accepted, nil +} + +// PubrecEncode encodes a Pubrec packet. +func (pk *Packet) PubrecEncode(buf *bytes.Buffer) error { + pk.FixedHeader.Remaining = 2 + pk.FixedHeader.Encode(buf) + buf.Write(encodeUint16(pk.PacketID)) + return nil +} + +// PubrecDecode decodes a Pubrec packet. +func (pk *Packet) PubrecDecode(buf []byte) error { + var err error + pk.PacketID, _, err = decodeUint16(buf, 0) + if err != nil { + return ErrMalformedPacketID + } + + return nil +} + +// PubrelEncode encodes a Pubrel packet. +func (pk *Packet) PubrelEncode(buf *bytes.Buffer) error { + pk.FixedHeader.Remaining = 2 + pk.FixedHeader.Encode(buf) + buf.Write(encodeUint16(pk.PacketID)) + return nil +} + +// PubrelDecode decodes a Pubrel packet. +func (pk *Packet) PubrelDecode(buf []byte) error { + var err error + pk.PacketID, _, err = decodeUint16(buf, 0) + if err != nil { + return ErrMalformedPacketID + } + return nil +} + +// SubackEncode encodes a Suback packet. +func (pk *Packet) SubackEncode(buf *bytes.Buffer) error { + packetID := encodeUint16(pk.PacketID) + pk.FixedHeader.Remaining = len(packetID) + len(pk.ReturnCodes) // Set length. + pk.FixedHeader.Encode(buf) + + buf.Write(packetID) // Encode Packet ID. + buf.Write(pk.ReturnCodes) // Encode granted QOS flags. + + return nil +} + +// SubackDecode decodes a Suback packet. +func (pk *Packet) SubackDecode(buf []byte) error { + var offset int + var err error + + // Get Packet ID. + pk.PacketID, offset, err = decodeUint16(buf, offset) + if err != nil { + return ErrMalformedPacketID + } + + // Get Granted QOS flags. + pk.ReturnCodes = buf[offset:] + + return nil +} + +// SubscribeEncode encodes a Subscribe packet. +func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error { + + // Add the Packet ID. + // [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. + if pk.PacketID == 0 { + return ErrMissingPacketID + } + + packetID := encodeUint16(pk.PacketID) + + // Count topics lengths and associated QOS flags. + var topicsLen int + for _, topic := range pk.Topics { + topicsLen += len(encodeString(topic)) + 1 + } + + pk.FixedHeader.Remaining = len(packetID) + topicsLen + pk.FixedHeader.Encode(buf) + buf.Write(packetID) + + // Add all provided topic names and associated QOS flags. + for i, topic := range pk.Topics { + buf.Write(encodeString(topic)) + buf.WriteByte(pk.Qoss[i]) + } + + return nil +} + +// SubscribeDecode decodes a Subscribe packet. +func (pk *Packet) SubscribeDecode(buf []byte) error { + var offset int + var err error + + // Get the Packet ID. + pk.PacketID, offset, err = decodeUint16(buf, 0) + if err != nil { + return ErrMalformedPacketID + } + + // Keep decoding until there's no space left. + for offset < len(buf) { + + // Decode Topic Name. + var topic string + topic, offset, err = decodeString(buf, offset) + if err != nil { + return ErrMalformedTopic + } + pk.Topics = append(pk.Topics, topic) + + // Decode QOS flag. + var qos byte + qos, offset, err = decodeByte(buf, offset) + if err != nil { + return ErrMalformedQoS + } + + // Ensure QoS byte is within range. + if !(qos >= 0 && qos <= 2) { + //if !validateQoS(qos) { + return ErrMalformedQoS + } + + pk.Qoss = append(pk.Qoss, qos) + } + + return nil +} + +// SubscribeValidate ensures the packet is compliant. +func (pk *Packet) SubscribeValidate() (byte, error) { + // @SPEC [MQTT-2.3.1-1]. + // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. + if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 { + return Failed, ErrMissingPacketID + } + + return Accepted, nil +} + +// UnsubackEncode encodes an Unsuback packet. +func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error { + pk.FixedHeader.Remaining = 2 + pk.FixedHeader.Encode(buf) + buf.Write(encodeUint16(pk.PacketID)) + return nil +} + +// UnsubackDecode decodes an Unsuback packet. +func (pk *Packet) UnsubackDecode(buf []byte) error { + var err error + pk.PacketID, _, err = decodeUint16(buf, 0) + if err != nil { + return ErrMalformedPacketID + } + return nil +} + +// UnsubscribeEncode encodes an Unsubscribe packet. +func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error { + + // Add the Packet ID. + // [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. + if pk.PacketID == 0 { + return ErrMissingPacketID + } + + packetID := encodeUint16(pk.PacketID) + + // Count topics lengths. + var topicsLen int + for _, topic := range pk.Topics { + topicsLen += len(encodeString(topic)) + } + + pk.FixedHeader.Remaining = len(packetID) + topicsLen + pk.FixedHeader.Encode(buf) + buf.Write(packetID) + + // Add all provided topic names. + for _, topic := range pk.Topics { + buf.Write(encodeString(topic)) + } + + return nil +} + +// UnsubscribeDecode decodes an Unsubscribe packet. +func (pk *Packet) UnsubscribeDecode(buf []byte) error { + var offset int + var err error + + // Get the Packet ID. + pk.PacketID, offset, err = decodeUint16(buf, 0) + if err != nil { + return ErrMalformedPacketID + } + + // Keep decoding until there's no space left. + for offset < len(buf) { + var t string + t, offset, err = decodeString(buf, offset) // Decode Topic Name. + if err != nil { + return ErrMalformedTopic + } + + if len(t) > 0 { + pk.Topics = append(pk.Topics, t) + } + } + + return nil + +} + +// UnsubscribeValidate validates an Unsubscribe packet. +func (pk *Packet) UnsubscribeValidate() (byte, error) { + // @SPEC [MQTT-2.3.1-1]. + // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. + if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 { + return Failed, ErrMissingPacketID + } + + return Accepted, nil +} diff --git a/server/internal/packets/packets_tables_test.go b/server/internal/packets/packets_tables_test.go new file mode 100644 index 0000000000000000000000000000000000000000..6a5e8cf1767acf59249e92e3162a4913e19deac7 --- /dev/null +++ b/server/internal/packets/packets_tables_test.go @@ -0,0 +1,1416 @@ +package packets + +type packetTestData struct { + group string // group specifies a group that should run the test, blank for all + rawBytes []byte // the bytes that make the packet + actualBytes []byte // the actual byte array that is created in the event of a byte mutation (eg. MQTT-2.3.1-1 qos/packet id) + packet *Packet // the packet that is expected + desc string // a description of the test + failFirst interface{} // expected fail result to be run immediately after the method is called + expect interface{} // generic expected fail result to be checked + isolate bool // isolate can be used to isolate a test + primary bool // primary is a test that should be run using readPackets + meta interface{} // meta conains a metadata value used in testing on a case-by-case basis. + code byte // code is an expected validation return code +} + +func encodeTestOK(wanted packetTestData) bool { + if wanted.rawBytes == nil { + return false + } + if wanted.group != "" && wanted.group != "encode" { + return false + } + return true +} + +func decodeTestOK(wanted packetTestData) bool { + if wanted.group != "" && wanted.group != "decode" { + return false + } + return true +} + +var expectedPackets = map[byte][]packetTestData{ + Connect: { + { + desc: "MQTT 3.1", + primary: true, + rawBytes: []byte{ + byte(Connect << 4), 17, // Fixed header + 0, 6, // Protocol Name - MSB+LSB + 'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name + 3, // Protocol Version + 0, // Packet Flags + 0, 30, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 17, + }, + ProtocolName: []byte("MQIsdp"), + ProtocolVersion: 3, + CleanSession: false, + Keepalive: 30, + ClientIdentifier: "zen", + }, + }, + + { + desc: "MQTT 3.1.1", + primary: true, + rawBytes: []byte{ + byte(Connect << 4), 16, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Packet Flags + 0, 60, // Keepalive + 0, 4, // Client ID - MSB+LSB + 'z', 'e', 'n', '3', // Client ID "zen" + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 16, + }, + ProtocolName: []byte("MQTT"), + ProtocolVersion: 4, + CleanSession: false, + Keepalive: 60, + ClientIdentifier: "zen3", + }, + }, + { + desc: "MQTT 3.1.1, Clean Session", + rawBytes: []byte{ + byte(Connect << 4), 15, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 2, // Packet Flags + 0, 45, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 15, + }, + ProtocolName: []byte("MQTT"), + ProtocolVersion: 4, + CleanSession: true, + Keepalive: 45, + ClientIdentifier: "zen", + }, + }, + { + desc: "MQTT 3.1.1, Clean Session, LWT", + rawBytes: []byte{ + byte(Connect << 4), 31, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 14, // Packet Flags + 0, 27, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 3, // Will Topic - MSB+LSB + 'l', 'w', 't', + 0, 9, // Will Message MSB+LSB + 'n', 'o', 't', ' ', 'a', 'g', 'a', 'i', 'n', + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 31, + }, + ProtocolName: []byte("MQTT"), + ProtocolVersion: 4, + CleanSession: true, + Keepalive: 27, + ClientIdentifier: "zen", + WillFlag: true, + WillTopic: "lwt", + WillMessage: []byte("not again"), + WillQos: 1, + }, + }, + { + desc: "MQTT 3.1.1, Username, Password", + rawBytes: []byte{ + byte(Connect << 4), 28, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 194, // Packet Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 5, // Username MSB+LSB + 'm', 'o', 'c', 'h', 'i', + 0, 4, // Password MSB+LSB + ',', '.', '/', ';', + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 28, + }, + ProtocolName: []byte("MQTT"), + ProtocolVersion: 4, + CleanSession: true, + Keepalive: 20, + ClientIdentifier: "zen", + UsernameFlag: true, + PasswordFlag: true, + Username: []byte("mochi"), + Password: []byte(",./;"), + }, + }, + { + desc: "MQTT 3.1.1, Username, Password, LWT", + primary: true, + rawBytes: []byte{ + byte(Connect << 4), 44, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 206, // Packet Flags + 0, 120, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 3, // Will Topic - MSB+LSB + 'l', 'w', 't', + 0, 9, // Will Message MSB+LSB + 'n', 'o', 't', ' ', 'a', 'g', 'a', 'i', 'n', + 0, 5, // Username MSB+LSB + 'm', 'o', 'c', 'h', 'i', + 0, 4, // Password MSB+LSB + ',', '.', '/', ';', + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 44, + }, + ProtocolName: []byte("MQTT"), + ProtocolVersion: 4, + CleanSession: true, + Keepalive: 120, + ClientIdentifier: "zen", + UsernameFlag: true, + PasswordFlag: true, + Username: []byte("mochi"), + Password: []byte(",./;"), + WillFlag: true, + WillTopic: "lwt", + WillMessage: []byte("not again"), + WillQos: 1, + }, + }, + + // Fail States + { + desc: "Malformed Connect - protocol name", + group: "decode", + failFirst: ErrMalformedProtocolName, + rawBytes: []byte{ + byte(Connect << 4), 0, // Fixed header + 0, 7, // Protocol Name - MSB+LSB + 'M', 'Q', 'I', 's', 'd', // Protocol Name + }, + }, + + { + desc: "Malformed Connect - protocol version", + group: "decode", + failFirst: ErrMalformedProtocolVersion, + rawBytes: []byte{ + byte(Connect << 4), 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + }, + }, + + { + desc: "Malformed Connect - flags", + group: "decode", + failFirst: ErrMalformedFlags, + rawBytes: []byte{ + byte(Connect << 4), 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + + }, + }, + { + desc: "Malformed Connect - keepalive", + group: "decode", + failFirst: ErrMalformedKeepalive, + rawBytes: []byte{ + byte(Connect << 4), 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Flags + }, + }, + { + desc: "Malformed Connect - client id", + group: "decode", + failFirst: ErrMalformedClientID, + rawBytes: []byte{ + byte(Connect << 4), 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', // Client ID "zen" + }, + }, + { + desc: "Malformed Connect - will topic", + group: "decode", + failFirst: ErrMalformedWillTopic, + rawBytes: []byte{ + byte(Connect << 4), 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 14, // Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 6, // Will Topic - MSB+LSB + 'l', + }, + }, + { + desc: "Malformed Connect - will flag", + group: "decode", + failFirst: ErrMalformedWillMessage, + rawBytes: []byte{ + byte(Connect << 4), 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 14, // Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 3, // Will Topic - MSB+LSB + 'l', 'w', 't', + 0, 9, // Will Message MSB+LSB + 'n', 'o', 't', ' ', 'a', + }, + }, + { + desc: "Malformed Connect - username", + group: "decode", + failFirst: ErrMalformedUsername, + rawBytes: []byte{ + byte(Connect << 4), 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 206, // Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 3, // Will Topic - MSB+LSB + 'l', 'w', 't', + 0, 9, // Will Message MSB+LSB + 'n', 'o', 't', ' ', 'a', 'g', 'a', 'i', 'n', + 0, 5, // Username MSB+LSB + 'm', 'o', 'c', + }, + }, + { + desc: "Malformed Connect - password", + group: "decode", + failFirst: ErrMalformedPassword, + rawBytes: []byte{ + byte(Connect << 4), 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 206, // Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'z', 'e', 'n', // Client ID "zen" + 0, 3, // Will Topic - MSB+LSB + 'l', 'w', 't', + 0, 9, // Will Message MSB+LSB + 'n', 'o', 't', ' ', 'a', 'g', 'a', 'i', 'n', + 0, 5, // Username MSB+LSB + 'm', 'o', 'c', 'h', 'i', + 0, 4, // Password MSB+LSB + ',', '.', + }, + }, + + // Validation Tests + { + desc: "Invalid Protocol Name", + group: "validate", + code: CodeConnectProtocolViolation, + packet: &Packet{ + ProtocolName: []byte("stuff"), + }, + }, + { + desc: "Invalid Protocol Version", + group: "validate", + code: CodeConnectBadProtocolVersion, + packet: &Packet{ + ProtocolName: []byte("MQTT"), + ProtocolVersion: 2, + }, + }, + { + desc: "Invalid Protocol Version", + group: "validate", + code: CodeConnectBadProtocolVersion, + packet: &Packet{ + ProtocolName: []byte("MQIsdp"), + ProtocolVersion: 2, + }, + }, + { + desc: "Reserved bit not 0", + group: "validate", + code: CodeConnectProtocolViolation, + packet: &Packet{ + ProtocolName: []byte("MQTT"), + ProtocolVersion: 4, + ReservedBit: 1, + }, + }, + { + desc: "Client ID too long", + group: "validate", + code: CodeConnectProtocolViolation, + packet: &Packet{ + ProtocolName: []byte("MQTT"), + ProtocolVersion: 4, + ClientIdentifier: func() string { + return string(make([]byte, 65536)) + }(), + }, + }, + { + desc: "Has Password Flag but no Username flag", + group: "validate", + code: CodeConnectProtocolViolation, + packet: &Packet{ + ProtocolName: []byte("MQTT"), + ProtocolVersion: 4, + PasswordFlag: true, + }, + }, + { + desc: "Username too long", + group: "validate", + code: CodeConnectProtocolViolation, + packet: &Packet{ + ProtocolName: []byte("MQTT"), + ProtocolVersion: 4, + UsernameFlag: true, + Username: func() []byte { + return make([]byte, 65536) + }(), + }, + }, + { + desc: "Password too long", + group: "validate", + code: CodeConnectProtocolViolation, + packet: &Packet{ + ProtocolName: []byte("MQTT"), + ProtocolVersion: 4, + UsernameFlag: true, + Username: []byte{}, + PasswordFlag: true, + Password: func() []byte { + return make([]byte, 65536) + }(), + }, + }, + { + desc: "Clean session false and client id not set", + group: "validate", + code: CodeConnectBadClientID, + packet: &Packet{ + ProtocolName: []byte("MQTT"), + ProtocolVersion: 4, + CleanSession: false, + }, + }, + + // Spec Tests + { + // @SPEC [MQTT-1.4.0-1] + // The character data in a UTF-8 encoded string MUST be well-formed UTF-8 + // as defined by the Unicode specification [Unicode] and restated in RFC 3629 [RFC 3629]. + // In particular this data MUST NOT include encodings of code points between U+D800 and U+DFFF. + desc: "Invalid UTF8 string (a) - Code point U+D800.", + group: "decode", + failFirst: ErrMalformedClientID, + rawBytes: []byte{ + byte(Connect << 4), 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Flags + 0, 20, // Keepalive + 0, 4, // Client ID - MSB+LSB + 'e', 0xed, 0xa0, 0x80, // Client id bearing U+D800 + }, + }, + { + desc: "Invalid UTF8 string (b) - Code point U+DFFF.", + group: "decode", + failFirst: ErrMalformedClientID, + rawBytes: []byte{ + byte(Connect << 4), 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Flags + 0, 20, // Keepalive + 0, 4, // Client ID - MSB+LSB + 'e', 0xed, 0xa3, 0xbf, // Client id bearing U+D8FF + }, + }, + + // @SPEC [MQTT-1.4.0-2] + // A UTF-8 encoded string MUST NOT include an encoding of the null character U+0000. + { + desc: "Invalid UTF8 string (c) - Code point U+0000.", + group: "decode", + failFirst: ErrMalformedClientID, + rawBytes: []byte{ + byte(Connect << 4), 0, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Flags + 0, 20, // Keepalive + 0, 3, // Client ID - MSB+LSB + 'e', 0xc0, 0x80, // Client id bearing U+0000 + }, + }, + + // @ SPEC [MQTT-1.4.0-3] + // A UTF-8 encoded sequence 0xEF 0xBB 0xBF is always to be interpreted to mean U+FEFF ("ZERO WIDTH NO-BREAK SPACE") + // wherever it appears in a string and MUST NOT be skipped over or stripped off by a packet receiver. + { + desc: "UTF8 string must not skip or strip code point U+FEFF.", + //group: "decode", + //failFirst: ErrMalformedClientID, + rawBytes: []byte{ + byte(Connect << 4), 18, // Fixed header + 0, 4, // Protocol Name - MSB+LSB + 'M', 'Q', 'T', 'T', // Protocol Name + 4, // Protocol Version + 0, // Flags + 0, 20, // Keepalive + 0, 6, // Client ID - MSB+LSB + 'e', 'b', 0xEF, 0xBB, 0xBF, 'd', // Client id bearing U+FEFF + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connect, + Remaining: 16, + }, + ProtocolName: []byte("MQTT"), + ProtocolVersion: 4, + Keepalive: 20, + ClientIdentifier: string([]byte{'e', 'b', 0xEF, 0xBB, 0xBF, 'd'}), + }, + }, + }, + Connack: { + { + desc: "Accepted, No Session", + primary: true, + rawBytes: []byte{ + byte(Connack << 4), 2, // fixed header + 0, // No existing session + Accepted, + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + SessionPresent: false, + ReturnCode: Accepted, + }, + }, + { + desc: "Accepted, Session Exists", + primary: true, + rawBytes: []byte{ + byte(Connack << 4), 2, // fixed header + 1, // Session present + Accepted, + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + SessionPresent: true, + ReturnCode: Accepted, + }, + }, + { + desc: "Bad Protocol Version", + rawBytes: []byte{ + byte(Connack << 4), 2, // fixed header + 1, // Session present + CodeConnectBadProtocolVersion, + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + SessionPresent: true, + ReturnCode: CodeConnectBadProtocolVersion, + }, + }, + { + desc: "Bad Client ID", + rawBytes: []byte{ + byte(Connack << 4), 2, // fixed header + 1, // Session present + CodeConnectBadClientID, + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + SessionPresent: true, + ReturnCode: CodeConnectBadClientID, + }, + }, + { + desc: "Server Unavailable", + rawBytes: []byte{ + byte(Connack << 4), 2, // fixed header + 1, // Session present + CodeConnectServerUnavailable, + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + SessionPresent: true, + ReturnCode: CodeConnectServerUnavailable, + }, + }, + { + desc: "Bad Username or Password", + rawBytes: []byte{ + byte(Connack << 4), 2, // fixed header + 1, // Session present + CodeConnectBadAuthValues, + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + SessionPresent: true, + ReturnCode: CodeConnectBadAuthValues, + }, + }, + { + desc: "Not Authorised", + rawBytes: []byte{ + byte(Connack << 4), 2, // fixed header + 1, // Session present + CodeConnectNotAuthorised, + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Connack, + Remaining: 2, + }, + SessionPresent: true, + ReturnCode: CodeConnectNotAuthorised, + }, + }, + + // Fail States + { + desc: "Malformed Connack - session present", + group: "decode", + failFirst: ErrMalformedSessionPresent, + rawBytes: []byte{ + byte(Connect << 4), 2, // Fixed header + }, + }, + { + desc: "Malformed Connack - bad return code", + group: "decode", + //primary: true, + failFirst: ErrMalformedReturnCode, + rawBytes: []byte{ + byte(Connect << 4), 2, // Fixed header + 0, + }, + }, + }, + + Publish: { + { + desc: "Publish - No payload", + primary: true, + rawBytes: []byte{ + byte(Publish << 4), 7, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Remaining: 7, + }, + TopicName: "a/b/c", + Payload: []byte{}, + }, + }, + { + desc: "Publish - basic", + primary: true, + rawBytes: []byte{ + byte(Publish << 4), 18, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Remaining: 18, + }, + TopicName: "a/b/c", + Payload: []byte("hello mochi"), + }, + }, + { + desc: "Publish - QoS:1, Packet ID", + primary: true, + rawBytes: []byte{ + byte(Publish<<4) | 2, 14, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, 7, // Packet ID - LSB+MSB + 'h', 'e', 'l', 'l', 'o', // Payload + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Qos: 1, + Remaining: 14, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + PacketID: 7, + }, + meta: byte(2), + }, + { + desc: "Publish - QoS:1, Packet ID, No payload", + primary: true, + rawBytes: []byte{ + byte(Publish<<4) | 2, 9, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'y', '/', 'u', '/', 'i', // Topic Name + 0, 8, // Packet ID - LSB+MSB + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Qos: 1, + Remaining: 9, + }, + TopicName: "y/u/i", + PacketID: 8, + Payload: []byte{}, + }, + meta: byte(2), + }, + { + desc: "Publish - Retain", + rawBytes: []byte{ + byte(Publish<<4) | 1, 10, // Fixed header + 0, 3, // Topic Name - LSB+MSB + 'a', '/', 'b', // Topic Name + 'h', 'e', 'l', 'l', 'o', // Payload + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Retain: true, + }, + TopicName: "a/b", + Payload: []byte("hello"), + }, + meta: byte(1), + }, + { + desc: "Publish - Dup", + rawBytes: []byte{ + byte(Publish<<4) | 8, 10, // Fixed header + 0, 3, // Topic Name - LSB+MSB + 'a', '/', 'b', // Topic Name + 'h', 'e', 'l', 'l', 'o', // Payload + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Dup: true, + }, + TopicName: "a/b", + Payload: []byte("hello"), + }, + meta: byte(8), + }, + + // Fail States + { + desc: "Malformed Publish - topic name", + group: "decode", + failFirst: ErrMalformedTopic, + rawBytes: []byte{ + byte(Publish << 4), 7, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', + 0, 11, // Packet ID - LSB+MSB + }, + }, + + { + desc: "Malformed Publish - Packet ID", + group: "decode", + failFirst: ErrMalformedPacketID, + rawBytes: []byte{ + byte(Publish<<4) | 2, 7, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'x', '/', 'y', '/', 'z', // Topic Name + 0, // Packet ID - LSB+MSB + }, + }, + + // Copy tests + { + desc: "Publish - basic copyable", + group: "copy", + rawBytes: []byte{ + byte(Publish << 4), 18, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'z', '/', 'e', '/', 'n', // Topic Name + 'm', 'o', 'c', 'h', 'i', ' ', 'm', 'o', 'c', 'h', 'i', // Payload + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Dup: true, + Retain: true, + Qos: 1, + }, + TopicName: "z/e/n", + Payload: []byte("mochi mochi"), + }, + }, + + // Spec tests + { + // @SPEC [MQTT-2.3.1-5] + // A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0. + desc: "[MQTT-2.3.1-5] Packet ID must be 0 if QoS is 0 (a)", + group: "encode", + // this version tests for correct byte array mutuation. + // this does not check if -incoming- packets are parsed as correct, + // it is impossible for the parser to determine if the payload start is incorrect. + rawBytes: []byte{ + byte(Publish << 4), 12, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, 3, // Packet ID - LSB+MSB + 'h', 'e', 'l', 'l', 'o', // Payload + }, + actualBytes: []byte{ + byte(Publish << 4), 12, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + // Packet ID is removed. + 'h', 'e', 'l', 'l', 'o', // Payload + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Remaining: 12, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + }, + }, + { + // @SPEC [MQTT-2.3.1-1]. + // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. + desc: "[MQTT-2.3.1-1] No Packet ID with QOS > 0", + group: "encode", + expect: ErrMissingPacketID, + code: Failed, + rawBytes: []byte{ + byte(Publish<<4) | 2, 14, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, 0, // Packet ID - LSB+MSB + 'h', 'e', 'l', 'l', 'o', // Payload + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Qos: 1, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + PacketID: 0, + }, + meta: byte(2), + }, + /* + { + // @SPEC [MQTT-2.3.1-1]. + // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. + desc: "[MQTT-2.3.1-1] No Packet ID with QOS > 0", + group: "validate", + //primary: true, + expect: ErrMissingPacketID, + code: Failed, + rawBytes: []byte{ + byte(Publish<<4) | 2, 14, // Fixed header + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 0, 0, // Packet ID - LSB+MSB + 'h', 'e', 'l', 'l', 'o', // Payload + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Qos: 1, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + PacketID: 0, + }, + meta: byte(2), + }, + + */ + + // Validation Tests + { + // @SPEC [MQTT-2.3.1-5] + desc: "[MQTT-2.3.1-5] Packet ID must be 0 if QoS is 0 (b)", + group: "validate", + expect: ErrSurplusPacketID, + code: Failed, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Remaining: 12, + Qos: 0, + }, + TopicName: "a/b/c", + Payload: []byte("hello"), + PacketID: 3, + }, + }, + { + // @SPEC [MQTT-2.3.1-1]. + // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. + desc: "[MQTT-2.3.1-1] No Packet ID with QOS > 0", + group: "validate", + expect: ErrMissingPacketID, + code: Failed, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Publish, + Qos: 1, + }, + PacketID: 0, + }, + }, + }, + + Puback: { + { + desc: "Puback", + primary: true, + rawBytes: []byte{ + byte(Puback << 4), 2, // Fixed header + 0, 11, // Packet ID - LSB+MSB + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Puback, + Remaining: 2, + }, + PacketID: 11, + }, + }, + + // Fail states + { + desc: "Malformed Puback - Packet ID", + group: "decode", + failFirst: ErrMalformedPacketID, + rawBytes: []byte{ + byte(Puback << 4), 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + }, + Pubrec: { + { + desc: "Pubrec", + primary: true, + rawBytes: []byte{ + byte(Pubrec << 4), 2, // Fixed header + 0, 12, // Packet ID - LSB+MSB + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Pubrec, + Remaining: 2, + }, + PacketID: 12, + }, + }, + + // Fail states + { + desc: "Malformed Pubrec - Packet ID", + group: "decode", + failFirst: ErrMalformedPacketID, + rawBytes: []byte{ + byte(Pubrec << 4), 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + }, + Pubrel: { + { + desc: "Pubrel", + primary: true, + rawBytes: []byte{ + byte(Pubrel<<4) | 2, 2, // Fixed header + 0, 12, // Packet ID - LSB+MSB + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Pubrel, + Remaining: 2, + Qos: 1, + }, + PacketID: 12, + }, + meta: byte(2), + }, + + // Fail states + { + desc: "Malformed Pubrel - Packet ID", + group: "decode", + failFirst: ErrMalformedPacketID, + rawBytes: []byte{ + byte(Pubrel << 4), 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + }, + Pubcomp: { + { + desc: "Pubcomp", + primary: true, + rawBytes: []byte{ + byte(Pubcomp << 4), 2, // Fixed header + 0, 14, // Packet ID - LSB+MSB + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Pubcomp, + Remaining: 2, + }, + PacketID: 14, + }, + }, + + // Fail states + { + desc: "Malformed Pubcomp - Packet ID", + group: "decode", + failFirst: ErrMalformedPacketID, + rawBytes: []byte{ + byte(Pubcomp << 4), 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + }, + Subscribe: { + { + desc: "Subscribe", + primary: true, + rawBytes: []byte{ + byte(Subscribe << 4), 30, // Fixed header + 0, 15, // Packet ID - LSB+MSB + + 0, 3, // Topic Name - LSB+MSB + 'a', '/', 'b', // Topic Name + 0, // QoS + + 0, 11, // Topic Name - LSB+MSB + 'd', '/', 'e', '/', 'f', '/', 'g', '/', 'h', '/', 'i', // Topic Name + 1, // QoS + + 0, 5, // Topic Name - LSB+MSB + 'x', '/', 'y', '/', 'z', // Topic Name + 2, // QoS + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Subscribe, + Remaining: 30, + }, + PacketID: 15, + Topics: []string{ + "a/b", + "d/e/f/g/h/i", + "x/y/z", + }, + Qoss: []byte{0, 1, 2}, + }, + }, + + // Fail states + { + desc: "Malformed Subscribe - Packet ID", + group: "decode", + failFirst: ErrMalformedPacketID, + rawBytes: []byte{ + byte(Subscribe << 4), 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + { + desc: "Malformed Subscribe - topic", + group: "decode", + failFirst: ErrMalformedTopic, + rawBytes: []byte{ + byte(Subscribe << 4), 2, // Fixed header + 0, 21, // Packet ID - LSB+MSB + + 0, 3, // Topic Name - LSB+MSB + 'a', '/', + }, + }, + { + desc: "Malformed Subscribe - qos", + group: "decode", + failFirst: ErrMalformedQoS, + rawBytes: []byte{ + byte(Subscribe << 4), 2, // Fixed header + 0, 22, // Packet ID - LSB+MSB + + 0, 3, // Topic Name - LSB+MSB + 'j', '/', 'b', // Topic Name + + }, + }, + { + desc: "Malformed Subscribe - qos out of range", + group: "decode", + failFirst: ErrMalformedQoS, + rawBytes: []byte{ + byte(Subscribe << 4), 2, // Fixed header + 0, 22, // Packet ID - LSB+MSB + + 0, 3, // Topic Name - LSB+MSB + 'c', '/', 'd', // Topic Name + 5, // QoS + + }, + }, + + // Validation + { + // @SPEC [MQTT-2.3.1-1]. + // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. + desc: "[MQTT-2.3.1-1] Subscribe No Packet ID with QOS > 0", + group: "validate", + expect: ErrMissingPacketID, + code: Failed, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Subscribe, + Qos: 1, + }, + PacketID: 0, + }, + }, + + // Spec tests + { + // @SPEC [MQTT-2.3.1-1]. + // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. + desc: "[MQTT-2.3.1-1] Subscribe No Packet ID with QOS > 0", + group: "encode", + code: Failed, + expect: ErrMissingPacketID, + rawBytes: []byte{ + byte(Subscribe<<4) | 1<<1, 10, // Fixed header + 0, 0, // Packet ID - LSB+MSB + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + 1, // QoS + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Subscribe, + Qos: 1, + Remaining: 10, + }, + Topics: []string{ + "a/b/c", + }, + Qoss: []byte{1}, + PacketID: 0, + }, + meta: byte(2), + }, + }, + Suback: { + { + desc: "Suback", + primary: true, + rawBytes: []byte{ + byte(Suback << 4), 6, // Fixed header + 0, 17, // Packet ID - LSB+MSB + 0, // Return Code QoS 0 + 1, // Return Code QoS 1 + 2, // Return Code QoS 2 + 0x80, // Return Code fail + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Suback, + Remaining: 6, + }, + PacketID: 17, + ReturnCodes: []byte{0, 1, 2, 0x80}, + }, + }, + + // Fail states + { + desc: "Malformed Suback - Packet ID", + group: "decode", + failFirst: ErrMalformedPacketID, + rawBytes: []byte{ + byte(Subscribe << 4), 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + }, + + Unsubscribe: { + { + desc: "Unsubscribe", + primary: true, + rawBytes: []byte{ + byte(Unsubscribe << 4), 27, // Fixed header + 0, 35, // Packet ID - LSB+MSB + + 0, 3, // Topic Name - LSB+MSB + 'a', '/', 'b', // Topic Name + + 0, 11, // Topic Name - LSB+MSB + 'd', '/', 'e', '/', 'f', '/', 'g', '/', 'h', '/', 'i', // Topic Name + + 0, 5, // Topic Name - LSB+MSB + 'x', '/', 'y', '/', 'z', // Topic Name + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Unsubscribe, + Remaining: 27, + }, + PacketID: 35, + Topics: []string{ + "a/b", + "d/e/f/g/h/i", + "x/y/z", + }, + }, + }, + // Fail states + { + desc: "Malformed Unsubscribe - Packet ID", + group: "decode", + failFirst: ErrMalformedPacketID, + rawBytes: []byte{ + byte(Unsubscribe << 4), 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + { + desc: "Malformed Unsubscribe - topic", + group: "decode", + failFirst: ErrMalformedTopic, + rawBytes: []byte{ + byte(Unsubscribe << 4), 2, // Fixed header + 0, 21, // Packet ID - LSB+MSB + 0, 3, // Topic Name - LSB+MSB + 'a', '/', + }, + }, + + // Validation + { + // @SPEC [MQTT-2.3.1-1]. + // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. + desc: "[MQTT-2.3.1-1] Subscribe No Packet ID with QOS > 0", + group: "validate", + expect: ErrMissingPacketID, + code: Failed, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Unsubscribe, + Qos: 1, + }, + PacketID: 0, + }, + }, + + // Spec tests + { + // @SPEC [MQTT-2.3.1-1]. + // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. + desc: "[MQTT-2.3.1-1] Unsubscribe No Packet ID with QOS > 0", + group: "encode", + code: Failed, + expect: ErrMissingPacketID, + rawBytes: []byte{ + byte(Unsubscribe<<4) | 1<<1, 9, // Fixed header + 0, 0, // Packet ID - LSB+MSB + 0, 5, // Topic Name - LSB+MSB + 'a', '/', 'b', '/', 'c', // Topic Name + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Unsubscribe, + Qos: 1, + Remaining: 9, + }, + Topics: []string{ + "a/b/c", + }, + PacketID: 0, + }, + meta: byte(2), + }, + }, + Unsuback: { + { + desc: "Unsuback", + primary: true, + rawBytes: []byte{ + byte(Unsuback << 4), 2, // Fixed header + 0, 37, // Packet ID - LSB+MSB + + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Unsuback, + Remaining: 2, + }, + PacketID: 37, + }, + }, + + // Fail states + { + desc: "Malformed Unsuback - Packet ID", + group: "decode", + failFirst: ErrMalformedPacketID, + rawBytes: []byte{ + byte(Unsuback << 4), 2, // Fixed header + 0, // Packet ID - LSB+MSB + }, + }, + }, + + Pingreq: { + { + desc: "Ping request", + primary: true, + rawBytes: []byte{ + byte(Pingreq << 4), 0, // fixed header + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Pingreq, + Remaining: 0, + }, + }, + }, + }, + Pingresp: { + { + desc: "Ping response", + primary: true, + rawBytes: []byte{ + byte(Pingresp << 4), 0, // fixed header + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Pingresp, + Remaining: 0, + }, + }, + }, + }, + + Disconnect: { + { + desc: "Disconnect", + primary: true, + rawBytes: []byte{ + byte(Disconnect << 4), 0, // fixed header + }, + packet: &Packet{ + FixedHeader: FixedHeader{ + Type: Disconnect, + Remaining: 0, + }, + }, + }, + }, +} diff --git a/server/internal/packets/packets_test.go b/server/internal/packets/packets_test.go new file mode 100644 index 0000000000000000000000000000000000000000..aa5d7b5c1057a06284bf9ba7a81847599d49b479 --- /dev/null +++ b/server/internal/packets/packets_test.go @@ -0,0 +1,1082 @@ +package packets + +import ( + "bytes" + "testing" + + "github.com/jinzhu/copier" + "github.com/stretchr/testify/require" +) + +func TestConnectEncode(t *testing.T) { + require.Contains(t, expectedPackets, Connect) + for i, wanted := range expectedPackets[Connect] { + if !encodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(1), Connect, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + pk := new(Packet) + copier.Copy(pk, wanted.packet) + + require.Equal(t, Connect, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) + + buf := new(bytes.Buffer) + pk.ConnectEncode(buf) + encoded := buf.Bytes() + + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) + require.Equal(t, byte(Connect<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) + + ok, _ := pk.ConnectValidate() + require.Equal(t, byte(Accepted), ok, "Connect packet didn't validate - %v", ok) + + require.Equal(t, wanted.packet.FixedHeader.Type, pk.FixedHeader.Type, "Mismatched packet fixed header type [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.FixedHeader.Dup, pk.FixedHeader.Dup, "Mismatched packet fixed header dup [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.FixedHeader.Qos, pk.FixedHeader.Qos, "Mismatched packet fixed header qos [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.FixedHeader.Retain, pk.FixedHeader.Retain, "Mismatched packet fixed header retain [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.ProtocolVersion, pk.ProtocolVersion, "Mismatched packet protocol version [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.ProtocolName, pk.ProtocolName, "Mismatched packet protocol name [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.CleanSession, pk.CleanSession, "Mismatched packet cleansession [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.ClientIdentifier, pk.ClientIdentifier, "Mismatched packet client id [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.Keepalive, pk.Keepalive, "Mismatched keepalive value [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.UsernameFlag, pk.UsernameFlag, "Mismatched packet username flag [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.Username, pk.Username, "Mismatched packet username [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.PasswordFlag, pk.PasswordFlag, "Mismatched packet password flag [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.Password, pk.Password, "Mismatched packet password [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.WillFlag, pk.WillFlag, "Mismatched packet will flag [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.WillTopic, pk.WillTopic, "Mismatched packet will topic [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.WillMessage, pk.WillMessage, "Mismatched packet will message [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.WillQos, pk.WillQos, "Mismatched packet will qos [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.WillRetain, pk.WillRetain, "Mismatched packet will retain [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkConnectEncode(b *testing.B) { + pk := new(Packet) + copier.Copy(pk, expectedPackets[Connect][0].packet) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.ConnectEncode(buf) + } +} + +func TestConnectDecode(t *testing.T) { + require.Contains(t, expectedPackets, Connect) + for i, wanted := range expectedPackets[Connect] { + if !decodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(1), Connect, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + require.Equal(t, true, (len(wanted.rawBytes) > 2), "Insufficent bytes in packet [i:%d] %s", i, wanted.desc) + + pk := &Packet{FixedHeader: FixedHeader{Type: Connect}} + err := pk.ConnectDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. + if wanted.failFirst != nil { + require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) + continue + } + + require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.FixedHeader.Type, pk.FixedHeader.Type, "Mismatched packet fixed header type [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.FixedHeader.Dup, pk.FixedHeader.Dup, "Mismatched packet fixed header dup [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.FixedHeader.Qos, pk.FixedHeader.Qos, "Mismatched packet fixed header qos [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.FixedHeader.Retain, pk.FixedHeader.Retain, "Mismatched packet fixed header retain [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.ProtocolVersion, pk.ProtocolVersion, "Mismatched packet protocol version [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.ProtocolName, pk.ProtocolName, "Mismatched packet protocol name [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.CleanSession, pk.CleanSession, "Mismatched packet cleansession [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.ClientIdentifier, pk.ClientIdentifier, "Mismatched packet client id [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.Keepalive, pk.Keepalive, "Mismatched keepalive value [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.UsernameFlag, pk.UsernameFlag, "Mismatched packet username flag [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.Username, pk.Username, "Mismatched packet username [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.PasswordFlag, pk.PasswordFlag, "Mismatched packet password flag [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.Password, pk.Password, "Mismatched packet password [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.WillFlag, pk.WillFlag, "Mismatched packet will flag [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.WillTopic, pk.WillTopic, "Mismatched packet will topic [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.WillMessage, pk.WillMessage, "Mismatched packet will message [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.WillQos, pk.WillQos, "Mismatched packet will qos [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.WillRetain, pk.WillRetain, "Mismatched packet will retain [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkConnectDecode(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Connect}} + pk.FixedHeader.Decode(expectedPackets[Connect][0].rawBytes[0]) + + for n := 0; n < b.N; n++ { + pk.ConnectDecode(expectedPackets[Connect][0].rawBytes[2:]) + } +} + +func TestConnectValidate(t *testing.T) { + require.Contains(t, expectedPackets, Connect) + for i, wanted := range expectedPackets[Connect] { + if wanted.group == "validate" { + pk := wanted.packet + ok, _ := pk.ConnectValidate() + require.Equal(t, wanted.code, ok, "Connect packet didn't validate [i:%d] %s", i, wanted.desc) + } + } +} + +func TestConnackEncode(t *testing.T) { + require.Contains(t, expectedPackets, Connack) + for i, wanted := range expectedPackets[Connack] { + if !encodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(2), Connack, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + pk := new(Packet) + copier.Copy(pk, wanted.packet) + + require.Equal(t, Connack, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) + + buf := new(bytes.Buffer) + err := pk.ConnackEncode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() + + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) + require.Equal(t, byte(Connack<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.ReturnCode, pk.ReturnCode, "Mismatched return code [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.SessionPresent, pk.SessionPresent, "Mismatched session present bool [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkConnackEncode(b *testing.B) { + pk := new(Packet) + copier.Copy(pk, expectedPackets[Connack][0].packet) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.ConnackEncode(buf) + } +} + +func TestConnackDecode(t *testing.T) { + require.Contains(t, expectedPackets, Connack) + for i, wanted := range expectedPackets[Connack] { + if !decodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(2), Connack, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + + pk := &Packet{FixedHeader: FixedHeader{Type: Connack}} + err := pk.ConnackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. + if wanted.failFirst != nil { + require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) + continue + } + + require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.ReturnCode, pk.ReturnCode, "Mismatched return code [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.SessionPresent, pk.SessionPresent, "Mismatched session present bool [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkConnackDecode(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Connack}} + pk.FixedHeader.Decode(expectedPackets[Connack][0].rawBytes[0]) + + for n := 0; n < b.N; n++ { + pk.ConnackDecode(expectedPackets[Connack][0].rawBytes[2:]) + } +} + +func TestDisconnectEncode(t *testing.T) { + require.Contains(t, expectedPackets, Disconnect) + for i, wanted := range expectedPackets[Disconnect] { + require.Equal(t, uint8(14), Disconnect, "Incorrect Packet Type [i:%d]", i) + + pk := new(Packet) + copier.Copy(pk, wanted.packet) + + require.Equal(t, Disconnect, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d]", i) + + buf := new(bytes.Buffer) + err := pk.DisconnectEncode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() + + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d]", i) + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d]", i) + } +} + +func BenchmarkDisconnectEncode(b *testing.B) { + pk := new(Packet) + copier.Copy(pk, expectedPackets[Disconnect][0].packet) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.DisconnectEncode(buf) + } +} + +func TestPingreqEncode(t *testing.T) { + require.Contains(t, expectedPackets, Pingreq) + for i, wanted := range expectedPackets[Pingreq] { + require.Equal(t, uint8(12), Pingreq, "Incorrect Packet Type [i:%d]", i) + pk := new(Packet) + copier.Copy(pk, wanted.packet) + + require.Equal(t, Pingreq, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d]", i) + + buf := new(bytes.Buffer) + err := pk.PingreqEncode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() + + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d]", i) + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d]", i) + } +} + +func BenchmarkPingreqEncode(b *testing.B) { + pk := new(Packet) + copier.Copy(pk, expectedPackets[Pingreq][0].packet) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.PingreqEncode(buf) + } +} + +func TestPingrespEncode(t *testing.T) { + require.Contains(t, expectedPackets, Pingresp) + for i, wanted := range expectedPackets[Pingresp] { + require.Equal(t, uint8(13), Pingresp, "Incorrect Packet Type [i:%d]", i) + + pk := new(Packet) + copier.Copy(pk, wanted.packet) + + require.Equal(t, Pingresp, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d]", i) + + buf := new(bytes.Buffer) + err := pk.PingrespEncode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() + + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d]", i) + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d]", i) + } +} + +func BenchmarkPingrespEncode(b *testing.B) { + pk := new(Packet) + copier.Copy(pk, expectedPackets[Pingresp][0].packet) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.PingrespEncode(buf) + } +} + +func TestPubackEncode(t *testing.T) { + require.Contains(t, expectedPackets, Puback) + for i, wanted := range expectedPackets[Puback] { + if !encodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(4), Puback, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + pk := new(Packet) + copier.Copy(pk, wanted.packet) + + require.Equal(t, Puback, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) + + buf := new(bytes.Buffer) + err := pk.PubackEncode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() + + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) + require.Equal(t, byte(Puback<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkPubackEncode(b *testing.B) { + pk := new(Packet) + copier.Copy(pk, expectedPackets[Puback][0].packet) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.PubackEncode(buf) + } +} + +func TestPubackDecode(t *testing.T) { + require.Contains(t, expectedPackets, Puback) + for i, wanted := range expectedPackets[Puback] { + if !decodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(4), Puback, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + + pk := &Packet{FixedHeader: FixedHeader{Type: Puback}} + err := pk.PubackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. + + if wanted.failFirst != nil { + require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) + continue + } + + require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkPubackDecode(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Puback}} + pk.FixedHeader.Decode(expectedPackets[Puback][0].rawBytes[0]) + + for n := 0; n < b.N; n++ { + pk.PubackDecode(expectedPackets[Puback][0].rawBytes[2:]) + } +} + +func TestPubcompEncode(t *testing.T) { + require.Contains(t, expectedPackets, Pubcomp) + for i, wanted := range expectedPackets[Pubcomp] { + if !encodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(7), Pubcomp, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + pk := new(Packet) + copier.Copy(pk, wanted.packet) + + require.Equal(t, Pubcomp, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) + + buf := new(bytes.Buffer) + err := pk.PubcompEncode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() + + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) + require.Equal(t, byte(Pubcomp<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkPubcompEncode(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Pubcomp}} + copier.Copy(pk, expectedPackets[Pubcomp][0].packet) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.PubcompEncode(buf) + } +} + +func TestPubcompDecode(t *testing.T) { + require.Contains(t, expectedPackets, Pubcomp) + for i, wanted := range expectedPackets[Pubcomp] { + if !decodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(7), Pubcomp, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + + pk := &Packet{FixedHeader: FixedHeader{Type: Pubcomp}} + err := pk.PubcompDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. + + if wanted.failFirst != nil { + require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) + continue + } + + require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkPubcompDecode(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Pubcomp}} + pk.FixedHeader.Decode(expectedPackets[Pubcomp][0].rawBytes[0]) + + for n := 0; n < b.N; n++ { + pk.PubcompDecode(expectedPackets[Pubcomp][0].rawBytes[2:]) + } +} + +func TestPublishEncode(t *testing.T) { + require.Contains(t, expectedPackets, Publish) + for i, wanted := range expectedPackets[Publish] { + if !encodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(3), Publish, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + pk := new(Packet) + copier.Copy(pk, wanted.packet) + + require.Equal(t, Publish, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) + + buf := new(bytes.Buffer) + err := pk.PublishEncode(buf) + encoded := buf.Bytes() + + if wanted.expect != nil { + require.Error(t, err, "Expected error writing buffer [i:%d] %s", i, wanted.desc) + } else { + + // If actualBytes is set, compare mutated version of byte string instead (to avoid length mismatches, etc). + if len(wanted.actualBytes) > 0 { + wanted.rawBytes = wanted.actualBytes + } + + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) + if wanted.meta != nil { + require.Equal(t, byte(Publish<<4)|wanted.meta.(byte), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) + } else { + require.Equal(t, byte(Publish<<4), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) + } + + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.FixedHeader.Qos, pk.FixedHeader.Qos, "Mismatched QOS [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.FixedHeader.Dup, pk.FixedHeader.Dup, "Mismatched Dup [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.FixedHeader.Retain, pk.FixedHeader.Retain, "Mismatched Retain [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + } + } +} + +func BenchmarkPublishEncode(b *testing.B) { + pk := new(Packet) + copier.Copy(pk, expectedPackets[Publish][0].packet) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.PublishEncode(buf) + } +} + +func TestPublishDecode(t *testing.T) { + require.Contains(t, expectedPackets, Publish) + for i, wanted := range expectedPackets[Publish] { + if !decodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(3), Publish, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + + pk := &Packet{FixedHeader: FixedHeader{Type: Publish}} + pk.FixedHeader.Decode(wanted.rawBytes[0]) + + err := pk.PublishDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. + if wanted.failFirst != nil { + require.Error(t, err, "Expected fh error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) + continue + } + + require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.FixedHeader.Qos, pk.FixedHeader.Qos, "Mismatched QOS [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.FixedHeader.Dup, pk.FixedHeader.Dup, "Mismatched Dup [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.FixedHeader.Retain, pk.FixedHeader.Retain, "Mismatched Retain [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + + } +} + +func BenchmarkPublishDecode(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Publish}} + pk.FixedHeader.Decode(expectedPackets[Publish][1].rawBytes[0]) + + for n := 0; n < b.N; n++ { + pk.PublishDecode(expectedPackets[Publish][1].rawBytes[2:]) + } +} + +func TestPublishCopy(t *testing.T) { + require.Contains(t, expectedPackets, Publish) + for i, wanted := range expectedPackets[Publish] { + if wanted.group == "copy" { + + pk := &Packet{FixedHeader: FixedHeader{Type: Publish}} + err := pk.PublishDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. + require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) + + copied := pk.PublishCopy() + + require.Equal(t, byte(0), copied.FixedHeader.Qos, "Mismatched QOS [i:%d] %s", i, wanted.desc) + require.Equal(t, false, copied.FixedHeader.Dup, "Mismatched Dup [i:%d] %s", i, wanted.desc) + require.Equal(t, false, copied.FixedHeader.Retain, "Mismatched Retain [i:%d] %s", i, wanted.desc) + + require.Equal(t, pk.Payload, copied.Payload, "Mismatched Payload [i:%d] %s", i, wanted.desc) + require.Equal(t, pk.TopicName, copied.TopicName, "Mismatched Topic Name [i:%d] %s", i, wanted.desc) + + } + } +} + +func BenchmarkPublishCopy(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Publish}} + pk.FixedHeader.Decode(expectedPackets[Publish][1].rawBytes[0]) + + for n := 0; n < b.N; n++ { + pk.PublishCopy() + } +} + +func TestPublishValidate(t *testing.T) { + require.Contains(t, expectedPackets, Publish) + for i, wanted := range expectedPackets[Publish] { + if wanted.group == "validate" || i == 0 { + pk := wanted.packet + ok, err := pk.PublishValidate() + + if i == 0 { + require.NoError(t, err, "Publish should have validated - error incorrect [i:%d] %s", i, wanted.desc) + require.Equal(t, Accepted, ok, "Publish should have validated - code incorrect [i:%d] %s", i, wanted.desc) + } else { + require.Equal(t, Failed, ok, "Publish packet didn't validate - code incorrect [i:%d] %s", i, wanted.desc) + if err != nil { + require.Equal(t, wanted.expect, err, "Publish packet didn't validate - error incorrect [i:%d] %s", i, wanted.desc) + } + } + } + } +} + +func BenchmarkPublishValidate(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Publish}} + pk.FixedHeader.Decode(expectedPackets[Publish][1].rawBytes[0]) + + for n := 0; n < b.N; n++ { + _, err := pk.PublishValidate() + if err != nil { + panic(err) + } + } +} + +func TestPubrecEncode(t *testing.T) { + require.Contains(t, expectedPackets, Pubrec) + for i, wanted := range expectedPackets[Pubrec] { + if !encodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(5), Pubrec, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + pk := new(Packet) + copier.Copy(pk, wanted.packet) + + require.Equal(t, Pubrec, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) + + buf := new(bytes.Buffer) + err := pk.PubrecEncode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() + + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) + require.Equal(t, byte(Pubrec<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + + } +} + +func BenchmarkPubrecEncode(b *testing.B) { + pk := new(Packet) + copier.Copy(pk, expectedPackets[Pubrec][0].packet) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.PubrecEncode(buf) + } +} + +func TestPubrecDecode(t *testing.T) { + require.Contains(t, expectedPackets, Pubrec) + for i, wanted := range expectedPackets[Pubrec] { + if !decodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(5), Pubrec, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + + pk := &Packet{FixedHeader: FixedHeader{Type: Pubrec}} + err := pk.PubrecDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. + + if wanted.failFirst != nil { + require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) + continue + } + + require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkPubrecDecode(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Pubrec}} + pk.FixedHeader.Decode(expectedPackets[Pubrec][0].rawBytes[0]) + + for n := 0; n < b.N; n++ { + pk.PubrecDecode(expectedPackets[Pubrec][0].rawBytes[2:]) + } +} + +func TestPubrelEncode(t *testing.T) { + require.Contains(t, expectedPackets, Pubrel) + for i, wanted := range expectedPackets[Pubrel] { + if !encodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(6), Pubrel, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + pk := new(Packet) + copier.Copy(pk, wanted.packet) + + require.Equal(t, Pubrel, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) + + buf := new(bytes.Buffer) + err := pk.PubrelEncode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() + + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) + if wanted.meta != nil { + require.Equal(t, byte(Pubrel<<4)|wanted.meta.(byte), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) + } else { + require.Equal(t, byte(Pubrel<<4), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) + } + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkPubrelEncode(b *testing.B) { + pk := new(Packet) + copier.Copy(pk, expectedPackets[Pubrel][0].packet) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.PubrelEncode(buf) + } +} + +func TestPubrelDecode(t *testing.T) { + require.Contains(t, expectedPackets, Pubrel) + for i, wanted := range expectedPackets[Pubrel] { + if !decodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(6), Pubrel, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + + pk := &Packet{FixedHeader: FixedHeader{Type: Pubrel, Qos: 1}} + err := pk.PubrelDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. + + if wanted.failFirst != nil { + require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) + continue + } + + require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkPubrelDecode(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Pubrel, Qos: 1}} + pk.FixedHeader.Decode(expectedPackets[Pubrel][0].rawBytes[0]) + + for n := 0; n < b.N; n++ { + pk.PubrelDecode(expectedPackets[Pubrel][0].rawBytes[2:]) + } +} + +func TestSubackEncode(t *testing.T) { + require.Contains(t, expectedPackets, Suback) + for i, wanted := range expectedPackets[Suback] { + if !encodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(9), Suback, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + pk := new(Packet) + copier.Copy(pk, wanted.packet) + + require.Equal(t, Suback, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) + + buf := new(bytes.Buffer) + err := pk.SubackEncode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() + + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) + if wanted.meta != nil { + require.Equal(t, byte(Suback<<4)|wanted.meta.(byte), encoded[0], "Mismatched mod fixed header packets [i:%d] %s", i, wanted.desc) + } else { + require.Equal(t, byte(Suback<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) + } + + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.ReturnCodes, pk.ReturnCodes, "Mismatched Return Codes [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkSubackEncode(b *testing.B) { + pk := new(Packet) + copier.Copy(pk, expectedPackets[Suback][0].packet) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.SubackEncode(buf) + } +} + +func TestSubackDecode(t *testing.T) { + require.Contains(t, expectedPackets, Suback) + for i, wanted := range expectedPackets[Suback] { + if !decodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(9), Suback, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + + pk := &Packet{FixedHeader: FixedHeader{Type: Suback}} + err := pk.SubackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. + if wanted.failFirst != nil { + require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) + continue + } + + require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.ReturnCodes, pk.ReturnCodes, "Mismatched Return Codes [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkSubackDecode(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Suback}} + pk.FixedHeader.Decode(expectedPackets[Suback][0].rawBytes[0]) + + for n := 0; n < b.N; n++ { + pk.SubackDecode(expectedPackets[Suback][0].rawBytes[2:]) + } +} + +func TestSubscribeEncode(t *testing.T) { + require.Contains(t, expectedPackets, Subscribe) + for i, wanted := range expectedPackets[Subscribe] { + if !encodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(8), Subscribe, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + pk := new(Packet) + copier.Copy(pk, wanted.packet) + + require.Equal(t, Subscribe, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) + + buf := new(bytes.Buffer) + err := pk.SubscribeEncode(buf) + encoded := buf.Bytes() + + if wanted.expect != nil { + require.Error(t, err, "Expected error writing buffer [i:%d] %s", i, wanted.desc) + } else { + require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) + if wanted.meta != nil { + require.Equal(t, byte(Subscribe<<4)|wanted.meta.(byte), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) + } else { + require.Equal(t, byte(Subscribe<<4), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) + } + + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.Topics, pk.Topics, "Mismatched Topics slice [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.Qoss, pk.Qoss, "Mismatched Qoss slice [i:%d] %s", i, wanted.desc) + } + } +} + +func BenchmarkSubscribeEncode(b *testing.B) { + pk := new(Packet) + copier.Copy(pk, expectedPackets[Subscribe][0].packet) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.SubscribeEncode(buf) + } +} + +func TestSubscribeDecode(t *testing.T) { + require.Contains(t, expectedPackets, Subscribe) + for i, wanted := range expectedPackets[Subscribe] { + if !decodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(8), Subscribe, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + + pk := &Packet{FixedHeader: FixedHeader{Type: Subscribe, Qos: 1}} + err := pk.SubscribeDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. + if wanted.failFirst != nil { + require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) + continue + } + + require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.Topics, pk.Topics, "Mismatched Topics slice [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.Qoss, pk.Qoss, "Mismatched Qoss slice [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkSubscribeDecode(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Subscribe, Qos: 1}} + pk.FixedHeader.Decode(expectedPackets[Subscribe][0].rawBytes[0]) + + for n := 0; n < b.N; n++ { + pk.SubscribeDecode(expectedPackets[Subscribe][0].rawBytes[2:]) + } +} + +func TestSubscribeValidate(t *testing.T) { + require.Contains(t, expectedPackets, Subscribe) + for i, wanted := range expectedPackets[Subscribe] { + if wanted.group == "validate" || i == 0 { + pk := wanted.packet + ok, err := pk.SubscribeValidate() + + if i == 0 { + require.NoError(t, err, "Subscribe should have validated - error incorrect [i:%d] %s", i, wanted.desc) + require.Equal(t, Accepted, ok, "Subscribe should have validated - code incorrect [i:%d] %s", i, wanted.desc) + } else { + require.Equal(t, Failed, ok, "Subscribe packet didn't validate - code incorrect [i:%d] %s", i, wanted.desc) + if err != nil { + require.Equal(t, wanted.expect, err, "Subscribe packet didn't validate - error incorrect [i:%d] %s", i, wanted.desc) + } + } + } + } +} + +func BenchmarkSubscribeValidate(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Subscribe, Qos: 1}} + pk.FixedHeader.Decode(expectedPackets[Subscribe][0].rawBytes[0]) + + for n := 0; n < b.N; n++ { + pk.SubscribeValidate() + } +} + +func TestUnsubackEncode(t *testing.T) { + require.Contains(t, expectedPackets, Unsuback) + for i, wanted := range expectedPackets[Unsuback] { + if !encodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(11), Unsuback, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + pk := new(Packet) + copier.Copy(pk, wanted.packet) + + require.Equal(t, Unsuback, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) + + buf := new(bytes.Buffer) + err := pk.UnsubackEncode(buf) + require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) + encoded := buf.Bytes() + + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) + if wanted.meta != nil { + require.Equal(t, byte(Unsuback<<4)|wanted.meta.(byte), encoded[0], "Mismatched mod fixed header packets [i:%d] %s", i, wanted.desc) + } else { + require.Equal(t, byte(Unsuback<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) + } + + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkUnsubackEncode(b *testing.B) { + pk := new(Packet) + copier.Copy(pk, expectedPackets[Unsuback][0].packet) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.UnsubackEncode(buf) + } +} + +func TestUnsubackDecode(t *testing.T) { + require.Contains(t, expectedPackets, Unsuback) + for i, wanted := range expectedPackets[Unsuback] { + if !decodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(11), Unsuback, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + + pk := &Packet{FixedHeader: FixedHeader{Type: Unsuback}} + err := pk.UnsubackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. + if wanted.failFirst != nil { + require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) + continue + } + + require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkUnsubackDecode(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Unsuback}} + pk.FixedHeader.Decode(expectedPackets[Unsuback][0].rawBytes[0]) + + for n := 0; n < b.N; n++ { + pk.UnsubackDecode(expectedPackets[Unsuback][0].rawBytes[2:]) + } +} + +func TestUnsubscribeEncode(t *testing.T) { + require.Contains(t, expectedPackets, Unsubscribe) + for i, wanted := range expectedPackets[Unsubscribe] { + if !encodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(10), Unsubscribe, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + pk := new(Packet) + copier.Copy(pk, wanted.packet) + + require.Equal(t, Unsubscribe, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) + + buf := new(bytes.Buffer) + err := pk.UnsubscribeEncode(buf) + encoded := buf.Bytes() + if wanted.expect != nil { + require.Error(t, err, "Expected error writing buffer [i:%d] %s", i, wanted.desc) + } else { + require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) + if wanted.meta != nil { + require.Equal(t, byte(Unsubscribe<<4)|wanted.meta.(byte), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) + } else { + require.Equal(t, byte(Unsubscribe<<4), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) + } + + require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) + require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) + + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.Topics, pk.Topics, "Mismatched Topics slice [i:%d] %s", i, wanted.desc) + } + } +} + +func BenchmarkUnsubscribeEncode(b *testing.B) { + pk := new(Packet) + copier.Copy(pk, expectedPackets[Unsubscribe][0].packet) + + buf := new(bytes.Buffer) + for n := 0; n < b.N; n++ { + pk.UnsubscribeEncode(buf) + } +} + +func TestUnsubscribeDecode(t *testing.T) { + require.Contains(t, expectedPackets, Unsubscribe) + for i, wanted := range expectedPackets[Unsubscribe] { + if !decodeTestOK(wanted) { + continue + } + + require.Equal(t, uint8(10), Unsubscribe, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) + + pk := &Packet{FixedHeader: FixedHeader{Type: Unsubscribe, Qos: 1}} + err := pk.UnsubscribeDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. + if wanted.failFirst != nil { + require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) + continue + } + + require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) + require.Equal(t, wanted.packet.Topics, pk.Topics, "Mismatched Topics slice [i:%d] %s", i, wanted.desc) + } +} + +func BenchmarkUnsubscribeDecode(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Unsubscribe, Qos: 1}} + pk.FixedHeader.Decode(expectedPackets[Unsubscribe][0].rawBytes[0]) + + for n := 0; n < b.N; n++ { + pk.UnsubscribeDecode(expectedPackets[Unsubscribe][0].rawBytes[2:]) + } +} + +func TestUnsubscribeValidate(t *testing.T) { + require.Contains(t, expectedPackets, Unsubscribe) + for i, wanted := range expectedPackets[Unsubscribe] { + if wanted.group == "validate" || i == 0 { + pk := wanted.packet + ok, err := pk.UnsubscribeValidate() + if i == 0 { + require.NoError(t, err, "Unsubscribe should have validated - error incorrect [i:%d] %s", i, wanted.desc) + require.Equal(t, Accepted, ok, "Unsubscribe should have validated - code incorrect [i:%d] %s", i, wanted.desc) + } else { + require.Equal(t, Failed, ok, "Unsubscribe packet didn't validate - code incorrect [i:%d] %s", i, wanted.desc) + if err != nil { + require.Equal(t, wanted.expect, err, "Unsubscribe packet didn't validate - error incorrect [i:%d] %s", i, wanted.desc) + } + } + } + } +} + +func BenchmarkUnsubscribeValidate(b *testing.B) { + pk := &Packet{FixedHeader: FixedHeader{Type: Unsubscribe, Qos: 1}} + pk.FixedHeader.Decode(expectedPackets[Unsubscribe][0].rawBytes[0]) + + for n := 0; n < b.N; n++ { + pk.UnsubscribeValidate() + } +} diff --git a/server/internal/topics/trie.go b/server/internal/topics/trie.go index a9a47ac2279b5c01191e49a6e9928a81e001e1a8..533330dfb8ee0657e54eff137b22b011533f9207 100644 --- a/server/internal/topics/trie.go +++ b/server/internal/topics/trie.go @@ -4,7 +4,7 @@ import ( "strings" "sync" - "github.com/mochi-co/mqtt/server/packets" + "github.com/mochi-co/mqtt/server/internal/packets" ) // Subscriptions is a map of subscriptions keyed on client. diff --git a/server/internal/topics/trie_test.go b/server/internal/topics/trie_test.go index 2170f6af62bd16e54f18d4eec3a5021d8f94ecb0..1c159e30fc5e27569a8f3a3e8663d9ca2849a309 100644 --- a/server/internal/topics/trie_test.go +++ b/server/internal/topics/trie_test.go @@ -5,7 +5,7 @@ "testing" "github.com/stretchr/testify/require" - "github.com/mochi-co/mqtt/server/packets" + "github.com/mochi-co/mqtt/server/internal/packets" ) func TestNew(t *testing.T) { diff --git a/server/packets/codec.go b/server/packets/codec.go deleted file mode 100644 index 7ab4cdc8ea185232038225af99e4e45a2780a987..0000000000000000000000000000000000000000 --- a/server/packets/codec.go +++ /dev/null @@ -1,114 +0,0 @@ -package packets - -import ( - "encoding/binary" - "unicode/utf8" - "unsafe" -) - -// bytesToString provides a zero-alloc, no-copy byte to string conversion. -// via https://github.com/golang/go/issues/25484#issuecomment-391415660 -func bytesToString(bs []byte) string { - return *(*string)(unsafe.Pointer(&bs)) -} - -// decodeUint16 extracts the value of two bytes from a byte array. -func decodeUint16(buf []byte, offset int) (uint16, int, error) { - if len(buf) < offset+2 { - return 0, 0, ErrOffsetUintOutOfRange - } - - return binary.BigEndian.Uint16(buf[offset : offset+2]), offset + 2, nil -} - -// decodeString extracts a string from a byte array, beginning at an offset. -func decodeString(buf []byte, offset int) (string, int, error) { - b, n, err := decodeBytes(buf, offset) - if err != nil { - return "", 0, err - } - - return bytesToString(b), n, nil -} - -// decodeBytes extracts a byte array from a byte array, beginning at an offset. Used primarily for message payloads. -func decodeBytes(buf []byte, offset int) ([]byte, int, error) { - length, next, err := decodeUint16(buf, offset) - if err != nil { - return make([]byte, 0, 0), 0, err - } - - if next+int(length) > len(buf) { - return make([]byte, 0, 0), 0, ErrOffsetStrOutOfRange - } - - if !validUTF8(buf[next : next+int(length)]) { - return make([]byte, 0, 0), 0, ErrOffsetStrInvalidUTF8 - } - - return buf[next : next+int(length)], next + int(length), nil -} - -// decodeByte extracts the value of a byte from a byte array. -func decodeByte(buf []byte, offset int) (byte, int, error) { - if len(buf) <= offset { - return 0, 0, ErrOffsetByteOutOfRange - } - return buf[offset], offset + 1, nil -} - -// decodeByteBool extracts the value of a byte from a byte array and returns a bool. -func decodeByteBool(buf []byte, offset int) (bool, int, error) { - if len(buf) <= offset { - return false, 0, ErrOffsetBoolOutOfRange - } - return 1&buf[offset] > 0, offset + 1, nil -} - -// encodeBool returns a byte instead of a bool. -func encodeBool(b bool) byte { - if b { - return 1 - } - return 0 -} - -// encodeBytes encodes a byte array to a byte array. Used primarily for message payloads. -func encodeBytes(val []byte) []byte { - // In many circumstances the number of bytes being encoded is small. - // Setting the cap to a low amount allows us to account for those without - // triggering allocation growth on append unless we need to. - buf := make([]byte, 2, 32) - binary.BigEndian.PutUint16(buf, uint16(len(val))) - return append(buf, val...) -} - -// encodeUint16 encodes a uint16 value to a byte array. -func encodeUint16(val uint16) []byte { - buf := make([]byte, 2) - binary.BigEndian.PutUint16(buf, val) - return buf -} - -// encodeString encodes a string to a byte array. -func encodeString(val string) []byte { - // Like encodeBytes, we set the cap to a small number to avoid - // triggering allocation growth on append unless we absolutely need to. - buf := make([]byte, 2, 32) - binary.BigEndian.PutUint16(buf, uint16(len(val))) - return append(buf, []byte(val)...) -} - -// validUTF8 checks if the byte array contains valid UTF-8 characters, specifically -// conforming to the MQTT specification requirements. -func validUTF8(b []byte) bool { - // [MQTT-1.4.0-1] The character data in a UTF-8 encoded string MUST be well-formed UTF-8... - if !utf8.Valid(b) { - return false - } - - // [MQTT-1.4.0-2] A UTF-8 encoded string MUST NOT include an encoding of the null character U+0000... - // ... - return true - -} diff --git a/server/packets/codec_test.go b/server/packets/codec_test.go deleted file mode 100644 index 2cd0438cc2928b630462dedda5af95504d1442df..0000000000000000000000000000000000000000 --- a/server/packets/codec_test.go +++ /dev/null @@ -1,383 +0,0 @@ -package packets - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestBytesToString(t *testing.T) { - b := []byte{'a', 'b', 'c'} - require.Equal(t, "abc", bytesToString(b)) -} - -func BenchmarkBytesToString(b *testing.B) { - for n := 0; n < b.N; n++ { - bytesToString([]byte{'a', 'b', 'c'}) - } -} - -func TestDecodeString(t *testing.T) { - expect := []struct { - rawBytes []byte - result []string - offset int - shouldFail bool - }{ - { - offset: 0, - rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, - result: []string{"a/b/c/d", "a"}, - }, - { - offset: 14, - rawBytes: []byte{ - byte(Connect << 4), 17, // Fixed header - 0, 6, // Protocol Name - MSB+LSB - 'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name - 3, // Protocol Version - 0, // Packet Flags - 0, 30, // Keepalive - 0, 3, // Client ID - MSB+LSB - 'h', 'e', 'y', // Client ID "zen"}, - }, - result: []string{"hey"}, - }, - - { - offset: 2, - rawBytes: []byte{0, 0, 0, 23, 49, 47, 50, 47, 51, 47, 52, 47, 97, 47, 98, 47, 99, 47, 100, 47, 101, 47, 94, 47, 64, 47, 33, 97}, - result: []string{"1/2/3/4/a/b/c/d/e/^/@/!", "a"}, - }, - { - offset: 0, - rawBytes: []byte{0, 5, 120, 47, 121, 47, 122, 33, 64, 35, 36, 37, 94, 38}, - result: []string{"x/y/z", "!@#$%^&"}, - }, - { - offset: 0, - rawBytes: []byte{0, 9, 'a', '/', 'b', '/', 'c', '/', 'd', 'z'}, - result: []string{"a/b/c/d", "z"}, - shouldFail: true, - }, - { - offset: 5, - rawBytes: []byte{0, 7, 97, 47, 98, 47, 'x'}, - result: []string{"a/b/c/d", "x"}, - shouldFail: true, - }, - { - offset: 9, - rawBytes: []byte{0, 7, 97, 47, 98, 47, 'y'}, - result: []string{"a/b/c/d", "y"}, - shouldFail: true, - }, - { - offset: 17, - rawBytes: []byte{ - byte(Connect << 4), 0, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 0, // Flags - 0, 20, // Keepalive - 0, 3, // Client ID - MSB+LSB - 'z', 'e', 'n', // Client ID "zen" - 0, 6, // Will Topic - MSB+LSB - 'l', - }, - result: []string{"lwt"}, - shouldFail: true, - }, - } - - for i, wanted := range expect { - result, _, err := decodeString(wanted.rawBytes, wanted.offset) - if wanted.shouldFail { - require.Error(t, err, "Expected error decoding string [i:%d]", i) - continue - } - - require.NoError(t, err, "Error decoding string [i:%d]", i) - require.Equal(t, wanted.result[0], result, "Incorrect decoded value [i:%d]", i) - } -} - -func BenchmarkDecodeString(b *testing.B) { - in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97} - for n := 0; n < b.N; n++ { - decodeString(in, 0) - } -} - -func TestDecodeBytes(t *testing.T) { - expect := []struct { - rawBytes []byte - result []uint8 - next int - offset int - shouldFail bool - }{ - { - rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52}, // ... truncated connect packet (clean session) - result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), - next: 6, - offset: 0, - }, - { - rawBytes: []byte{0, 4, 77, 81, 84, 84, 4, 192, 0, 50, 0, 36, 49, 53, 52, 50}, // ... truncated connect packet, only checking start - result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), - next: 6, - offset: 0, - }, - { - rawBytes: []byte{0, 4, 77, 81}, - result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), - offset: 0, - shouldFail: true, - }, - { - rawBytes: []byte{0, 4, 77, 81}, - result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), - offset: 8, - shouldFail: true, - }, - { - rawBytes: []byte{0, 4, 77, 81}, - result: []uint8([]byte{0x4d, 0x51, 0x54, 0x54}), - offset: 0, - shouldFail: true, - }, - } - - for i, wanted := range expect { - result, _, err := decodeBytes(wanted.rawBytes, wanted.offset) - if wanted.shouldFail { - require.Error(t, err, "Expected error decoding bytes [i:%d]", i) - continue - } - - require.NoError(t, err, "Error decoding bytes [i:%d]", i) - require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) - } -} - -func BenchmarkDecodeBytes(b *testing.B) { - in := []byte{0, 4, 77, 81, 84, 84, 4, 194, 0, 50, 0, 36, 49, 53, 52} - for n := 0; n < b.N; n++ { - decodeBytes(in, 0) - } -} - -func TestDecodeByte(t *testing.T) { - expect := []struct { - rawBytes []byte - result uint8 - offset int - shouldFail bool - }{ - { - rawBytes: []byte{0, 4, 77, 81, 84, 84}, // nonsense slice of bytes - result: uint8(0x00), - offset: 0, - }, - { - rawBytes: []byte{0, 4, 77, 81, 84, 84}, - result: uint8(0x04), - offset: 1, - }, - { - rawBytes: []byte{0, 4, 77, 81, 84, 84}, - result: uint8(0x4d), - offset: 2, - }, - { - rawBytes: []byte{0, 4, 77, 81, 84, 84}, - result: uint8(0x51), - offset: 3, - }, - { - rawBytes: []byte{0, 4, 77, 80, 82, 84}, - result: uint8(0x00), - offset: 8, - shouldFail: true, - }, - } - - for i, wanted := range expect { - result, offset, err := decodeByte(wanted.rawBytes, wanted.offset) - if wanted.shouldFail { - require.Error(t, err, "Expected error decoding byte [i:%d]", i) - continue - } - - require.NoError(t, err, "Error decoding byte [i:%d]", i) - require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) - require.Equal(t, i+1, offset, "Incorrect offset value [i:%d]", i) - } -} - -func BenchmarkDecodeByte(b *testing.B) { - in := []byte{0, 4, 77, 81, 84, 84} - for n := 0; n < b.N; n++ { - decodeByte(in, 0) - } -} - -func TestDecodeUint16(t *testing.T) { - expect := []struct { - rawBytes []byte - result uint16 - offset int - shouldFail bool - }{ - { - rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, - result: uint16(0x07), - offset: 0, - }, - { - rawBytes: []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97}, - result: uint16(0x761), - offset: 1, - }, - { - rawBytes: []byte{0, 7, 255, 47}, - result: uint16(0x761), - offset: 8, - shouldFail: true, - }, - } - - for i, wanted := range expect { - result, offset, err := decodeUint16(wanted.rawBytes, wanted.offset) - if wanted.shouldFail { - require.Error(t, err, "Expected error decoding uint16 [i:%d]", i) - continue - } - - require.NoError(t, err, "Error decoding uint16 [i:%d]", i) - require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) - require.Equal(t, i+2, offset, "Incorrect offset value [i:%d]", i) - } -} - -func BenchmarkDecodeUint16(b *testing.B) { - in := []byte{0, 7, 97, 47, 98, 47, 99, 47, 100, 97} - for n := 0; n < b.N; n++ { - decodeUint16(in, 0) - } -} - -func TestDecodeByteBool(t *testing.T) { - expect := []struct { - rawBytes []byte - result bool - offset int - shouldFail bool - }{ - { - rawBytes: []byte{0x00, 0x00}, - result: false, - }, - { - rawBytes: []byte{0x01, 0x00}, - result: true, - }, - { - rawBytes: []byte{0x01, 0x00}, - offset: 5, - shouldFail: true, - }, - } - - for i, wanted := range expect { - result, offset, err := decodeByteBool(wanted.rawBytes, wanted.offset) - if wanted.shouldFail { - require.Error(t, err, "Expected error decoding byte bool [i:%d]", i) - continue - } - - require.NoError(t, err, "Error decoding byte bool [i:%d]", i) - require.Equal(t, wanted.result, result, "Incorrect decoded value [i:%d]", i) - require.Equal(t, 1, offset, "Incorrect offset value [i:%d]", i) - } -} - -func BenchmarkDecodeByteBool(b *testing.B) { - in := []byte{0x00, 0x00} - for n := 0; n < b.N; n++ { - decodeByteBool(in, 0) - } -} - -func TestEncodeBool(t *testing.T) { - result := encodeBool(true) - require.Equal(t, byte(1), result, "Incorrect encoded value; not true") - - result = encodeBool(false) - require.Equal(t, byte(0), result, "Incorrect encoded value; not false") - - // Check failure. - result = encodeBool(false) - require.NotEqual(t, byte(1), result, "Expected failure, incorrect encoded value") -} - -func BenchmarkEncodeBool(b *testing.B) { - for n := 0; n < b.N; n++ { - encodeBool(true) - } -} - -func TestEncodeBytes(t *testing.T) { - result := encodeBytes([]byte("testing")) - require.Equal(t, []uint8{0, 7, 116, 101, 115, 116, 105, 110, 103}, result, "Incorrect encoded value") - - result = encodeBytes([]byte("testing")) - require.NotEqual(t, []uint8{0, 7, 113, 101, 115, 116, 105, 110, 103}, result, "Expected failure, incorrect encoded value") -} - -func BenchmarkEncodeBytes(b *testing.B) { - bb := []byte("testing") - for n := 0; n < b.N; n++ { - encodeBytes(bb) - } -} - -func TestEncodeUint16(t *testing.T) { - result := encodeUint16(0) - require.Equal(t, []byte{0x00, 0x00}, result, "Incorrect encoded value, 0") - - result = encodeUint16(32767) - require.Equal(t, []byte{0x7f, 0xff}, result, "Incorrect encoded value, 32767") - - result = encodeUint16(65535) - require.Equal(t, []byte{0xff, 0xff}, result, "Incorrect encoded value, 65535") -} - -func BenchmarkEncodeUint16(b *testing.B) { - for n := 0; n < b.N; n++ { - encodeUint16(32767) - } -} - -func TestEncodeString(t *testing.T) { - result := encodeString("testing") - require.Equal(t, []uint8{0x00, 0x07, 0x74, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x67}, result, "Incorrect encoded value, testing") - - result = encodeString("") - require.Equal(t, []uint8{0x00, 0x00}, result, "Incorrect encoded value, null") - - result = encodeString("a") - require.Equal(t, []uint8{0x00, 0x01, 0x61}, result, "Incorrect encoded value, a") - - result = encodeString("b") - require.NotEqual(t, []uint8{0x00, 0x00}, result, "Expected failure, incorrect encoded value, b") - -} - -func BenchmarkEncodeString(b *testing.B) { - for n := 0; n < b.N; n++ { - encodeString("benchmarking") - } -} diff --git a/server/packets/fixedheader.go b/server/packets/fixedheader.go deleted file mode 100644 index a159143bffeb1b91ddaf38bb1b66dd8722056a6f..0000000000000000000000000000000000000000 --- a/server/packets/fixedheader.go +++ /dev/null @@ -1,59 +0,0 @@ -package packets - -import ( - "bytes" -) - -// FixedHeader contains the values of the fixed header portion of the MQTT packet. -type FixedHeader struct { - Type byte // the type of the packet (PUBLISH, SUBSCRIBE, etc) from bits 7 - 4 (byte 1). - Dup bool // indicates if the packet was already sent at an earlier time. - Qos byte // indicates the quality of service expected. - Retain bool // whether the message should be retained. - Remaining int // the number of remaining bytes in the payload. -} - -// Encode encodes the FixedHeader and returns a bytes buffer. -func (fh *FixedHeader) Encode(buf *bytes.Buffer) { - buf.WriteByte(fh.Type<<4 | encodeBool(fh.Dup)<<3 | fh.Qos<<1 | encodeBool(fh.Retain)) - encodeLength(buf, fh.Remaining) -} - -// decode extracts the specification bits from the header byte. -func (fh *FixedHeader) Decode(headerByte byte) error { - fh.Type = headerByte >> 4 // Get the message type from the first 4 bytes. - - switch fh.Type { - case Publish: - fh.Dup = (headerByte>>3)&0x01 > 0 // Extract flags. Check if message is duplicate. - fh.Qos = (headerByte >> 1) & 0x03 // Extract QoS flag. - fh.Retain = headerByte&0x01 > 0 // Extract retain flag. - case Pubrel: - fh.Qos = (headerByte >> 1) & 0x03 - case Subscribe: - fh.Qos = (headerByte >> 1) & 0x03 - case Unsubscribe: - fh.Qos = (headerByte >> 1) & 0x03 - default: - if (headerByte>>3)&0x01 > 0 || (headerByte>>1)&0x03 > 0 || headerByte&0x01 > 0 { - return ErrInvalidFlags - } - } - - return nil -} - -// encodeLength writes length bits for the header. -func encodeLength(buf *bytes.Buffer, length int) { - for { - digit := byte(length % 128) - length /= 128 - if length > 0 { - digit |= 0x80 - } - buf.WriteByte(digit) - if length == 0 { - break - } - } -} diff --git a/server/packets/fixedheader_test.go b/server/packets/fixedheader_test.go deleted file mode 100644 index 7ce4e3314e1fd161dc8f85487044f8effb023b92..0000000000000000000000000000000000000000 --- a/server/packets/fixedheader_test.go +++ /dev/null @@ -1,220 +0,0 @@ -package packets - -import ( - "bytes" - "math" - "testing" - - "github.com/stretchr/testify/require" -) - -type fixedHeaderTable struct { - rawBytes []byte - header FixedHeader - packetError bool - flagError bool -} - -var fixedHeaderExpected = []fixedHeaderTable{ - { - rawBytes: []byte{Connect << 4, 0x00}, - header: FixedHeader{Connect, false, 0, false, 0}, // Type byte, Dup bool, Qos byte, Retain bool, Remaining int - }, - { - rawBytes: []byte{Connack << 4, 0x00}, - header: FixedHeader{Connack, false, 0, false, 0}, - }, - { - rawBytes: []byte{Publish << 4, 0x00}, - header: FixedHeader{Publish, false, 0, false, 0}, - }, - { - rawBytes: []byte{Publish<<4 | 1<<1, 0x00}, - header: FixedHeader{Publish, false, 1, false, 0}, - }, - { - rawBytes: []byte{Publish<<4 | 1<<1 | 1, 0x00}, - header: FixedHeader{Publish, false, 1, true, 0}, - }, - { - rawBytes: []byte{Publish<<4 | 2<<1, 0x00}, - header: FixedHeader{Publish, false, 2, false, 0}, - }, - { - rawBytes: []byte{Publish<<4 | 2<<1 | 1, 0x00}, - header: FixedHeader{Publish, false, 2, true, 0}, - }, - { - rawBytes: []byte{Publish<<4 | 1<<3, 0x00}, - header: FixedHeader{Publish, true, 0, false, 0}, - }, - { - rawBytes: []byte{Publish<<4 | 1<<3 | 1, 0x00}, - header: FixedHeader{Publish, true, 0, true, 0}, - }, - { - rawBytes: []byte{Publish<<4 | 1<<3 | 1<<1 | 1, 0x00}, - header: FixedHeader{Publish, true, 1, true, 0}, - }, - { - rawBytes: []byte{Publish<<4 | 1<<3 | 2<<1 | 1, 0x00}, - header: FixedHeader{Publish, true, 2, true, 0}, - }, - { - rawBytes: []byte{Puback << 4, 0x00}, - header: FixedHeader{Puback, false, 0, false, 0}, - }, - { - rawBytes: []byte{Pubrec << 4, 0x00}, - header: FixedHeader{Pubrec, false, 0, false, 0}, - }, - { - rawBytes: []byte{Pubrel<<4 | 1<<1, 0x00}, - header: FixedHeader{Pubrel, false, 1, false, 0}, - }, - { - rawBytes: []byte{Pubcomp << 4, 0x00}, - header: FixedHeader{Pubcomp, false, 0, false, 0}, - }, - { - rawBytes: []byte{Subscribe<<4 | 1<<1, 0x00}, - header: FixedHeader{Subscribe, false, 1, false, 0}, - }, - { - rawBytes: []byte{Suback << 4, 0x00}, - header: FixedHeader{Suback, false, 0, false, 0}, - }, - { - rawBytes: []byte{Unsubscribe<<4 | 1<<1, 0x00}, - header: FixedHeader{Unsubscribe, false, 1, false, 0}, - }, - { - rawBytes: []byte{Unsuback << 4, 0x00}, - header: FixedHeader{Unsuback, false, 0, false, 0}, - }, - { - rawBytes: []byte{Pingreq << 4, 0x00}, - header: FixedHeader{Pingreq, false, 0, false, 0}, - }, - { - rawBytes: []byte{Pingresp << 4, 0x00}, - header: FixedHeader{Pingresp, false, 0, false, 0}, - }, - { - rawBytes: []byte{Disconnect << 4, 0x00}, - header: FixedHeader{Disconnect, false, 0, false, 0}, - }, - - // remaining length - { - rawBytes: []byte{Publish << 4, 0x0a}, - header: FixedHeader{Publish, false, 0, false, 10}, - }, - { - rawBytes: []byte{Publish << 4, 0x80, 0x04}, - header: FixedHeader{Publish, false, 0, false, 512}, - }, - { - rawBytes: []byte{Publish << 4, 0xd2, 0x07}, - header: FixedHeader{Publish, false, 0, false, 978}, - }, - { - rawBytes: []byte{Publish << 4, 0x86, 0x9d, 0x01}, - header: FixedHeader{Publish, false, 0, false, 20102}, - }, - { - rawBytes: []byte{Publish << 4, 0xd5, 0x86, 0xf9, 0x9e, 0x01}, - header: FixedHeader{Publish, false, 0, false, 333333333}, - packetError: true, - }, - - // Invalid flags for packet - { - rawBytes: []byte{Connect<<4 | 1<<3, 0x00}, - header: FixedHeader{Connect, true, 0, false, 0}, - flagError: true, - }, - { - rawBytes: []byte{Connect<<4 | 1<<1, 0x00}, - header: FixedHeader{Connect, false, 1, false, 0}, - flagError: true, - }, - { - rawBytes: []byte{Connect<<4 | 1, 0x00}, - header: FixedHeader{Connect, false, 0, true, 0}, - flagError: true, - }, -} - -func TestFixedHeaderEncode(t *testing.T) { - for i, wanted := range fixedHeaderExpected { - buf := new(bytes.Buffer) - wanted.header.Encode(buf) - if wanted.flagError == false { - require.Equal(t, len(wanted.rawBytes), len(buf.Bytes()), "Mismatched fixedheader length [i:%d] %v", i, wanted.rawBytes) - require.EqualValues(t, wanted.rawBytes, buf.Bytes(), "Mismatched byte values [i:%d] %v", i, wanted.rawBytes) - } - } -} - -func BenchmarkFixedHeaderEncode(b *testing.B) { - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - fixedHeaderExpected[0].header.Encode(buf) - } -} - -func TestFixedHeaderDecode(t *testing.T) { - for i, wanted := range fixedHeaderExpected { - fh := new(FixedHeader) - err := fh.Decode(wanted.rawBytes[0]) - if wanted.flagError { - require.Error(t, err, "Expected error reading fixedheader [i:%d] %v", i, wanted.rawBytes) - } else { - require.NoError(t, err, "Error reading fixedheader [i:%d] %v", i, wanted.rawBytes) - require.Equal(t, wanted.header.Type, fh.Type, "Mismatched fixedheader type [i:%d] %v", i, wanted.rawBytes) - require.Equal(t, wanted.header.Dup, fh.Dup, "Mismatched fixedheader dup [i:%d] %v", i, wanted.rawBytes) - require.Equal(t, wanted.header.Qos, fh.Qos, "Mismatched fixedheader qos [i:%d] %v", i, wanted.rawBytes) - require.Equal(t, wanted.header.Retain, fh.Retain, "Mismatched fixedheader retain [i:%d] %v", i, wanted.rawBytes) - } - } -} - -func BenchmarkFixedHeaderDecode(b *testing.B) { - fh := new(FixedHeader) - for n := 0; n < b.N; n++ { - err := fh.Decode(fixedHeaderExpected[0].rawBytes[0]) - if err != nil { - panic(err) - } - } -} - -func TestEncodeLength(t *testing.T) { - tt := []struct { - have int - want []byte - }{ - { - 120, - []byte{0x78}, - }, - { - math.MaxInt64, - []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x7f}, - }, - } - - for i, wanted := range tt { - buf := new(bytes.Buffer) - encodeLength(buf, wanted.have) - require.Equal(t, wanted.want, buf.Bytes(), "Returned bytes should match length [i:%d] %s", i, wanted.have) - } -} - -func BenchmarkEncodeLength(b *testing.B) { - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - encodeLength(buf, 120) - } -} diff --git a/server/packets/packets.go b/server/packets/packets.go deleted file mode 100644 index 7211d8638e4a0199c7cbbde184fa466ca2d27e3e..0000000000000000000000000000000000000000 --- a/server/packets/packets.go +++ /dev/null @@ -1,673 +0,0 @@ -package packets - -import ( - "bytes" - "errors" -) - -// All of the valid packet types and their packet identifier. -const ( - Reserved byte = iota - Connect // 1 - Connack // 2 - Publish // 3 - Puback // 4 - Pubrec // 5 - Pubrel // 6 - Pubcomp // 7 - Subscribe // 8 - Suback // 9 - Unsubscribe // 10 - Unsuback // 11 - Pingreq // 12 - Pingresp // 13 - Disconnect // 14 - - Accepted byte = 0x00 - Failed byte = 0xFF - CodeConnectBadProtocolVersion byte = 0x01 - CodeConnectBadClientID byte = 0x02 - CodeConnectServerUnavailable byte = 0x03 - CodeConnectBadAuthValues byte = 0x04 - CodeConnectNotAuthorised byte = 0x05 - CodeConnectNetworkError byte = 0xFE - CodeConnectProtocolViolation byte = 0xFF - ErrSubAckNetworkError byte = 0x80 -) - -var ( - // CONNECT - ErrMalformedProtocolName = errors.New("malformed packet: protocol name") - ErrMalformedProtocolVersion = errors.New("malformed packet: protocol version") - ErrMalformedFlags = errors.New("malformed packet: flags") - ErrMalformedKeepalive = errors.New("malformed packet: keepalive") - ErrMalformedClientID = errors.New("malformed packet: client id") - ErrMalformedWillTopic = errors.New("malformed packet: will topic") - ErrMalformedWillMessage = errors.New("malformed packet: will message") - ErrMalformedUsername = errors.New("malformed packet: username") - ErrMalformedPassword = errors.New("malformed packet: password") - - // CONNACK - ErrMalformedSessionPresent = errors.New("malformed packet: session present") - ErrMalformedReturnCode = errors.New("malformed packet: return code") - - // PUBLISH - ErrMalformedTopic = errors.New("malformed packet: topic name") - ErrMalformedPacketID = errors.New("malformed packet: packet id") - - // SUBSCRIBE - ErrMalformedQoS = errors.New("malformed packet: qos") - - // PACKETS - ErrProtocolViolation = errors.New("protocol violation") - ErrOffsetStrOutOfRange = errors.New("offset string out of range") - ErrOffsetBytesOutOfRange = errors.New("offset bytes out of range") - ErrOffsetByteOutOfRange = errors.New("offset byte out of range") - ErrOffsetBoolOutOfRange = errors.New("offset bool out of range") - ErrOffsetUintOutOfRange = errors.New("offset uint out of range") - ErrOffsetStrInvalidUTF8 = errors.New("offset string invalid utf8") - ErrInvalidFlags = errors.New("invalid flags set for packet") - ErrOversizedLengthIndicator = errors.New("protocol violation: oversized length indicator") - ErrMissingPacketID = errors.New("missing packet id") - ErrSurplusPacketID = errors.New("surplus packet id") -) - -// Packet is an MQTT packet. Instead of providing a packet interface and variant -// packet structs, this is a single concrete packet type to cover all packet -// types, which allows us to take advantage of various compiler optimizations. -type Packet struct { - FixedHeader FixedHeader - - PacketID uint16 - - // Connect - ProtocolName []byte - ProtocolVersion byte - CleanSession bool - WillFlag bool - WillQos byte - WillRetain bool - UsernameFlag bool - PasswordFlag bool - ReservedBit byte - Keepalive uint16 - ClientIdentifier string - WillTopic string - WillMessage []byte - Username []byte - Password []byte - - // Connack - SessionPresent bool - ReturnCode byte - - // Publish - TopicName string - Payload []byte - - // Subscribe, Unsubscribe - Topics []string - Qoss []byte - - ReturnCodes []byte // Suback -} - -// ConnectEncode encodes a connect packet. -func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error { - - protoName := encodeBytes(pk.ProtocolName) - protoVersion := pk.ProtocolVersion - flag := encodeBool(pk.CleanSession)<<1 | encodeBool(pk.WillFlag)<<2 | pk.WillQos<<3 | encodeBool(pk.WillRetain)<<5 | encodeBool(pk.PasswordFlag)<<6 | encodeBool(pk.UsernameFlag)<<7 - keepalive := encodeUint16(pk.Keepalive) - clientID := encodeString(pk.ClientIdentifier) - - var willTopic, willFlag, usernameFlag, passwordFlag []byte - - // If will flag is set, add topic and message. - if pk.WillFlag { - willTopic = encodeString(pk.WillTopic) - willFlag = encodeBytes(pk.WillMessage) - } - - // If username flag is set, add username. - if pk.UsernameFlag { - usernameFlag = encodeBytes(pk.Username) - } - - // If password flag is set, add password. - if pk.PasswordFlag { - passwordFlag = encodeBytes(pk.Password) - } - - // Get a length for the connect header. This is not super pretty, but it works. - pk.FixedHeader.Remaining = - len(protoName) + 1 + 1 + len(keepalive) + len(clientID) + - len(willTopic) + len(willFlag) + - len(usernameFlag) + len(passwordFlag) - - pk.FixedHeader.Encode(buf) - - // Eschew magic for readability. - buf.Write(protoName) - buf.WriteByte(protoVersion) - buf.WriteByte(flag) - buf.Write(keepalive) - buf.Write(clientID) - buf.Write(willTopic) - buf.Write(willFlag) - buf.Write(usernameFlag) - buf.Write(passwordFlag) - - return nil -} - -// ConnectDecode decodes a connect packet. -func (pk *Packet) ConnectDecode(buf []byte) error { - var offset int - var err error - - // Unpack protocol name and version. - pk.ProtocolName, offset, err = decodeBytes(buf, 0) - if err != nil { - return ErrMalformedProtocolName - } - - pk.ProtocolVersion, offset, err = decodeByte(buf, offset) - if err != nil { - return ErrMalformedProtocolVersion - } - // Unpack flags byte. - flags, offset, err := decodeByte(buf, offset) - if err != nil { - return ErrMalformedFlags - } - pk.ReservedBit = 1 & flags - pk.CleanSession = 1&(flags>>1) > 0 - pk.WillFlag = 1&(flags>>2) > 0 - pk.WillQos = 3 & (flags >> 3) // this one is not a bool - pk.WillRetain = 1&(flags>>5) > 0 - pk.PasswordFlag = 1&(flags>>6) > 0 - pk.UsernameFlag = 1&(flags>>7) > 0 - - // Get keepalive interval. - pk.Keepalive, offset, err = decodeUint16(buf, offset) - if err != nil { - return ErrMalformedKeepalive - } - - // Get client ID. - pk.ClientIdentifier, offset, err = decodeString(buf, offset) - if err != nil { - return ErrMalformedClientID - } - - // Get Last Will and Testament topic and message if applicable. - if pk.WillFlag { - pk.WillTopic, offset, err = decodeString(buf, offset) - if err != nil { - return ErrMalformedWillTopic - } - - pk.WillMessage, offset, err = decodeBytes(buf, offset) - if err != nil { - return ErrMalformedWillMessage - } - } - - // Get username and password if applicable. - if pk.UsernameFlag { - pk.Username, offset, err = decodeBytes(buf, offset) - if err != nil { - return ErrMalformedUsername - } - } - - if pk.PasswordFlag { - pk.Password, offset, err = decodeBytes(buf, offset) - if err != nil { - return ErrMalformedPassword - } - } - - return nil - -} - -// ConnectValidate ensures the connect packet is compliant. -func (pk *Packet) ConnectValidate() (b byte, err error) { - - // End if protocol name is bad. - if bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) != 0 && - bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) != 0 { - return CodeConnectProtocolViolation, ErrProtocolViolation - } - - // End if protocol version is bad. - if (bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'I', 's', 'd', 'p'}) == 0 && pk.ProtocolVersion != 3) || - (bytes.Compare(pk.ProtocolName, []byte{'M', 'Q', 'T', 'T'}) == 0 && pk.ProtocolVersion != 4) { - return CodeConnectBadProtocolVersion, ErrProtocolViolation - } - - // End if reserved bit is not 0. - if pk.ReservedBit != 0 { - return CodeConnectProtocolViolation, ErrProtocolViolation - } - - // End if ClientID is too long. - if len(pk.ClientIdentifier) > 65535 { - return CodeConnectProtocolViolation, ErrProtocolViolation - } - - // End if password flag is set without a username. - if pk.PasswordFlag && !pk.UsernameFlag { - return CodeConnectProtocolViolation, ErrProtocolViolation - } - - // End if Username or Password is too long. - if len(pk.Username) > 65535 || len(pk.Password) > 65535 { - return CodeConnectProtocolViolation, ErrProtocolViolation - } - - // End if client id isn't set and clean session is false. - if !pk.CleanSession && len(pk.ClientIdentifier) == 0 { - return CodeConnectBadClientID, ErrProtocolViolation - } - - return Accepted, nil -} - -// ConnackEncode encodes a Connack packet. -func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error { - pk.FixedHeader.Remaining = 2 - pk.FixedHeader.Encode(buf) - buf.WriteByte(encodeBool(pk.SessionPresent)) - buf.WriteByte(pk.ReturnCode) - return nil -} - -// ConnackDecode decodes a Connack packet. -func (pk *Packet) ConnackDecode(buf []byte) error { - var offset int - var err error - - pk.SessionPresent, offset, err = decodeByteBool(buf, 0) - if err != nil { - return ErrMalformedSessionPresent - } - - pk.ReturnCode, offset, err = decodeByte(buf, offset) - if err != nil { - return ErrMalformedReturnCode - } - - return nil -} - -// DisconnectEncode encodes a Disconnect packet. -func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error { - pk.FixedHeader.Encode(buf) - return nil -} - -// PingreqEncode encodes a Pingreq packet. -func (pk *Packet) PingreqEncode(buf *bytes.Buffer) error { - pk.FixedHeader.Encode(buf) - return nil -} - -// PingrespEncode encodes a Pingresp packet. -func (pk *Packet) PingrespEncode(buf *bytes.Buffer) error { - pk.FixedHeader.Encode(buf) - return nil -} - -// PubackEncode encodes a Puback packet. -func (pk *Packet) PubackEncode(buf *bytes.Buffer) error { - pk.FixedHeader.Remaining = 2 - pk.FixedHeader.Encode(buf) - buf.Write(encodeUint16(pk.PacketID)) - return nil -} - -// PubackDecode decodes a Puback packet. -func (pk *Packet) PubackDecode(buf []byte) error { - var err error - pk.PacketID, _, err = decodeUint16(buf, 0) - if err != nil { - return ErrMalformedPacketID - } - return nil -} - -// PubcompEncode encodes a Pubcomp packet. -func (pk *Packet) PubcompEncode(buf *bytes.Buffer) error { - pk.FixedHeader.Remaining = 2 - pk.FixedHeader.Encode(buf) - buf.Write(encodeUint16(pk.PacketID)) - return nil -} - -// PubcompDecode decodes a Pubcomp packet. -func (pk *Packet) PubcompDecode(buf []byte) error { - var err error - pk.PacketID, _, err = decodeUint16(buf, 0) - if err != nil { - return ErrMalformedPacketID - } - return nil -} - -// PublishEncode encodes a Publish packet. -func (pk *Packet) PublishEncode(buf *bytes.Buffer) error { - topicName := encodeString(pk.TopicName) - var packetID []byte - - // Add PacketID if QOS is set. - // [MQTT-2.3.1-5] A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0. - if pk.FixedHeader.Qos > 0 { - - // [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. - if pk.PacketID == 0 { - return ErrMissingPacketID - } - - packetID = encodeUint16(pk.PacketID) - } - - pk.FixedHeader.Remaining = len(topicName) + len(packetID) + len(pk.Payload) - pk.FixedHeader.Encode(buf) - buf.Write(topicName) - buf.Write(packetID) - buf.Write(pk.Payload) - - return nil -} - -// PublishDecode extracts the data values from the packet. -func (pk *Packet) PublishDecode(buf []byte) error { - var offset int - var err error - - pk.TopicName, offset, err = decodeString(buf, 0) - if err != nil { - return ErrMalformedTopic - } - - // If QOS decode Packet ID. - if pk.FixedHeader.Qos > 0 { - pk.PacketID, offset, err = decodeUint16(buf, offset) - if err != nil { - return ErrMalformedPacketID - } - } - - pk.Payload = buf[offset:] - - return nil -} - -// PublishCopy creates a new instance of Publish packet bearing the -// same payload and destination topic, but with an empty header for -// inheriting new QoS flags, etc. -func (pk *Packet) PublishCopy() Packet { - return Packet{ - FixedHeader: FixedHeader{ - Type: Publish, - Retain: pk.FixedHeader.Retain, - }, - TopicName: pk.TopicName, - Payload: pk.Payload, - } -} - -// PublishValidate validates a publish packet. -func (pk *Packet) PublishValidate() (byte, error) { - - // @SPEC [MQTT-2.3.1-1] - // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. - if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 { - return Failed, ErrMissingPacketID - } - - // @SPEC [MQTT-2.3.1-5] - // A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0. - if pk.FixedHeader.Qos == 0 && pk.PacketID > 0 { - return Failed, ErrSurplusPacketID - } - - return Accepted, nil -} - -// PubrecEncode encodes a Pubrec packet. -func (pk *Packet) PubrecEncode(buf *bytes.Buffer) error { - pk.FixedHeader.Remaining = 2 - pk.FixedHeader.Encode(buf) - buf.Write(encodeUint16(pk.PacketID)) - return nil -} - -// PubrecDecode decodes a Pubrec packet. -func (pk *Packet) PubrecDecode(buf []byte) error { - var err error - pk.PacketID, _, err = decodeUint16(buf, 0) - if err != nil { - return ErrMalformedPacketID - } - - return nil -} - -// PubrelEncode encodes a Pubrel packet. -func (pk *Packet) PubrelEncode(buf *bytes.Buffer) error { - pk.FixedHeader.Remaining = 2 - pk.FixedHeader.Encode(buf) - buf.Write(encodeUint16(pk.PacketID)) - return nil -} - -// PubrelDecode decodes a Pubrel packet. -func (pk *Packet) PubrelDecode(buf []byte) error { - var err error - pk.PacketID, _, err = decodeUint16(buf, 0) - if err != nil { - return ErrMalformedPacketID - } - return nil -} - -// SubackEncode encodes a Suback packet. -func (pk *Packet) SubackEncode(buf *bytes.Buffer) error { - packetID := encodeUint16(pk.PacketID) - pk.FixedHeader.Remaining = len(packetID) + len(pk.ReturnCodes) // Set length. - pk.FixedHeader.Encode(buf) - - buf.Write(packetID) // Encode Packet ID. - buf.Write(pk.ReturnCodes) // Encode granted QOS flags. - - return nil -} - -// SubackDecode decodes a Suback packet. -func (pk *Packet) SubackDecode(buf []byte) error { - var offset int - var err error - - // Get Packet ID. - pk.PacketID, offset, err = decodeUint16(buf, offset) - if err != nil { - return ErrMalformedPacketID - } - - // Get Granted QOS flags. - pk.ReturnCodes = buf[offset:] - - return nil -} - -// SubscribeEncode encodes a Subscribe packet. -func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error { - - // Add the Packet ID. - // [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. - if pk.PacketID == 0 { - return ErrMissingPacketID - } - - packetID := encodeUint16(pk.PacketID) - - // Count topics lengths and associated QOS flags. - var topicsLen int - for _, topic := range pk.Topics { - topicsLen += len(encodeString(topic)) + 1 - } - - pk.FixedHeader.Remaining = len(packetID) + topicsLen - pk.FixedHeader.Encode(buf) - buf.Write(packetID) - - // Add all provided topic names and associated QOS flags. - for i, topic := range pk.Topics { - buf.Write(encodeString(topic)) - buf.WriteByte(pk.Qoss[i]) - } - - return nil -} - -// SubscribeDecode decodes a Subscribe packet. -func (pk *Packet) SubscribeDecode(buf []byte) error { - var offset int - var err error - - // Get the Packet ID. - pk.PacketID, offset, err = decodeUint16(buf, 0) - if err != nil { - return ErrMalformedPacketID - } - - // Keep decoding until there's no space left. - for offset < len(buf) { - - // Decode Topic Name. - var topic string - topic, offset, err = decodeString(buf, offset) - if err != nil { - return ErrMalformedTopic - } - pk.Topics = append(pk.Topics, topic) - - // Decode QOS flag. - var qos byte - qos, offset, err = decodeByte(buf, offset) - if err != nil { - return ErrMalformedQoS - } - - // Ensure QoS byte is within range. - if !(qos >= 0 && qos <= 2) { - //if !validateQoS(qos) { - return ErrMalformedQoS - } - - pk.Qoss = append(pk.Qoss, qos) - } - - return nil -} - -// SubscribeValidate ensures the packet is compliant. -func (pk *Packet) SubscribeValidate() (byte, error) { - // @SPEC [MQTT-2.3.1-1]. - // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. - if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 { - return Failed, ErrMissingPacketID - } - - return Accepted, nil -} - -// UnsubackEncode encodes an Unsuback packet. -func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error { - pk.FixedHeader.Remaining = 2 - pk.FixedHeader.Encode(buf) - buf.Write(encodeUint16(pk.PacketID)) - return nil -} - -// UnsubackDecode decodes an Unsuback packet. -func (pk *Packet) UnsubackDecode(buf []byte) error { - var err error - pk.PacketID, _, err = decodeUint16(buf, 0) - if err != nil { - return ErrMalformedPacketID - } - return nil -} - -// UnsubscribeEncode encodes an Unsubscribe packet. -func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error { - - // Add the Packet ID. - // [MQTT-2.3.1-1] SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. - if pk.PacketID == 0 { - return ErrMissingPacketID - } - - packetID := encodeUint16(pk.PacketID) - - // Count topics lengths. - var topicsLen int - for _, topic := range pk.Topics { - topicsLen += len(encodeString(topic)) - } - - pk.FixedHeader.Remaining = len(packetID) + topicsLen - pk.FixedHeader.Encode(buf) - buf.Write(packetID) - - // Add all provided topic names. - for _, topic := range pk.Topics { - buf.Write(encodeString(topic)) - } - - return nil -} - -// UnsubscribeDecode decodes an Unsubscribe packet. -func (pk *Packet) UnsubscribeDecode(buf []byte) error { - var offset int - var err error - - // Get the Packet ID. - pk.PacketID, offset, err = decodeUint16(buf, 0) - if err != nil { - return ErrMalformedPacketID - } - - // Keep decoding until there's no space left. - for offset < len(buf) { - var t string - t, offset, err = decodeString(buf, offset) // Decode Topic Name. - if err != nil { - return ErrMalformedTopic - } - - if len(t) > 0 { - pk.Topics = append(pk.Topics, t) - } - } - - return nil - -} - -// UnsubscribeValidate validates an Unsubscribe packet. -func (pk *Packet) UnsubscribeValidate() (byte, error) { - // @SPEC [MQTT-2.3.1-1]. - // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. - if pk.FixedHeader.Qos > 0 && pk.PacketID == 0 { - return Failed, ErrMissingPacketID - } - - return Accepted, nil -} diff --git a/server/packets/packets_tables_test.go b/server/packets/packets_tables_test.go deleted file mode 100644 index 6a5e8cf1767acf59249e92e3162a4913e19deac7..0000000000000000000000000000000000000000 --- a/server/packets/packets_tables_test.go +++ /dev/null @@ -1,1416 +0,0 @@ -package packets - -type packetTestData struct { - group string // group specifies a group that should run the test, blank for all - rawBytes []byte // the bytes that make the packet - actualBytes []byte // the actual byte array that is created in the event of a byte mutation (eg. MQTT-2.3.1-1 qos/packet id) - packet *Packet // the packet that is expected - desc string // a description of the test - failFirst interface{} // expected fail result to be run immediately after the method is called - expect interface{} // generic expected fail result to be checked - isolate bool // isolate can be used to isolate a test - primary bool // primary is a test that should be run using readPackets - meta interface{} // meta conains a metadata value used in testing on a case-by-case basis. - code byte // code is an expected validation return code -} - -func encodeTestOK(wanted packetTestData) bool { - if wanted.rawBytes == nil { - return false - } - if wanted.group != "" && wanted.group != "encode" { - return false - } - return true -} - -func decodeTestOK(wanted packetTestData) bool { - if wanted.group != "" && wanted.group != "decode" { - return false - } - return true -} - -var expectedPackets = map[byte][]packetTestData{ - Connect: { - { - desc: "MQTT 3.1", - primary: true, - rawBytes: []byte{ - byte(Connect << 4), 17, // Fixed header - 0, 6, // Protocol Name - MSB+LSB - 'M', 'Q', 'I', 's', 'd', 'p', // Protocol Name - 3, // Protocol Version - 0, // Packet Flags - 0, 30, // Keepalive - 0, 3, // Client ID - MSB+LSB - 'z', 'e', 'n', // Client ID "zen" - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Connect, - Remaining: 17, - }, - ProtocolName: []byte("MQIsdp"), - ProtocolVersion: 3, - CleanSession: false, - Keepalive: 30, - ClientIdentifier: "zen", - }, - }, - - { - desc: "MQTT 3.1.1", - primary: true, - rawBytes: []byte{ - byte(Connect << 4), 16, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 0, // Packet Flags - 0, 60, // Keepalive - 0, 4, // Client ID - MSB+LSB - 'z', 'e', 'n', '3', // Client ID "zen" - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Connect, - Remaining: 16, - }, - ProtocolName: []byte("MQTT"), - ProtocolVersion: 4, - CleanSession: false, - Keepalive: 60, - ClientIdentifier: "zen3", - }, - }, - { - desc: "MQTT 3.1.1, Clean Session", - rawBytes: []byte{ - byte(Connect << 4), 15, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 2, // Packet Flags - 0, 45, // Keepalive - 0, 3, // Client ID - MSB+LSB - 'z', 'e', 'n', // Client ID "zen" - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Connect, - Remaining: 15, - }, - ProtocolName: []byte("MQTT"), - ProtocolVersion: 4, - CleanSession: true, - Keepalive: 45, - ClientIdentifier: "zen", - }, - }, - { - desc: "MQTT 3.1.1, Clean Session, LWT", - rawBytes: []byte{ - byte(Connect << 4), 31, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 14, // Packet Flags - 0, 27, // Keepalive - 0, 3, // Client ID - MSB+LSB - 'z', 'e', 'n', // Client ID "zen" - 0, 3, // Will Topic - MSB+LSB - 'l', 'w', 't', - 0, 9, // Will Message MSB+LSB - 'n', 'o', 't', ' ', 'a', 'g', 'a', 'i', 'n', - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Connect, - Remaining: 31, - }, - ProtocolName: []byte("MQTT"), - ProtocolVersion: 4, - CleanSession: true, - Keepalive: 27, - ClientIdentifier: "zen", - WillFlag: true, - WillTopic: "lwt", - WillMessage: []byte("not again"), - WillQos: 1, - }, - }, - { - desc: "MQTT 3.1.1, Username, Password", - rawBytes: []byte{ - byte(Connect << 4), 28, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 194, // Packet Flags - 0, 20, // Keepalive - 0, 3, // Client ID - MSB+LSB - 'z', 'e', 'n', // Client ID "zen" - 0, 5, // Username MSB+LSB - 'm', 'o', 'c', 'h', 'i', - 0, 4, // Password MSB+LSB - ',', '.', '/', ';', - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Connect, - Remaining: 28, - }, - ProtocolName: []byte("MQTT"), - ProtocolVersion: 4, - CleanSession: true, - Keepalive: 20, - ClientIdentifier: "zen", - UsernameFlag: true, - PasswordFlag: true, - Username: []byte("mochi"), - Password: []byte(",./;"), - }, - }, - { - desc: "MQTT 3.1.1, Username, Password, LWT", - primary: true, - rawBytes: []byte{ - byte(Connect << 4), 44, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 206, // Packet Flags - 0, 120, // Keepalive - 0, 3, // Client ID - MSB+LSB - 'z', 'e', 'n', // Client ID "zen" - 0, 3, // Will Topic - MSB+LSB - 'l', 'w', 't', - 0, 9, // Will Message MSB+LSB - 'n', 'o', 't', ' ', 'a', 'g', 'a', 'i', 'n', - 0, 5, // Username MSB+LSB - 'm', 'o', 'c', 'h', 'i', - 0, 4, // Password MSB+LSB - ',', '.', '/', ';', - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Connect, - Remaining: 44, - }, - ProtocolName: []byte("MQTT"), - ProtocolVersion: 4, - CleanSession: true, - Keepalive: 120, - ClientIdentifier: "zen", - UsernameFlag: true, - PasswordFlag: true, - Username: []byte("mochi"), - Password: []byte(",./;"), - WillFlag: true, - WillTopic: "lwt", - WillMessage: []byte("not again"), - WillQos: 1, - }, - }, - - // Fail States - { - desc: "Malformed Connect - protocol name", - group: "decode", - failFirst: ErrMalformedProtocolName, - rawBytes: []byte{ - byte(Connect << 4), 0, // Fixed header - 0, 7, // Protocol Name - MSB+LSB - 'M', 'Q', 'I', 's', 'd', // Protocol Name - }, - }, - - { - desc: "Malformed Connect - protocol version", - group: "decode", - failFirst: ErrMalformedProtocolVersion, - rawBytes: []byte{ - byte(Connect << 4), 0, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - }, - }, - - { - desc: "Malformed Connect - flags", - group: "decode", - failFirst: ErrMalformedFlags, - rawBytes: []byte{ - byte(Connect << 4), 0, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - - }, - }, - { - desc: "Malformed Connect - keepalive", - group: "decode", - failFirst: ErrMalformedKeepalive, - rawBytes: []byte{ - byte(Connect << 4), 0, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 0, // Flags - }, - }, - { - desc: "Malformed Connect - client id", - group: "decode", - failFirst: ErrMalformedClientID, - rawBytes: []byte{ - byte(Connect << 4), 0, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 0, // Flags - 0, 20, // Keepalive - 0, 3, // Client ID - MSB+LSB - 'z', 'e', // Client ID "zen" - }, - }, - { - desc: "Malformed Connect - will topic", - group: "decode", - failFirst: ErrMalformedWillTopic, - rawBytes: []byte{ - byte(Connect << 4), 0, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 14, // Flags - 0, 20, // Keepalive - 0, 3, // Client ID - MSB+LSB - 'z', 'e', 'n', // Client ID "zen" - 0, 6, // Will Topic - MSB+LSB - 'l', - }, - }, - { - desc: "Malformed Connect - will flag", - group: "decode", - failFirst: ErrMalformedWillMessage, - rawBytes: []byte{ - byte(Connect << 4), 0, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 14, // Flags - 0, 20, // Keepalive - 0, 3, // Client ID - MSB+LSB - 'z', 'e', 'n', // Client ID "zen" - 0, 3, // Will Topic - MSB+LSB - 'l', 'w', 't', - 0, 9, // Will Message MSB+LSB - 'n', 'o', 't', ' ', 'a', - }, - }, - { - desc: "Malformed Connect - username", - group: "decode", - failFirst: ErrMalformedUsername, - rawBytes: []byte{ - byte(Connect << 4), 0, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 206, // Flags - 0, 20, // Keepalive - 0, 3, // Client ID - MSB+LSB - 'z', 'e', 'n', // Client ID "zen" - 0, 3, // Will Topic - MSB+LSB - 'l', 'w', 't', - 0, 9, // Will Message MSB+LSB - 'n', 'o', 't', ' ', 'a', 'g', 'a', 'i', 'n', - 0, 5, // Username MSB+LSB - 'm', 'o', 'c', - }, - }, - { - desc: "Malformed Connect - password", - group: "decode", - failFirst: ErrMalformedPassword, - rawBytes: []byte{ - byte(Connect << 4), 0, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 206, // Flags - 0, 20, // Keepalive - 0, 3, // Client ID - MSB+LSB - 'z', 'e', 'n', // Client ID "zen" - 0, 3, // Will Topic - MSB+LSB - 'l', 'w', 't', - 0, 9, // Will Message MSB+LSB - 'n', 'o', 't', ' ', 'a', 'g', 'a', 'i', 'n', - 0, 5, // Username MSB+LSB - 'm', 'o', 'c', 'h', 'i', - 0, 4, // Password MSB+LSB - ',', '.', - }, - }, - - // Validation Tests - { - desc: "Invalid Protocol Name", - group: "validate", - code: CodeConnectProtocolViolation, - packet: &Packet{ - ProtocolName: []byte("stuff"), - }, - }, - { - desc: "Invalid Protocol Version", - group: "validate", - code: CodeConnectBadProtocolVersion, - packet: &Packet{ - ProtocolName: []byte("MQTT"), - ProtocolVersion: 2, - }, - }, - { - desc: "Invalid Protocol Version", - group: "validate", - code: CodeConnectBadProtocolVersion, - packet: &Packet{ - ProtocolName: []byte("MQIsdp"), - ProtocolVersion: 2, - }, - }, - { - desc: "Reserved bit not 0", - group: "validate", - code: CodeConnectProtocolViolation, - packet: &Packet{ - ProtocolName: []byte("MQTT"), - ProtocolVersion: 4, - ReservedBit: 1, - }, - }, - { - desc: "Client ID too long", - group: "validate", - code: CodeConnectProtocolViolation, - packet: &Packet{ - ProtocolName: []byte("MQTT"), - ProtocolVersion: 4, - ClientIdentifier: func() string { - return string(make([]byte, 65536)) - }(), - }, - }, - { - desc: "Has Password Flag but no Username flag", - group: "validate", - code: CodeConnectProtocolViolation, - packet: &Packet{ - ProtocolName: []byte("MQTT"), - ProtocolVersion: 4, - PasswordFlag: true, - }, - }, - { - desc: "Username too long", - group: "validate", - code: CodeConnectProtocolViolation, - packet: &Packet{ - ProtocolName: []byte("MQTT"), - ProtocolVersion: 4, - UsernameFlag: true, - Username: func() []byte { - return make([]byte, 65536) - }(), - }, - }, - { - desc: "Password too long", - group: "validate", - code: CodeConnectProtocolViolation, - packet: &Packet{ - ProtocolName: []byte("MQTT"), - ProtocolVersion: 4, - UsernameFlag: true, - Username: []byte{}, - PasswordFlag: true, - Password: func() []byte { - return make([]byte, 65536) - }(), - }, - }, - { - desc: "Clean session false and client id not set", - group: "validate", - code: CodeConnectBadClientID, - packet: &Packet{ - ProtocolName: []byte("MQTT"), - ProtocolVersion: 4, - CleanSession: false, - }, - }, - - // Spec Tests - { - // @SPEC [MQTT-1.4.0-1] - // The character data in a UTF-8 encoded string MUST be well-formed UTF-8 - // as defined by the Unicode specification [Unicode] and restated in RFC 3629 [RFC 3629]. - // In particular this data MUST NOT include encodings of code points between U+D800 and U+DFFF. - desc: "Invalid UTF8 string (a) - Code point U+D800.", - group: "decode", - failFirst: ErrMalformedClientID, - rawBytes: []byte{ - byte(Connect << 4), 0, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 0, // Flags - 0, 20, // Keepalive - 0, 4, // Client ID - MSB+LSB - 'e', 0xed, 0xa0, 0x80, // Client id bearing U+D800 - }, - }, - { - desc: "Invalid UTF8 string (b) - Code point U+DFFF.", - group: "decode", - failFirst: ErrMalformedClientID, - rawBytes: []byte{ - byte(Connect << 4), 0, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 0, // Flags - 0, 20, // Keepalive - 0, 4, // Client ID - MSB+LSB - 'e', 0xed, 0xa3, 0xbf, // Client id bearing U+D8FF - }, - }, - - // @SPEC [MQTT-1.4.0-2] - // A UTF-8 encoded string MUST NOT include an encoding of the null character U+0000. - { - desc: "Invalid UTF8 string (c) - Code point U+0000.", - group: "decode", - failFirst: ErrMalformedClientID, - rawBytes: []byte{ - byte(Connect << 4), 0, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 0, // Flags - 0, 20, // Keepalive - 0, 3, // Client ID - MSB+LSB - 'e', 0xc0, 0x80, // Client id bearing U+0000 - }, - }, - - // @ SPEC [MQTT-1.4.0-3] - // A UTF-8 encoded sequence 0xEF 0xBB 0xBF is always to be interpreted to mean U+FEFF ("ZERO WIDTH NO-BREAK SPACE") - // wherever it appears in a string and MUST NOT be skipped over or stripped off by a packet receiver. - { - desc: "UTF8 string must not skip or strip code point U+FEFF.", - //group: "decode", - //failFirst: ErrMalformedClientID, - rawBytes: []byte{ - byte(Connect << 4), 18, // Fixed header - 0, 4, // Protocol Name - MSB+LSB - 'M', 'Q', 'T', 'T', // Protocol Name - 4, // Protocol Version - 0, // Flags - 0, 20, // Keepalive - 0, 6, // Client ID - MSB+LSB - 'e', 'b', 0xEF, 0xBB, 0xBF, 'd', // Client id bearing U+FEFF - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Connect, - Remaining: 16, - }, - ProtocolName: []byte("MQTT"), - ProtocolVersion: 4, - Keepalive: 20, - ClientIdentifier: string([]byte{'e', 'b', 0xEF, 0xBB, 0xBF, 'd'}), - }, - }, - }, - Connack: { - { - desc: "Accepted, No Session", - primary: true, - rawBytes: []byte{ - byte(Connack << 4), 2, // fixed header - 0, // No existing session - Accepted, - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Connack, - Remaining: 2, - }, - SessionPresent: false, - ReturnCode: Accepted, - }, - }, - { - desc: "Accepted, Session Exists", - primary: true, - rawBytes: []byte{ - byte(Connack << 4), 2, // fixed header - 1, // Session present - Accepted, - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Connack, - Remaining: 2, - }, - SessionPresent: true, - ReturnCode: Accepted, - }, - }, - { - desc: "Bad Protocol Version", - rawBytes: []byte{ - byte(Connack << 4), 2, // fixed header - 1, // Session present - CodeConnectBadProtocolVersion, - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Connack, - Remaining: 2, - }, - SessionPresent: true, - ReturnCode: CodeConnectBadProtocolVersion, - }, - }, - { - desc: "Bad Client ID", - rawBytes: []byte{ - byte(Connack << 4), 2, // fixed header - 1, // Session present - CodeConnectBadClientID, - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Connack, - Remaining: 2, - }, - SessionPresent: true, - ReturnCode: CodeConnectBadClientID, - }, - }, - { - desc: "Server Unavailable", - rawBytes: []byte{ - byte(Connack << 4), 2, // fixed header - 1, // Session present - CodeConnectServerUnavailable, - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Connack, - Remaining: 2, - }, - SessionPresent: true, - ReturnCode: CodeConnectServerUnavailable, - }, - }, - { - desc: "Bad Username or Password", - rawBytes: []byte{ - byte(Connack << 4), 2, // fixed header - 1, // Session present - CodeConnectBadAuthValues, - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Connack, - Remaining: 2, - }, - SessionPresent: true, - ReturnCode: CodeConnectBadAuthValues, - }, - }, - { - desc: "Not Authorised", - rawBytes: []byte{ - byte(Connack << 4), 2, // fixed header - 1, // Session present - CodeConnectNotAuthorised, - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Connack, - Remaining: 2, - }, - SessionPresent: true, - ReturnCode: CodeConnectNotAuthorised, - }, - }, - - // Fail States - { - desc: "Malformed Connack - session present", - group: "decode", - failFirst: ErrMalformedSessionPresent, - rawBytes: []byte{ - byte(Connect << 4), 2, // Fixed header - }, - }, - { - desc: "Malformed Connack - bad return code", - group: "decode", - //primary: true, - failFirst: ErrMalformedReturnCode, - rawBytes: []byte{ - byte(Connect << 4), 2, // Fixed header - 0, - }, - }, - }, - - Publish: { - { - desc: "Publish - No payload", - primary: true, - rawBytes: []byte{ - byte(Publish << 4), 7, // Fixed header - 0, 5, // Topic Name - LSB+MSB - 'a', '/', 'b', '/', 'c', // Topic Name - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Publish, - Remaining: 7, - }, - TopicName: "a/b/c", - Payload: []byte{}, - }, - }, - { - desc: "Publish - basic", - primary: true, - rawBytes: []byte{ - byte(Publish << 4), 18, // Fixed header - 0, 5, // Topic Name - LSB+MSB - 'a', '/', 'b', '/', 'c', // Topic Name - 'h', 'e', 'l', 'l', 'o', ' ', 'm', 'o', 'c', 'h', 'i', // Payload - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Publish, - Remaining: 18, - }, - TopicName: "a/b/c", - Payload: []byte("hello mochi"), - }, - }, - { - desc: "Publish - QoS:1, Packet ID", - primary: true, - rawBytes: []byte{ - byte(Publish<<4) | 2, 14, // Fixed header - 0, 5, // Topic Name - LSB+MSB - 'a', '/', 'b', '/', 'c', // Topic Name - 0, 7, // Packet ID - LSB+MSB - 'h', 'e', 'l', 'l', 'o', // Payload - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Publish, - Qos: 1, - Remaining: 14, - }, - TopicName: "a/b/c", - Payload: []byte("hello"), - PacketID: 7, - }, - meta: byte(2), - }, - { - desc: "Publish - QoS:1, Packet ID, No payload", - primary: true, - rawBytes: []byte{ - byte(Publish<<4) | 2, 9, // Fixed header - 0, 5, // Topic Name - LSB+MSB - 'y', '/', 'u', '/', 'i', // Topic Name - 0, 8, // Packet ID - LSB+MSB - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Publish, - Qos: 1, - Remaining: 9, - }, - TopicName: "y/u/i", - PacketID: 8, - Payload: []byte{}, - }, - meta: byte(2), - }, - { - desc: "Publish - Retain", - rawBytes: []byte{ - byte(Publish<<4) | 1, 10, // Fixed header - 0, 3, // Topic Name - LSB+MSB - 'a', '/', 'b', // Topic Name - 'h', 'e', 'l', 'l', 'o', // Payload - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Publish, - Retain: true, - }, - TopicName: "a/b", - Payload: []byte("hello"), - }, - meta: byte(1), - }, - { - desc: "Publish - Dup", - rawBytes: []byte{ - byte(Publish<<4) | 8, 10, // Fixed header - 0, 3, // Topic Name - LSB+MSB - 'a', '/', 'b', // Topic Name - 'h', 'e', 'l', 'l', 'o', // Payload - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Publish, - Dup: true, - }, - TopicName: "a/b", - Payload: []byte("hello"), - }, - meta: byte(8), - }, - - // Fail States - { - desc: "Malformed Publish - topic name", - group: "decode", - failFirst: ErrMalformedTopic, - rawBytes: []byte{ - byte(Publish << 4), 7, // Fixed header - 0, 5, // Topic Name - LSB+MSB - 'a', '/', - 0, 11, // Packet ID - LSB+MSB - }, - }, - - { - desc: "Malformed Publish - Packet ID", - group: "decode", - failFirst: ErrMalformedPacketID, - rawBytes: []byte{ - byte(Publish<<4) | 2, 7, // Fixed header - 0, 5, // Topic Name - LSB+MSB - 'x', '/', 'y', '/', 'z', // Topic Name - 0, // Packet ID - LSB+MSB - }, - }, - - // Copy tests - { - desc: "Publish - basic copyable", - group: "copy", - rawBytes: []byte{ - byte(Publish << 4), 18, // Fixed header - 0, 5, // Topic Name - LSB+MSB - 'z', '/', 'e', '/', 'n', // Topic Name - 'm', 'o', 'c', 'h', 'i', ' ', 'm', 'o', 'c', 'h', 'i', // Payload - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Publish, - Dup: true, - Retain: true, - Qos: 1, - }, - TopicName: "z/e/n", - Payload: []byte("mochi mochi"), - }, - }, - - // Spec tests - { - // @SPEC [MQTT-2.3.1-5] - // A PUBLISH Packet MUST NOT contain a Packet Identifier if its QoS value is set to 0. - desc: "[MQTT-2.3.1-5] Packet ID must be 0 if QoS is 0 (a)", - group: "encode", - // this version tests for correct byte array mutuation. - // this does not check if -incoming- packets are parsed as correct, - // it is impossible for the parser to determine if the payload start is incorrect. - rawBytes: []byte{ - byte(Publish << 4), 12, // Fixed header - 0, 5, // Topic Name - LSB+MSB - 'a', '/', 'b', '/', 'c', // Topic Name - 0, 3, // Packet ID - LSB+MSB - 'h', 'e', 'l', 'l', 'o', // Payload - }, - actualBytes: []byte{ - byte(Publish << 4), 12, // Fixed header - 0, 5, // Topic Name - LSB+MSB - 'a', '/', 'b', '/', 'c', // Topic Name - // Packet ID is removed. - 'h', 'e', 'l', 'l', 'o', // Payload - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Publish, - Remaining: 12, - }, - TopicName: "a/b/c", - Payload: []byte("hello"), - }, - }, - { - // @SPEC [MQTT-2.3.1-1]. - // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. - desc: "[MQTT-2.3.1-1] No Packet ID with QOS > 0", - group: "encode", - expect: ErrMissingPacketID, - code: Failed, - rawBytes: []byte{ - byte(Publish<<4) | 2, 14, // Fixed header - 0, 5, // Topic Name - LSB+MSB - 'a', '/', 'b', '/', 'c', // Topic Name - 0, 0, // Packet ID - LSB+MSB - 'h', 'e', 'l', 'l', 'o', // Payload - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Publish, - Qos: 1, - }, - TopicName: "a/b/c", - Payload: []byte("hello"), - PacketID: 0, - }, - meta: byte(2), - }, - /* - { - // @SPEC [MQTT-2.3.1-1]. - // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. - desc: "[MQTT-2.3.1-1] No Packet ID with QOS > 0", - group: "validate", - //primary: true, - expect: ErrMissingPacketID, - code: Failed, - rawBytes: []byte{ - byte(Publish<<4) | 2, 14, // Fixed header - 0, 5, // Topic Name - LSB+MSB - 'a', '/', 'b', '/', 'c', // Topic Name - 0, 0, // Packet ID - LSB+MSB - 'h', 'e', 'l', 'l', 'o', // Payload - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Publish, - Qos: 1, - }, - TopicName: "a/b/c", - Payload: []byte("hello"), - PacketID: 0, - }, - meta: byte(2), - }, - - */ - - // Validation Tests - { - // @SPEC [MQTT-2.3.1-5] - desc: "[MQTT-2.3.1-5] Packet ID must be 0 if QoS is 0 (b)", - group: "validate", - expect: ErrSurplusPacketID, - code: Failed, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Publish, - Remaining: 12, - Qos: 0, - }, - TopicName: "a/b/c", - Payload: []byte("hello"), - PacketID: 3, - }, - }, - { - // @SPEC [MQTT-2.3.1-1]. - // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. - desc: "[MQTT-2.3.1-1] No Packet ID with QOS > 0", - group: "validate", - expect: ErrMissingPacketID, - code: Failed, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Publish, - Qos: 1, - }, - PacketID: 0, - }, - }, - }, - - Puback: { - { - desc: "Puback", - primary: true, - rawBytes: []byte{ - byte(Puback << 4), 2, // Fixed header - 0, 11, // Packet ID - LSB+MSB - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Puback, - Remaining: 2, - }, - PacketID: 11, - }, - }, - - // Fail states - { - desc: "Malformed Puback - Packet ID", - group: "decode", - failFirst: ErrMalformedPacketID, - rawBytes: []byte{ - byte(Puback << 4), 2, // Fixed header - 0, // Packet ID - LSB+MSB - }, - }, - }, - Pubrec: { - { - desc: "Pubrec", - primary: true, - rawBytes: []byte{ - byte(Pubrec << 4), 2, // Fixed header - 0, 12, // Packet ID - LSB+MSB - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Pubrec, - Remaining: 2, - }, - PacketID: 12, - }, - }, - - // Fail states - { - desc: "Malformed Pubrec - Packet ID", - group: "decode", - failFirst: ErrMalformedPacketID, - rawBytes: []byte{ - byte(Pubrec << 4), 2, // Fixed header - 0, // Packet ID - LSB+MSB - }, - }, - }, - Pubrel: { - { - desc: "Pubrel", - primary: true, - rawBytes: []byte{ - byte(Pubrel<<4) | 2, 2, // Fixed header - 0, 12, // Packet ID - LSB+MSB - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Pubrel, - Remaining: 2, - Qos: 1, - }, - PacketID: 12, - }, - meta: byte(2), - }, - - // Fail states - { - desc: "Malformed Pubrel - Packet ID", - group: "decode", - failFirst: ErrMalformedPacketID, - rawBytes: []byte{ - byte(Pubrel << 4), 2, // Fixed header - 0, // Packet ID - LSB+MSB - }, - }, - }, - Pubcomp: { - { - desc: "Pubcomp", - primary: true, - rawBytes: []byte{ - byte(Pubcomp << 4), 2, // Fixed header - 0, 14, // Packet ID - LSB+MSB - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Pubcomp, - Remaining: 2, - }, - PacketID: 14, - }, - }, - - // Fail states - { - desc: "Malformed Pubcomp - Packet ID", - group: "decode", - failFirst: ErrMalformedPacketID, - rawBytes: []byte{ - byte(Pubcomp << 4), 2, // Fixed header - 0, // Packet ID - LSB+MSB - }, - }, - }, - Subscribe: { - { - desc: "Subscribe", - primary: true, - rawBytes: []byte{ - byte(Subscribe << 4), 30, // Fixed header - 0, 15, // Packet ID - LSB+MSB - - 0, 3, // Topic Name - LSB+MSB - 'a', '/', 'b', // Topic Name - 0, // QoS - - 0, 11, // Topic Name - LSB+MSB - 'd', '/', 'e', '/', 'f', '/', 'g', '/', 'h', '/', 'i', // Topic Name - 1, // QoS - - 0, 5, // Topic Name - LSB+MSB - 'x', '/', 'y', '/', 'z', // Topic Name - 2, // QoS - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Subscribe, - Remaining: 30, - }, - PacketID: 15, - Topics: []string{ - "a/b", - "d/e/f/g/h/i", - "x/y/z", - }, - Qoss: []byte{0, 1, 2}, - }, - }, - - // Fail states - { - desc: "Malformed Subscribe - Packet ID", - group: "decode", - failFirst: ErrMalformedPacketID, - rawBytes: []byte{ - byte(Subscribe << 4), 2, // Fixed header - 0, // Packet ID - LSB+MSB - }, - }, - { - desc: "Malformed Subscribe - topic", - group: "decode", - failFirst: ErrMalformedTopic, - rawBytes: []byte{ - byte(Subscribe << 4), 2, // Fixed header - 0, 21, // Packet ID - LSB+MSB - - 0, 3, // Topic Name - LSB+MSB - 'a', '/', - }, - }, - { - desc: "Malformed Subscribe - qos", - group: "decode", - failFirst: ErrMalformedQoS, - rawBytes: []byte{ - byte(Subscribe << 4), 2, // Fixed header - 0, 22, // Packet ID - LSB+MSB - - 0, 3, // Topic Name - LSB+MSB - 'j', '/', 'b', // Topic Name - - }, - }, - { - desc: "Malformed Subscribe - qos out of range", - group: "decode", - failFirst: ErrMalformedQoS, - rawBytes: []byte{ - byte(Subscribe << 4), 2, // Fixed header - 0, 22, // Packet ID - LSB+MSB - - 0, 3, // Topic Name - LSB+MSB - 'c', '/', 'd', // Topic Name - 5, // QoS - - }, - }, - - // Validation - { - // @SPEC [MQTT-2.3.1-1]. - // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. - desc: "[MQTT-2.3.1-1] Subscribe No Packet ID with QOS > 0", - group: "validate", - expect: ErrMissingPacketID, - code: Failed, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Subscribe, - Qos: 1, - }, - PacketID: 0, - }, - }, - - // Spec tests - { - // @SPEC [MQTT-2.3.1-1]. - // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. - desc: "[MQTT-2.3.1-1] Subscribe No Packet ID with QOS > 0", - group: "encode", - code: Failed, - expect: ErrMissingPacketID, - rawBytes: []byte{ - byte(Subscribe<<4) | 1<<1, 10, // Fixed header - 0, 0, // Packet ID - LSB+MSB - 0, 5, // Topic Name - LSB+MSB - 'a', '/', 'b', '/', 'c', // Topic Name - 1, // QoS - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Subscribe, - Qos: 1, - Remaining: 10, - }, - Topics: []string{ - "a/b/c", - }, - Qoss: []byte{1}, - PacketID: 0, - }, - meta: byte(2), - }, - }, - Suback: { - { - desc: "Suback", - primary: true, - rawBytes: []byte{ - byte(Suback << 4), 6, // Fixed header - 0, 17, // Packet ID - LSB+MSB - 0, // Return Code QoS 0 - 1, // Return Code QoS 1 - 2, // Return Code QoS 2 - 0x80, // Return Code fail - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Suback, - Remaining: 6, - }, - PacketID: 17, - ReturnCodes: []byte{0, 1, 2, 0x80}, - }, - }, - - // Fail states - { - desc: "Malformed Suback - Packet ID", - group: "decode", - failFirst: ErrMalformedPacketID, - rawBytes: []byte{ - byte(Subscribe << 4), 2, // Fixed header - 0, // Packet ID - LSB+MSB - }, - }, - }, - - Unsubscribe: { - { - desc: "Unsubscribe", - primary: true, - rawBytes: []byte{ - byte(Unsubscribe << 4), 27, // Fixed header - 0, 35, // Packet ID - LSB+MSB - - 0, 3, // Topic Name - LSB+MSB - 'a', '/', 'b', // Topic Name - - 0, 11, // Topic Name - LSB+MSB - 'd', '/', 'e', '/', 'f', '/', 'g', '/', 'h', '/', 'i', // Topic Name - - 0, 5, // Topic Name - LSB+MSB - 'x', '/', 'y', '/', 'z', // Topic Name - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Unsubscribe, - Remaining: 27, - }, - PacketID: 35, - Topics: []string{ - "a/b", - "d/e/f/g/h/i", - "x/y/z", - }, - }, - }, - // Fail states - { - desc: "Malformed Unsubscribe - Packet ID", - group: "decode", - failFirst: ErrMalformedPacketID, - rawBytes: []byte{ - byte(Unsubscribe << 4), 2, // Fixed header - 0, // Packet ID - LSB+MSB - }, - }, - { - desc: "Malformed Unsubscribe - topic", - group: "decode", - failFirst: ErrMalformedTopic, - rawBytes: []byte{ - byte(Unsubscribe << 4), 2, // Fixed header - 0, 21, // Packet ID - LSB+MSB - 0, 3, // Topic Name - LSB+MSB - 'a', '/', - }, - }, - - // Validation - { - // @SPEC [MQTT-2.3.1-1]. - // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. - desc: "[MQTT-2.3.1-1] Subscribe No Packet ID with QOS > 0", - group: "validate", - expect: ErrMissingPacketID, - code: Failed, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Unsubscribe, - Qos: 1, - }, - PacketID: 0, - }, - }, - - // Spec tests - { - // @SPEC [MQTT-2.3.1-1]. - // SUBSCRIBE, UNSUBSCRIBE, and PUBLISH (in cases where QoS > 0) Control Packets MUST contain a non-zero 16-bit Packet Identifier. - desc: "[MQTT-2.3.1-1] Unsubscribe No Packet ID with QOS > 0", - group: "encode", - code: Failed, - expect: ErrMissingPacketID, - rawBytes: []byte{ - byte(Unsubscribe<<4) | 1<<1, 9, // Fixed header - 0, 0, // Packet ID - LSB+MSB - 0, 5, // Topic Name - LSB+MSB - 'a', '/', 'b', '/', 'c', // Topic Name - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Unsubscribe, - Qos: 1, - Remaining: 9, - }, - Topics: []string{ - "a/b/c", - }, - PacketID: 0, - }, - meta: byte(2), - }, - }, - Unsuback: { - { - desc: "Unsuback", - primary: true, - rawBytes: []byte{ - byte(Unsuback << 4), 2, // Fixed header - 0, 37, // Packet ID - LSB+MSB - - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Unsuback, - Remaining: 2, - }, - PacketID: 37, - }, - }, - - // Fail states - { - desc: "Malformed Unsuback - Packet ID", - group: "decode", - failFirst: ErrMalformedPacketID, - rawBytes: []byte{ - byte(Unsuback << 4), 2, // Fixed header - 0, // Packet ID - LSB+MSB - }, - }, - }, - - Pingreq: { - { - desc: "Ping request", - primary: true, - rawBytes: []byte{ - byte(Pingreq << 4), 0, // fixed header - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Pingreq, - Remaining: 0, - }, - }, - }, - }, - Pingresp: { - { - desc: "Ping response", - primary: true, - rawBytes: []byte{ - byte(Pingresp << 4), 0, // fixed header - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Pingresp, - Remaining: 0, - }, - }, - }, - }, - - Disconnect: { - { - desc: "Disconnect", - primary: true, - rawBytes: []byte{ - byte(Disconnect << 4), 0, // fixed header - }, - packet: &Packet{ - FixedHeader: FixedHeader{ - Type: Disconnect, - Remaining: 0, - }, - }, - }, - }, -} diff --git a/server/packets/packets_test.go b/server/packets/packets_test.go deleted file mode 100644 index aa5d7b5c1057a06284bf9ba7a81847599d49b479..0000000000000000000000000000000000000000 --- a/server/packets/packets_test.go +++ /dev/null @@ -1,1082 +0,0 @@ -package packets - -import ( - "bytes" - "testing" - - "github.com/jinzhu/copier" - "github.com/stretchr/testify/require" -) - -func TestConnectEncode(t *testing.T) { - require.Contains(t, expectedPackets, Connect) - for i, wanted := range expectedPackets[Connect] { - if !encodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(1), Connect, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - pk := new(Packet) - copier.Copy(pk, wanted.packet) - - require.Equal(t, Connect, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - - buf := new(bytes.Buffer) - pk.ConnectEncode(buf) - encoded := buf.Bytes() - - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) - require.Equal(t, byte(Connect<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) - require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) - - ok, _ := pk.ConnectValidate() - require.Equal(t, byte(Accepted), ok, "Connect packet didn't validate - %v", ok) - - require.Equal(t, wanted.packet.FixedHeader.Type, pk.FixedHeader.Type, "Mismatched packet fixed header type [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.FixedHeader.Dup, pk.FixedHeader.Dup, "Mismatched packet fixed header dup [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.FixedHeader.Qos, pk.FixedHeader.Qos, "Mismatched packet fixed header qos [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.FixedHeader.Retain, pk.FixedHeader.Retain, "Mismatched packet fixed header retain [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.ProtocolVersion, pk.ProtocolVersion, "Mismatched packet protocol version [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.ProtocolName, pk.ProtocolName, "Mismatched packet protocol name [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.CleanSession, pk.CleanSession, "Mismatched packet cleansession [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.ClientIdentifier, pk.ClientIdentifier, "Mismatched packet client id [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.Keepalive, pk.Keepalive, "Mismatched keepalive value [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.UsernameFlag, pk.UsernameFlag, "Mismatched packet username flag [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.Username, pk.Username, "Mismatched packet username [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.PasswordFlag, pk.PasswordFlag, "Mismatched packet password flag [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.Password, pk.Password, "Mismatched packet password [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.WillFlag, pk.WillFlag, "Mismatched packet will flag [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.WillTopic, pk.WillTopic, "Mismatched packet will topic [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.WillMessage, pk.WillMessage, "Mismatched packet will message [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.WillQos, pk.WillQos, "Mismatched packet will qos [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.WillRetain, pk.WillRetain, "Mismatched packet will retain [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkConnectEncode(b *testing.B) { - pk := new(Packet) - copier.Copy(pk, expectedPackets[Connect][0].packet) - - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - pk.ConnectEncode(buf) - } -} - -func TestConnectDecode(t *testing.T) { - require.Contains(t, expectedPackets, Connect) - for i, wanted := range expectedPackets[Connect] { - if !decodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(1), Connect, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - require.Equal(t, true, (len(wanted.rawBytes) > 2), "Insufficent bytes in packet [i:%d] %s", i, wanted.desc) - - pk := &Packet{FixedHeader: FixedHeader{Type: Connect}} - err := pk.ConnectDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. - if wanted.failFirst != nil { - require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) - continue - } - - require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.FixedHeader.Type, pk.FixedHeader.Type, "Mismatched packet fixed header type [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.FixedHeader.Dup, pk.FixedHeader.Dup, "Mismatched packet fixed header dup [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.FixedHeader.Qos, pk.FixedHeader.Qos, "Mismatched packet fixed header qos [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.FixedHeader.Retain, pk.FixedHeader.Retain, "Mismatched packet fixed header retain [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.ProtocolVersion, pk.ProtocolVersion, "Mismatched packet protocol version [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.ProtocolName, pk.ProtocolName, "Mismatched packet protocol name [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.CleanSession, pk.CleanSession, "Mismatched packet cleansession [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.ClientIdentifier, pk.ClientIdentifier, "Mismatched packet client id [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.Keepalive, pk.Keepalive, "Mismatched keepalive value [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.UsernameFlag, pk.UsernameFlag, "Mismatched packet username flag [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.Username, pk.Username, "Mismatched packet username [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.PasswordFlag, pk.PasswordFlag, "Mismatched packet password flag [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.Password, pk.Password, "Mismatched packet password [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.WillFlag, pk.WillFlag, "Mismatched packet will flag [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.WillTopic, pk.WillTopic, "Mismatched packet will topic [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.WillMessage, pk.WillMessage, "Mismatched packet will message [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.WillQos, pk.WillQos, "Mismatched packet will qos [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.WillRetain, pk.WillRetain, "Mismatched packet will retain [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkConnectDecode(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Connect}} - pk.FixedHeader.Decode(expectedPackets[Connect][0].rawBytes[0]) - - for n := 0; n < b.N; n++ { - pk.ConnectDecode(expectedPackets[Connect][0].rawBytes[2:]) - } -} - -func TestConnectValidate(t *testing.T) { - require.Contains(t, expectedPackets, Connect) - for i, wanted := range expectedPackets[Connect] { - if wanted.group == "validate" { - pk := wanted.packet - ok, _ := pk.ConnectValidate() - require.Equal(t, wanted.code, ok, "Connect packet didn't validate [i:%d] %s", i, wanted.desc) - } - } -} - -func TestConnackEncode(t *testing.T) { - require.Contains(t, expectedPackets, Connack) - for i, wanted := range expectedPackets[Connack] { - if !encodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(2), Connack, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - pk := new(Packet) - copier.Copy(pk, wanted.packet) - - require.Equal(t, Connack, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - - buf := new(bytes.Buffer) - err := pk.ConnackEncode(buf) - require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) - encoded := buf.Bytes() - - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) - require.Equal(t, byte(Connack<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) - require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.ReturnCode, pk.ReturnCode, "Mismatched return code [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.SessionPresent, pk.SessionPresent, "Mismatched session present bool [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkConnackEncode(b *testing.B) { - pk := new(Packet) - copier.Copy(pk, expectedPackets[Connack][0].packet) - - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - pk.ConnackEncode(buf) - } -} - -func TestConnackDecode(t *testing.T) { - require.Contains(t, expectedPackets, Connack) - for i, wanted := range expectedPackets[Connack] { - if !decodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(2), Connack, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - - pk := &Packet{FixedHeader: FixedHeader{Type: Connack}} - err := pk.ConnackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. - if wanted.failFirst != nil { - require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) - continue - } - - require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.ReturnCode, pk.ReturnCode, "Mismatched return code [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.SessionPresent, pk.SessionPresent, "Mismatched session present bool [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkConnackDecode(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Connack}} - pk.FixedHeader.Decode(expectedPackets[Connack][0].rawBytes[0]) - - for n := 0; n < b.N; n++ { - pk.ConnackDecode(expectedPackets[Connack][0].rawBytes[2:]) - } -} - -func TestDisconnectEncode(t *testing.T) { - require.Contains(t, expectedPackets, Disconnect) - for i, wanted := range expectedPackets[Disconnect] { - require.Equal(t, uint8(14), Disconnect, "Incorrect Packet Type [i:%d]", i) - - pk := new(Packet) - copier.Copy(pk, wanted.packet) - - require.Equal(t, Disconnect, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d]", i) - - buf := new(bytes.Buffer) - err := pk.DisconnectEncode(buf) - require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) - encoded := buf.Bytes() - - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d]", i) - require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d]", i) - } -} - -func BenchmarkDisconnectEncode(b *testing.B) { - pk := new(Packet) - copier.Copy(pk, expectedPackets[Disconnect][0].packet) - - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - pk.DisconnectEncode(buf) - } -} - -func TestPingreqEncode(t *testing.T) { - require.Contains(t, expectedPackets, Pingreq) - for i, wanted := range expectedPackets[Pingreq] { - require.Equal(t, uint8(12), Pingreq, "Incorrect Packet Type [i:%d]", i) - pk := new(Packet) - copier.Copy(pk, wanted.packet) - - require.Equal(t, Pingreq, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d]", i) - - buf := new(bytes.Buffer) - err := pk.PingreqEncode(buf) - require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) - encoded := buf.Bytes() - - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d]", i) - require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d]", i) - } -} - -func BenchmarkPingreqEncode(b *testing.B) { - pk := new(Packet) - copier.Copy(pk, expectedPackets[Pingreq][0].packet) - - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - pk.PingreqEncode(buf) - } -} - -func TestPingrespEncode(t *testing.T) { - require.Contains(t, expectedPackets, Pingresp) - for i, wanted := range expectedPackets[Pingresp] { - require.Equal(t, uint8(13), Pingresp, "Incorrect Packet Type [i:%d]", i) - - pk := new(Packet) - copier.Copy(pk, wanted.packet) - - require.Equal(t, Pingresp, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d]", i) - - buf := new(bytes.Buffer) - err := pk.PingrespEncode(buf) - require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) - encoded := buf.Bytes() - - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d]", i) - require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d]", i) - } -} - -func BenchmarkPingrespEncode(b *testing.B) { - pk := new(Packet) - copier.Copy(pk, expectedPackets[Pingresp][0].packet) - - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - pk.PingrespEncode(buf) - } -} - -func TestPubackEncode(t *testing.T) { - require.Contains(t, expectedPackets, Puback) - for i, wanted := range expectedPackets[Puback] { - if !encodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(4), Puback, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - pk := new(Packet) - copier.Copy(pk, wanted.packet) - - require.Equal(t, Puback, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - - buf := new(bytes.Buffer) - err := pk.PubackEncode(buf) - require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) - encoded := buf.Bytes() - - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) - require.Equal(t, byte(Puback<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) - require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkPubackEncode(b *testing.B) { - pk := new(Packet) - copier.Copy(pk, expectedPackets[Puback][0].packet) - - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - pk.PubackEncode(buf) - } -} - -func TestPubackDecode(t *testing.T) { - require.Contains(t, expectedPackets, Puback) - for i, wanted := range expectedPackets[Puback] { - if !decodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(4), Puback, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - - pk := &Packet{FixedHeader: FixedHeader{Type: Puback}} - err := pk.PubackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. - - if wanted.failFirst != nil { - require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) - continue - } - - require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkPubackDecode(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Puback}} - pk.FixedHeader.Decode(expectedPackets[Puback][0].rawBytes[0]) - - for n := 0; n < b.N; n++ { - pk.PubackDecode(expectedPackets[Puback][0].rawBytes[2:]) - } -} - -func TestPubcompEncode(t *testing.T) { - require.Contains(t, expectedPackets, Pubcomp) - for i, wanted := range expectedPackets[Pubcomp] { - if !encodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(7), Pubcomp, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - pk := new(Packet) - copier.Copy(pk, wanted.packet) - - require.Equal(t, Pubcomp, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - - buf := new(bytes.Buffer) - err := pk.PubcompEncode(buf) - require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) - encoded := buf.Bytes() - - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) - require.Equal(t, byte(Pubcomp<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) - require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkPubcompEncode(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Pubcomp}} - copier.Copy(pk, expectedPackets[Pubcomp][0].packet) - - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - pk.PubcompEncode(buf) - } -} - -func TestPubcompDecode(t *testing.T) { - require.Contains(t, expectedPackets, Pubcomp) - for i, wanted := range expectedPackets[Pubcomp] { - if !decodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(7), Pubcomp, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - - pk := &Packet{FixedHeader: FixedHeader{Type: Pubcomp}} - err := pk.PubcompDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. - - if wanted.failFirst != nil { - require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) - continue - } - - require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkPubcompDecode(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Pubcomp}} - pk.FixedHeader.Decode(expectedPackets[Pubcomp][0].rawBytes[0]) - - for n := 0; n < b.N; n++ { - pk.PubcompDecode(expectedPackets[Pubcomp][0].rawBytes[2:]) - } -} - -func TestPublishEncode(t *testing.T) { - require.Contains(t, expectedPackets, Publish) - for i, wanted := range expectedPackets[Publish] { - if !encodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(3), Publish, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - pk := new(Packet) - copier.Copy(pk, wanted.packet) - - require.Equal(t, Publish, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - - buf := new(bytes.Buffer) - err := pk.PublishEncode(buf) - encoded := buf.Bytes() - - if wanted.expect != nil { - require.Error(t, err, "Expected error writing buffer [i:%d] %s", i, wanted.desc) - } else { - - // If actualBytes is set, compare mutated version of byte string instead (to avoid length mismatches, etc). - if len(wanted.actualBytes) > 0 { - wanted.rawBytes = wanted.actualBytes - } - - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) - if wanted.meta != nil { - require.Equal(t, byte(Publish<<4)|wanted.meta.(byte), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) - } else { - require.Equal(t, byte(Publish<<4), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) - } - - require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.FixedHeader.Qos, pk.FixedHeader.Qos, "Mismatched QOS [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.FixedHeader.Dup, pk.FixedHeader.Dup, "Mismatched Dup [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.FixedHeader.Retain, pk.FixedHeader.Retain, "Mismatched Retain [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) - } - } -} - -func BenchmarkPublishEncode(b *testing.B) { - pk := new(Packet) - copier.Copy(pk, expectedPackets[Publish][0].packet) - - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - pk.PublishEncode(buf) - } -} - -func TestPublishDecode(t *testing.T) { - require.Contains(t, expectedPackets, Publish) - for i, wanted := range expectedPackets[Publish] { - if !decodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(3), Publish, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - - pk := &Packet{FixedHeader: FixedHeader{Type: Publish}} - pk.FixedHeader.Decode(wanted.rawBytes[0]) - - err := pk.PublishDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. - if wanted.failFirst != nil { - require.Error(t, err, "Expected fh error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) - continue - } - - require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.FixedHeader.Qos, pk.FixedHeader.Qos, "Mismatched QOS [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.FixedHeader.Dup, pk.FixedHeader.Dup, "Mismatched Dup [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.FixedHeader.Retain, pk.FixedHeader.Retain, "Mismatched Retain [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - - } -} - -func BenchmarkPublishDecode(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Publish}} - pk.FixedHeader.Decode(expectedPackets[Publish][1].rawBytes[0]) - - for n := 0; n < b.N; n++ { - pk.PublishDecode(expectedPackets[Publish][1].rawBytes[2:]) - } -} - -func TestPublishCopy(t *testing.T) { - require.Contains(t, expectedPackets, Publish) - for i, wanted := range expectedPackets[Publish] { - if wanted.group == "copy" { - - pk := &Packet{FixedHeader: FixedHeader{Type: Publish}} - err := pk.PublishDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. - require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) - - copied := pk.PublishCopy() - - require.Equal(t, byte(0), copied.FixedHeader.Qos, "Mismatched QOS [i:%d] %s", i, wanted.desc) - require.Equal(t, false, copied.FixedHeader.Dup, "Mismatched Dup [i:%d] %s", i, wanted.desc) - require.Equal(t, false, copied.FixedHeader.Retain, "Mismatched Retain [i:%d] %s", i, wanted.desc) - - require.Equal(t, pk.Payload, copied.Payload, "Mismatched Payload [i:%d] %s", i, wanted.desc) - require.Equal(t, pk.TopicName, copied.TopicName, "Mismatched Topic Name [i:%d] %s", i, wanted.desc) - - } - } -} - -func BenchmarkPublishCopy(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Publish}} - pk.FixedHeader.Decode(expectedPackets[Publish][1].rawBytes[0]) - - for n := 0; n < b.N; n++ { - pk.PublishCopy() - } -} - -func TestPublishValidate(t *testing.T) { - require.Contains(t, expectedPackets, Publish) - for i, wanted := range expectedPackets[Publish] { - if wanted.group == "validate" || i == 0 { - pk := wanted.packet - ok, err := pk.PublishValidate() - - if i == 0 { - require.NoError(t, err, "Publish should have validated - error incorrect [i:%d] %s", i, wanted.desc) - require.Equal(t, Accepted, ok, "Publish should have validated - code incorrect [i:%d] %s", i, wanted.desc) - } else { - require.Equal(t, Failed, ok, "Publish packet didn't validate - code incorrect [i:%d] %s", i, wanted.desc) - if err != nil { - require.Equal(t, wanted.expect, err, "Publish packet didn't validate - error incorrect [i:%d] %s", i, wanted.desc) - } - } - } - } -} - -func BenchmarkPublishValidate(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Publish}} - pk.FixedHeader.Decode(expectedPackets[Publish][1].rawBytes[0]) - - for n := 0; n < b.N; n++ { - _, err := pk.PublishValidate() - if err != nil { - panic(err) - } - } -} - -func TestPubrecEncode(t *testing.T) { - require.Contains(t, expectedPackets, Pubrec) - for i, wanted := range expectedPackets[Pubrec] { - if !encodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(5), Pubrec, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - pk := new(Packet) - copier.Copy(pk, wanted.packet) - - require.Equal(t, Pubrec, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - - buf := new(bytes.Buffer) - err := pk.PubrecEncode(buf) - require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) - encoded := buf.Bytes() - - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) - require.Equal(t, byte(Pubrec<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) - require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - - } -} - -func BenchmarkPubrecEncode(b *testing.B) { - pk := new(Packet) - copier.Copy(pk, expectedPackets[Pubrec][0].packet) - - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - pk.PubrecEncode(buf) - } -} - -func TestPubrecDecode(t *testing.T) { - require.Contains(t, expectedPackets, Pubrec) - for i, wanted := range expectedPackets[Pubrec] { - if !decodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(5), Pubrec, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - - pk := &Packet{FixedHeader: FixedHeader{Type: Pubrec}} - err := pk.PubrecDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. - - if wanted.failFirst != nil { - require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) - continue - } - - require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkPubrecDecode(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Pubrec}} - pk.FixedHeader.Decode(expectedPackets[Pubrec][0].rawBytes[0]) - - for n := 0; n < b.N; n++ { - pk.PubrecDecode(expectedPackets[Pubrec][0].rawBytes[2:]) - } -} - -func TestPubrelEncode(t *testing.T) { - require.Contains(t, expectedPackets, Pubrel) - for i, wanted := range expectedPackets[Pubrel] { - if !encodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(6), Pubrel, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - pk := new(Packet) - copier.Copy(pk, wanted.packet) - - require.Equal(t, Pubrel, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - - buf := new(bytes.Buffer) - err := pk.PubrelEncode(buf) - require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) - encoded := buf.Bytes() - - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) - if wanted.meta != nil { - require.Equal(t, byte(Pubrel<<4)|wanted.meta.(byte), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) - } else { - require.Equal(t, byte(Pubrel<<4), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) - } - require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkPubrelEncode(b *testing.B) { - pk := new(Packet) - copier.Copy(pk, expectedPackets[Pubrel][0].packet) - - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - pk.PubrelEncode(buf) - } -} - -func TestPubrelDecode(t *testing.T) { - require.Contains(t, expectedPackets, Pubrel) - for i, wanted := range expectedPackets[Pubrel] { - if !decodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(6), Pubrel, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - - pk := &Packet{FixedHeader: FixedHeader{Type: Pubrel, Qos: 1}} - err := pk.PubrelDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. - - if wanted.failFirst != nil { - require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) - continue - } - - require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkPubrelDecode(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Pubrel, Qos: 1}} - pk.FixedHeader.Decode(expectedPackets[Pubrel][0].rawBytes[0]) - - for n := 0; n < b.N; n++ { - pk.PubrelDecode(expectedPackets[Pubrel][0].rawBytes[2:]) - } -} - -func TestSubackEncode(t *testing.T) { - require.Contains(t, expectedPackets, Suback) - for i, wanted := range expectedPackets[Suback] { - if !encodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(9), Suback, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - pk := new(Packet) - copier.Copy(pk, wanted.packet) - - require.Equal(t, Suback, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - - buf := new(bytes.Buffer) - err := pk.SubackEncode(buf) - require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) - encoded := buf.Bytes() - - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) - if wanted.meta != nil { - require.Equal(t, byte(Suback<<4)|wanted.meta.(byte), encoded[0], "Mismatched mod fixed header packets [i:%d] %s", i, wanted.desc) - } else { - require.Equal(t, byte(Suback<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) - } - - require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.ReturnCodes, pk.ReturnCodes, "Mismatched Return Codes [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkSubackEncode(b *testing.B) { - pk := new(Packet) - copier.Copy(pk, expectedPackets[Suback][0].packet) - - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - pk.SubackEncode(buf) - } -} - -func TestSubackDecode(t *testing.T) { - require.Contains(t, expectedPackets, Suback) - for i, wanted := range expectedPackets[Suback] { - if !decodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(9), Suback, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - - pk := &Packet{FixedHeader: FixedHeader{Type: Suback}} - err := pk.SubackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. - if wanted.failFirst != nil { - require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) - continue - } - - require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.ReturnCodes, pk.ReturnCodes, "Mismatched Return Codes [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkSubackDecode(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Suback}} - pk.FixedHeader.Decode(expectedPackets[Suback][0].rawBytes[0]) - - for n := 0; n < b.N; n++ { - pk.SubackDecode(expectedPackets[Suback][0].rawBytes[2:]) - } -} - -func TestSubscribeEncode(t *testing.T) { - require.Contains(t, expectedPackets, Subscribe) - for i, wanted := range expectedPackets[Subscribe] { - if !encodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(8), Subscribe, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - pk := new(Packet) - copier.Copy(pk, wanted.packet) - - require.Equal(t, Subscribe, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - - buf := new(bytes.Buffer) - err := pk.SubscribeEncode(buf) - encoded := buf.Bytes() - - if wanted.expect != nil { - require.Error(t, err, "Expected error writing buffer [i:%d] %s", i, wanted.desc) - } else { - require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) - if wanted.meta != nil { - require.Equal(t, byte(Subscribe<<4)|wanted.meta.(byte), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) - } else { - require.Equal(t, byte(Subscribe<<4), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) - } - - require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.Topics, pk.Topics, "Mismatched Topics slice [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.Qoss, pk.Qoss, "Mismatched Qoss slice [i:%d] %s", i, wanted.desc) - } - } -} - -func BenchmarkSubscribeEncode(b *testing.B) { - pk := new(Packet) - copier.Copy(pk, expectedPackets[Subscribe][0].packet) - - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - pk.SubscribeEncode(buf) - } -} - -func TestSubscribeDecode(t *testing.T) { - require.Contains(t, expectedPackets, Subscribe) - for i, wanted := range expectedPackets[Subscribe] { - if !decodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(8), Subscribe, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - - pk := &Packet{FixedHeader: FixedHeader{Type: Subscribe, Qos: 1}} - err := pk.SubscribeDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. - if wanted.failFirst != nil { - require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) - continue - } - - require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.Topics, pk.Topics, "Mismatched Topics slice [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.Qoss, pk.Qoss, "Mismatched Qoss slice [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkSubscribeDecode(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Subscribe, Qos: 1}} - pk.FixedHeader.Decode(expectedPackets[Subscribe][0].rawBytes[0]) - - for n := 0; n < b.N; n++ { - pk.SubscribeDecode(expectedPackets[Subscribe][0].rawBytes[2:]) - } -} - -func TestSubscribeValidate(t *testing.T) { - require.Contains(t, expectedPackets, Subscribe) - for i, wanted := range expectedPackets[Subscribe] { - if wanted.group == "validate" || i == 0 { - pk := wanted.packet - ok, err := pk.SubscribeValidate() - - if i == 0 { - require.NoError(t, err, "Subscribe should have validated - error incorrect [i:%d] %s", i, wanted.desc) - require.Equal(t, Accepted, ok, "Subscribe should have validated - code incorrect [i:%d] %s", i, wanted.desc) - } else { - require.Equal(t, Failed, ok, "Subscribe packet didn't validate - code incorrect [i:%d] %s", i, wanted.desc) - if err != nil { - require.Equal(t, wanted.expect, err, "Subscribe packet didn't validate - error incorrect [i:%d] %s", i, wanted.desc) - } - } - } - } -} - -func BenchmarkSubscribeValidate(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Subscribe, Qos: 1}} - pk.FixedHeader.Decode(expectedPackets[Subscribe][0].rawBytes[0]) - - for n := 0; n < b.N; n++ { - pk.SubscribeValidate() - } -} - -func TestUnsubackEncode(t *testing.T) { - require.Contains(t, expectedPackets, Unsuback) - for i, wanted := range expectedPackets[Unsuback] { - if !encodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(11), Unsuback, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - pk := new(Packet) - copier.Copy(pk, wanted.packet) - - require.Equal(t, Unsuback, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - - buf := new(bytes.Buffer) - err := pk.UnsubackEncode(buf) - require.NoError(t, err, "Expected no error writing buffer [i:%d] %s", i, wanted.desc) - encoded := buf.Bytes() - - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) - if wanted.meta != nil { - require.Equal(t, byte(Unsuback<<4)|wanted.meta.(byte), encoded[0], "Mismatched mod fixed header packets [i:%d] %s", i, wanted.desc) - } else { - require.Equal(t, byte(Unsuback<<4), encoded[0], "Mismatched fixed header packets [i:%d] %s", i, wanted.desc) - } - - require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkUnsubackEncode(b *testing.B) { - pk := new(Packet) - copier.Copy(pk, expectedPackets[Unsuback][0].packet) - - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - pk.UnsubackEncode(buf) - } -} - -func TestUnsubackDecode(t *testing.T) { - require.Contains(t, expectedPackets, Unsuback) - for i, wanted := range expectedPackets[Unsuback] { - if !decodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(11), Unsuback, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - - pk := &Packet{FixedHeader: FixedHeader{Type: Unsuback}} - err := pk.UnsubackDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. - if wanted.failFirst != nil { - require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) - continue - } - - require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkUnsubackDecode(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Unsuback}} - pk.FixedHeader.Decode(expectedPackets[Unsuback][0].rawBytes[0]) - - for n := 0; n < b.N; n++ { - pk.UnsubackDecode(expectedPackets[Unsuback][0].rawBytes[2:]) - } -} - -func TestUnsubscribeEncode(t *testing.T) { - require.Contains(t, expectedPackets, Unsubscribe) - for i, wanted := range expectedPackets[Unsubscribe] { - if !encodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(10), Unsubscribe, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - pk := new(Packet) - copier.Copy(pk, wanted.packet) - - require.Equal(t, Unsubscribe, pk.FixedHeader.Type, "Mismatched FixedHeader Type [i:%d] %s", i, wanted.desc) - - buf := new(bytes.Buffer) - err := pk.UnsubscribeEncode(buf) - encoded := buf.Bytes() - if wanted.expect != nil { - require.Error(t, err, "Expected error writing buffer [i:%d] %s", i, wanted.desc) - } else { - require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, len(wanted.rawBytes), len(encoded), "Mismatched packet length [i:%d] %s", i, wanted.desc) - if wanted.meta != nil { - require.Equal(t, byte(Unsubscribe<<4)|wanted.meta.(byte), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) - } else { - require.Equal(t, byte(Unsubscribe<<4), encoded[0], "Mismatched fixed header bytes [i:%d] %s", i, wanted.desc) - } - - require.NoError(t, err, "Error writing buffer [i:%d] %s", i, wanted.desc) - require.EqualValues(t, wanted.rawBytes, encoded, "Mismatched byte values [i:%d] %s", i, wanted.desc) - - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.Topics, pk.Topics, "Mismatched Topics slice [i:%d] %s", i, wanted.desc) - } - } -} - -func BenchmarkUnsubscribeEncode(b *testing.B) { - pk := new(Packet) - copier.Copy(pk, expectedPackets[Unsubscribe][0].packet) - - buf := new(bytes.Buffer) - for n := 0; n < b.N; n++ { - pk.UnsubscribeEncode(buf) - } -} - -func TestUnsubscribeDecode(t *testing.T) { - require.Contains(t, expectedPackets, Unsubscribe) - for i, wanted := range expectedPackets[Unsubscribe] { - if !decodeTestOK(wanted) { - continue - } - - require.Equal(t, uint8(10), Unsubscribe, "Incorrect Packet Type [i:%d] %s", i, wanted.desc) - - pk := &Packet{FixedHeader: FixedHeader{Type: Unsubscribe, Qos: 1}} - err := pk.UnsubscribeDecode(wanted.rawBytes[2:]) // Unpack skips fixedheader. - if wanted.failFirst != nil { - require.Error(t, err, "Expected error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.failFirst, err, "Expected fail state; %v [i:%d] %s", err.Error(), i, wanted.desc) - continue - } - - require.NoError(t, err, "Error unpacking buffer [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.PacketID, pk.PacketID, "Mismatched Packet ID [i:%d] %s", i, wanted.desc) - require.Equal(t, wanted.packet.Topics, pk.Topics, "Mismatched Topics slice [i:%d] %s", i, wanted.desc) - } -} - -func BenchmarkUnsubscribeDecode(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Unsubscribe, Qos: 1}} - pk.FixedHeader.Decode(expectedPackets[Unsubscribe][0].rawBytes[0]) - - for n := 0; n < b.N; n++ { - pk.UnsubscribeDecode(expectedPackets[Unsubscribe][0].rawBytes[2:]) - } -} - -func TestUnsubscribeValidate(t *testing.T) { - require.Contains(t, expectedPackets, Unsubscribe) - for i, wanted := range expectedPackets[Unsubscribe] { - if wanted.group == "validate" || i == 0 { - pk := wanted.packet - ok, err := pk.UnsubscribeValidate() - if i == 0 { - require.NoError(t, err, "Unsubscribe should have validated - error incorrect [i:%d] %s", i, wanted.desc) - require.Equal(t, Accepted, ok, "Unsubscribe should have validated - code incorrect [i:%d] %s", i, wanted.desc) - } else { - require.Equal(t, Failed, ok, "Unsubscribe packet didn't validate - code incorrect [i:%d] %s", i, wanted.desc) - if err != nil { - require.Equal(t, wanted.expect, err, "Unsubscribe packet didn't validate - error incorrect [i:%d] %s", i, wanted.desc) - } - } - } - } -} - -func BenchmarkUnsubscribeValidate(b *testing.B) { - pk := &Packet{FixedHeader: FixedHeader{Type: Unsubscribe, Qos: 1}} - pk.FixedHeader.Decode(expectedPackets[Unsubscribe][0].rawBytes[0]) - - for n := 0; n < b.N; n++ { - pk.UnsubscribeValidate() - } -} diff --git a/server/server_test.go b/server/server_test.go index fdec2ee6f4690b34ff3eb895125fcf2e9db75274..d8785993a2388bbbbb42c05f00eba701ee23e1fd 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -13,10 +13,10 @@ "github.com/stretchr/testify/require" "github.com/mochi-co/mqtt/server/internal/circ" "github.com/mochi-co/mqtt/server/internal/clients" + "github.com/mochi-co/mqtt/server/internal/packets" "github.com/mochi-co/mqtt/server/internal/topics" "github.com/mochi-co/mqtt/server/listeners" "github.com/mochi-co/mqtt/server/listeners/auth" - "github.com/mochi-co/mqtt/server/packets" "github.com/mochi-co/mqtt/server/persistence" "github.com/mochi-co/mqtt/server/system" )