Skip to content

Commit

Permalink
Implement Interceptors
Browse files Browse the repository at this point in the history
Provide API so that handling around RTP can be easily defined by the
user. See the design doc here[0]

[0] pion/webrtc-v3-design#34
  • Loading branch information
masterada authored and Sean-Der committed Nov 18, 2020
1 parent 1ffa87e commit 3b8d4ed
Show file tree
Hide file tree
Showing 14 changed files with 382 additions and 41 deletions.
18 changes: 18 additions & 0 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
type API struct {
settingEngine *SettingEngine
mediaEngine *MediaEngine
interceptors []Interceptor
interceptor Interceptor
}

// NewAPI Creates a new API object for keeping semi-global settings to WebRTC objects
Expand All @@ -35,6 +37,10 @@ func NewAPI(options ...func(*API)) *API {
a.mediaEngine = &MediaEngine{}
}

if len(a.interceptors) > 0 {
a.interceptor = newInterceptorChain(a.interceptors, a.settingEngine.LoggerFactory.NewLogger("interceptor"))
}

return a
}

Expand All @@ -53,3 +59,15 @@ func WithSettingEngine(s SettingEngine) func(a *API) {
a.settingEngine = &s
}
}

func WithInterceptor(interceptor Interceptor) func(a *API) {
return func(a *API) {
a.interceptors = append(a.interceptors, interceptor)
}
}

func ClearInterceptors() func(a *API) {
return func(a *API) {
a.interceptors = nil
}
}
54 changes: 54 additions & 0 deletions interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
// +build !js

package webrtc

import (
"io"

"github.com/pion/rtcp"
"github.com/pion/rtp"
)

type WriteRTP func(p *rtp.Packet, attributes map[interface{}]interface{}) (int, error)
type ReadRTP func() (*rtp.Packet, map[interface{}]interface{}, error)
type WriteRTCP func(pkts []rtcp.Packet, attributes map[interface{}]interface{}) (int, error)
type ReadRTCP func() ([]rtcp.Packet, map[interface{}]interface{}, error)

type Interceptor interface {
BindReadRTCP(read ReadRTCP) ReadRTCP // TODO: call this
BindWriteRTCP(write WriteRTCP) WriteRTCP // TODO: call this

BindLocalTrack(ctx *TrackLocalContext, write WriteRTP) WriteRTP
UnbindLocalTrack(ctx *TrackLocalContext)

BindRemoteTrack(ctx *TrackRemoteContext, read ReadRTP) ReadRTP
UnbindRemoteTrack(ctx *TrackRemoteContext)

io.Closer
}

type DummyInterceptor struct{}

func (d *DummyInterceptor) BindReadRTCP(read ReadRTCP) ReadRTCP {
return read
}

func (d *DummyInterceptor) BindWriteRTCP(write WriteRTCP) WriteRTCP {
return write
}

func (d *DummyInterceptor) BindLocalTrack(_ *TrackLocalContext, write WriteRTP) WriteRTP {
return write
}

func (d *DummyInterceptor) UnbindLocalTrack(_ *TrackLocalContext) {}

func (d *DummyInterceptor) BindRemoteTrack(_ *TrackRemoteContext, read ReadRTP) ReadRTP {
return read
}

func (d *DummyInterceptor) UnbindRemoteTrack(_ *TrackRemoteContext) {}

func (d *DummyInterceptor) Close() error {
return nil
}
75 changes: 75 additions & 0 deletions interceptor_chain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// +build !js

package webrtc

import (
"github.com/pion/logging"
)

type interceptorChain struct {
interceptors []Interceptor
log logging.LeveledLogger
}

func newInterceptorChain(interceptors []Interceptor, log logging.LeveledLogger) Interceptor {
return &interceptorChain{interceptors: interceptors, log: log}
}

func (i *interceptorChain) BindReadRTCP(read ReadRTCP) ReadRTCP {
for _, interceptor := range i.interceptors {
read = interceptor.BindReadRTCP(read)
}

return read
}

func (i *interceptorChain) BindWriteRTCP(write WriteRTCP) WriteRTCP {
for _, interceptor := range i.interceptors {
write = interceptor.BindWriteRTCP(write)
}

return write
}

func (i *interceptorChain) BindLocalTrack(ctx *TrackLocalContext, write WriteRTP) WriteRTP {
for _, interceptor := range i.interceptors {
write = interceptor.BindLocalTrack(ctx, write)
}

return write
}

func (i *interceptorChain) UnbindLocalTrack(ctx *TrackLocalContext) {
for _, interceptor := range i.interceptors {
interceptor.UnbindLocalTrack(ctx)
}
}

