Skip to content

Commit

Permalink
interceptor POC
Browse files Browse the repository at this point in the history
  • Loading branch information
masterada committed Nov 18, 2020
1 parent 1ffa87e commit 471e014
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 7 deletions.
123 changes: 123 additions & 0 deletions interceptor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package webrtc

import (
"context"
"errors"
"io"

"github.com/pion/rtcp"
"github.com/pion/rtp"
)

type Interceptor interface {
Intercept(*PeerConnection, ReadWriter) ReadWriter
}

// Reader is an interface to handle incoming RTP stream.
type ReadWriter interface {
ReadRTP(context.Context) (*rtp.Packet, map[interface{}]interface{}, error)
WriteRTP(context.Context, *rtp.Packet, map[interface{}]interface{}) error
ReadRTCP(context.Context) ([]rtcp.Packet, error)
WriteRTCP(context.Context, []rtcp.Packet) error
io.Closer
}

type contextReadWriter struct{}

type interceptorChain struct {
readWriter ReadWriter
}

type keyReadRTP struct{}
type keyReadRTCP struct{}
type keyWriteRTP struct{}
type keyWriteRTCP struct{}

type writeRTP func(packet *rtp.Packet)
type writeRTCP func(packets []rtcp.Packet)

func (c *contextReadWriter) ReadRTP(ctx context.Context) (*rtp.Packet, map[interface{}]interface{}, error) {
p, ok := ctx.Value(keyReadRTP{}).(*rtp.Packet)
if !ok {
return nil, nil, errors.New("packet not found in context")
}

return p, make(map[interface{}]interface{}), nil
}

func (c *contextReadWriter) WriteRTP(ctx context.Context, packet *rtp.Packet, _ map[interface{}]interface{}) error {
writeRTP, ok := ctx.Value(keyWriteRTP{}).(writeRTP)
if !ok {
return errors.New("callback not found in context")
}
writeRTP(packet)

return nil
}

func (c *contextReadWriter) ReadRTCP(ctx context.Context) ([]rtcp.Packet, error) {
p, ok := ctx.Value(keyReadRTCP{}).([]rtcp.Packet)
if !ok {
return nil, errors.New("packets not found in context")
}
return p, nil
}

func (c *contextReadWriter) WriteRTCP(ctx context.Context, packets []rtcp.Packet) error {
writeRTCP, ok := ctx.Value(keyWriteRTCP{}).(writeRTCP)
if !ok {
return errors.New("callback not found in context")
}
writeRTCP(packets)

return nil
}

func (c *contextReadWriter) Close() error {
return nil
}

func newInterceptorChain(pc *PeerConnection, interceptors []Interceptor) *interceptorChain {
var readWriter ReadWriter = &contextReadWriter{}
for _, interceptor := range interceptors {
readWriter = interceptor.Intercept(pc, readWriter)
}
return &interceptorChain{readWriter: readWriter}
}

func (i *interceptorChain) wrapReadRTP(packet *rtp.Packet) (*rtp.Packet, error) {
ctx := context.WithValue(context.Background(), keyReadRTP{}, packet)
p, _, err := i.readWriter.ReadRTP(ctx)
return p, err
}

func (i *interceptorChain) wrapWriteRTP(packet *rtp.Packet) (*rtp.Packet, error) {
var p *rtp.Packet
ctx := context.WithValue(context.Background(), keyWriteRTP{}, func(p2 *rtp.Packet) {
p = p2
})
err := i.readWriter.WriteRTP(ctx, packet, make(map[interface{}]interface{}))
if err != nil {
return nil, err
}

return p, nil
}

func (i *interceptorChain) wrapReadRTCP(packets []rtcp.Packet) ([]rtcp.Packet, error) {
ctx := context.WithValue(context.Background(), keyReadRTCP{}, packets)
return i.readWriter.ReadRTCP(ctx)
}

func (i *interceptorChain) wrapWriteRTCP(packet *rtp.Packet) (*rtp.Packet, error) {
var p *rtp.Packet
ctx := context.WithValue(context.Background(), keyWriteRTP{}, func(p2 *rtp.Packet) {
p = p2
})
err := i.readWriter.WriteRTP(ctx, packet, make(map[interface{}]interface{}))
if err != nil {
return nil, err
}

return p, nil
}
122 changes: 122 additions & 0 deletions interceptor_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package webrtc

import (
"context"
"testing"

"github.com/pion/rtp"
)

type testInterceptor1 struct {
}

type testInterceptor2 struct {
t *testing.T
}

type testReadWriter1 struct {
ReadWriter
}

type testReadWriter2 struct {
ReadWriter
t *testing.T
}

type testInterceptorKey struct{}

func (t *testInterceptor1) Intercept(_ *PeerConnection, readWriter ReadWriter) ReadWriter {
return &testReadWriter1{ReadWriter: readWriter}
}

func (t *testInterceptor2) Intercept(_ *PeerConnection, readWriter ReadWriter) ReadWriter {
return &testReadWriter2{ReadWriter: readWriter, t: t.t}
}

