diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index 7e10bcf5..b6ea0d03 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -43,6 +43,8 @@ var ( ErrIngressClosing = psrpc.NewErrorf(psrpc.Unavailable, "ingress closing") ErrMissingStreamKey = psrpc.NewErrorf(psrpc.InvalidArgument, "missing stream key") ErrPrerollBufferReset = psrpc.NewErrorf(psrpc.Internal, "preroll buffer reset") + ErrInvalidSimulcast = psrpc.NewErrorf(psrpc.NotAcceptable, "invalid simulcast configuration") + ErrSimulcastTranscode = psrpc.NewErrorf(psrpc.NotAcceptable, "simulcast is not supported when transcoding") ) func New(err string) error { diff --git a/pkg/whip/sdk_media_sink.go b/pkg/whip/sdk_media_sink.go index 622a366d..a019a0c5 100644 --- a/pkg/whip/sdk_media_sink.go +++ b/pkg/whip/sdk_media_sink.go @@ -19,6 +19,7 @@ import ( "context" "io" "strings" + "sync" "time" "github.com/Eyevinn/mp4ff/avc" @@ -41,17 +42,30 @@ var ( ErrParamsUnavailable = psrpc.NewErrorf(psrpc.InvalidArgument, "codec parameters unavailable in sample") ) +type SDKMediaSinkTrack struct { + readySamples chan *sample + writePLI func() + + quality livekit.VideoQuality + width, height uint + + sink *SDKMediaSink +} + type SDKMediaSink struct { - logger logger.Logger - params *params.Params - writePLI func() - track *webrtc.TrackRemote - outputSync *utils.TrackOutputSynchronizer - sdkOutput *lksdk_output.LKSDKOutput - - readySamples chan *sample - fuse core.Fuse - trackInitialized bool + logger logger.Logger + params *params.Params + outputSync *utils.TrackOutputSynchronizer + sdkOutput *lksdk_output.LKSDKOutput + sinkInitialized bool + + codecParameters webrtc.RTPCodecParameters + streamKind types.StreamKind + + tracksLock sync.Mutex + tracks []*SDKMediaSinkTrack + + fuse core.Fuse } type sample struct { @@ -59,150 +73,193 @@ type sample struct { ts time.Duration } -func NewSDKMediaSink(l logger.Logger, p *params.Params, sdkOutput *lksdk_output.LKSDKOutput, track *webrtc.TrackRemote, outputSync *utils.TrackOutputSynchronizer, writePLI func()) *SDKMediaSink { - s := &SDKMediaSink{ - logger: l, - params: p, - writePLI: writePLI, - track: track, - outputSync: outputSync, - sdkOutput: sdkOutput, +func NewSDKMediaSink( + l logger.Logger, p *params.Params, sdkOutput *lksdk_output.LKSDKOutput, + codecParameters webrtc.RTPCodecParameters, streamKind types.StreamKind, + outputSync *utils.TrackOutputSynchronizer, +) *SDKMediaSink { + return &SDKMediaSink{ + logger: l, + params: p, + outputSync: outputSync, + sdkOutput: sdkOutput, + fuse: core.NewFuse(), + tracks: []*SDKMediaSinkTrack{}, + streamKind: streamKind, + codecParameters: codecParameters, + } +} + +func (sp *SDKMediaSink) AddTrack(quality livekit.VideoQuality) { + sp.tracksLock.Lock() + defer sp.tracksLock.Unlock() + + sp.tracks = append(sp.tracks, &SDKMediaSinkTrack{ readySamples: make(chan *sample, 15), - fuse: core.NewFuse(), + sink: sp, + quality: quality, + }) +} + +func (sp *SDKMediaSink) SetWritePLI(quality livekit.VideoQuality, writePLI func()) *SDKMediaSinkTrack { + sp.tracksLock.Lock() + defer sp.tracksLock.Unlock() + + for i := range sp.tracks { + if sp.tracks[i].quality == quality { + sp.tracks[i].writePLI = writePLI + return sp.tracks[i] + } } - return s + return nil } -func (sp *SDKMediaSink) PushSample(s *media.Sample, ts time.Duration) error { - if sp.fuse.IsBroken() { - return io.EOF +func (sp *SDKMediaSink) Close() error { + sp.fuse.Break() + sp.outputSync.Close() + + return nil +} + +func (sp *SDKMediaSink) ensureTracksInitialized(s *media.Sample, t *SDKMediaSinkTrack) (bool, error) { + sp.tracksLock.Lock() + defer sp.tracksLock.Unlock() + + if sp.sinkInitialized { + return sp.sinkInitialized, nil } - err := sp.ensureTrackInitialized(s) - if err != nil { - return err + if sp.streamKind == types.Audio { + stereo := strings.Contains(sp.codecParameters.SDPFmtpLine, "sprop-stereo=1") + audioState := getAudioState(sp.codecParameters.MimeType, stereo, sp.codecParameters.ClockRate) + sp.params.SetInputAudioState(context.Background(), audioState, true) + + sp.logger.Infow("adding audio track", "stereo", stereo, "codec", sp.codecParameters.MimeType) + if err := sp.sdkOutput.AddAudioTrack(t, sp.codecParameters.MimeType, false, stereo); err != nil { + return false, err + } + sp.sinkInitialized = true + return sp.sinkInitialized, nil } - if !sp.trackInitialized { - // Drop the sample - return nil + + var err error + t.width, t.height, err = getVideoParams(sp.codecParameters.MimeType, s) + switch err { + case nil: + // continue + case ErrParamsUnavailable: + return false, nil + default: + return false, err } - // Synchronize the outputs before the network jitter buffer to avoid old samples stuck - // in the channel from increasing the whole pipeline delay. - drop, err := sp.outputSync.WaitForMediaTime(ts) - if err != nil { - return err + layers := []*livekit.VideoLayer{} + sampleProviders := []lksdk_output.VideoSampleProvider{} + + for _, track := range sp.tracks { + if track.width != 0 && track.height != 0 { + layers = append(layers, &livekit.VideoLayer{ + Width: uint32(track.width), + Height: uint32(track.height), + Quality: track.quality, + }) + sampleProviders = append(sampleProviders, track) + } } - if drop { - sp.logger.Debugw("dropping sample", "timestamp", ts) - return nil + + if len(layers) == 0 && len(sp.tracks) != 1 { + return false, nil + } else if len(layers) != len(sp.tracks) { + return false, nil } - select { - case <-sp.fuse.Watch(): - return io.EOF - case sp.readySamples <- &sample{s, ts}: - default: - // drop the sample if the output queue is full. This is needed if we are reconnecting. + if len(layers) != 0 { + videoState := getVideoState(sp.codecParameters.MimeType, uint(layers[0].Width), uint(layers[0].Height)) + sp.params.SetInputVideoState(context.Background(), videoState, true) } - return nil + if err := sp.sdkOutput.AddVideoTrack(sampleProviders, layers, sp.codecParameters.MimeType); err != nil { + return false, err + } + + for _, l := range layers { + sp.logger.Infow("adding video track", "width", l.Width, "height", l.Height, "codec", sp.codecParameters.MimeType) + } + sp.sinkInitialized = true + + return sp.sinkInitialized, nil } -func (sp *SDKMediaSink) NextSample(ctx context.Context) (media.Sample, error) { +func (t *SDKMediaSinkTrack) NextSample(ctx context.Context) (media.Sample, error) { for { select { - case <-sp.fuse.Watch(): + case <-t.sink.fuse.Watch(): case <-ctx.Done(): return media.Sample{}, io.EOF - case s := <-sp.readySamples: + case s := <-t.readySamples: return *s.s, nil } } } -func (sp *SDKMediaSink) OnBind() error { - sp.logger.Infow("media sink bound") - - return nil -} +func (t *SDKMediaSinkTrack) PushSample(s *media.Sample, ts time.Duration) error { + if t.sink.fuse.IsBroken() { + return io.EOF + } -func (sp *SDKMediaSink) OnUnbind() error { - sp.logger.Infow("media sink unbound") + tracksInitialized, err := t.sink.ensureTracksInitialized(s, t) + if err != nil { + return err + } else if !tracksInitialized { + // Drop the sample + return nil + } - return nil -} + // Synchronize the outputs before the network jitter buffer to avoid old samples stuck + // in the channel from increasing the whole pipeline delay. + drop, err := t.sink.outputSync.WaitForMediaTime(ts) + if err != nil { + return err + } + if drop { + t.sink.logger.Debugw("dropping sample", "timestamp", ts) + return nil + } -func (sp *SDKMediaSink) ForceKeyFrame() error { - if sp.writePLI != nil { - sp.writePLI() + select { + case <-t.sink.fuse.Watch(): + return io.EOF + case t.readySamples <- &sample{s, ts}: + default: + // drop the sample if the output queue is full. This is needed if we are reconnecting. } return nil } -func (sp *SDKMediaSink) SetWriter(w io.WriteCloser) error { - return psrpc.Unimplemented +func (t *SDKMediaSinkTrack) Close() error { + return t.sink.Close() } -func (sp *SDKMediaSink) Close() error { - sp.fuse.Break() - sp.outputSync.Close() - +func (t *SDKMediaSinkTrack) OnBind() error { + t.sink.logger.Infow("media sink bound") return nil } -func (sp *SDKMediaSink) ensureTrackInitialized(s *media.Sample) error { - if sp.trackInitialized { - return nil - } - - kind := streamKindFromCodecType(sp.track.Kind()) - mimeType := sp.track.Codec().MimeType - - switch kind { - case types.Audio: - stereo := parseAudioFmtp(sp.track.Codec().SDPFmtpLine) - audioState := getAudioState(sp.track.Codec().MimeType, stereo, sp.track.Codec().ClockRate) - sp.params.SetInputAudioState(context.Background(), audioState, true) - - sp.logger.Infow("adding audio track", "stereo", stereo, "codec", mimeType) - sp.sdkOutput.AddAudioTrack(sp, mimeType, false, stereo) - case types.Video: - w, h, err := getVideoParams(mimeType, s) - switch err { - case nil: - // continue - case ErrParamsUnavailable: - return nil - default: - return err - } - - layers := []*livekit.VideoLayer{ - &livekit.VideoLayer{Width: uint32(w), Height: uint32(h), Quality: livekit.VideoQuality_HIGH}, - } - s := []lksdk_output.VideoSampleProvider{ - sp, - } - - videoState := getVideoState(sp.track.Codec().MimeType, w, h) - sp.params.SetInputVideoState(context.Background(), videoState, true) +func (t *SDKMediaSinkTrack) OnUnbind() error { + t.sink.logger.Infow("media sink unbound") + return nil +} - sp.logger.Infow("adding video track", "width", w, "height", h, "codec", mimeType) - sp.sdkOutput.AddVideoTrack(s, layers, mimeType) +func (t *SDKMediaSinkTrack) ForceKeyFrame() error { + if t.writePLI != nil { + t.writePLI() } - sp.trackInitialized = true - return nil } -func parseAudioFmtp(audioFmtp string) bool { - return strings.Index(audioFmtp, "sprop-stereo=1") >= 0 -} - func getVideoParams(mimeType string, s *media.Sample) (uint, uint, error) { switch strings.ToLower(mimeType) { case strings.ToLower(webrtc.MimeTypeH264): diff --git a/pkg/whip/whip_handler.go b/pkg/whip/whip_handler.go index bad92953..00fcab71 100644 --- a/pkg/whip/whip_handler.go +++ b/pkg/whip/whip_handler.go @@ -17,12 +17,14 @@ package whip import ( "context" "io" + "strings" "sync" "time" "github.com/pion/dtls/v2/pkg/crypto/elliptic" "github.com/pion/interceptor" "github.com/pion/rtcp" + "github.com/pion/sdp/v3" "github.com/pion/webrtc/v3" google_protobuf2 "google.golang.org/protobuf/types/known/emptypb" @@ -64,11 +66,16 @@ type whipHandler struct { result chan error closeOnce sync.Once - trackLock sync.Mutex - tracks map[string]*webrtc.TrackRemote - trackHandlers map[types.StreamKind]*whipTrackHandler + trackLock sync.Mutex + simulcastLayers []string + tracks map[string]*webrtc.TrackRemote + trackHandlers []*whipTrackHandler + trackAddedChan chan *webrtc.TrackRemote + + trackSDKMediaSinkLock sync.Mutex + trackSDKMediaSink map[types.StreamKind]*SDKMediaSink + trackRelayMediaSink map[types.StreamKind]*RelayMediaSink // only for transcoding mode - trackAddedChan chan *webrtc.TrackRemote } func NewWHIPHandler(webRTCConfig *rtcconfig.WebRTCConfig) *whipHandler { @@ -81,8 +88,9 @@ func NewWHIPHandler(webRTCConfig *rtcconfig.WebRTCConfig) *whipHandler { outputSync: utils.NewOutputSynchronizer(), result: make(chan error, 1), tracks: make(map[string]*webrtc.TrackRemote), - trackHandlers: make(map[types.StreamKind]*whipTrackHandler), + trackHandlers: []*whipTrackHandler{}, trackRelayMediaSink: make(map[types.StreamKind]*RelayMediaSink), + trackSDKMediaSink: make(map[types.StreamKind]*SDKMediaSink), } } @@ -94,22 +102,25 @@ func (h *whipHandler) Init(ctx context.Context, p *params.Params, sdpOffer strin h.updateSettings() + offer := &webrtc.SessionDescription{ + Type: webrtc.SDPTypeOffer, + SDP: sdpOffer, + } + h.expectedTrackCount, err = h.validateOfferAndGetExpectedTrackCount(offer) + if err != nil { + return "", err + } + if p.BypassTranscoding { h.sdkOutput, err = lksdk_output.NewLKSDKOutput(ctx, p) if err != nil { return "", err } + } else if len(h.simulcastLayers) != 0 { + return "", errors.ErrSimulcastTranscode } - offer := &webrtc.SessionDescription{ - Type: webrtc.SDPTypeOffer, - SDP: sdpOffer, - } - h.expectedTrackCount, err = validateOfferAndGetExpectedTrackCount(offer) h.trackAddedChan = make(chan *webrtc.TrackRemote, h.expectedTrackCount) - if err != nil { - return "", err - } m, err := newMediaEngine() if err != nil { @@ -351,9 +362,6 @@ func (h *whipHandler) getSDPAnswer(ctx context.Context, offer *webrtc.SessionDes if err != nil { return "", err } - if len(parsedAnswer.MediaDescriptions) != h.expectedTrackCount { - return "", errors.ErrUnsupportedDecodeFormat - } for _, m := range parsedAnswer.MediaDescriptions { // Pion puts a media description with fmt = 0 and no attributes for unsupported codecs if len(m.Attributes) == 0 { @@ -379,7 +387,21 @@ func (h *whipHandler) addTrack(track *webrtc.TrackRemote, receiver *webrtc.RTPRe sync := h.sync.AddTrack(track, whipIdentity) - mediaSink, err := h.newMediaSink(track) + trackQuality := livekit.VideoQuality_HIGH + if track.RID() != "" { + for i, expectedRid := range h.simulcastLayers { + if expectedRid == track.RID() { + switch i { + case 1: + trackQuality = livekit.VideoQuality_MEDIUM + case 2: + trackQuality = livekit.VideoQuality_LOW + } + } + } + } + + mediaSink, err := h.getMediaSink(track, trackQuality) if err != nil { logger.Warnw("failed creating whip media handler", err) return @@ -390,7 +412,7 @@ func (h *whipHandler) addTrack(track *webrtc.TrackRemote, receiver *webrtc.RTPRe logger.Warnw("failed creating whip track handler", err) return } - h.trackHandlers[kind] = th + h.trackHandlers = append(h.trackHandlers, th) select { case h.trackAddedChan <- track: @@ -399,12 +421,27 @@ func (h *whipHandler) addTrack(track *webrtc.TrackRemote, receiver *webrtc.RTPRe } } -func (h *whipHandler) newMediaSink(track *webrtc.TrackRemote) (MediaSink, error) { +func (h *whipHandler) getMediaSink(track *webrtc.TrackRemote, trackQuality livekit.VideoQuality) (MediaSink, error) { kind := streamKindFromCodecType(track.Kind()) if h.sdkOutput != nil { - // pasthrough - return NewSDKMediaSink(h.logger, h.params, h.sdkOutput, track, h.outputSync.AddTrack(), func() { + h.trackSDKMediaSinkLock.Lock() + defer h.trackSDKMediaSinkLock.Unlock() + + if _, ok := h.trackSDKMediaSink[kind]; !ok { + h.trackSDKMediaSink[kind] = NewSDKMediaSink(h.logger, h.params, h.sdkOutput, track.Codec(), streamKindFromCodecType(track.Kind()), h.outputSync.AddTrack()) + + layers := []livekit.VideoQuality{livekit.VideoQuality_HIGH} + if kind == types.Video && len(h.simulcastLayers) == 3 { + layers = []livekit.VideoQuality{livekit.VideoQuality_HIGH, livekit.VideoQuality_MEDIUM, livekit.VideoQuality_LOW} + } + + for _, layer := range layers { + h.trackSDKMediaSink[kind].AddTrack(layer) + } + } + + return h.trackSDKMediaSink[kind].SetWritePLI(trackQuality, func() { h.writePLI(track.SSRC()) }), nil } else { @@ -437,22 +474,54 @@ func streamKindFromCodecType(typ webrtc.RTPCodecType) types.StreamKind { } } -func validateOfferAndGetExpectedTrackCount(offer *webrtc.SessionDescription) (int, error) { +func (h *whipHandler) validateOfferAndGetExpectedTrackCount(offer *webrtc.SessionDescription) (int, error) { parsed, err := offer.Unmarshal() if err != nil { return 0, err } - mediaTypes := make(map[string]struct{}) + audioCount, videoCount := 0, 0 + for _, m := range parsed.MediaDescriptions { - if _, ok := mediaTypes[m.MediaName.Media]; ok { + if types.StreamKind(m.MediaName.Media) == types.Audio { // Duplicate track for a given type. Forbidden by the RFC - return 0, errors.ErrDuplicateTrack + if audioCount != 0 { + return 0, errors.ErrDuplicateTrack + } + + audioCount++ + + } else if types.StreamKind(m.MediaName.Media) == types.Video { + // Duplicate track for a given type. Forbidden by the RFC + if videoCount != 0 { + return 0, errors.ErrDuplicateTrack + } + + for _, a := range m.Attributes { + if a.Key == "simulcast" { + spaceSplit := strings.Split(a.Value, " ") + if len(spaceSplit) != 2 || spaceSplit[0] != "send" { + return 0, errors.ErrInvalidSimulcast + } + + layersSplit := strings.Split(spaceSplit[1], ";") + if len(layersSplit) != 3 { + return 0, errors.ErrInvalidSimulcast + } + + h.simulcastLayers = layersSplit + videoCount += 3 + } + } + + // No Simulcast + if videoCount == 0 { + videoCount++ + } } - mediaTypes[m.MediaName.Media] = struct{}{} } - return len(parsed.MediaDescriptions), nil + return audioCount + videoCount, nil } func newMediaEngine() (*webrtc.MediaEngine, error) { @@ -490,6 +559,14 @@ func newMediaEngine() (*webrtc.MediaEngine, error) { } } + if err := m.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: sdp.SDESMidURI}, webrtc.RTPCodecTypeVideo); err != nil { + return nil, err + } + + if err := m.RegisterHeaderExtension(webrtc.RTPHeaderExtensionCapability{URI: sdp.SDESRTPStreamIDURI}, webrtc.RTPCodecTypeVideo); err != nil { + return nil, err + } + return m, nil }