func (i *interceptorChain) BindRemoteTrack(ctx *TrackRemoteContext, read ReadRTP) ReadRTP {
for _, interceptor := range i.interceptors {
read = interceptor.BindRemoteTrack(ctx, read)
}

return read
}

func (i *interceptorChain) UnbindRemoteTrack(ctx *TrackRemoteContext) {
for _, interceptor := range i.interceptors {
interceptor.UnbindRemoteTrack(ctx)
}
}

func (i *interceptorChain) Close() error {
var err error
for _, interceptor := range i.interceptors {
if err2 := interceptor.Close(); err2 != nil {
if err == nil {
err = err2
} else {
i.log.Warnf("failed closing interceptor: %+v", err2)
}
}
}

return err
}
24 changes: 24 additions & 0 deletions interceptor_peerconnection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// +build !js

package webrtc

import (
"github.com/pion/rtcp"
)

func WrapPeerConnection(pc *PeerConnection, interceptor Interceptor) *PeerConnection {
if interceptor == nil {
return pc
}

writeRTCP := interceptor.BindWriteRTCP(func(pkts []rtcp.Packet, attributes map[interface{}]interface{}) (int, error) {
return pc.writeRTCP(pkts)
})

pc.interceptorWriteRTCP = func(pkts []rtcp.Packet) error {
_, err := writeRTCP(pkts, make(map[interface{}]interface{}))
return err
}

return pc
}
46 changes: 46 additions & 0 deletions interceptor_track_local.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// +build !js

package webrtc

import (
"github.com/pion/rtp"
)

type interceptorChainTrackLocalWrapper struct {
TrackLocal
interceptor Interceptor
}

type interceptorTrackLocalWriter struct {
TrackLocalWriter
writeRTP WriteRTP
}

func WrapTrackLocal(track TrackLocal, interceptor Interceptor) TrackLocal {
if interceptor == nil {
return track
}

return &interceptorChainTrackLocalWrapper{TrackLocal: track, interceptor: interceptor}
}

func (i *interceptorTrackLocalWriter) WriteRTP(header *rtp.Header, payload []byte) (int, error) {
return i.writeRTP(&rtp.Packet{Header: *header, Payload: payload}, make(map[interface{}]interface{}))
}

func (i *interceptorChainTrackLocalWrapper) Bind(context TrackLocalContext) error {
trackWriteStream := context.WriteStream()
writeRTP := i.interceptor.BindLocalTrack(&context, func(p *rtp.Packet, attributes map[interface{}]interface{}) (int, error) {
return trackWriteStream.WriteRTP(&p.Header, p.Payload)
})

context.writeStream = &interceptorTrackLocalWriter{TrackLocalWriter: trackWriteStream, writeRTP: writeRTP}

return i.TrackLocal.Bind(context)
}

func (i *interceptorChainTrackLocalWrapper) Unbind(context TrackLocalContext) error {
i.interceptor.UnbindLocalTrack(&context)

return i.TrackLocal.Unbind(context)
}
28 changes: 28 additions & 0 deletions interceptor_track_remote.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// +build !js

package webrtc

import (
"github.com/pion/rtp"
)

func WrapTrackRemote(track *TrackRemote, interceptor Interceptor) *TrackRemote {
if interceptor == nil {
return track
}

readRTP := interceptor.BindRemoteTrack(track.context(), func() (*rtp.Packet, map[interface{}]interface{}, error) {
p, err := track.readRTP()
if err != nil {
return nil, nil, err
}
return p, make(map[interface{}]interface{}), nil
})

track.interceptorReadRTP = func() (*rtp.Packet, error) {
p, _, err := readRTP()
return p, err
}

return track
}
46 changes: 39 additions & 7 deletions mediaengine.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,19 +249,19 @@ func (m *MediaEngine) GetHeaderExtensionID(extension RTPHeaderExtensionCapabilit
return
}