func (t *testReadWriter1) ReadRTP(ctx context.Context) (*rtp.Packet, map[interface{}]interface{}, error) {
p, m, err := t.ReadWriter.ReadRTP(ctx)
if err != nil {
return nil, nil, err
}

p.SSRC = 1
m[testInterceptorKey{}] = "read1"

return p, m, nil
}

func (t *testReadWriter1) WriteRTP(ctx context.Context, p *rtp.Packet, m map[interface{}]interface{}) error {
p.SSRC = 1
m[testInterceptorKey{}] = "read1"

return t.ReadWriter.WriteRTP(ctx, p, m)
}

func (t *testReadWriter2) ReadRTP(ctx context.Context) (*rtp.Packet, map[interface{}]interface{}, error) {
p, m, err := t.ReadWriter.ReadRTP(ctx)
if err != nil {
return nil, nil, err
}

if p.SSRC != 1 {
t.t.Errorf("expected SSRC to be 1, got: %d", p.SSRC)
}
metaVal := m[testInterceptorKey{}]
if metaVal != "read1" {
t.t.Errorf("expected meta to be set to read1, got: %s", metaVal)
}

// test replacing the packet
p2 := &rtp.Packet{}
p2.SSRC = 2

return p2, m, nil
}

func (t *testReadWriter2) WriteRTP(ctx context.Context, p *rtp.Packet, m map[interface{}]interface{}) error {
if p.SSRC != 1 {
t.t.Errorf("expected SSRC to be 1, got: %d", p.SSRC)
}
metaVal := m[testInterceptorKey{}]
if metaVal != "read1" {
t.t.Errorf("expected meta to be set to read1, got: %s", metaVal)
}

// test replacing the packet
p2 := &rtp.Packet{}
p2.SSRC = 2

return t.ReadWriter.WriteRTP(ctx, p2, m)
}

func TestInterceptorChainReadRTP(t *testing.T) {
chain := newInterceptorChain(nil, []Interceptor{
&testInterceptor1{},
&testInterceptor2{t: t},
})

p, err := chain.wrapReadRTP(&rtp.Packet{})
if err != nil {
t.Fatalf("%+v", err)
}

if p.SSRC != 2 {
t.Errorf("expected SSRC to be 2, got: %d", p.SSRC)
}
}

func TestInterceptorChainWriteRTP(t *testing.T) {
chain := newInterceptorChain(nil, []Interceptor{
&testInterceptor1{},
&testInterceptor2{t: t},
})

p, err := chain.wrapWriteRTP(&rtp.Packet{})
if err != nil {
t.Fatalf("%+v", err)
}

if p.SSRC != 2 {
t.Errorf("expected SSRC to be 2, got: %d", p.SSRC)
}
}
4 changes: 4 additions & 0 deletions peerconnection.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ type PeerConnection struct {
dtlsTransport *DTLSTransport
sctpTransport *SCTPTransport

interceptorChain *interceptorChain

// A reference to the associated API state used by this connection
api *API
log logging.LeveledLogger
Expand Down Expand Up @@ -119,6 +121,8 @@ func (api *API) NewPeerConnection(configuration Configuration) (*PeerConnection,
log: api.settingEngine.LoggerFactory.NewLogger("pc"),
}

pc.interceptorChain = newInterceptorChain(pc, api.settingEngine.interceptors)

var err error
if err = pc.initConfiguration(configuration); err != nil {
return nil, err
Expand Down
3 changes: 2 additions & 1 deletion rtpreceiveparameters.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ package webrtc

// RTPReceiveParameters contains the RTP stack settings used by receivers
type RTPReceiveParameters struct {
Encodings []RTPDecodingParameters
Encodings []RTPDecodingParameters
interceptorChain *interceptorChain
}
14 changes: 8 additions & 6 deletions rtpreceiver.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,10 @@ func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error {
if len(parameters.Encodings) == 1 && parameters.Encodings[0].SSRC != 0 {
t := trackStreams{
track: &TrackRemote{
kind: r.kind,
ssrc: parameters.Encodings[0].SSRC,
receiver: r,
kind: r.kind,
ssrc: parameters.Encodings[0].SSRC,
receiver: r,
interceptorChain: parameters.interceptorChain,
},
}

Expand All @@ -112,9 +113,10 @@ func (r *RTPReceiver) Receive(parameters RTPReceiveParameters) error {
for _, encoding := range parameters.Encodings {
r.tracks = append(r.tracks, trackStreams{
track: &TrackRemote{
kind: r.kind,
rid: encoding.RID,
receiver: r,
kind: r.kind,
rid: encoding.RID,
receiver: r,
interceptorChain: parameters.interceptorChain,
},
})
}
Expand Down
1 change: 1 addition & 0 deletions settingengine.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type SettingEngine struct {
LoggerFactory logging.LoggerFactory
iceTCPMux ice.TCPMux
iceProxyDialer proxy.Dialer
interceptors []Interceptor
}

// DetachDataChannels enables detaching data channels. When enabled
Expand Down

0 comments on commit 471e014

Please sign in to comment.