diff --git a/server_play_test.go b/server_play_test.go index f54c354e..5305ae54 100644 --- a/server_play_test.go +++ b/server_play_test.go @@ -319,6 +319,7 @@ func TestServerPlaySetupErrors(t *testing.T) { "different paths", "double setup", "closed stream", + "different protocols", } { t.Run(ca, func(t *testing.T) { var stream *ServerStream @@ -336,6 +337,9 @@ func TestServerPlaySetupErrors(t *testing.T) { case "closed stream": require.EqualError(t, ctx.Error, "stream is closed") + + case "different protocols": + require.EqualError(t, ctx.Error, "can't setup medias with different protocols") } close(nconnClosed) }, @@ -350,7 +354,9 @@ func TestServerPlaySetupErrors(t *testing.T) { }, stream, nil }, }, - RTSPAddress: "localhost:8554", + RTSPAddress: "localhost:8554", + UDPRTPAddress: "127.0.0.1:8000", + UDPRTCPAddress: "127.0.0.1:8001", } err := s.Start() @@ -372,10 +378,10 @@ func TestServerPlaySetupErrors(t *testing.T) { desc := doDescribe(t, conn) th := &headers.Transport{ - Protocol: headers.TransportProtocolTCP, - Delivery: deliveryPtr(headers.TransportDeliveryUnicast), - Mode: transportModePtr(headers.TransportModePlay), - InterleavedIDs: &[2]int{0, 1}, + Protocol: headers.TransportProtocolUDP, + Delivery: deliveryPtr(headers.TransportDeliveryUnicast), + Mode: transportModePtr(headers.TransportModePlay), + ClientPorts: &[2]int{35466, 35467}, } res, err := writeReqReadRes(conn, base.Request{ @@ -387,14 +393,16 @@ func TestServerPlaySetupErrors(t *testing.T) { }, }) - switch ca { - case "different paths": + if ca != "closed stream" { require.NoError(t, err) require.Equal(t, base.StatusOK, res.StatusCode) + } + switch ca { + case "different paths": session := readSession(t, res) - th.InterleavedIDs = &[2]int{2, 3} + th.ClientPorts = &[2]int{35468, 35469} res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, @@ -409,12 +417,9 @@ func TestServerPlaySetupErrors(t *testing.T) { require.Equal(t, base.StatusBadRequest, res.StatusCode) case "double setup": - require.NoError(t, err) - require.Equal(t, base.StatusOK, res.StatusCode) - session := readSession(t, res) - th.InterleavedIDs = &[2]int{2, 3} + th.ClientPorts = &[2]int{35468, 35469} res, err = writeReqReadRes(conn, base.Request{ Method: base.Setup, @@ -431,6 +436,24 @@ func TestServerPlaySetupErrors(t *testing.T) { case "closed stream": require.NoError(t, err) require.Equal(t, base.StatusBadRequest, res.StatusCode) + + case "different protocols": + session := readSession(t, res) + + th.Protocol = headers.TransportProtocolTCP + th.InterleavedIDs = &[2]int{0, 1} + + res, err = writeReqReadRes(conn, base.Request{ + Method: base.Setup, + URL: mediaURL(t, desc.BaseURL, desc.Medias[0]), + Header: base.Header{ + "CSeq": base.HeaderValue{"4"}, + "Transport": th.Marshal(), + "Session": base.HeaderValue{session}, + }, + }) + require.NoError(t, err) + require.Equal(t, base.StatusBadRequest, res.StatusCode) } <-nconnClosed diff --git a/server_session.go b/server_session.go index fc213b41..b48a36d0 100644 --- a/server_session.go +++ b/server_session.go @@ -662,15 +662,15 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( }, err } - var inTSH headers.Transports - err = inTSH.Unmarshal(req.Header["Transport"]) + var transportHeaders headers.Transports + err = transportHeaders.Unmarshal(req.Header["Transport"]) if err != nil { return &base.Response{ StatusCode: base.StatusBadRequest, }, liberrors.ErrServerTransportHeaderInvalid{Err: err} } - inTH := findFirstSupportedTransportHeader(ss.s, inTSH) + inTH := findFirstSupportedTransportHeader(ss.s, transportHeaders) if inTH == nil { return &base.Response{ StatusCode: base.StatusUnsupportedTransport, @@ -706,16 +706,26 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( transport = TransportUDPMulticast } else { transport = TransportUDP - - if inTH.ClientPorts == nil { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerTransportHeaderNoClientPorts{} - } } } else { transport = TransportTCP + } + + if ss.setuppedTransport != nil && *ss.setuppedTransport != transport { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, liberrors.ErrServerMediasDifferentProtocols{} + } + switch transport { + case TransportUDP: + if inTH.ClientPorts == nil { + return &base.Response{ + StatusCode: base.StatusBadRequest, + }, liberrors.ErrServerTransportHeaderNoClientPorts{} + } + + case TransportTCP: if inTH.InterleavedIDs != nil { if (inTH.InterleavedIDs[0] + 1) != inTH.InterleavedIDs[1] { return &base.Response{ @@ -731,12 +741,6 @@ func (ss *ServerSession) handleRequestInner(sc *ServerConn, req *base.Request) ( } } - if ss.setuppedTransport != nil && *ss.setuppedTransport != transport { - return &base.Response{ - StatusCode: base.StatusBadRequest, - }, liberrors.ErrServerMediasDifferentProtocols{} - } - switch ss.state { case ServerSessionStateInitial, ServerSessionStatePrePlay: // play if inTH.Mode != nil && *inTH.Mode != headers.TransportModePlay {