func (m *MediaEngine) getCodecByPayload(payloadType PayloadType) (RTPCodecParameters, error) {
func (m *MediaEngine) getCodecByPayload(payloadType PayloadType) (RTPCodecParameters, RTPCodecType, error) {
for _, codec := range m.negotiatedVideoCodecs {
if codec.PayloadType == payloadType {
return codec, nil
return codec, RTPCodecTypeVideo, nil
}
}
for _, codec := range m.negotiatedAudioCodecs {
if codec.PayloadType == payloadType {
return codec, nil
return codec, RTPCodecTypeAudio, nil
}
}

return RTPCodecParameters{}, ErrCodecNotFound
return RTPCodecParameters{}, 0, ErrCodecNotFound
}

func (m *MediaEngine) collectStats(collector *statsReportCollector) {
Expand Down Expand Up @@ -309,7 +309,7 @@ func (m *MediaEngine) updateCodecParameters(remoteCodec RTPCodecParameters, typ
return err
}

if _, err = m.getCodecByPayload(PayloadType(payloadType)); err != nil {
if _, _, err = m.getCodecByPayload(PayloadType(payloadType)); err != nil {
return nil // not an error, we just ignore this codec we don't support
}
}
Expand Down Expand Up @@ -378,8 +378,8 @@ func (m *MediaEngine) updateFromRemoteDescription(desc sdp.SessionDescription) e
return err
}

for id, extension := range extensions {
if err = m.updateHeaderExtension(extension, id, typ); err != nil {
for extension, id := range extensions {
if err = m.updateHeaderExtension(id, extension, typ); err != nil {
return err
}
}
Expand All @@ -405,6 +405,38 @@ func (m *MediaEngine) getCodecsByKind(typ RTPCodecType) []RTPCodecParameters {
return nil
}

func (m *MediaEngine) getRTPParametersByKind(typ RTPCodecType) RTPParameters {
headerExtensions := make([]RTPHeaderExtensionParameter, 0)
for id, e := range m.negotiatedHeaderExtensions {
if e.isAudio && typ == RTPCodecTypeAudio || e.isVideo && typ == RTPCodecTypeVideo {
headerExtensions = append(headerExtensions, RTPHeaderExtensionParameter{ID: id, URI: e.uri})
}
}

return RTPParameters{
HeaderExtensions: headerExtensions,
Codecs: m.getCodecsByKind(typ),
}
}

func (m *MediaEngine) getRTPParametersByPayloadType(payloadType PayloadType) (RTPParameters, error) {
codec, typ, err := m.getCodecByPayload(payloadType)
if err != nil {
return RTPParameters{}, err
}

headerExtensions := make([]RTPHeaderExtensionParameter, 0)
for id, e := range m.negotiatedHeaderExtensions {
if e.isAudio && typ == RTPCodecTypeAudio || e.isVideo && typ == RTPCodecTypeVideo {
headerExtensions = append(headerExtensions, RTPHeaderExtensionParameter{ID: id, URI: e.uri})
}
}

return RTPParameters{
HeaderExtensions: headerExtensions,
Codecs: []RTPCodecParameters{codec},
}, nil
}
func (m *MediaEngine) negotiatedHeaderExtensionsForType(typ RTPCodecType) map[int]mediaEngineHeaderExtension {
headerExtensions := map[int]mediaEngineHeaderExtension{}
for id, e := range m.negotiatedHeaderExtensions {
Expand Down
10 changes: 5 additions & 5 deletions mediaengine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ a=fmtp:111 minptime=10; useinbandfec=1
assert.False(t, m.negotiatedVideo)
assert.True(t, m.negotiatedAudio)

opusCodec, err := m.getCodecByPayload(111)
opusCodec, _, err := m.getCodecByPayload(111)
assert.NoError(t, err)
assert.Equal(t, opusCodec.MimeType, mimeTypeOpus)
})
Expand All @@ -85,10 +85,10 @@ a=fmtp:112 minptime=10; useinbandfec=1
assert.False(t, m.negotiatedVideo)
assert.True(t, m.negotiatedAudio)

_, err := m.getCodecByPayload(111)
_, _, err := m.getCodecByPayload(111)
assert.Error(t, err)

opusCodec, err := m.getCodecByPayload(112)
opusCodec, _, err := m.getCodecByPayload(112)
assert.NoError(t, err)
assert.Equal(t, opusCodec.MimeType, mimeTypeOpus)
})
Expand All @@ -110,7 +110,7 @@ a=fmtp:111 minptime=10; useinbandfec=1
assert.False(t, m.negotiatedVideo)
assert.True(t, m.negotiatedAudio)

opusCodec, err := m.getCodecByPayload(111)
opusCodec, _, err := m.getCodecByPayload(111)
assert.NoError(t, err)
assert.Equal(t, opusCodec.MimeType, "audio/OPUS")
})
Expand All @@ -131,7 +131,7 @@ a=rtpmap:111 opus/48000/2
assert.False(t, m.negotiatedVideo)
assert.True(t, m.negotiatedAudio)

opusCodec, err := m.getCodecByPayload(111)
opusCodec, _, err := m.getCodecByPayload(111)
assert.NoError(t, err)
assert.Equal(t, opusCodec.MimeType, mimeTypeOpus)
})
Expand Down
Loading

0 comments on commit 3b8d4ed

Please sign in to comment.