~/Projects/mqtt-go
git clone https://code.lsong.org/mqtt-go
Commit
- Commit
- c77d1c03313f098d1ef5eb535ca469bcdc3cab45
- Author
- thedevop <60499013+[email protected]>
- Date
- 2023-07-13 13:19:22 -0400 -0400
- Diffstat
server.go | 4 ++++ server_test.go | 31 ++++++++++++++++++++++++++++++-
Ensure msg doesn't exceed subscription QoS (#253) Co-authored-by: JB <28275108+[email protected]>
diff --git a/server.go b/server.go index cdd037e9f78fb7e403277416d5a5de635430d9cf..d32bf33261c08c9e091edc2b2d23d7c252dc48f4 100644 --- a/server.go +++ b/server.go @@ -818,6 +818,10 @@ } sort.Ints(out.Properties.SubscriptionIdentifier) } + if out.FixedHeader.Qos > sub.Qos { + out.FixedHeader.Qos = sub.Qos + } + if out.FixedHeader.Qos > s.Options.Capabilities.MaximumQos { out.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9] } diff --git a/server_test.go b/server_test.go index f5a778b904f9bd3065e1a5ad9b91a35d161352ac..3ce8e5a71ce3f4b24a7be1b4cda33fb5205417f0 100644 --- a/server_test.go +++ b/server_test.go @@ -1566,6 +1566,35 @@ require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, <-receiverBuf) } +func TestPublishToClientSubscriptionDowngradeQos(t *testing.T) { + s := newServer() + s.Options.Capabilities.MaximumQos = 2 + + cl, r, w := newTestClient() + s.Clients.Add(cl) + + _, ok := cl.State.Inflight.Get(1) + require.False(t, ok) + cl.State.packetID = 6 // just to match the same packet id (7) in the fixtures + + go func() { + pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet + pkx.FixedHeader.Qos = 2 + s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, pkx) + time.Sleep(time.Microsecond * 100) + w.Close() + }() + + receiverBuf := make(chan []byte) + go func() { + buf, err := io.ReadAll(r) + require.NoError(t, err) + receiverBuf <- buf + }() + + require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, <-receiverBuf) +} + func TestPublishToClientExceedClientWritesPending(t *testing.T) { s := newServer() @@ -1630,7 +1659,7 @@ for i := uint32(0); i <= cl.ops.options.Capabilities.maximumPacketID; i++ { cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)}) } - _, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) + _, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet) require.Error(t, err) require.ErrorIs(t, err, packets.ErrQuotaExceeded) }