Liu Song’s Projects


~/Projects/mqtt-go

git clone https://code.lsong.org/mqtt-go

Commit

Commit
c77d1c03313f098d1ef5eb535ca469bcdc3cab45
Author
thedevop <60499013+[email protected]>
Date
2023-07-13 13:19:22 -0400 -0400
Diffstat
 server.go | 4 ++++
 server_test.go | 31 ++++++++++++++++++++++++++++++-

Ensure msg doesn't exceed subscription QoS (#253)

Co-authored-by: JB <28275108+[email protected]>


diff --git a/server.go b/server.go
index cdd037e9f78fb7e403277416d5a5de635430d9cf..d32bf33261c08c9e091edc2b2d23d7c252dc48f4 100644
--- a/server.go
+++ b/server.go
@@ -818,6 +818,10 @@ 		}
 		sort.Ints(out.Properties.SubscriptionIdentifier)
 	}
 
+	if out.FixedHeader.Qos > sub.Qos {
+		out.FixedHeader.Qos = sub.Qos
+	}
+
 	if out.FixedHeader.Qos > s.Options.Capabilities.MaximumQos {
 		out.FixedHeader.Qos = s.Options.Capabilities.MaximumQos // [MQTT-3.2.2-9]
 	}




diff --git a/server_test.go b/server_test.go
index f5a778b904f9bd3065e1a5ad9b91a35d161352ac..3ce8e5a71ce3f4b24a7be1b4cda33fb5205417f0 100644
--- a/server_test.go
+++ b/server_test.go
@@ -1566,6 +1566,35 @@
 	require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, <-receiverBuf)
 }
 
+func TestPublishToClientSubscriptionDowngradeQos(t *testing.T) {
+	s := newServer()
+	s.Options.Capabilities.MaximumQos = 2
+
+	cl, r, w := newTestClient()
+	s.Clients.Add(cl)
+
+	_, ok := cl.State.Inflight.Get(1)
+	require.False(t, ok)
+	cl.State.packetID = 6 // just to match the same packet id (7) in the fixtures
+
+	go func() {
+		pkx := *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet
+		pkx.FixedHeader.Qos = 2
+		s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, pkx)
+		time.Sleep(time.Microsecond * 100)
+		w.Close()
+	}()
+
+	receiverBuf := make(chan []byte)
+	go func() {
+		buf, err := io.ReadAll(r)
+		require.NoError(t, err)
+		receiverBuf <- buf
+	}()
+
+	require.Equal(t, packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).RawBytes, <-receiverBuf)
+}
+
 func TestPublishToClientExceedClientWritesPending(t *testing.T) {
 	s := newServer()
 
@@ -1630,7 +1659,7 @@ 	for i := uint32(0); i <= cl.ops.options.Capabilities.maximumPacketID; i++ {
 		cl.State.Inflight.Set(packets.Packet{PacketID: uint16(i)})
 	}
 
-	_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c"}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
+	_, err := s.publishToClient(cl, packets.Subscription{Filter: "a/b/c", Qos: 1}, *packets.TPacketData[packets.Publish].Get(packets.TPublishQos1).Packet)
 	require.Error(t, err)
 	require.ErrorIs(t, err, packets.ErrQuotaExceeded)
 }