diff --git a/proxy/.gitignore b/proxy/.gitignore new file mode 100644 index 0000000000..c20f4b6782 --- /dev/null +++ b/proxy/.gitignore @@ -0,0 +1,4 @@ +.idea +srs-proxy +.env +.go-formarted \ No newline at end of file diff --git a/proxy/Makefile b/proxy/Makefile new file mode 100644 index 0000000000..29084d5b76 --- /dev/null +++ b/proxy/Makefile @@ -0,0 +1,23 @@ +.PHONY: all build test fmt clean run + +all: build + +build: fmt ./srs-proxy + +./srs-proxy: *.go + go build -o srs-proxy . + +test: + go test ./... + +fmt: ./.go-formarted + +./.go-formarted: *.go + touch .go-formarted + go fmt ./... + +clean: + rm -f srs-proxy .go-formarted + +run: fmt + go run . diff --git a/proxy/api.go b/proxy/api.go new file mode 100644 index 0000000000..04baa92526 --- /dev/null +++ b/proxy/api.go @@ -0,0 +1,272 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "fmt" + "net/http" + "os" + "strings" + "sync" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +// srsHTTPAPIServer is the proxy for SRS HTTP API, to proxy the WebRTC HTTP API like WHIP and WHEP, +// to proxy other HTTP API of SRS like the streams and clients, etc. +type srsHTTPAPIServer struct { + // The underlayer HTTP server. + server *http.Server + // The WebRTC server. + rtc *srsWebRTCServer + // The gracefully quit timeout, wait server to quit. + gracefulQuitTimeout time.Duration + // The wait group for all goroutines. + wg sync.WaitGroup +} + +func NewSRSHTTPAPIServer(opts ...func(*srsHTTPAPIServer)) *srsHTTPAPIServer { + v := &srsHTTPAPIServer{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srsHTTPAPIServer) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + v.server.Shutdown(ctx) + + v.wg.Wait() + return nil +} + +func (v *srsHTTPAPIServer) Run(ctx context.Context) error { + // Parse address to listen. + addr := envHttpAPI() + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + + // Create server and handler. + mux := http.NewServeMux() + v.server = &http.Server{Addr: addr, Handler: mux} + logger.Df(ctx, "HTTP API server listen at %v", addr) + + // Shutdown the server gracefully when quiting. + go func() { + ctxParent := ctx + <-ctxParent.Done() + + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + + v.server.Shutdown(ctx) + }() + + // The basic version handler, also can be used as health check API. + logger.Df(ctx, "Handle /api/v1/versions by %v", addr) + mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { + apiResponse(ctx, w, r, map[string]string{ + "signature": Signature(), + "version": Version(), + }) + }) + + // The WebRTC WHIP API handler. + logger.Df(ctx, "Handle /rtc/v1/whip/ by %v", addr) + mux.HandleFunc("/rtc/v1/whip/", func(w http.ResponseWriter, r *http.Request) { + if err := v.rtc.HandleApiForWHIP(ctx, w, r); err != nil { + apiError(ctx, w, r, err) + } + }) + + // The WebRTC WHEP API handler. + logger.Df(ctx, "Handle /rtc/v1/whep/ by %v", addr) + mux.HandleFunc("/rtc/v1/whep/", func(w http.ResponseWriter, r *http.Request) { + if err := v.rtc.HandleApiForWHEP(ctx, w, r); err != nil { + apiError(ctx, w, r, err) + } + }) + + // Run HTTP API server. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + err := v.server.ListenAndServe() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If HTTP API server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "HTTP API accept err %+v", err) + } else { + logger.Df(ctx, "HTTP API server done") + } + } + }() + + return nil +} + +// systemAPI is the system HTTP API of the proxy server, for SRS media server to register the service +// to proxy server. It also provides some other system APIs like the status of proxy server, like exporter +// for Prometheus metrics. +type systemAPI struct { + // The underlayer HTTP server. + server *http.Server + // The gracefully quit timeout, wait server to quit. + gracefulQuitTimeout time.Duration + // The wait group for all goroutines. + wg sync.WaitGroup +} + +func NewSystemAPI(opts ...func(*systemAPI)) *systemAPI { + v := &systemAPI{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *systemAPI) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + v.server.Shutdown(ctx) + + v.wg.Wait() + return nil +} + +func (v *systemAPI) Run(ctx context.Context) error { + // Parse address to listen. + addr := envSystemAPI() + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + + // Create server and handler. + mux := http.NewServeMux() + v.server = &http.Server{Addr: addr, Handler: mux} + logger.Df(ctx, "System API server listen at %v", addr) + + // Shutdown the server gracefully when quiting. + go func() { + ctxParent := ctx + <-ctxParent.Done() + + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + + v.server.Shutdown(ctx) + }() + + // The basic version handler, also can be used as health check API. + logger.Df(ctx, "Handle /api/v1/versions by %v", addr) + mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { + apiResponse(ctx, w, r, map[string]string{ + "signature": Signature(), + "version": Version(), + }) + }) + + // The register service for SRS media servers. + logger.Df(ctx, "Handle /api/v1/srs/register by %v", addr) + mux.HandleFunc("/api/v1/srs/register", func(w http.ResponseWriter, r *http.Request) { + if err := func() error { + var deviceID, ip, serverID, serviceID, pid string + var rtmp, stream, api, srt, rtc []string + if err := ParseBody(r.Body, &struct { + // The IP of SRS, mandatory. + IP *string `json:"ip"` + // The server id of SRS, store in file, may not change, mandatory. + ServerID *string `json:"server"` + // The service id of SRS, always change when restarted, mandatory. + ServiceID *string `json:"service"` + // The process id of SRS, always change when restarted, mandatory. + PID *string `json:"pid"` + // The RTMP listen endpoints, mandatory. + RTMP *[]string `json:"rtmp"` + // The HTTP Stream listen endpoints, optional. + HTTP *[]string `json:"http"` + // The API listen endpoints, optional. + API *[]string `json:"api"` + // The SRT listen endpoints, optional. + SRT *[]string `json:"srt"` + // The RTC listen endpoints, optional. + RTC *[]string `json:"rtc"` + // The device id of SRS, optional. + DeviceID *string `json:"device_id"` + }{ + IP: &ip, DeviceID: &deviceID, + ServerID: &serverID, ServiceID: &serviceID, PID: &pid, + RTMP: &rtmp, HTTP: &stream, API: &api, SRT: &srt, RTC: &rtc, + }); err != nil { + return errors.Wrapf(err, "parse body") + } + + if ip == "" { + return errors.Errorf("empty ip") + } + if serverID == "" { + return errors.Errorf("empty server") + } + if serviceID == "" { + return errors.Errorf("empty service") + } + if pid == "" { + return errors.Errorf("empty pid") + } + if len(rtmp) == 0 { + return errors.Errorf("empty rtmp") + } + + server := NewSRSServer(func(srs *SRSServer) { + srs.IP, srs.DeviceID = ip, deviceID + srs.ServerID, srs.ServiceID, srs.PID = serverID, serviceID, pid + srs.RTMP, srs.HTTP, srs.API = rtmp, stream, api + srs.SRT, srs.RTC = srt, rtc + srs.UpdatedAt = time.Now() + }) + if err := srsLoadBalancer.Update(ctx, server); err != nil { + return errors.Wrapf(err, "update SRS server %+v", server) + } + + logger.Df(ctx, "Register SRS media server, %+v", server) + return nil + }(); err != nil { + apiError(ctx, w, r, err) + } + + type Response struct { + Code int `json:"code"` + PID string `json:"pid"` + } + + apiResponse(ctx, w, r, &Response{ + Code: 0, PID: fmt.Sprintf("%v", os.Getpid()), + }) + }) + + // Run System API server. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + err := v.server.ListenAndServe() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If System API server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "System API accept err %+v", err) + } else { + logger.Df(ctx, "System API server done") + } + } + }() + + return nil +} diff --git a/proxy/debug.go b/proxy/debug.go new file mode 100644 index 0000000000..3a389b8bbd --- /dev/null +++ b/proxy/debug.go @@ -0,0 +1,20 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "net/http" + + "srs-proxy/logger" +) + +func handleGoPprof(ctx context.Context) { + if addr := envGoPprof(); addr != "" { + go func() { + logger.Df(ctx, "Start Go pprof at %v", addr) + http.ListenAndServe(addr, nil) + }() + } +} diff --git a/proxy/env.go b/proxy/env.go new file mode 100644 index 0000000000..0c201bb1d6 --- /dev/null +++ b/proxy/env.go @@ -0,0 +1,197 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "os" + "path" + + "github.com/joho/godotenv" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +// loadEnvFile loads the environment variables from file. Note that we only use .env file. +func loadEnvFile(ctx context.Context) error { + if workDir, err := os.Getwd(); err != nil { + return errors.Wrapf(err, "getpwd") + } else { + envFile := path.Join(workDir, ".env") + if _, err := os.Stat(envFile); err == nil { + if err := godotenv.Load(envFile); err != nil { + return errors.Wrapf(err, "load %v", envFile) + } + } + } + + return nil +} + +// buildDefaultEnvironmentVariables setups the default environment variables. +func buildDefaultEnvironmentVariables(ctx context.Context) { + // Whether enable the Go pprof. + setEnvDefault("GO_PPROF", "") + // Force shutdown timeout. + setEnvDefault("PROXY_FORCE_QUIT_TIMEOUT", "30s") + // Graceful quit timeout. + setEnvDefault("PROXY_GRACE_QUIT_TIMEOUT", "20s") + + // The HTTP API server. + setEnvDefault("PROXY_HTTP_API", "11985") + // The HTTP web server. + setEnvDefault("PROXY_HTTP_SERVER", "18080") + // The RTMP media server. + setEnvDefault("PROXY_RTMP_SERVER", "11935") + // The WebRTC media server, via UDP protocol. + setEnvDefault("PROXY_WEBRTC_SERVER", "18000") + // The SRT media server, via UDP protocol. + setEnvDefault("PROXY_SRT_SERVER", "20080") + // The API server of proxy itself. + setEnvDefault("PROXY_SYSTEM_API", "12025") + // The static directory for web server. + setEnvDefault("PROXY_STATIC_FILES", "../trunk/research") + + // The load balancer, use redis or memory. + setEnvDefault("PROXY_LOAD_BALANCER_TYPE", "memory") + // The redis server host. + setEnvDefault("PROXY_REDIS_HOST", "127.0.0.1") + // The redis server port. + setEnvDefault("PROXY_REDIS_PORT", "6379") + // The redis server password. + setEnvDefault("PROXY_REDIS_PASSWORD", "") + // The redis server db. + setEnvDefault("PROXY_REDIS_DB", "0") + + // Whether enable the default backend server, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_ENABLED", "off") + // Default backend server IP, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_IP", "127.0.0.1") + // Default backend server port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_RTMP", "1935") + // Default backend api port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_API", "1985") + // Default backend udp rtc port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_RTC", "8000") + // Default backend udp srt port, for debugging. + setEnvDefault("PROXY_DEFAULT_BACKEND_SRT", "10080") + + logger.Df(ctx, "load .env as GO_PPROF=%v, "+ + "PROXY_FORCE_QUIT_TIMEOUT=%v, PROXY_GRACE_QUIT_TIMEOUT=%v, "+ + "PROXY_HTTP_API=%v, PROXY_HTTP_SERVER=%v, PROXY_RTMP_SERVER=%v, "+ + "PROXY_WEBRTC_SERVER=%v, PROXY_SRT_SERVER=%v, "+ + "PROXY_SYSTEM_API=%v, PROXY_STATIC_FILES=%v, PROXY_DEFAULT_BACKEND_ENABLED=%v, "+ + "PROXY_DEFAULT_BACKEND_IP=%v, PROXY_DEFAULT_BACKEND_RTMP=%v, "+ + "PROXY_DEFAULT_BACKEND_HTTP=%v, PROXY_DEFAULT_BACKEND_API=%v, "+ + "PROXY_DEFAULT_BACKEND_RTC=%v, PROXY_DEFAULT_BACKEND_SRT=%v, "+ + "PROXY_LOAD_BALANCER_TYPE=%v, PROXY_REDIS_HOST=%v, PROXY_REDIS_PORT=%v, "+ + "PROXY_REDIS_PASSWORD=%v, PROXY_REDIS_DB=%v", + envGoPprof(), + envForceQuitTimeout(), envGraceQuitTimeout(), + envHttpAPI(), envHttpServer(), envRtmpServer(), + envWebRTCServer(), envSRTServer(), + envSystemAPI(), envStaticFiles(), envDefaultBackendEnabled(), + envDefaultBackendIP(), envDefaultBackendRTMP(), + envDefaultBackendHttp(), envDefaultBackendAPI(), + envDefaultBackendRTC(), envDefaultBackendSRT(), + envLoadBalancerType(), envRedisHost(), envRedisPort(), + envRedisPassword(), envRedisDB(), + ) +} + +func envStaticFiles() string { + return os.Getenv("PROXY_STATIC_FILES") +} + +func envDefaultBackendSRT() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_SRT") +} + +func envDefaultBackendRTC() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_RTC") +} + +func envDefaultBackendAPI() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_API") +} + +func envSRTServer() string { + return os.Getenv("PROXY_SRT_SERVER") +} + +func envWebRTCServer() string { + return os.Getenv("PROXY_WEBRTC_SERVER") +} + +func envDefaultBackendHttp() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_HTTP") +} + +func envRedisDB() string { + return os.Getenv("PROXY_REDIS_DB") +} + +func envRedisPassword() string { + return os.Getenv("PROXY_REDIS_PASSWORD") +} + +func envRedisPort() string { + return os.Getenv("PROXY_REDIS_PORT") +} + +func envRedisHost() string { + return os.Getenv("PROXY_REDIS_HOST") +} + +func envLoadBalancerType() string { + return os.Getenv("PROXY_LOAD_BALANCER_TYPE") +} + +func envDefaultBackendRTMP() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_RTMP") +} + +func envDefaultBackendIP() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_IP") +} + +func envDefaultBackendEnabled() string { + return os.Getenv("PROXY_DEFAULT_BACKEND_ENABLED") +} + +func envGraceQuitTimeout() string { + return os.Getenv("PROXY_GRACE_QUIT_TIMEOUT") +} + +func envForceQuitTimeout() string { + return os.Getenv("PROXY_FORCE_QUIT_TIMEOUT") +} + +func envGoPprof() string { + return os.Getenv("GO_PPROF") +} + +func envSystemAPI() string { + return os.Getenv("PROXY_SYSTEM_API") +} + +func envRtmpServer() string { + return os.Getenv("PROXY_RTMP_SERVER") +} + +func envHttpServer() string { + return os.Getenv("PROXY_HTTP_SERVER") +} + +func envHttpAPI() string { + return os.Getenv("PROXY_HTTP_API") +} + +// setEnvDefault set env key=value if not set. +func setEnvDefault(key, value string) { + if os.Getenv(key) == "" { + os.Setenv(key, value) + } +} diff --git a/proxy/errors/errors.go b/proxy/errors/errors.go new file mode 100644 index 0000000000..257bc3ccda --- /dev/null +++ b/proxy/errors/errors.go @@ -0,0 +1,270 @@ +// Package errors provides simple error handling primitives. +// +// The traditional error handling idiom in Go is roughly akin to +// +// if err != nil { +// return err +// } +// +// which applied recursively up the call stack results in error reports +// without context or debugging information. The errors package allows +// programmers to add context to the failure path in their code in a way +// that does not destroy the original value of the error. +// +// Adding context to an error +// +// The errors.Wrap function returns a new error that adds context to the +// original error by recording a stack trace at the point Wrap is called, +// and the supplied message. For example +// +// _, err := ioutil.ReadAll(r) +// if err != nil { +// return errors.Wrap(err, "read failed") +// } +// +// If additional control is required the errors.WithStack and errors.WithMessage +// functions destructure errors.Wrap into its component operations of annotating +// an error with a stack trace and an a message, respectively. +// +// Retrieving the cause of an error +// +// Using errors.Wrap constructs a stack of errors, adding context to the +// preceding error. Depending on the nature of the error it may be necessary +// to reverse the operation of errors.Wrap to retrieve the original error +// for inspection. Any error value which implements this interface +// +// type causer interface { +// Cause() error +// } +// +// can be inspected by errors.Cause. errors.Cause will recursively retrieve +// the topmost error which does not implement causer, which is assumed to be +// the original cause. For example: +// +// switch err := errors.Cause(err).(type) { +// case *MyError: +// // handle specifically +// default: +// // unknown error +// } +// +// causer interface is not exported by this package, but is considered a part +// of stable public API. +// +// Formatted printing of errors +// +// All error values returned from this package implement fmt.Formatter and can +// be formatted by the fmt package. The following verbs are supported +// +// %s print the error. If the error has a Cause it will be +// printed recursively +// %v see %s +// %+v extended format. Each Frame of the error's StackTrace will +// be printed in detail. +// +// Retrieving the stack trace of an error or wrapper +// +// New, Errorf, Wrap, and Wrapf record a stack trace at the point they are +// invoked. This information can be retrieved with the following interface. +// +// type stackTracer interface { +// StackTrace() errors.StackTrace +// } +// +// Where errors.StackTrace is defined as +// +// type StackTrace []Frame +// +// The Frame type represents a call site in the stack trace. Frame supports +// the fmt.Formatter interface that can be used for printing information about +// the stack trace of this error. For example: +// +// if err, ok := err.(stackTracer); ok { +// for _, f := range err.StackTrace() { +// fmt.Printf("%+s:%d", f) +// } +// } +// +// stackTracer interface is not exported by this package, but is considered a part +// of stable public API. +// +// See the documentation for Frame.Format for more details. +// Fork from https://github.com/pkg/errors +package errors + +import ( + "fmt" + "io" +) + +// New returns an error with the supplied message. +// New also records the stack trace at the point it was called. +func New(message string) error { + return &fundamental{ + msg: message, + stack: callers(), + } +} + +// Errorf formats according to a format specifier and returns the string +// as a value that satisfies error. +// Errorf also records the stack trace at the point it was called. +func Errorf(format string, args ...interface{}) error { + return &fundamental{ + msg: fmt.Sprintf(format, args...), + stack: callers(), + } +} + +// fundamental is an error that has a message and a stack, but no caller. +type fundamental struct { + msg string + *stack +} + +func (f *fundamental) Error() string { return f.msg } + +func (f *fundamental) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + io.WriteString(s, f.msg) + f.stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, f.msg) + case 'q': + fmt.Fprintf(s, "%q", f.msg) + } +} + +// WithStack annotates err with a stack trace at the point WithStack was called. +// If err is nil, WithStack returns nil. +func WithStack(err error) error { + if err == nil { + return nil + } + return &withStack{ + err, + callers(), + } +} + +type withStack struct { + error + *stack +} + +func (w *withStack) Cause() error { return w.error } + +func (w *withStack) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v", w.Cause()) + w.stack.Format(s, verb) + return + } + fallthrough + case 's': + io.WriteString(s, w.Error()) + case 'q': + fmt.Fprintf(s, "%q", w.Error()) + } +} + +// Wrap returns an error annotating err with a stack trace +// at the point Wrap is called, and the supplied message. +// If err is nil, Wrap returns nil. +func Wrap(err error, message string) error { + if err == nil { + return nil + } + err = &withMessage{ + cause: err, + msg: message, + } + return &withStack{ + err, + callers(), + } +} + +// Wrapf returns an error annotating err with a stack trace +// at the point Wrapf is call, and the format specifier. +// If err is nil, Wrapf returns nil. +func Wrapf(err error, format string, args ...interface{}) error { + if err == nil { + return nil + } + err = &withMessage{ + cause: err, + msg: fmt.Sprintf(format, args...), + } + return &withStack{ + err, + callers(), + } +} + +// WithMessage annotates err with a new message. +// If err is nil, WithMessage returns nil. +func WithMessage(err error, message string) error { + if err == nil { + return nil + } + return &withMessage{ + cause: err, + msg: message, + } +} + +type withMessage struct { + cause error + msg string +} + +func (w *withMessage) Error() string { return w.msg + ": " + w.cause.Error() } +func (w *withMessage) Cause() error { return w.cause } + +func (w *withMessage) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + if s.Flag('+') { + fmt.Fprintf(s, "%+v\n", w.Cause()) + io.WriteString(s, w.msg) + return + } + fallthrough + case 's', 'q': + io.WriteString(s, w.Error()) + } +} + +// Cause returns the underlying cause of the error, if possible. +// An error value has a cause if it implements the following +// interface: +// +// type causer interface { +// Cause() error +// } +// +// If the error does not implement Cause, the original error will +// be returned. If the error is nil, nil will be returned without further +// investigation. +func Cause(err error) error { + type causer interface { + Cause() error + } + + for err != nil { + cause, ok := err.(causer) + if !ok { + break + } + err = cause.Cause() + } + return err +} diff --git a/proxy/errors/stack.go b/proxy/errors/stack.go new file mode 100644 index 0000000000..6c42db5a85 --- /dev/null +++ b/proxy/errors/stack.go @@ -0,0 +1,187 @@ +// Fork from https://github.com/pkg/errors +package errors + +import ( + "fmt" + "io" + "path" + "runtime" + "strings" +) + +// Frame represents a program counter inside a stack frame. +type Frame uintptr + +// pc returns the program counter for this frame; +// multiple frames may have the same PC value. +func (f Frame) pc() uintptr { return uintptr(f) - 1 } + +// file returns the full path to the file that contains the +// function for this Frame's pc. +func (f Frame) file() string { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return "unknown" + } + file, _ := fn.FileLine(f.pc()) + return file +} + +// line returns the line number of source code of the +// function for this Frame's pc. +func (f Frame) line() int { + fn := runtime.FuncForPC(f.pc()) + if fn == nil { + return 0 + } + _, line := fn.FileLine(f.pc()) + return line +} + +// Format formats the frame according to the fmt.Formatter interface. +// +// %s source file +// %d source line +// %n function name +// %v equivalent to %s:%d +// +// Format accepts flags that alter the printing of some verbs, as follows: +// +// %+s path of source file relative to the compile time GOPATH +// %+v equivalent to %+s:%d +func (f Frame) Format(s fmt.State, verb rune) { + switch verb { + case 's': + switch { + case s.Flag('+'): + pc := f.pc() + fn := runtime.FuncForPC(pc) + if fn == nil { + io.WriteString(s, "unknown") + } else { + file, _ := fn.FileLine(pc) + fmt.Fprintf(s, "%s\n\t%s", fn.Name(), file) + } + default: + io.WriteString(s, path.Base(f.file())) + } + case 'd': + fmt.Fprintf(s, "%d", f.line()) + case 'n': + name := runtime.FuncForPC(f.pc()).Name() + io.WriteString(s, funcname(name)) + case 'v': + f.Format(s, 's') + io.WriteString(s, ":") + f.Format(s, 'd') + } +} + +// StackTrace is stack of Frames from innermost (newest) to outermost (oldest). +type StackTrace []Frame + +// Format formats the stack of Frames according to the fmt.Formatter interface. +// +// %s lists source files for each Frame in the stack +// %v lists the source file and line number for each Frame in the stack +// +// Format accepts flags that alter the printing of some verbs, as follows: +// +// %+v Prints filename, function, and line number for each Frame in the stack. +func (st StackTrace) Format(s fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case s.Flag('+'): + for _, f := range st { + fmt.Fprintf(s, "\n%+v", f) + } + case s.Flag('#'): + fmt.Fprintf(s, "%#v", []Frame(st)) + default: + fmt.Fprintf(s, "%v", []Frame(st)) + } + case 's': + fmt.Fprintf(s, "%s", []Frame(st)) + } +} + +// stack represents a stack of program counters. +type stack []uintptr + +func (s *stack) Format(st fmt.State, verb rune) { + switch verb { + case 'v': + switch { + case st.Flag('+'): + for _, pc := range *s { + f := Frame(pc) + fmt.Fprintf(st, "\n%+v", f) + } + } + } +} + +func (s *stack) StackTrace() StackTrace { + f := make([]Frame, len(*s)) + for i := 0; i < len(f); i++ { + f[i] = Frame((*s)[i]) + } + return f +} + +func callers() *stack { + const depth = 32 + var pcs [depth]uintptr + n := runtime.Callers(3, pcs[:]) + var st stack = pcs[0:n] + return &st +} + +// funcname removes the path prefix component of a function's name reported by func.Name(). +func funcname(name string) string { + i := strings.LastIndex(name, "/") + name = name[i+1:] + i = strings.Index(name, ".") + return name[i+1:] +} + +func trimGOPATH(name, file string) string { + // Here we want to get the source file path relative to the compile time + // GOPATH. As of Go 1.6.x there is no direct way to know the compiled + // GOPATH at runtime, but we can infer the number of path segments in the + // GOPATH. We note that fn.Name() returns the function name qualified by + // the import path, which does not include the GOPATH. Thus we can trim + // segments from the beginning of the file path until the number of path + // separators remaining is one more than the number of path separators in + // the function name. For example, given: + // + // GOPATH /home/user + // file /home/user/src/pkg/sub/file.go + // fn.Name() pkg/sub.Type.Method + // + // We want to produce: + // + // pkg/sub/file.go + // + // From this we can easily see that fn.Name() has one less path separator + // than our desired output. We count separators from the end of the file + // path until it finds two more than in the function name and then move + // one character forward to preserve the initial path segment without a + // leading separator. + const sep = "/" + goal := strings.Count(name, sep) + 2 + i := len(file) + for n := 0; n < goal; n++ { + i = strings.LastIndex(file[:i], sep) + if i == -1 { + // not enough separators found, set i so that the slice expression + // below leaves file unmodified + i = -len(sep) + break + } + } + // get back to 0 or trim the leading separator + file = file[i+len(sep):] + return file +} diff --git a/proxy/go.mod b/proxy/go.mod new file mode 100644 index 0000000000..2e2a17ab34 --- /dev/null +++ b/proxy/go.mod @@ -0,0 +1,13 @@ +module srs-proxy + +go 1.18 + +require ( + github.com/go-redis/redis/v8 v8.11.5 + github.com/joho/godotenv v1.5.1 +) + +require ( + github.com/cespare/xxhash/v2 v2.1.2 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect +) diff --git a/proxy/go.sum b/proxy/go.sum new file mode 100644 index 0000000000..1efc5318ed --- /dev/null +++ b/proxy/go.sum @@ -0,0 +1,17 @@ +github.com/cespare/xxhash/v2 v2.1.2 h1:YRXhKfTDauu4ajMg1TPgFO5jnlC2HCbmLXMcTG5cbYE= +github.com/cespare/xxhash/v2 v2.1.2/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +golang.org/x/net v0.0.0-20210428140749-89ef3d95e781 h1:DzZ89McO9/gWPsQXS/FVKAlG02ZjaQ6AlZRBimEYOd0= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e h1:fLOSk5Q00efkSvAm+4xcoXD+RRmLmmulPn5I3Y9F2EM= +golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= diff --git a/proxy/http.go b/proxy/http.go new file mode 100644 index 0000000000..f02af02a30 --- /dev/null +++ b/proxy/http.go @@ -0,0 +1,419 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "fmt" + "io" + "io/ioutil" + "net/http" + "os" + "strconv" + "strings" + stdSync "sync" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +// srsHTTPStreamServer is the proxy server for SRS HTTP stream server, for HTTP-FLV, HTTP-TS, +// HLS, etc. The proxy server will figure out which SRS origin server to proxy to, then proxy +// the request to the origin server. +type srsHTTPStreamServer struct { + // The underlayer HTTP server. + server *http.Server + // The gracefully quit timeout, wait server to quit. + gracefulQuitTimeout time.Duration + // The wait group for all goroutines. + wg stdSync.WaitGroup +} + +func NewSRSHTTPStreamServer(opts ...func(*srsHTTPStreamServer)) *srsHTTPStreamServer { + v := &srsHTTPStreamServer{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srsHTTPStreamServer) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + v.server.Shutdown(ctx) + + v.wg.Wait() + return nil +} + +func (v *srsHTTPStreamServer) Run(ctx context.Context) error { + // Parse address to listen. + addr := envHttpServer() + if !strings.Contains(addr, ":") { + addr = ":" + addr + } + + // Create server and handler. + mux := http.NewServeMux() + v.server = &http.Server{Addr: addr, Handler: mux} + logger.Df(ctx, "HTTP Stream server listen at %v", addr) + + // Shutdown the server gracefully when quiting. + go func() { + ctxParent := ctx + <-ctxParent.Done() + + ctx, cancel := context.WithTimeout(context.Background(), v.gracefulQuitTimeout) + defer cancel() + + v.server.Shutdown(ctx) + }() + + // The basic version handler, also can be used as health check API. + logger.Df(ctx, "Handle /api/v1/versions by %v", addr) + mux.HandleFunc("/api/v1/versions", func(w http.ResponseWriter, r *http.Request) { + type Response struct { + Code int `json:"code"` + PID string `json:"pid"` + Data struct { + Major int `json:"major"` + Minor int `json:"minor"` + Revision int `json:"revision"` + Version string `json:"version"` + } `json:"data"` + } + + res := Response{Code: 0, PID: fmt.Sprintf("%v", os.Getpid())} + res.Data.Major = VersionMajor() + res.Data.Minor = VersionMinor() + res.Data.Revision = VersionRevision() + res.Data.Version = Version() + + apiResponse(ctx, w, r, &res) + }) + + // The static web server, for the web pages. + var staticServer http.Handler + if staticFiles := envStaticFiles(); staticFiles != "" { + if _, err := os.Stat(staticFiles); err != nil { + return errors.Wrapf(err, "invalid static files %v", staticFiles) + } + + staticServer = http.FileServer(http.Dir(staticFiles)) + logger.Df(ctx, "Handle static files at %v", staticFiles) + } + + // The default handler, for both static web server and streaming server. + logger.Df(ctx, "Handle / by %v", addr) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + // For HLS streaming, we will proxy the request to the streaming server. + if strings.HasSuffix(r.URL.Path, ".m3u8") { + unifiedURL, fullURL := convertURLToStreamURL(r) + streamURL, err := buildStreamURL(unifiedURL) + if err != nil { + http.Error(w, fmt.Sprintf("build stream url by %v from %v", unifiedURL, fullURL), http.StatusBadRequest) + return + } + + stream, _ := srsLoadBalancer.LoadOrStoreHLS(ctx, streamURL, NewHLSPlayStream(func(s *HLSPlayStream) { + s.SRSProxyBackendHLSID = logger.GenerateContextID() + s.StreamURL, s.FullURL = streamURL, fullURL + })) + + stream.Initialize(ctx).ServeHTTP(w, r) + return + } + + // For HTTP streaming, we will proxy the request to the streaming server. + if strings.HasSuffix(r.URL.Path, ".flv") || + strings.HasSuffix(r.URL.Path, ".ts") { + // If SPBHID is specified, it must be a HLS stream client. + if srsProxyBackendID := r.URL.Query().Get("spbhid"); srsProxyBackendID != "" { + if stream, err := srsLoadBalancer.LoadHLSBySPBHID(ctx, srsProxyBackendID); err != nil { + http.Error(w, fmt.Sprintf("load stream by spbhid %v", srsProxyBackendID), http.StatusBadRequest) + } else { + stream.Initialize(ctx).ServeHTTP(w, r) + } + return + } + + // Use HTTP pseudo streaming to proxy the request. + NewHTTPFlvTsConnection(func(c *HTTPFlvTsConnection) { + c.ctx = ctx + }).ServeHTTP(w, r) + return + } + + // Serve by static server. + if staticServer != nil { + staticServer.ServeHTTP(w, r) + return + } + + http.NotFound(w, r) + }) + + // Run HTTP server. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + err := v.server.ListenAndServe() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If HTTP Stream server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "HTTP Stream accept err %+v", err) + } else { + logger.Df(ctx, "HTTP Stream server done") + } + } + }() + + return nil +} + +// HTTPFlvTsConnection is an HTTP pseudo streaming connection, such as an HTTP-FLV or HTTP-TS +// connection. There is no state need to be sync between proxy servers. +// +// When we got an HTTP FLV or TS request, we will parse the stream URL from the HTTP request, +// then proxy to the corresponding backend server. All state is in the HTTP request, so this +// connection is stateless. +type HTTPFlvTsConnection struct { + // The context for HTTP streaming. + ctx context.Context +} + +func NewHTTPFlvTsConnection(opts ...func(*HTTPFlvTsConnection)) *HTTPFlvTsConnection { + v := &HTTPFlvTsConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *HTTPFlvTsConnection) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + ctx := logger.WithContext(v.ctx) + + if err := v.serve(ctx, w, r); err != nil { + apiError(ctx, w, r, err) + } else { + logger.Df(ctx, "HTTP client done") + } +} + +func (v *HTTPFlvTsConnection) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { + return nil + } + + // Build the stream URL in vhost/app/stream schema. + unifiedURL, fullURL := convertURLToStreamURL(r) + logger.Df(ctx, "Got HTTP client from %v for %v", r.RemoteAddr, fullURL) + + streamURL, err := buildStreamURL(unifiedURL) + if err != nil { + return errors.Wrapf(err, "build stream url %v", unifiedURL) + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.serveByBackend(ctx, w, r, backend); err != nil { + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) + } + + return nil +} + +func (v *HTTPFlvTsConnection) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error { + // Parse HTTP port from backend. + if len(backend.HTTP) == 0 { + return errors.Errorf("no http stream server") + } + + var httpPort int + if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse http port %v", backend.HTTP[0]) + } else { + httpPort = int(iv) + } + + // Connect to backend SRS server via HTTP client. + backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) + req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil) + if err != nil { + return errors.Wrapf(err, "create request to %v", backendURL) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return errors.Wrapf(err, "do request to %v", backendURL) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status) + } + + // Copy all headers from backend to client. + w.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + + logger.Df(ctx, "HTTP start streaming") + + // Proxy the stream from backend to client. + if _, err := io.Copy(w, resp.Body); err != nil { + return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL) + } + + return nil +} + +// HLSPlayStream is an HLS stream proxy, which represents the stream level object. This means multiple HLS +// clients will share this object, and they do not use the same ctx among proxy servers. +// +// Unlike the HTTP FLV or TS connection, HLS client may request the m3u8 or ts via different HTTP connections. +// Especially for requesting ts, we need to identify the stream URl or backend server for it. So we create +// the spbhid which can be seen as the hash of stream URL or backend server. The spbhid enable us to convert +// to the stream URL and then query the backend server to serve it. +type HLSPlayStream struct { + // The context for HLS streaming. + ctx context.Context + + // The spbhid, used to identify the backend server. + SRSProxyBackendHLSID string `json:"spbhid"` + // The stream URL in vhost/app/stream schema. + StreamURL string `json:"stream_url"` + // The full request URL for HLS streaming + FullURL string `json:"full_url"` +} + +func NewHLSPlayStream(opts ...func(*HLSPlayStream)) *HLSPlayStream { + v := &HLSPlayStream{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *HLSPlayStream) Initialize(ctx context.Context) *HLSPlayStream { + if v.ctx == nil { + v.ctx = logger.WithContext(ctx) + } + return v +} + +func (v *HLSPlayStream) ServeHTTP(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + + if err := v.serve(v.ctx, w, r); err != nil { + apiError(v.ctx, w, r, err) + } else { + logger.Df(v.ctx, "HLS client %v for %v with %v done", + v.SRSProxyBackendHLSID, v.StreamURL, r.URL.Path) + } +} + +func (v *HLSPlayStream) serve(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + ctx, streamURL, fullURL := v.ctx, v.StreamURL, v.FullURL + + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { + return nil + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.serveByBackend(ctx, w, r, backend); err != nil { + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) + } + + return nil +} + +func (v *HLSPlayStream) serveByBackend(ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer) error { + // Parse HTTP port from backend. + if len(backend.HTTP) == 0 { + return errors.Errorf("no rtmp server") + } + + var httpPort int + if iv, err := strconv.ParseInt(backend.HTTP[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse http port %v", backend.HTTP[0]) + } else { + httpPort = int(iv) + } + + // Connect to backend SRS server via HTTP client. + backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, httpPort, r.URL.Path) + if r.URL.RawQuery != "" { + backendURL += "?" + r.URL.RawQuery + } + + req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, nil) + if err != nil { + return errors.Wrapf(err, "create request to %v", backendURL) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return errors.Errorf("do request to %v EOF", backendURL) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.Errorf("proxy stream to %v failed, status=%v", backendURL, resp.Status) + } + + // Copy all headers from backend to client. + w.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + + // For TS file, directly copy it. + if !strings.HasSuffix(r.URL.Path, ".m3u8") { + if _, err := io.Copy(w, resp.Body); err != nil { + return errors.Wrapf(err, "copy stream to client, backend=%v", backendURL) + } + + return nil + } + + // Read all content of m3u8, append the stream ID to ts URL. Note that we only append stream ID to ts + // URL, to identify the stream to specified backend server. The spbhid is the SRS Proxy Backend HLS ID. + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return errors.Wrapf(err, "read stream from %v", backendURL) + } + + m3u8 := string(b) + if strings.Contains(m3u8, ".ts?") { + m3u8 = strings.ReplaceAll(m3u8, ".ts?", fmt.Sprintf(".ts?spbhid=%v&&", v.SRSProxyBackendHLSID)) + } else { + m3u8 = strings.ReplaceAll(m3u8, ".ts", fmt.Sprintf(".ts?spbhid=%v", v.SRSProxyBackendHLSID)) + } + + if _, err := io.Copy(w, strings.NewReader(m3u8)); err != nil { + return errors.Wrapf(err, "proxy m3u8 client to %v", backendURL) + } + + return nil +} diff --git a/proxy/logger/context.go b/proxy/logger/context.go new file mode 100644 index 0000000000..ef15a7d4fb --- /dev/null +++ b/proxy/logger/context.go @@ -0,0 +1,43 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package logger + +import ( + "context" + "crypto/rand" + "crypto/sha256" + "encoding/hex" +) + +type key string + +var cidKey key = "cid.proxy.ossrs.org" + +// generateContextID generates a random context id in string. +func GenerateContextID() string { + randomBytes := make([]byte, 32) + _, _ = rand.Read(randomBytes) + hash := sha256.Sum256(randomBytes) + hashString := hex.EncodeToString(hash[:]) + cid := hashString[:7] + return cid +} + +// WithContext creates a new context with cid, which will be used for log. +func WithContext(ctx context.Context) context.Context { + return WithContextID(ctx, GenerateContextID()) +} + +// WithContextID creates a new context with cid, which will be used for log. +func WithContextID(ctx context.Context, cid string) context.Context { + return context.WithValue(ctx, cidKey, cid) +} + +// ContextID returns the cid in context, or empty string if not set. +func ContextID(ctx context.Context) string { + if cid, ok := ctx.Value(cidKey).(string); ok { + return cid + } + return "" +} diff --git a/proxy/logger/log.go b/proxy/logger/log.go new file mode 100644 index 0000000000..debbe1a847 --- /dev/null +++ b/proxy/logger/log.go @@ -0,0 +1,87 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package logger + +import ( + "context" + "io/ioutil" + stdLog "log" + "os" +) + +type logger interface { + Printf(ctx context.Context, format string, v ...any) +} + +type loggerPlus struct { + logger *stdLog.Logger + level string +} + +func newLoggerPlus(opts ...func(*loggerPlus)) *loggerPlus { + v := &loggerPlus{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *loggerPlus) Printf(ctx context.Context, f string, a ...interface{}) { + format, args := f, a + if cid := ContextID(ctx); cid != "" { + format, args = "[%v][%v][%v] "+format, append([]interface{}{v.level, os.Getpid(), cid}, a...) + } + + v.logger.Printf(format, args...) +} + +var verboseLogger logger + +func Vf(ctx context.Context, format string, a ...interface{}) { + verboseLogger.Printf(ctx, format, a...) +} + +var debugLogger logger + +func Df(ctx context.Context, format string, a ...interface{}) { + debugLogger.Printf(ctx, format, a...) +} + +var warnLogger logger + +func Wf(ctx context.Context, format string, a ...interface{}) { + warnLogger.Printf(ctx, format, a...) +} + +var errorLogger logger + +func Ef(ctx context.Context, format string, a ...interface{}) { + errorLogger.Printf(ctx, format, a...) +} + +const ( + logVerboseLabel = "verb" + logDebugLabel = "debug" + logWarnLabel = "warn" + logErrorLabel = "error" +) + +func init() { + verboseLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(ioutil.Discard, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logVerboseLabel + }) + debugLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(os.Stdout, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logDebugLabel + }) + warnLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logWarnLabel + }) + errorLogger = newLoggerPlus(func(logger *loggerPlus) { + logger.logger = stdLog.New(os.Stderr, "", stdLog.Ldate|stdLog.Ltime|stdLog.Lmicroseconds) + logger.level = logErrorLabel + }) +} diff --git a/proxy/main.go b/proxy/main.go new file mode 100644 index 0000000000..6327a7cf80 --- /dev/null +++ b/proxy/main.go @@ -0,0 +1,121 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "os" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +func main() { + ctx := logger.WithContext(context.Background()) + logger.Df(ctx, "%v/%v started", Signature(), Version()) + + // Install signals. + ctx, cancel := context.WithCancel(ctx) + installSignals(ctx, cancel) + + // Start the main loop, ignore the user cancel error. + err := doMain(ctx) + if err != nil && ctx.Err() != context.Canceled { + logger.Ef(ctx, "main: %+v", err) + os.Exit(-1) + } + + logger.Df(ctx, "%v done", Signature()) +} + +func doMain(ctx context.Context) error { + // Setup the environment variables. + if err := loadEnvFile(ctx); err != nil { + return errors.Wrapf(err, "load env") + } + + buildDefaultEnvironmentVariables(ctx) + + // When cancelled, the program is forced to exit due to a timeout. Normally, this doesn't occur + // because the main thread exits after the context is cancelled. However, sometimes the main thread + // may be blocked for some reason, so a forced exit is necessary to ensure the program terminates. + if err := installForceQuit(ctx); err != nil { + return errors.Wrapf(err, "install force quit") + } + + // Start the Go pprof if enabled. + handleGoPprof(ctx) + + // Initialize SRS load balancers. + switch lbType := envLoadBalancerType(); lbType { + case "memory": + srsLoadBalancer = NewMemoryLoadBalancer() + case "redis": + srsLoadBalancer = NewRedisLoadBalancer() + default: + return errors.Errorf("invalid load balancer %v", lbType) + } + + if err := srsLoadBalancer.Initialize(ctx); err != nil { + return errors.Wrapf(err, "initialize srs load balancer") + } + + // Parse the gracefully quit timeout. + gracefulQuitTimeout, err := parseGracefullyQuitTimeout() + if err != nil { + return errors.Wrapf(err, "parse gracefully quit timeout") + } + + // Start the RTMP server. + srsRTMPServer := NewSRSRTMPServer() + defer srsRTMPServer.Close() + if err := srsRTMPServer.Run(ctx); err != nil { + return errors.Wrapf(err, "rtmp server") + } + + // Start the WebRTC server. + srsWebRTCServer := NewSRSWebRTCServer() + defer srsWebRTCServer.Close() + if err := srsWebRTCServer.Run(ctx); err != nil { + return errors.Wrapf(err, "rtc server") + } + + // Start the HTTP API server. + srsHTTPAPIServer := NewSRSHTTPAPIServer(func(server *srsHTTPAPIServer) { + server.gracefulQuitTimeout, server.rtc = gracefulQuitTimeout, srsWebRTCServer + }) + defer srsHTTPAPIServer.Close() + if err := srsHTTPAPIServer.Run(ctx); err != nil { + return errors.Wrapf(err, "http api server") + } + + // Start the SRT server. + srsSRTServer := NewSRSSRTServer() + defer srsSRTServer.Close() + if err := srsSRTServer.Run(ctx); err != nil { + return errors.Wrapf(err, "srt server") + } + + // Start the System API server. + systemAPI := NewSystemAPI(func(server *systemAPI) { + server.gracefulQuitTimeout = gracefulQuitTimeout + }) + defer systemAPI.Close() + if err := systemAPI.Run(ctx); err != nil { + return errors.Wrapf(err, "system api server") + } + + // Start the HTTP web server. + srsHTTPStreamServer := NewSRSHTTPStreamServer(func(server *srsHTTPStreamServer) { + server.gracefulQuitTimeout = gracefulQuitTimeout + }) + defer srsHTTPStreamServer.Close() + if err := srsHTTPStreamServer.Run(ctx); err != nil { + return errors.Wrapf(err, "http server") + } + + // Wait for the main loop to quit. + <-ctx.Done() + return nil +} diff --git a/proxy/rtc.go b/proxy/rtc.go new file mode 100644 index 0000000000..5a7d9936c7 --- /dev/null +++ b/proxy/rtc.go @@ -0,0 +1,515 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "encoding/binary" + "fmt" + "io/ioutil" + "net" + "net/http" + "strconv" + "strings" + stdSync "sync" + + "srs-proxy/errors" + "srs-proxy/logger" + "srs-proxy/sync" +) + +// srsWebRTCServer is the proxy for SRS WebRTC server via WHIP or WHEP protocol. It will figure out +// which backend server to proxy to. It will also replace the UDP port to the proxy server's in the +// SDP answer. +type srsWebRTCServer struct { + // The UDP listener for WebRTC server. + listener *net.UDPConn + + // Fast cache for the username to identify the connection. + // The key is username, the value is the UDP address. + usernames sync.Map[string, *RTCConnection] + // Fast cache for the udp address to identify the connection. + // The key is UDP address, the value is the username. + // TODO: Support fast earch by uint64 address. + addresses sync.Map[string, *RTCConnection] + + // The wait group for server. + wg stdSync.WaitGroup +} + +func NewSRSWebRTCServer(opts ...func(*srsWebRTCServer)) *srsWebRTCServer { + v := &srsWebRTCServer{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srsWebRTCServer) Close() error { + if v.listener != nil { + _ = v.listener.Close() + } + + v.wg.Wait() + return nil +} + +func (v *srsWebRTCServer) HandleApiForWHIP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + defer r.Body.Close() + ctx = logger.WithContext(ctx) + + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { + return nil + } + + // Read remote SDP offer from body. + remoteSDPOffer, err := ioutil.ReadAll(r.Body) + if err != nil { + return errors.Wrapf(err, "read remote sdp offer") + } + + // Build the stream URL in vhost/app/stream schema. + unifiedURL, fullURL := convertURLToStreamURL(r) + logger.Df(ctx, "Got WebRTC WHIP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL) + + streamURL, err := buildStreamURL(unifiedURL) + if err != nil { + return errors.Wrapf(err, "build stream url %v", unifiedURL) + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil { + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) + } + + return nil +} + +func (v *srsWebRTCServer) HandleApiForWHEP(ctx context.Context, w http.ResponseWriter, r *http.Request) error { + defer r.Body.Close() + ctx = logger.WithContext(ctx) + + // Always allow CORS for all requests. + if ok := apiCORS(ctx, w, r); ok { + return nil + } + + // Read remote SDP offer from body. + remoteSDPOffer, err := ioutil.ReadAll(r.Body) + if err != nil { + return errors.Wrapf(err, "read remote sdp offer") + } + + // Build the stream URL in vhost/app/stream schema. + unifiedURL, fullURL := convertURLToStreamURL(r) + logger.Df(ctx, "Got WebRTC WHEP from %v with %vB offer for %v", r.RemoteAddr, len(remoteSDPOffer), fullURL) + + streamURL, err := buildStreamURL(unifiedURL) + if err != nil { + return errors.Wrapf(err, "build stream url %v", unifiedURL) + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + if err = v.proxyApiToBackend(ctx, w, r, backend, string(remoteSDPOffer), streamURL); err != nil { + return errors.Wrapf(err, "serve %v with %v by backend %+v", fullURL, streamURL, backend) + } + + return nil +} + +func (v *srsWebRTCServer) proxyApiToBackend( + ctx context.Context, w http.ResponseWriter, r *http.Request, backend *SRSServer, + remoteSDPOffer string, streamURL string, +) error { + // Parse HTTP port from backend. + if len(backend.API) == 0 { + return errors.Errorf("no http api server") + } + + var apiPort int + if iv, err := strconv.ParseInt(backend.API[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse http port %v", backend.API[0]) + } else { + apiPort = int(iv) + } + + // Connect to backend SRS server via HTTP client. + backendURL := fmt.Sprintf("http://%v:%v%s", backend.IP, apiPort, r.URL.Path) + if r.URL.RawQuery != "" { + backendURL += "?" + r.URL.RawQuery + } + + req, err := http.NewRequestWithContext(ctx, r.Method, backendURL, strings.NewReader(remoteSDPOffer)) + if err != nil { + return errors.Wrapf(err, "create request to %v", backendURL) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return errors.Errorf("do request to %v EOF", backendURL) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + return errors.Errorf("proxy api to %v failed, status=%v", backendURL, resp.Status) + } + + // Copy all headers from backend to client. + w.WriteHeader(resp.StatusCode) + for k, v := range resp.Header { + for _, vv := range v { + w.Header().Add(k, vv) + } + } + + // Parse the local SDP answer from backend. + b, err := ioutil.ReadAll(resp.Body) + if err != nil { + return errors.Wrapf(err, "read stream from %v", backendURL) + } + + // Replace the WebRTC UDP port in answer. + localSDPAnswer := string(b) + for _, endpoint := range backend.RTC { + _, _, port, err := parseListenEndpoint(endpoint) + if err != nil { + return errors.Wrapf(err, "parse endpoint %v", endpoint) + } + + from := fmt.Sprintf(" %v typ host", port) + to := fmt.Sprintf(" %v typ host", envWebRTCServer()) + localSDPAnswer = strings.Replace(localSDPAnswer, from, to, -1) + } + + // Fetch the ice-ufrag and ice-pwd from local SDP answer. + remoteICEUfrag, remoteICEPwd, err := parseIceUfragPwd(remoteSDPOffer) + if err != nil { + return errors.Wrapf(err, "parse remote sdp offer") + } + + localICEUfrag, localICEPwd, err := parseIceUfragPwd(localSDPAnswer) + if err != nil { + return errors.Wrapf(err, "parse local sdp answer") + } + + // Save the new WebRTC connection to LB. + icePair := &RTCICEPair{ + RemoteICEUfrag: remoteICEUfrag, RemoteICEPwd: remoteICEPwd, + LocalICEUfrag: localICEUfrag, LocalICEPwd: localICEPwd, + } + if err := srsLoadBalancer.StoreWebRTC(ctx, streamURL, NewRTCConnection(func(c *RTCConnection) { + c.StreamURL, c.Ufrag = streamURL, icePair.Ufrag() + c.Initialize(ctx, v.listener) + + // Cache the connection for fast search by username. + v.usernames.Store(c.Ufrag, c) + })); err != nil { + return errors.Wrapf(err, "load or store webrtc %v", streamURL) + } + + // Response client with local answer. + if _, err = w.Write([]byte(localSDPAnswer)); err != nil { + return errors.Wrapf(err, "write local sdp answer %v", localSDPAnswer) + } + + logger.Df(ctx, "Create WebRTC connection with local answer %vB with ice-ufrag=%v, ice-pwd=%vB", + len(localSDPAnswer), localICEUfrag, len(localICEPwd)) + return nil +} + +func (v *srsWebRTCServer) Run(ctx context.Context) error { + // Parse address to listen. + endpoint := envWebRTCServer() + if !strings.Contains(endpoint, ":") { + endpoint = fmt.Sprintf(":%v", endpoint) + } + + saddr, err := net.ResolveUDPAddr("udp", endpoint) + if err != nil { + return errors.Wrapf(err, "resolve udp addr %v", endpoint) + } + + listener, err := net.ListenUDP("udp", saddr) + if err != nil { + return errors.Wrapf(err, "listen udp %v", saddr) + } + v.listener = listener + logger.Df(ctx, "WebRTC server listen at %v", saddr) + + // Consume all messages from UDP media transport. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, caddr, err := listener.ReadFromUDP(buf) + if err != nil { + // TODO: If WebRTC server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "read from udp failed, err=%+v", err) + continue + } + + if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil { + logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err) + } + } + }() + + return nil +} + +func (v *srsWebRTCServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { + var connection *RTCConnection + + // If STUN binding request, parse the ufrag and identify the connection. + if err := func() error { + if rtcIsRTPOrRTCP(data) || !rtcIsSTUN(data) { + return nil + } + + var pkt RTCStunPacket + if err := pkt.UnmarshalBinary(data); err != nil { + return errors.Wrapf(err, "unmarshal stun packet") + } + + // Search the connection in fast cache. + if s, ok := v.usernames.Load(pkt.Username); ok { + connection = s + return nil + } + + // Load connection by username. + if s, err := srsLoadBalancer.LoadWebRTCByUfrag(ctx, pkt.Username); err != nil { + return errors.Wrapf(err, "load webrtc by ufrag %v", pkt.Username) + } else { + connection = s.Initialize(ctx, v.listener) + logger.Df(ctx, "Create WebRTC connection by ufrag=%v, stream=%v", pkt.Username, connection.StreamURL) + } + + // Cache connection for fast search. + if connection != nil { + v.usernames.Store(pkt.Username, connection) + } + return nil + }(); err != nil { + return err + } + + // Search the connection by addr. + if s, ok := v.addresses.Load(addr.String()); ok { + connection = s + } else if connection != nil { + // Cache the address for fast search. + v.addresses.Store(addr.String(), connection) + } + + // If connection is not found, ignore the packet. + if connection == nil { + // TODO: Should logging the dropped packet, only logging the first one for each address. + return nil + } + + // Proxy the packet to backend. + if err := connection.HandlePacket(addr, data); err != nil { + return errors.Wrapf(err, "proxy %vB for %v", len(data), connection.StreamURL) + } + + return nil +} + +// RTCConnection is a WebRTC connection proxy, for both WHIP and WHEP. It represents a WebRTC +// connection, identify by the ufrag in sdp offer/answer and ICE binding request. +// +// It's not like RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is +// in the client request. The RTCConnection is stateful, and need to sync the ufrag between +// proxy servers. +// +// The media transport is UDP, which is also a special thing for WebRTC. So if the client switch +// to another UDP address, it may connect to another WebRTC proxy, then we should discover the +// RTCConnection by the ufrag from the ICE binding request. +type RTCConnection struct { + // The stream context for WebRTC streaming. + ctx context.Context + + // The stream URL in vhost/app/stream schema. + StreamURL string `json:"stream_url"` + // The ufrag for this WebRTC connection. + Ufrag string `json:"ufrag"` + + // The UDP connection proxy to backend. + backendUDP *net.UDPConn + // The client UDP address. Note that it may change. + clientUDP *net.UDPAddr + // The listener UDP connection, used to send messages to client. + listenerUDP *net.UDPConn +} + +func NewRTCConnection(opts ...func(*RTCConnection)) *RTCConnection { + v := &RTCConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *RTCConnection) Initialize(ctx context.Context, listener *net.UDPConn) *RTCConnection { + if v.ctx == nil { + v.ctx = logger.WithContext(ctx) + } + if listener != nil { + v.listenerUDP = listener + } + return v +} + +func (v *RTCConnection) HandlePacket(addr *net.UDPAddr, data []byte) error { + ctx := v.ctx + + // Update the current UDP address. + v.clientUDP = addr + + // Start the UDP proxy to backend. + if err := v.connectBackend(ctx); err != nil { + return errors.Wrapf(err, "connect backend for %v", v.StreamURL) + } + + // Proxy client message to backend. + if v.backendUDP == nil { + return nil + } + + // Proxy all messages from backend to client. + go func() { + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, _, err := v.backendUDP.ReadFromUDP(buf) + if err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "read from backend failed, err=%v", err) + break + } + + if _, err = v.listenerUDP.WriteToUDP(buf[:n], v.clientUDP); err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "write to client failed, err=%v", err) + break + } + } + }() + + if _, err := v.backendUDP.Write(data); err != nil { + return errors.Wrapf(err, "write to backend %v", v.StreamURL) + } + + return nil +} + +func (v *RTCConnection) connectBackend(ctx context.Context) error { + if v.backendUDP != nil { + return nil + } + + // Pick a backend SRS server to proxy the RTC stream. + backend, err := srsLoadBalancer.Pick(ctx, v.StreamURL) + if err != nil { + return errors.Wrapf(err, "pick backend") + } + + // Parse UDP port from backend. + if len(backend.RTC) == 0 { + return errors.Errorf("no udp server") + } + + _, _, udpPort, err := parseListenEndpoint(backend.RTC[0]) + if err != nil { + return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.RTC[0], backend, v.StreamURL) + } + + // Connect to backend SRS server via UDP client. + // TODO: FIXME: Support close the connection when timeout or DTLS alert. + backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)} + if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { + return errors.Wrapf(err, "dial udp to %v", backendAddr) + } else { + v.backendUDP = backendUDP + } + + return nil +} + +type RTCICEPair struct { + // The remote ufrag, used for ICE username and session id. + RemoteICEUfrag string `json:"remote_ufrag"` + // The remote pwd, used for ICE password. + RemoteICEPwd string `json:"remote_pwd"` + // The local ufrag, used for ICE username and session id. + LocalICEUfrag string `json:"local_ufrag"` + // The local pwd, used for ICE password. + LocalICEPwd string `json:"local_pwd"` +} + +// Generate the ICE ufrag for the WebRTC streaming, format is remote-ufrag:local-ufrag. +func (v *RTCICEPair) Ufrag() string { + return fmt.Sprintf("%v:%v", v.LocalICEUfrag, v.RemoteICEUfrag) +} + +type RTCStunPacket struct { + // The stun message type. + MessageType uint16 + // The stun username, or ufrag. + Username string +} + +func (v *RTCStunPacket) UnmarshalBinary(data []byte) error { + if len(data) < 20 { + return errors.Errorf("stun packet too short %v", len(data)) + } + + p := data + v.MessageType = binary.BigEndian.Uint16(p) + messageLen := binary.BigEndian.Uint16(p[2:]) + //magicCookie := p[:8] + //transactionID := p[:20] + p = p[20:] + + if len(p) != int(messageLen) { + return errors.Errorf("stun packet length invalid %v != %v", len(data), messageLen) + } + + for len(p) > 0 { + typ := binary.BigEndian.Uint16(p) + length := binary.BigEndian.Uint16(p[2:]) + p = p[4:] + + if len(p) < int(length) { + return errors.Errorf("stun attribute length invalid %v < %v", len(p), length) + } + + value := p[:length] + p = p[length:] + + if length%4 != 0 { + p = p[4-length%4:] + } + + switch typ { + case 0x0006: + v.Username = string(value) + } + } + + return nil +} diff --git a/proxy/rtmp.go b/proxy/rtmp.go new file mode 100644 index 0000000000..d93f04b3a6 --- /dev/null +++ b/proxy/rtmp.go @@ -0,0 +1,655 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "fmt" + "math/rand" + "net" + "strconv" + "strings" + "sync" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" + "srs-proxy/rtmp" +) + +// srsRTMPServer is the proxy for SRS RTMP server, to proxy the RTMP stream to backend SRS +// server. It will figure out the backend server to proxy to. Unlike the edge server, it will +// not cache the stream, but just proxy the stream to backend. +type srsRTMPServer struct { + // The TCP listener for RTMP server. + listener *net.TCPListener + // The random number generator. + rd *rand.Rand + // The wait group for all goroutines. + wg sync.WaitGroup +} + +func NewSRSRTMPServer(opts ...func(*srsRTMPServer)) *srsRTMPServer { + v := &srsRTMPServer{ + rd: rand.New(rand.NewSource(time.Now().UnixNano())), + } + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srsRTMPServer) Close() error { + if v.listener != nil { + v.listener.Close() + } + + v.wg.Wait() + return nil +} + +func (v *srsRTMPServer) Run(ctx context.Context) error { + endpoint := envRtmpServer() + if !strings.Contains(endpoint, ":") { + endpoint = ":" + endpoint + } + + addr, err := net.ResolveTCPAddr("tcp", endpoint) + if err != nil { + return errors.Wrapf(err, "resolve rtmp addr %v", endpoint) + } + + listener, err := net.ListenTCP("tcp", addr) + if err != nil { + return errors.Wrapf(err, "listen rtmp addr %v", addr) + } + v.listener = listener + logger.Df(ctx, "RTMP server listen at %v", addr) + + v.wg.Add(1) + go func() { + defer v.wg.Done() + + for { + conn, err := v.listener.AcceptTCP() + if err != nil { + if ctx.Err() != context.Canceled { + // TODO: If RTMP server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "RTMP server accept err %+v", err) + } else { + logger.Df(ctx, "RTMP server done") + } + return + } + + v.wg.Add(1) + go func(ctx context.Context, conn *net.TCPConn) { + defer v.wg.Done() + defer conn.Close() + + handleErr := func(err error) { + if isPeerClosedError(err) { + logger.Df(ctx, "RTMP peer is closed") + } else { + logger.Wf(ctx, "RTMP serve err %+v", err) + } + } + + rc := NewRTMPConnection(func(client *RTMPConnection) { + client.rd = v.rd + }) + if err := rc.serve(ctx, conn); err != nil { + handleErr(err) + } else { + logger.Df(ctx, "RTMP client done") + } + }(logger.WithContext(ctx), conn) + } + }() + + return nil +} + +// RTMPConnection is an RTMP streaming connection. There is no state need to be sync between +// proxy servers. +// +// When we got an RTMP request, we will parse the stream URL from the RTMP publish or play request, +// then proxy to the corresponding backend server. All state is in the RTMP request, so this +// connection is stateless. +type RTMPConnection struct { + // The random number generator. + rd *rand.Rand +} + +func NewRTMPConnection(opts ...func(*RTMPConnection)) *RTMPConnection { + v := &RTMPConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *RTMPConnection) serve(ctx context.Context, conn *net.TCPConn) error { + logger.Df(ctx, "Got RTMP client from %v", conn.RemoteAddr()) + + // If any goroutine quit, cancel another one. + parentCtx := ctx + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var backend *RTMPClientToBackend + if true { + go func() { + <-ctx.Done() + conn.Close() + if backend != nil { + backend.Close() + } + }() + } + + // Simple handshake with client. + hs := rtmp.NewHandshake(v.rd) + if _, err := hs.ReadC0S0(conn); err != nil { + return errors.Wrapf(err, "read c0") + } + if _, err := hs.ReadC1S1(conn); err != nil { + return errors.Wrapf(err, "read c1") + } + if err := hs.WriteC0S0(conn); err != nil { + return errors.Wrapf(err, "write s1") + } + if err := hs.WriteC1S1(conn); err != nil { + return errors.Wrapf(err, "write s1") + } + if err := hs.WriteC2S2(conn, hs.C1S1()); err != nil { + return errors.Wrapf(err, "write s2") + } + if _, err := hs.ReadC2S2(conn); err != nil { + return errors.Wrapf(err, "read c2") + } + + client := rtmp.NewProtocol(conn) + logger.Df(ctx, "RTMP simple handshake done") + + // Expect RTMP connect command with tcUrl. + var connectReq *rtmp.ConnectAppPacket + if _, err := rtmp.ExpectPacket(ctx, client, &connectReq); err != nil { + return errors.Wrapf(err, "expect connect req") + } + + if true { + ack := rtmp.NewWindowAcknowledgementSize() + ack.AckSize = 2500000 + if err := client.WritePacket(ctx, ack, 0); err != nil { + return errors.Wrapf(err, "write set ack size") + } + } + if true { + chunk := rtmp.NewSetChunkSize() + chunk.ChunkSize = 128 + if err := client.WritePacket(ctx, chunk, 0); err != nil { + return errors.Wrapf(err, "write set chunk size") + } + } + + connectRes := rtmp.NewConnectAppResPacket(connectReq.TransactionID) + connectRes.CommandObject.Set("fmsVer", rtmp.NewAmf0String("FMS/3,5,3,888")) + connectRes.CommandObject.Set("capabilities", rtmp.NewAmf0Number(127)) + connectRes.CommandObject.Set("mode", rtmp.NewAmf0Number(1)) + connectRes.Args.Set("level", rtmp.NewAmf0String("status")) + connectRes.Args.Set("code", rtmp.NewAmf0String("NetConnection.Connect.Success")) + connectRes.Args.Set("description", rtmp.NewAmf0String("Connection succeeded")) + connectRes.Args.Set("objectEncoding", rtmp.NewAmf0Number(0)) + connectResData := rtmp.NewAmf0EcmaArray() + connectResData.Set("version", rtmp.NewAmf0String("3,5,3,888")) + connectResData.Set("srs_version", rtmp.NewAmf0String(Version())) + connectResData.Set("srs_id", rtmp.NewAmf0String(logger.ContextID(ctx))) + connectRes.Args.Set("data", connectResData) + if err := client.WritePacket(ctx, connectRes, 0); err != nil { + return errors.Wrapf(err, "write connect res") + } + + tcUrl := connectReq.TcUrl() + logger.Df(ctx, "RTMP connect app %v", tcUrl) + + // Expect RTMP command to identify the client, a publisher or viewer. + var currentStreamID, nextStreamID int + var streamName string + var clientType RTMPClientType + for clientType == "" { + var identifyReq rtmp.Packet + if _, err := rtmp.ExpectPacket(ctx, client, &identifyReq); err != nil { + return errors.Wrapf(err, "expect identify req") + } + + var response rtmp.Packet + switch pkt := identifyReq.(type) { + case *rtmp.CallPacket: + if pkt.CommandName == "createStream" { + identifyRes := rtmp.NewCreateStreamResPacket(pkt.TransactionID) + response = identifyRes + + nextStreamID = 1 + identifyRes.StreamID = *rtmp.NewAmf0Number(float64(nextStreamID)) + } else if pkt.CommandName == "getStreamLength" { + // Ignore and do not reply these packets. + } else { + // For releaseStream, FCPublish, etc. + identifyRes := rtmp.NewCallPacket() + response = identifyRes + + identifyRes.TransactionID = pkt.TransactionID + identifyRes.CommandName = "_result" + identifyRes.CommandObject = rtmp.NewAmf0Null() + identifyRes.Args = rtmp.NewAmf0Undefined() + } + case *rtmp.PublishPacket: + streamName = string(pkt.StreamName) + clientType = RTMPClientTypePublisher + + identifyRes := rtmp.NewCallPacket() + response = identifyRes + + identifyRes.CommandName = "onFCPublish" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start")) + data.Set("description", rtmp.NewAmf0String("Started publishing stream.")) + identifyRes.Args = data + case *rtmp.PlayPacket: + streamName = string(pkt.StreamName) + clientType = RTMPClientTypeViewer + + identifyRes := rtmp.NewCallPacket() + response = identifyRes + + identifyRes.CommandName = "onStatus" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("level", rtmp.NewAmf0String("status")) + data.Set("code", rtmp.NewAmf0String("NetStream.Play.Reset")) + data.Set("description", rtmp.NewAmf0String("Playing and resetting stream.")) + data.Set("details", rtmp.NewAmf0String("stream")) + data.Set("clientid", rtmp.NewAmf0String("ASAICiss")) + identifyRes.Args = data + } + + if response != nil { + if err := client.WritePacket(ctx, response, currentStreamID); err != nil { + return errors.Wrapf(err, "write identify res for req=%v, stream=%v", + identifyReq, currentStreamID) + } + } + + // Update the stream ID for next request. + currentStreamID = nextStreamID + } + logger.Df(ctx, "RTMP identify tcUrl=%v, stream=%v, id=%v, type=%v", + tcUrl, streamName, currentStreamID, clientType) + + // Find a backend SRS server to proxy the RTMP stream. + backend = NewRTMPClientToBackend(func(client *RTMPClientToBackend) { + client.rd, client.typ = v.rd, clientType + }) + defer backend.Close() + + if err := backend.Connect(ctx, tcUrl, streamName); err != nil { + return errors.Wrapf(err, "connect backend, tcUrl=%v, stream=%v", tcUrl, streamName) + } + + // Start the streaming. + if clientType == RTMPClientTypePublisher { + identifyRes := rtmp.NewCallPacket() + + identifyRes.CommandName = "onStatus" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("level", rtmp.NewAmf0String("status")) + data.Set("code", rtmp.NewAmf0String("NetStream.Publish.Start")) + data.Set("description", rtmp.NewAmf0String("Started publishing stream.")) + data.Set("clientid", rtmp.NewAmf0String("ASAICiss")) + identifyRes.Args = data + + if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil { + return errors.Wrapf(err, "start publish") + } + } else if clientType == RTMPClientTypeViewer { + identifyRes := rtmp.NewCallPacket() + + identifyRes.CommandName = "onStatus" + identifyRes.CommandObject = rtmp.NewAmf0Null() + + data := rtmp.NewAmf0Object() + data.Set("level", rtmp.NewAmf0String("status")) + data.Set("code", rtmp.NewAmf0String("NetStream.Play.Start")) + data.Set("description", rtmp.NewAmf0String("Started playing stream.")) + data.Set("details", rtmp.NewAmf0String("stream")) + data.Set("clientid", rtmp.NewAmf0String("ASAICiss")) + identifyRes.Args = data + + if err := client.WritePacket(ctx, identifyRes, currentStreamID); err != nil { + return errors.Wrapf(err, "start play") + } + } + logger.Df(ctx, "RTMP start streaming") + + // For all proxy goroutines. + var wg sync.WaitGroup + defer wg.Wait() + + // Proxy all message from backend to client. + wg.Add(1) + var r0 error + go func() { + defer wg.Done() + defer cancel() + + r0 = func() error { + for { + m, err := backend.client.ReadMessage(ctx) + if err != nil { + return errors.Wrapf(err, "read message") + } + //logger.Df(ctx, "client<- %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) + + // TODO: Update the stream ID if not the same. + if err := client.WriteMessage(ctx, m); err != nil { + return errors.Wrapf(err, "write message") + } + } + }() + }() + + // Proxy all messages from client to backend. + wg.Add(1) + var r1 error + go func() { + defer wg.Done() + defer cancel() + + r1 = func() error { + for { + m, err := client.ReadMessage(ctx) + if err != nil { + return errors.Wrapf(err, "read message") + } + //logger.Df(ctx, "client-> %v %v %vB", m.MessageType, m.Timestamp, len(m.Payload)) + + // TODO: Update the stream ID if not the same. + if err := backend.client.WriteMessage(ctx, m); err != nil { + return errors.Wrapf(err, "write message") + } + } + }() + }() + + // Wait until all goroutine quit. + wg.Wait() + + // Reset the error if caused by another goroutine. + if r0 != nil { + return errors.Wrapf(r0, "proxy backend->client") + } + if r1 != nil { + return errors.Wrapf(r1, "proxy client->backend") + } + + return parentCtx.Err() +} + +type RTMPClientType string + +const ( + RTMPClientTypePublisher RTMPClientType = "publisher" + RTMPClientTypeViewer RTMPClientType = "viewer" +) + +// RTMPClientToBackend is a RTMP client to proxy the RTMP stream to backend. +type RTMPClientToBackend struct { + // The random number generator. + rd *rand.Rand + // The underlayer tcp client. + tcpConn *net.TCPConn + // The RTMP protocol client. + client *rtmp.Protocol + // The stream type. + typ RTMPClientType +} + +func NewRTMPClientToBackend(opts ...func(*RTMPClientToBackend)) *RTMPClientToBackend { + v := &RTMPClientToBackend{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *RTMPClientToBackend) Close() error { + if v.tcpConn != nil { + v.tcpConn.Close() + } + return nil +} + +func (v *RTMPClientToBackend) Connect(ctx context.Context, tcUrl, streamName string) error { + // Build the stream URL in vhost/app/stream schema. + streamURL, err := buildStreamURL(fmt.Sprintf("%v/%v", tcUrl, streamName)) + if err != nil { + return errors.Wrapf(err, "build stream url %v/%v", tcUrl, streamName) + } + + // Pick a backend SRS server to proxy the RTMP stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + // Parse RTMP port from backend. + if len(backend.RTMP) == 0 { + return errors.Errorf("no rtmp server %+v for %v", backend, streamURL) + } + + var rtmpPort int + if iv, err := strconv.ParseInt(backend.RTMP[0], 10, 64); err != nil { + return errors.Wrapf(err, "parse backend %+v rtmp port %v", backend, backend.RTMP[0]) + } else { + rtmpPort = int(iv) + } + + // Connect to backend SRS server via TCP client. + addr := &net.TCPAddr{IP: net.ParseIP(backend.IP), Port: rtmpPort} + c, err := net.DialTCP("tcp", nil, addr) + if err != nil { + return errors.Wrapf(err, "dial backend addr=%v, srs=%v", addr, backend) + } + v.tcpConn = c + + hs := rtmp.NewHandshake(v.rd) + client := rtmp.NewProtocol(c) + v.client = client + + // Simple RTMP handshake with server. + if err := hs.WriteC0S0(c); err != nil { + return errors.Wrapf(err, "write c0") + } + if err := hs.WriteC1S1(c); err != nil { + return errors.Wrapf(err, "write c1") + } + + if _, err = hs.ReadC0S0(c); err != nil { + return errors.Wrapf(err, "read s0") + } + if _, err := hs.ReadC1S1(c); err != nil { + return errors.Wrapf(err, "read s1") + } + if _, err = hs.ReadC2S2(c); err != nil { + return errors.Wrapf(err, "read c2") + } + logger.Df(ctx, "backend simple handshake done, server=%v", addr) + + if err := hs.WriteC2S2(c, hs.C1S1()); err != nil { + return errors.Wrapf(err, "write c2") + } + + // Connect RTMP app on tcUrl with server. + if true { + connectApp := rtmp.NewConnectAppPacket() + connectApp.CommandObject.Set("tcUrl", rtmp.NewAmf0String(tcUrl)) + if err := client.WritePacket(ctx, connectApp, 1); err != nil { + return errors.Wrapf(err, "write connect app") + } + } + + if true { + var connectAppRes *rtmp.ConnectAppResPacket + if _, err := rtmp.ExpectPacket(ctx, client, &connectAppRes); err != nil { + return errors.Wrapf(err, "expect connect app res") + } + logger.Df(ctx, "backend connect RTMP app, tcUrl=%v, id=%v", tcUrl, connectAppRes.SrsID()) + } + + // Play or view RTMP stream with server. + if v.typ == RTMPClientTypeViewer { + return v.play(ctx, client, streamName) + } + + // Publish RTMP stream with server. + return v.publish(ctx, client, streamName) +} + +func (v *RTMPClientToBackend) publish(ctx context.Context, client *rtmp.Protocol, streamName string) error { + if true { + identifyReq := rtmp.NewCallPacket() + identifyReq.CommandName = "releaseStream" + identifyReq.TransactionID = 2 + identifyReq.CommandObject = rtmp.NewAmf0Null() + identifyReq.Args = rtmp.NewAmf0String(streamName) + if err := client.WritePacket(ctx, identifyReq, 0); err != nil { + return errors.Wrapf(err, "releaseStream") + } + } + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect releaseStream res") + } + if identifyRes.CommandName == "_result" { + break + } + } + + if true { + identifyReq := rtmp.NewCallPacket() + identifyReq.CommandName = "FCPublish" + identifyReq.TransactionID = 3 + identifyReq.CommandObject = rtmp.NewAmf0Null() + identifyReq.Args = rtmp.NewAmf0String(streamName) + if err := client.WritePacket(ctx, identifyReq, 0); err != nil { + return errors.Wrapf(err, "FCPublish") + } + } + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect FCPublish res") + } + if identifyRes.CommandName == "_result" { + break + } + } + + var currentStreamID int + if true { + createStream := rtmp.NewCreateStreamPacket() + createStream.TransactionID = 4 + createStream.CommandObject = rtmp.NewAmf0Null() + if err := client.WritePacket(ctx, createStream, 0); err != nil { + return errors.Wrapf(err, "createStream") + } + } + for { + var identifyRes *rtmp.CreateStreamResPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect createStream res") + } + if sid := identifyRes.StreamID; sid != 0 { + currentStreamID = int(sid) + break + } + } + + if true { + publishStream := rtmp.NewPublishPacket() + publishStream.TransactionID = 5 + publishStream.CommandObject = rtmp.NewAmf0Null() + publishStream.StreamName = *rtmp.NewAmf0String(streamName) + publishStream.StreamType = *rtmp.NewAmf0String("live") + if err := client.WritePacket(ctx, publishStream, currentStreamID); err != nil { + return errors.Wrapf(err, "publish") + } + } + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect publish res") + } + // Ignore onFCPublish, expect onStatus(NetStream.Publish.Start). + if identifyRes.CommandName == "onStatus" { + if data := rtmp.NewAmf0Converter(identifyRes.Args).ToObject(); data == nil { + return errors.Errorf("onStatus args not object") + } else if code := rtmp.NewAmf0Converter(data.Get("code")).ToString(); code == nil { + return errors.Errorf("onStatus code not string") + } else if *code != "NetStream.Publish.Start" { + return errors.Errorf("onStatus code=%v not NetStream.Publish.Start", *code) + } + break + } + } + logger.Df(ctx, "backend publish stream=%v, sid=%v", streamName, currentStreamID) + + return nil +} + +func (v *RTMPClientToBackend) play(ctx context.Context, client *rtmp.Protocol, streamName string) error { + var currentStreamID int + if true { + createStream := rtmp.NewCreateStreamPacket() + createStream.TransactionID = 4 + createStream.CommandObject = rtmp.NewAmf0Null() + if err := client.WritePacket(ctx, createStream, 0); err != nil { + return errors.Wrapf(err, "createStream") + } + } + for { + var identifyRes *rtmp.CreateStreamResPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect createStream res") + } + if sid := identifyRes.StreamID; sid != 0 { + currentStreamID = int(sid) + break + } + } + + playStream := rtmp.NewPlayPacket() + playStream.StreamName = *rtmp.NewAmf0String(streamName) + if err := client.WritePacket(ctx, playStream, currentStreamID); err != nil { + return errors.Wrapf(err, "play") + } + + for { + var identifyRes *rtmp.CallPacket + if _, err := rtmp.ExpectPacket(ctx, client, &identifyRes); err != nil { + return errors.Wrapf(err, "expect releaseStream res") + } + if identifyRes.CommandName == "onStatus" && identifyRes.ArgsCode() == "NetStream.Play.Start" { + break + } + } + return nil +} diff --git a/proxy/rtmp/amf0.go b/proxy/rtmp/amf0.go new file mode 100644 index 0000000000..a013d5eccb --- /dev/null +++ b/proxy/rtmp/amf0.go @@ -0,0 +1,771 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package rtmp + +import ( + "bytes" + "encoding" + "encoding/binary" + "fmt" + "math" + "sync" + + "srs-proxy/errors" +) + +// Please read @doc amf0_spec_121207.pdf, @page 4, @section 2.1 Types Overview +type amf0Marker uint8 + +const ( + amf0MarkerNumber amf0Marker = iota // 0 + amf0MarkerBoolean // 1 + amf0MarkerString // 2 + amf0MarkerObject // 3 + amf0MarkerMovieClip // 4 + amf0MarkerNull // 5 + amf0MarkerUndefined // 6 + amf0MarkerReference // 7 + amf0MarkerEcmaArray // 8 + amf0MarkerObjectEnd // 9 + amf0MarkerStrictArray // 10 + amf0MarkerDate // 11 + amf0MarkerLongString // 12 + amf0MarkerUnsupported // 13 + amf0MarkerRecordSet // 14 + amf0MarkerXmlDocument // 15 + amf0MarkerTypedObject // 16 + amf0MarkerAvmPlusObject // 17 + + amf0MarkerForbidden amf0Marker = 0xff +) + +func (v amf0Marker) String() string { + switch v { + case amf0MarkerNumber: + return "Amf0Number" + case amf0MarkerBoolean: + return "amf0Boolean" + case amf0MarkerString: + return "Amf0String" + case amf0MarkerObject: + return "Amf0Object" + case amf0MarkerNull: + return "Null" + case amf0MarkerUndefined: + return "Undefined" + case amf0MarkerReference: + return "Reference" + case amf0MarkerEcmaArray: + return "EcmaArray" + case amf0MarkerObjectEnd: + return "ObjectEnd" + case amf0MarkerStrictArray: + return "StrictArray" + case amf0MarkerDate: + return "Date" + case amf0MarkerLongString: + return "LongString" + case amf0MarkerUnsupported: + return "Unsupported" + case amf0MarkerXmlDocument: + return "XmlDocument" + case amf0MarkerTypedObject: + return "TypedObject" + case amf0MarkerAvmPlusObject: + return "AvmPlusObject" + case amf0MarkerMovieClip: + return "MovieClip" + case amf0MarkerRecordSet: + return "RecordSet" + default: + return "Forbidden" + } +} + +// For utest to mock it. +type amf0Buffer interface { + Bytes() []byte + WriteByte(c byte) error + Write(p []byte) (n int, err error) +} + +var createBuffer = func() amf0Buffer { + return &bytes.Buffer{} +} + +// All AMF0 things. +type amf0Any interface { + // Binary marshaler and unmarshaler. + encoding.BinaryUnmarshaler + encoding.BinaryMarshaler + // Get the size of bytes to marshal this object. + Size() int + + // Get the Marker of any AMF0 stuff. + amf0Marker() amf0Marker +} + +type amf0Converter struct { + from amf0Any +} + +func NewAmf0Converter(from amf0Any) *amf0Converter { + return &amf0Converter{from: from} +} + +func (v *amf0Converter) ToNumber() *amf0Number { + return amf0AnyTo[*amf0Number](v.from) +} + +func (v *amf0Converter) ToBoolean() *amf0Boolean { + return amf0AnyTo[*amf0Boolean](v.from) +} + +func (v *amf0Converter) ToString() *amf0String { + return amf0AnyTo[*amf0String](v.from) +} + +func (v *amf0Converter) ToObject() *amf0Object { + return amf0AnyTo[*amf0Object](v.from) +} + +func (v *amf0Converter) ToNull() *amf0Null { + return amf0AnyTo[*amf0Null](v.from) +} + +func (v *amf0Converter) ToUndefined() *amf0Undefined { + return amf0AnyTo[*amf0Undefined](v.from) +} + +func (v *amf0Converter) ToEcmaArray() *amf0EcmaArray { + return amf0AnyTo[*amf0EcmaArray](v.from) +} + +func (v *amf0Converter) ToStrictArray() *amf0StrictArray { + return amf0AnyTo[*amf0StrictArray](v.from) +} + +// Convert any to specified object. +func amf0AnyTo[T amf0Any](a amf0Any) T { + var to T + if a != nil { + if v, ok := a.(T); ok { + return v + } + } + return to +} + +// Discovery the amf0 object from the bytes b. +func Amf0Discovery(p []byte) (a amf0Any, err error) { + if len(p) < 1 { + return nil, errors.Errorf("require 1 bytes only %v", len(p)) + } + m := amf0Marker(p[0]) + + switch m { + case amf0MarkerNumber: + return NewAmf0Number(0), nil + case amf0MarkerBoolean: + return NewAmf0Boolean(false), nil + case amf0MarkerString: + return NewAmf0String(""), nil + case amf0MarkerObject: + return NewAmf0Object(), nil + case amf0MarkerNull: + return NewAmf0Null(), nil + case amf0MarkerUndefined: + return NewAmf0Undefined(), nil + case amf0MarkerReference: + case amf0MarkerEcmaArray: + return NewAmf0EcmaArray(), nil + case amf0MarkerObjectEnd: + return &amf0ObjectEOF{}, nil + case amf0MarkerStrictArray: + return NewAmf0StrictArray(), nil + case amf0MarkerDate, amf0MarkerLongString, amf0MarkerUnsupported, amf0MarkerXmlDocument, + amf0MarkerTypedObject, amf0MarkerAvmPlusObject, amf0MarkerForbidden, amf0MarkerMovieClip, + amf0MarkerRecordSet: + return nil, errors.Errorf("Marker %v is not supported", m) + } + return nil, errors.Errorf("Marker %v is invalid", m) +} + +// The UTF8 string, please read @doc amf0_spec_121207.pdf, @page 3, @section 1.3.1 Strings and UTF-8 +type amf0UTF8 string + +func (v *amf0UTF8) Size() int { + return 2 + len(string(*v)) +} + +func (v *amf0UTF8) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 2 { + return errors.Errorf("require 2 bytes only %v", len(p)) + } + size := uint16(p[0])<<8 | uint16(p[1]) + + if p = data[2:]; len(p) < int(size) { + return errors.Errorf("require %v bytes only %v", int(size), len(p)) + } + *v = amf0UTF8(string(p[:size])) + + return +} + +func (v *amf0UTF8) MarshalBinary() (data []byte, err error) { + data = make([]byte, v.Size()) + + size := uint16(len(string(*v))) + data[0] = byte(size >> 8) + data[1] = byte(size) + + if size > 0 { + copy(data[2:], []byte(*v)) + } + + return +} + +// The number object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.2 Number Type +type amf0Number float64 + +func NewAmf0Number(f float64) *amf0Number { + v := amf0Number(f) + return &v +} + +func (v *amf0Number) amf0Marker() amf0Marker { + return amf0MarkerNumber +} + +func (v *amf0Number) Size() int { + return 1 + 8 +} + +func (v *amf0Number) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 9 { + return errors.Errorf("require 9 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerNumber { + return errors.Errorf("Amf0Number amf0Marker %v is illegal", m) + } + + f := binary.BigEndian.Uint64(p[1:]) + *v = amf0Number(math.Float64frombits(f)) + return +} + +func (v *amf0Number) MarshalBinary() (data []byte, err error) { + data = make([]byte, 9) + data[0] = byte(amf0MarkerNumber) + f := math.Float64bits(float64(*v)) + binary.BigEndian.PutUint64(data[1:], f) + return +} + +// The string objet, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.4 String Type +type amf0String string + +func NewAmf0String(s string) *amf0String { + v := amf0String(s) + return &v +} + +func (v *amf0String) amf0Marker() amf0Marker { + return amf0MarkerString +} + +func (v *amf0String) Size() int { + u := amf0UTF8(*v) + return 1 + u.Size() +} + +func (v *amf0String) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 1 { + return errors.Errorf("require 1 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerString { + return errors.Errorf("Amf0String amf0Marker %v is illegal", m) + } + + var sv amf0UTF8 + if err = sv.UnmarshalBinary(p[1:]); err != nil { + return errors.WithMessage(err, "utf8") + } + *v = amf0String(string(sv)) + return +} + +func (v *amf0String) MarshalBinary() (data []byte, err error) { + u := amf0UTF8(*v) + + var pb []byte + if pb, err = u.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "utf8") + } + + data = append([]byte{byte(amf0MarkerString)}, pb...) + return +} + +// The AMF0 object end type, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.11 Object End Type +type amf0ObjectEOF struct { +} + +func (v *amf0ObjectEOF) amf0Marker() amf0Marker { + return amf0MarkerObjectEnd +} + +func (v *amf0ObjectEOF) Size() int { + return 3 +} + +func (v *amf0ObjectEOF) UnmarshalBinary(data []byte) (err error) { + p := data + + if len(p) < 3 { + return errors.Errorf("require 3 bytes only %v", len(p)) + } + + if p[0] != 0 || p[1] != 0 || p[2] != 9 { + return errors.Errorf("EOF amf0Marker %v is illegal", p[0:3]) + } + return +} + +func (v *amf0ObjectEOF) MarshalBinary() (data []byte, err error) { + return []byte{0, 0, 9}, nil +} + +// Use array for object and ecma array, to keep the original order. +type amf0Property struct { + key amf0UTF8 + value amf0Any +} + +// The object-like AMF0 structure, like object and ecma array and strict array. +type amf0ObjectBase struct { + properties []*amf0Property + lock sync.Mutex +} + +func (v *amf0ObjectBase) Size() int { + v.lock.Lock() + defer v.lock.Unlock() + + var size int + + for _, p := range v.properties { + key, value := p.key, p.value + size += key.Size() + value.Size() + } + + return size +} + +func (v *amf0ObjectBase) Get(key string) amf0Any { + v.lock.Lock() + defer v.lock.Unlock() + + for _, p := range v.properties { + if string(p.key) == key { + return p.value + } + } + + return nil +} + +func (v *amf0ObjectBase) Set(key string, value amf0Any) *amf0ObjectBase { + v.lock.Lock() + defer v.lock.Unlock() + + prop := &amf0Property{key: amf0UTF8(key), value: value} + + var ok bool + for i, p := range v.properties { + if string(p.key) == key { + v.properties[i] = prop + ok = true + } + } + + if !ok { + v.properties = append(v.properties, prop) + } + + return v +} + +func (v *amf0ObjectBase) unmarshal(p []byte, eof bool, maxElems int) (err error) { + // if no eof, elems specified by maxElems. + if !eof && maxElems < 0 { + return errors.Errorf("maxElems=%v without eof", maxElems) + } + // if eof, maxElems must be -1. + if eof && maxElems != -1 { + return errors.Errorf("maxElems=%v with eof", maxElems) + } + + readOne := func() (amf0UTF8, amf0Any, error) { + var u amf0UTF8 + if err = u.UnmarshalBinary(p); err != nil { + return "", nil, errors.WithMessage(err, "prop name") + } + + p = p[u.Size():] + var a amf0Any + if a, err = Amf0Discovery(p); err != nil { + return "", nil, errors.WithMessage(err, fmt.Sprintf("discover prop %v", string(u))) + } + return u, a, nil + } + + pushOne := func(u amf0UTF8, a amf0Any) error { + // For object property, consume the whole bytes. + if err = a.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, fmt.Sprintf("unmarshal prop %v", string(u))) + } + + v.Set(string(u), a) + p = p[a.Size():] + return nil + } + + for eof { + u, a, err := readOne() + if err != nil { + return errors.WithMessage(err, "read") + } + + // For object EOF, we should only consume total 3bytes. + if u.Size() == 2 && a.amf0Marker() == amf0MarkerObjectEnd { + // 2 bytes is consumed by u(name), the a(eof) should only consume 1 byte. + p = p[1:] + return nil + } + + if err := pushOne(u, a); err != nil { + return errors.WithMessage(err, "push") + } + } + + for len(v.properties) < maxElems { + u, a, err := readOne() + if err != nil { + return errors.WithMessage(err, "read") + } + + if err := pushOne(u, a); err != nil { + return errors.WithMessage(err, "push") + } + } + + return +} + +func (v *amf0ObjectBase) marshal(b amf0Buffer) (err error) { + v.lock.Lock() + defer v.lock.Unlock() + + var pb []byte + for _, p := range v.properties { + key, value := p.key, p.value + + if pb, err = key.MarshalBinary(); err != nil { + return errors.WithMessage(err, fmt.Sprintf("marshal %v", string(key))) + } + if _, err = b.Write(pb); err != nil { + return errors.Wrapf(err, "write %v", string(key)) + } + + if pb, err = value.MarshalBinary(); err != nil { + return errors.WithMessage(err, fmt.Sprintf("marshal value for %v", string(key))) + } + if _, err = b.Write(pb); err != nil { + return errors.Wrapf(err, "marshal value for %v", string(key)) + } + } + + return +} + +// The AMF0 object, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.5 Object Type +type amf0Object struct { + amf0ObjectBase + eof amf0ObjectEOF +} + +func NewAmf0Object() *amf0Object { + v := &amf0Object{} + v.properties = []*amf0Property{} + return v +} + +func (v *amf0Object) amf0Marker() amf0Marker { + return amf0MarkerObject +} + +func (v *amf0Object) Size() int { + return int(1) + v.eof.Size() + v.amf0ObjectBase.Size() +} + +func (v *amf0Object) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 1 { + return errors.Errorf("require 1 byte only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerObject { + return errors.Errorf("Amf0Object amf0Marker %v is illegal", m) + } + p = p[1:] + + if err = v.unmarshal(p, true, -1); err != nil { + return errors.WithMessage(err, "unmarshal") + } + + return +} + +func (v *amf0Object) MarshalBinary() (data []byte, err error) { + b := createBuffer() + + if err = b.WriteByte(byte(amf0MarkerObject)); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + if err = v.marshal(b); err != nil { + return nil, errors.WithMessage(err, "marshal") + } + + var pb []byte + if pb, err = v.eof.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal") + } + if _, err = b.Write(pb); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + return b.Bytes(), nil +} + +// The AMF0 ecma array, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.10 ECMA Array Type +type amf0EcmaArray struct { + amf0ObjectBase + count uint32 + eof amf0ObjectEOF +} + +func NewAmf0EcmaArray() *amf0EcmaArray { + v := &amf0EcmaArray{} + v.properties = []*amf0Property{} + return v +} + +func (v *amf0EcmaArray) amf0Marker() amf0Marker { + return amf0MarkerEcmaArray +} + +func (v *amf0EcmaArray) Size() int { + return int(1) + 4 + v.eof.Size() + v.amf0ObjectBase.Size() +} + +func (v *amf0EcmaArray) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 5 { + return errors.Errorf("require 5 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerEcmaArray { + return errors.Errorf("EcmaArray amf0Marker %v is illegal", m) + } + v.count = binary.BigEndian.Uint32(p[1:]) + p = p[5:] + + if err = v.unmarshal(p, true, -1); err != nil { + return errors.WithMessage(err, "unmarshal") + } + return +} + +func (v *amf0EcmaArray) MarshalBinary() (data []byte, err error) { + b := createBuffer() + + if err = b.WriteByte(byte(amf0MarkerEcmaArray)); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + if err = binary.Write(b, binary.BigEndian, v.count); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + if err = v.marshal(b); err != nil { + return nil, errors.WithMessage(err, "marshal") + } + + var pb []byte + if pb, err = v.eof.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal") + } + if _, err = b.Write(pb); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + return b.Bytes(), nil +} + +// The AMF0 strict array, please read @doc amf0_spec_121207.pdf, @page 7, @section 2.12 Strict Array Type +type amf0StrictArray struct { + amf0ObjectBase + count uint32 +} + +func NewAmf0StrictArray() *amf0StrictArray { + v := &amf0StrictArray{} + v.properties = []*amf0Property{} + return v +} + +func (v *amf0StrictArray) amf0Marker() amf0Marker { + return amf0MarkerStrictArray +} + +func (v *amf0StrictArray) Size() int { + return int(1) + 4 + v.amf0ObjectBase.Size() +} + +func (v *amf0StrictArray) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 5 { + return errors.Errorf("require 5 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerStrictArray { + return errors.Errorf("StrictArray amf0Marker %v is illegal", m) + } + v.count = binary.BigEndian.Uint32(p[1:]) + p = p[5:] + + if int(v.count) <= 0 { + return + } + + if err = v.unmarshal(p, false, int(v.count)); err != nil { + return errors.WithMessage(err, "unmarshal") + } + return +} + +func (v *amf0StrictArray) MarshalBinary() (data []byte, err error) { + b := createBuffer() + + if err = b.WriteByte(byte(amf0MarkerStrictArray)); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + if err = binary.Write(b, binary.BigEndian, v.count); err != nil { + return nil, errors.Wrap(err, "marshal") + } + + if err = v.marshal(b); err != nil { + return nil, errors.WithMessage(err, "marshal") + } + + return b.Bytes(), nil +} + +// The single amf0Marker object, for all AMF0 which only has the amf0Marker, like null and undefined. +type amf0SingleMarkerObject struct { + target amf0Marker +} + +func newAmf0SingleMarkerObject(m amf0Marker) amf0SingleMarkerObject { + return amf0SingleMarkerObject{target: m} +} + +func (v *amf0SingleMarkerObject) amf0Marker() amf0Marker { + return v.target +} + +func (v *amf0SingleMarkerObject) Size() int { + return int(1) +} + +func (v *amf0SingleMarkerObject) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 1 { + return errors.Errorf("require 1 byte only %v", len(p)) + } + if m := amf0Marker(p[0]); m != v.target { + return errors.Errorf("%v amf0Marker %v is illegal", v.target, m) + } + return +} + +func (v *amf0SingleMarkerObject) MarshalBinary() (data []byte, err error) { + return []byte{byte(v.target)}, nil +} + +// The AMF0 null, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.7 null Type +type amf0Null struct { + amf0SingleMarkerObject +} + +func NewAmf0Null() *amf0Null { + v := amf0Null{} + v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerNull) + return &v +} + +// The AMF0 undefined, please read @doc amf0_spec_121207.pdf, @page 6, @section 2.8 undefined Type +type amf0Undefined struct { + amf0SingleMarkerObject +} + +func NewAmf0Undefined() amf0Any { + v := amf0Undefined{} + v.amf0SingleMarkerObject = newAmf0SingleMarkerObject(amf0MarkerUndefined) + return &v +} + +// The AMF0 boolean, please read @doc amf0_spec_121207.pdf, @page 5, @section 2.3 Boolean Type +type amf0Boolean bool + +func NewAmf0Boolean(b bool) amf0Any { + v := amf0Boolean(b) + return &v +} + +func (v *amf0Boolean) amf0Marker() amf0Marker { + return amf0MarkerBoolean +} + +func (v *amf0Boolean) Size() int { + return int(2) +} + +func (v *amf0Boolean) UnmarshalBinary(data []byte) (err error) { + var p []byte + if p = data; len(p) < 2 { + return errors.Errorf("require 2 bytes only %v", len(p)) + } + if m := amf0Marker(p[0]); m != amf0MarkerBoolean { + return errors.Errorf("BOOL amf0Marker %v is illegal", m) + } + if p[1] == 0 { + *v = false + } else { + *v = true + } + return +} + +func (v *amf0Boolean) MarshalBinary() (data []byte, err error) { + var b byte + if *v { + b = 1 + } + return []byte{byte(amf0MarkerBoolean), b}, nil +} diff --git a/proxy/rtmp/rtmp.go b/proxy/rtmp/rtmp.go new file mode 100644 index 0000000000..ee0970e960 --- /dev/null +++ b/proxy/rtmp/rtmp.go @@ -0,0 +1,1792 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package rtmp + +import ( + "bufio" + "bytes" + "context" + "encoding" + "encoding/binary" + "fmt" + "io" + "math/rand" + "sync" + + "srs-proxy/errors" +) + +// The handshake implements the RTMP handshake protocol. +type Handshake struct { + // The random number generator. + r *rand.Rand + // The c1s1 cache. + c1s1 []byte +} + +func NewHandshake(r *rand.Rand) *Handshake { + return &Handshake{r: r} +} + +func (v *Handshake) C1S1() []byte { + return v.c1s1 +} + +func (v *Handshake) WriteC0S0(w io.Writer) (err error) { + r := bytes.NewReader([]byte{0x03}) + if _, err = io.Copy(w, r); err != nil { + return errors.Wrap(err, "write c0s0") + } + + return +} + +func (v *Handshake) ReadC0S0(r io.Reader) (c0 []byte, err error) { + b := &bytes.Buffer{} + if _, err = io.CopyN(b, r, 1); err != nil { + return nil, errors.Wrap(err, "read c0s0") + } + + c0 = b.Bytes() + + return +} + +func (v *Handshake) WriteC1S1(w io.Writer) (err error) { + p := make([]byte, 1536) + + for i := 8; i < len(p); i++ { + p[i] = byte(v.r.Int()) + } + + r := bytes.NewReader(p) + if _, err = io.Copy(w, r); err != nil { + return errors.Wrap(err, "write c0s1") + } + + return +} + +func (v *Handshake) ReadC1S1(r io.Reader) (c1s1 []byte, err error) { + b := &bytes.Buffer{} + if _, err = io.CopyN(b, r, 1536); err != nil { + return nil, errors.Wrap(err, "read c1s1") + } + + c1s1 = b.Bytes() + v.c1s1 = c1s1 + + return +} + +func (v *Handshake) WriteC2S2(w io.Writer, s1c1 []byte) (err error) { + r := bytes.NewReader(s1c1[:]) + if _, err = io.Copy(w, r); err != nil { + return errors.Wrap(err, "write c2s2") + } + + return +} + +func (v *Handshake) ReadC2S2(r io.Reader) (c2 []byte, err error) { + b := &bytes.Buffer{} + if _, err = io.CopyN(b, r, 1536); err != nil { + return nil, errors.Wrap(err, "read c2s2") + } + + c2 = b.Bytes() + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 16, @section 6.1. Chunk Format +// Extended timestamp: 0 or 4 bytes +// This field MUST be sent when the normal timsestamp is set to +// 0xffffff, it MUST NOT be sent if the normal timestamp is set to +// anything else. So for values less than 0xffffff the normal +// timestamp field SHOULD be used in which case the extended timestamp +// MUST NOT be present. For values greater than or equal to 0xffffff +// the normal timestamp field MUST NOT be used and MUST be set to +// 0xffffff and the extended timestamp MUST be sent. +const extendedTimestamp = uint64(0xffffff) + +// The default chunk size of RTMP is 128 bytes. +const defaultChunkSize = 128 + +// The intput or output settings for RTMP protocol. +type settings struct { + chunkSize uint32 +} + +func newSettings() *settings { + return &settings{ + chunkSize: defaultChunkSize, + } +} + +// The chunk stream which transport a message once. +type chunkStream struct { + format formatType + cid chunkID + header messageHeader + message *Message + count uint64 + extendedTimestamp bool +} + +func newChunkStream() *chunkStream { + return &chunkStream{} +} + +// The protocol implements the RTMP command and chunk stack. +type Protocol struct { + r *bufio.Reader + w *bufio.Writer + input struct { + opt *settings + chunks map[chunkID]*chunkStream + + transactions map[amf0Number]amf0String + ltransactions sync.Mutex + } + output struct { + opt *settings + } +} + +func NewProtocol(rw io.ReadWriter) *Protocol { + v := &Protocol{ + r: bufio.NewReader(rw), + w: bufio.NewWriter(rw), + } + + v.input.opt = newSettings() + v.input.chunks = map[chunkID]*chunkStream{} + v.input.transactions = map[amf0Number]amf0String{} + + v.output.opt = newSettings() + + return v +} + +func ExpectPacket[T Packet](ctx context.Context, v *Protocol, ppkt *T) (m *Message, err error) { + for { + if m, err = v.ReadMessage(ctx); err != nil { + return nil, errors.WithMessage(err, "read message") + } + + var pkt Packet + if pkt, err = v.DecodeMessage(m); err != nil { + return nil, errors.WithMessage(err, "decode message") + } + + if p, ok := pkt.(T); ok { + *ppkt = p + break + } + } + + return +} + +// Deprecated: Please use rtmp.ExpectPacket instead. +func (v *Protocol) ExpectPacket(ctx context.Context, ppkt any) (m *Message, err error) { + panic("Please use rtmp.ExpectPacket instead") +} + +func (v *Protocol) ExpectMessage(ctx context.Context, types ...MessageType) (m *Message, err error) { + for { + if m, err = v.ReadMessage(ctx); err != nil { + return nil, errors.WithMessage(err, "read message") + } + + if len(types) == 0 { + return + } + + for _, t := range types { + if m.MessageType == t { + return + } + } + } + + return +} + +func (v *Protocol) parseAMFObject(p []byte) (pkt Packet, err error) { + var commandName amf0String + if err = commandName.UnmarshalBinary(p); err != nil { + return nil, errors.WithMessage(err, "unmarshal command name") + } + + switch commandName { + case commandResult, commandError: + var transactionID amf0Number + if err = transactionID.UnmarshalBinary(p[commandName.Size():]); err != nil { + return nil, errors.WithMessage(err, "unmarshal tid") + } + + var requestName amf0String + if err = func() error { + v.input.ltransactions.Lock() + defer v.input.ltransactions.Unlock() + + var ok bool + if requestName, ok = v.input.transactions[transactionID]; !ok { + return errors.Errorf("No matched request for tid=%v", transactionID) + } + delete(v.input.transactions, transactionID) + + return nil + }(); err != nil { + return nil, errors.WithMessage(err, "discovery request name") + } + + switch requestName { + case commandConnect: + return NewConnectAppResPacket(transactionID), nil + case commandCreateStream: + return NewCreateStreamResPacket(transactionID), nil + case commandReleaseStream, commandFCPublish, commandFCUnpublish: + call := NewCallPacket() + call.TransactionID = transactionID + return call, nil + default: + return nil, errors.Errorf("No request for %v", string(requestName)) + } + case commandConnect: + return NewConnectAppPacket(), nil + case commandPublish: + return NewPublishPacket(), nil + case commandPlay: + return NewPlayPacket(), nil + default: + return NewCallPacket(), nil + } +} + +func (v *Protocol) DecodeMessage(m *Message) (pkt Packet, err error) { + p := m.Payload[:] + if len(p) == 0 { + return nil, errors.New("Empty packet") + } + + switch m.MessageType { + case MessageTypeAMF3Command, MessageTypeAMF3Data: + p = p[1:] + } + + switch m.MessageType { + case MessageTypeSetChunkSize: + pkt = NewSetChunkSize() + case MessageTypeWindowAcknowledgementSize: + pkt = NewWindowAcknowledgementSize() + case MessageTypeSetPeerBandwidth: + pkt = NewSetPeerBandwidth() + case MessageTypeAMF0Command, MessageTypeAMF3Command, MessageTypeAMF0Data, MessageTypeAMF3Data: + if pkt, err = v.parseAMFObject(p); err != nil { + return nil, errors.WithMessage(err, fmt.Sprintf("Parse AMF %v", m.MessageType)) + } + case MessageTypeUserControl: + pkt = NewUserControl() + default: + return nil, errors.Errorf("Unknown message %v", m.MessageType) + } + + if err = pkt.UnmarshalBinary(p); err != nil { + return nil, errors.WithMessage(err, fmt.Sprintf("Unmarshal %v", m.MessageType)) + } + + return +} + +func (v *Protocol) ReadMessage(ctx context.Context) (m *Message, err error) { + for m == nil { + // TODO: We should convert buffered io to async io, because we will be stuck in block io here, + // TODO: but the risk is acceptable because we literally will set the underlay io timeout. + if ctx.Err() != nil { + return nil, ctx.Err() + } + + var cid chunkID + var format formatType + if format, cid, err = v.readBasicHeader(ctx); err != nil { + return nil, errors.WithMessage(err, "read basic header") + } + + var ok bool + var chunk *chunkStream + if chunk, ok = v.input.chunks[cid]; !ok { + chunk = newChunkStream() + v.input.chunks[cid] = chunk + chunk.header.betterCid = cid + } + + if err = v.readMessageHeader(ctx, chunk, format); err != nil { + return nil, errors.WithMessage(err, "read message header") + } + + if m, err = v.readMessagePayload(ctx, chunk); err != nil { + return nil, errors.WithMessage(err, "read message payload") + } + + if err = v.onMessageArrivated(m); err != nil { + return nil, errors.WithMessage(err, "on message") + } + } + + return +} + +func (v *Protocol) readMessagePayload(ctx context.Context, chunk *chunkStream) (m *Message, err error) { + // Empty payload message. + if chunk.message.payloadLength == 0 { + m = chunk.message + chunk.message = nil + return + } + + // Calculate the chunk payload size. + chunkedPayloadSize := int(chunk.message.payloadLength) - len(chunk.message.Payload) + if chunkedPayloadSize > int(v.input.opt.chunkSize) { + chunkedPayloadSize = int(v.input.opt.chunkSize) + } + + b := make([]byte, chunkedPayloadSize) + if _, err = io.ReadFull(v.r, b); err != nil { + return nil, errors.Wrapf(err, "read chunk %vB", chunkedPayloadSize) + } + chunk.message.Payload = append(chunk.message.Payload, b...) + + // Got entire RTMP message? + if int(chunk.message.payloadLength) == len(chunk.message.Payload) { + m = chunk.message + chunk.message = nil + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 18, @section 6.1.2. Chunk Message Header +// There are four different formats for the chunk message header, +// selected by the "fmt" field in the chunk basic header. +type formatType uint8 + +const ( + // 6.1.2.1. Type 0 + // Chunks of Type 0 are 11 bytes long. This type MUST be used at the + // start of a chunk stream, and whenever the stream timestamp goes + // backward (e.g., because of a backward seek). + formatType0 formatType = iota + // 6.1.2.2. Type 1 + // Chunks of Type 1 are 7 bytes long. The message stream ID is not + // included; this chunk takes the same stream ID as the preceding chunk. + // Streams with variable-sized messages (for example, many video + // formats) SHOULD use this format for the first chunk of each new + // message after the first. + formatType1 + // 6.1.2.3. Type 2 + // Chunks of Type 2 are 3 bytes long. Neither the stream ID nor the + // message length is included; this chunk has the same stream ID and + // message length as the preceding chunk. Streams with constant-sized + // messages (for example, some audio and data formats) SHOULD use this + // format for the first chunk of each message after the first. + formatType2 + // 6.1.2.4. Type 3 + // Chunks of Type 3 have no header. Stream ID, message length and + // timestamp delta are not present; chunks of this type take values from + // the preceding chunk. When a single message is split into chunks, all + // chunks of a message except the first one, SHOULD use this type. Refer + // to example 2 in section 6.2.2. Stream consisting of messages of + // exactly the same size, stream ID and spacing in time SHOULD use this + // type for all chunks after chunk of Type 2. Refer to example 1 in + // section 6.2.1. If the delta between the first message and the second + // message is same as the time stamp of first message, then chunk of + // type 3 would immediately follow the chunk of type 0 as there is no + // need for a chunk of type 2 to register the delta. If Type 3 chunk + // follows a Type 0 chunk, then timestamp delta for this Type 3 chunk is + // the same as the timestamp of Type 0 chunk. + formatType3 +) + +// The message header size, index is format. +var messageHeaderSizes = []int{11, 7, 3, 0} + +// Parse the chunk message header. +// 3bytes: timestamp delta, fmt=0,1,2 +// 3bytes: payload length, fmt=0,1 +// 1bytes: message type, fmt=0,1 +// 4bytes: stream id, fmt=0 +// where: +// fmt=0, 0x0X +// fmt=1, 0x4X +// fmt=2, 0x8X +// fmt=3, 0xCX +func (v *Protocol) readMessageHeader(ctx context.Context, chunk *chunkStream, format formatType) (err error) { + // We should not assert anything about fmt, for the first packet. + // (when first packet, the chunk.message is nil). + // the fmt maybe 0/1/2/3, the FMLE will send a 0xC4 for some audio packet. + // the previous packet is: + // 04 // fmt=0, cid=4 + // 00 00 1a // timestamp=26 + // 00 00 9d // payload_length=157 + // 08 // message_type=8(audio) + // 01 00 00 00 // stream_id=1 + // the current packet maybe: + // c4 // fmt=3, cid=4 + // it's ok, for the packet is audio, and timestamp delta is 26. + // the current packet must be parsed as: + // fmt=0, cid=4 + // timestamp=26+26=52 + // payload_length=157 + // message_type=8(audio) + // stream_id=1 + // so we must update the timestamp even fmt=3 for first packet. + // + // The fresh packet used to update the timestamp even fmt=3 for first packet. + // fresh packet always means the chunk is the first one of message. + var isFirstChunkOfMsg bool + if chunk.message == nil { + isFirstChunkOfMsg = true + } + + // But, we can ensure that when a chunk stream is fresh, + // the fmt must be 0, a new stream. + if chunk.count == 0 && format != formatType0 { + // For librtmp, if ping, it will send a fresh stream with fmt=1, + // 0x42 where: fmt=1, cid=2, protocol contorl user-control message + // 0x00 0x00 0x00 where: timestamp=0 + // 0x00 0x00 0x06 where: payload_length=6 + // 0x04 where: message_type=4(protocol control user-control message) + // 0x00 0x06 where: event Ping(0x06) + // 0x00 0x00 0x0d 0x0f where: event data 4bytes ping timestamp. + // @see: https://github.com/ossrs/srs/issues/98 + if chunk.cid == chunkIDProtocolControl && format == formatType1 { + // We accept cid=2, fmt=1 to make librtmp happy. + } else { + return errors.Errorf("For fresh chunk, fmt %v != %v(required), cid is %v", format, formatType0, chunk.cid) + } + } + + // When exists cache msg, means got an partial message, + // the fmt must not be type0 which means new message. + if chunk.message != nil && format == formatType0 { + return errors.Errorf("For exists chunk, fmt is %v, cid is %v", format, chunk.cid) + } + + // Create msg when new chunk stream start + if chunk.message == nil { + chunk.message = NewMessage() + } + + // Read the message header. + p := make([]byte, messageHeaderSizes[format]) + if _, err = io.ReadFull(v.r, p); err != nil { + return errors.Wrapf(err, "read %vB message header", len(p)) + } + + // Prse the message header. + // 3bytes: timestamp delta, fmt=0,1,2 + // 3bytes: payload length, fmt=0,1 + // 1bytes: message type, fmt=0,1 + // 4bytes: stream id, fmt=0 + // where: + // fmt=0, 0x0X + // fmt=1, 0x4X + // fmt=2, 0x8X + // fmt=3, 0xCX + if format <= formatType2 { + chunk.header.timestampDelta = uint32(p[0])<<16 | uint32(p[1])<<8 | uint32(p[2]) + p = p[3:] + + // fmt: 0 + // timestamp: 3 bytes + // If the timestamp is greater than or equal to 16777215 + // (hexadecimal 0x00ffffff), this value MUST be 16777215, and the + // 'extended timestamp header' MUST be present. Otherwise, this value + // SHOULD be the entire timestamp. + // + // fmt: 1 or 2 + // timestamp delta: 3 bytes + // If the delta is greater than or equal to 16777215 (hexadecimal + // 0x00ffffff), this value MUST be 16777215, and the 'extended + // timestamp header' MUST be present. Otherwise, this value SHOULD be + // the entire delta. + chunk.extendedTimestamp = uint64(chunk.header.timestampDelta) >= extendedTimestamp + if !chunk.extendedTimestamp { + // Extended timestamp: 0 or 4 bytes + // This field MUST be sent when the normal timsestamp is set to + // 0xffffff, it MUST NOT be sent if the normal timestamp is set to + // anything else. So for values less than 0xffffff the normal + // timestamp field SHOULD be used in which case the extended timestamp + // MUST NOT be present. For values greater than or equal to 0xffffff + // the normal timestamp field MUST NOT be used and MUST be set to + // 0xffffff and the extended timestamp MUST be sent. + if format == formatType0 { + // 6.1.2.1. Type 0 + // For a type-0 chunk, the absolute timestamp of the message is sent + // here. + chunk.header.Timestamp = uint64(chunk.header.timestampDelta) + } else { + // 6.1.2.2. Type 1 + // 6.1.2.3. Type 2 + // For a type-1 or type-2 chunk, the difference between the previous + // chunk's timestamp and the current chunk's timestamp is sent here. + chunk.header.Timestamp += uint64(chunk.header.timestampDelta) + } + } + + if format <= formatType1 { + payloadLength := uint32(p[0])<<16 | uint32(p[1])<<8 | uint32(p[2]) + p = p[3:] + + // For a message, if msg exists in cache, the size must not changed. + // always use the actual msg size to compare, for the cache payload length can changed, + // for the fmt type1(stream_id not changed), user can change the payload + // length(it's not allowed in the continue chunks). + if !isFirstChunkOfMsg && chunk.header.payloadLength != payloadLength { + return errors.Errorf("Chunk message size %v != %v(required)", payloadLength, chunk.header.payloadLength) + } + chunk.header.payloadLength = payloadLength + + chunk.header.MessageType = MessageType(p[0]) + p = p[1:] + + if format == formatType0 { + chunk.header.streamID = uint32(p[0]) | uint32(p[1])<<8 | uint32(p[2])<<16 | uint32(p[3])<<24 + p = p[4:] + } + } + } else { + // Update the timestamp even fmt=3 for first chunk packet + if isFirstChunkOfMsg && !chunk.extendedTimestamp { + chunk.header.Timestamp += uint64(chunk.header.timestampDelta) + } + } + + // Read extended-timestamp + if chunk.extendedTimestamp { + var timestamp uint32 + if err = binary.Read(v.r, binary.BigEndian, ×tamp); err != nil { + return errors.Wrapf(err, "read ext-ts, pkt-ts=%v", chunk.header.Timestamp) + } + + // We always use 31bits timestamp, for some server may use 32bits extended timestamp. + // @see https://github.com/ossrs/srs/issues/111 + timestamp &= 0x7fffffff + + // TODO: FIXME: Support detect the extended timestamp. + // @see http://blog.csdn.net/win_lin/article/details/13363699 + chunk.header.Timestamp = uint64(timestamp) + } + + // The extended-timestamp must be unsigned-int, + // 24bits timestamp: 0xffffff = 16777215ms = 16777.215s = 4.66h + // 32bits timestamp: 0xffffffff = 4294967295ms = 4294967.295s = 1193.046h = 49.71d + // because the rtmp protocol says the 32bits timestamp is about "50 days": + // 3. Byte Order, Alignment, and Time Format + // Because timestamps are generally only 32 bits long, they will roll + // over after fewer than 50 days. + // + // but, its sample says the timestamp is 31bits: + // An application could assume, for example, that all + // adjacent timestamps are within 2^31 milliseconds of each other, so + // 10000 comes after 4000000000, while 3000000000 comes before + // 4000000000. + // and flv specification says timestamp is 31bits: + // Extension of the Timestamp field to form a SI32 value. This + // field represents the upper 8 bits, while the previous + // Timestamp field represents the lower 24 bits of the time in + // milliseconds. + // in a word, 31bits timestamp is ok. + // convert extended timestamp to 31bits. + chunk.header.Timestamp &= 0x7fffffff + + // Copy header to msg + chunk.message.messageHeader = chunk.header + + // Increase the msg count, the chunk stream can accept fmt=1/2/3 message now. + chunk.count++ + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 17, @section 6.1.1. Chunk Basic Header +// The Chunk Basic Header encodes the chunk stream ID and the chunk +// type(represented by fmt field in the figure below). Chunk type +// determines the format of the encoded message header. Chunk Basic +// Header field may be 1, 2, or 3 bytes, depending on the chunk stream +// ID. +// +// The bits 0-5 (least significant) in the chunk basic header represent +// the chunk stream ID. +// +// Chunk stream IDs 2-63 can be encoded in the 1-byte version of this +// field. +// 0 1 2 3 4 5 6 7 +// +-+-+-+-+-+-+-+-+ +// |fmt| cs id | +// +-+-+-+-+-+-+-+-+ +// Figure 6 Chunk basic header 1 +// +// Chunk stream IDs 64-319 can be encoded in the 2-byte version of this +// field. ID is computed as (the second byte + 64). +// 0 1 +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |fmt| 0 | cs id - 64 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// Figure 7 Chunk basic header 2 +// +// Chunk stream IDs 64-65599 can be encoded in the 3-byte version of +// this field. ID is computed as ((the third byte)*256 + the second byte +// + 64). +// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// |fmt| 1 | cs id - 64 | +// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +// Figure 8 Chunk basic header 3 +// +// cs id: 6 bits +// fmt: 2 bits +// cs id - 64: 8 or 16 bits +// +// Chunk stream IDs with values 64-319 could be represented by both 2- +// byte version and 3-byte version of this field. +func (v *Protocol) readBasicHeader(ctx context.Context) (format formatType, cid chunkID, err error) { + // 2-63, 1B chunk header + var t uint8 + if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { + return format, cid, errors.Wrap(err, "read basic header") + } + cid = chunkID(t & 0x3f) + format = formatType((t >> 6) & 0x03) + + if cid > 1 { + return + } + + // 64-319, 2B chunk header + if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { + return format, cid, errors.Wrapf(err, "read basic header for cid=%v", cid) + } + cid = chunkID(64 + uint32(t)) + + // 64-65599, 3B chunk header + if cid == 1 { + if err = binary.Read(v.r, binary.BigEndian, &t); err != nil { + return format, cid, errors.Wrapf(err, "read basic header for cid=%v", cid) + } + cid += chunkID(uint32(t) * 256) + } + + return +} + +func (v *Protocol) WritePacket(ctx context.Context, pkt Packet, streamID int) (err error) { + m := NewMessage() + + if m.Payload, err = pkt.MarshalBinary(); err != nil { + return errors.WithMessage(err, "marshal payload") + } + + m.MessageType = pkt.Type() + m.streamID = uint32(streamID) + m.betterCid = pkt.BetterCid() + + if err = v.WriteMessage(ctx, m); err != nil { + return errors.WithMessage(err, "write message") + } + + if err = v.onPacketWriten(m, pkt); err != nil { + return errors.WithMessage(err, "on write packet") + } + + return +} + +func (v *Protocol) onPacketWriten(m *Message, pkt Packet) (err error) { + var tid amf0Number + var name amf0String + + switch pkt := pkt.(type) { + case *ConnectAppPacket: + tid, name = pkt.TransactionID, pkt.CommandName + case *CreateStreamPacket: + tid, name = pkt.TransactionID, pkt.CommandName + case *CallPacket: + tid, name = pkt.TransactionID, pkt.CommandName + } + + if tid > 0 && len(name) > 0 { + v.input.ltransactions.Lock() + defer v.input.ltransactions.Unlock() + + v.input.transactions[tid] = name + } + + return +} + +func (v *Protocol) onMessageArrivated(m *Message) (err error) { + if m == nil { + return + } + + var pkt Packet + switch m.MessageType { + case MessageTypeSetChunkSize, MessageTypeUserControl, MessageTypeWindowAcknowledgementSize: + if pkt, err = v.DecodeMessage(m); err != nil { + return errors.Errorf("decode message %v", m.MessageType) + } + } + + switch pkt := pkt.(type) { + case *SetChunkSize: + v.input.opt.chunkSize = pkt.ChunkSize + } + + return +} + +func (v *Protocol) WriteMessage(ctx context.Context, m *Message) (err error) { + m.payloadLength = uint32(len(m.Payload)) + + var c0h, c3h []byte + if c0h, err = m.generateC0Header(); err != nil { + return errors.WithMessage(err, "generate c0 header") + } + if c3h, err = m.generateC3Header(); err != nil { + return errors.WithMessage(err, "generate c3 header") + } + + var h []byte + p := m.Payload + for len(p) > 0 { + // TODO: We should convert buffered io to async io, because we will be stuck in block io here, + // TODO: but the risk is acceptable because we literally will set the underlay io timeout. + if ctx.Err() != nil { + return ctx.Err() + } + + if h == nil { + h = c0h + } else { + h = c3h + } + + if _, err = io.Copy(v.w, bytes.NewReader(h)); err != nil { + return errors.Wrapf(err, "write c0c3 header %x", h) + } + + size := len(p) + if size > int(v.output.opt.chunkSize) { + size = int(v.output.opt.chunkSize) + } + + if _, err = io.Copy(v.w, bytes.NewReader(p[:size])); err != nil { + return errors.Wrapf(err, "write chunk payload %vB", size) + } + p = p[size:] + } + + // TODO: We should convert buffered io to async io, because we will be stuck in block io here, + // TODO: but the risk is acceptable because we literally will set the underlay io timeout. + if ctx.Err() != nil { + return ctx.Err() + } + + // TODO: FIXME: Use writev to write for high performance. + if err = v.w.Flush(); err != nil { + return errors.Wrapf(err, "flush writer") + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 30, @section 4.1. Message Header +// 1byte. One byte field to represent the message type. A range of type IDs +// (1-7) are reserved for protocol control messages. +type MessageType uint8 + +const ( + // Please read @doc rtmp_specification_1.0.pdf, @page 30, @section 5. Protocol Control Messages + // RTMP reserves message type IDs 1-7 for protocol control messages. + // These messages contain information needed by the RTM Chunk Stream + // protocol or RTMP itself. Protocol messages with IDs 1 & 2 are + // reserved for usage with RTM Chunk Stream protocol. Protocol messages + // with IDs 3-6 are reserved for usage of RTMP. Protocol message with ID + // 7 is used between edge server and origin server. + MessageTypeSetChunkSize MessageType = 0x01 + MessageTypeAbort MessageType = 0x02 // 0x02 + MessageTypeAcknowledgement MessageType = 0x03 // 0x03 + MessageTypeUserControl MessageType = 0x04 // 0x04 + MessageTypeWindowAcknowledgementSize MessageType = 0x05 // 0x05 + MessageTypeSetPeerBandwidth MessageType = 0x06 // 0x06 + MessageTypeEdgeAndOriginServerCommand MessageType = 0x07 // 0x07 + // Please read @doc rtmp_specification_1.0.pdf, @page 38, @section 3. Types of messages + // The server and the client send messages over the network to + // communicate with each other. The messages can be of any type which + // includes audio messages, video messages, command messages, shared + // object messages, data messages, and user control messages. + // + // Please read @doc rtmp_specification_1.0.pdf, @page 41, @section 3.4. Audio message + // The client or the server sends this message to send audio data to the + // peer. The message type value of 8 is reserved for audio messages. + MessageTypeAudio MessageType = 0x08 + // Please read @doc rtmp_specification_1.0.pdf, @page 41, @section 3.5. Video message + // The client or the server sends this message to send video data to the + // peer. The message type value of 9 is reserved for video messages. + // These messages are large and can delay the sending of other type of + // messages. To avoid such a situation, the video message is assigned + // the lowest priority. + MessageTypeVideo MessageType = 0x09 // 0x09 + // Please read @doc rtmp_specification_1.0.pdf, @page 38, @section 3.1. Command message + // Command messages carry the AMF-encoded commands between the client + // and the server. These messages have been assigned message type value + // of 20 for AMF0 encoding and message type value of 17 for AMF3 + // encoding. These messages are sent to perform some operations like + // connect, createStream, publish, play, pause on the peer. Command + // messages like onstatus, result etc. are used to inform the sender + // about the status of the requested commands. A command message + // consists of command name, transaction ID, and command object that + // contains related parameters. A client or a server can request Remote + // Procedure Calls (RPC) over streams that are communicated using the + // command messages to the peer. + MessageTypeAMF3Command MessageType = 17 // 0x11 + MessageTypeAMF0Command MessageType = 20 // 0x14 + // Please read @doc rtmp_specification_1.0.pdf, @page 38, @section 3.2. Data message + // The client or the server sends this message to send Metadata or any + // user data to the peer. Metadata includes details about the + // data(audio, video etc.) like creation time, duration, theme and so + // on. These messages have been assigned message type value of 18 for + // AMF0 and message type value of 15 for AMF3. + MessageTypeAMF0Data MessageType = 18 // 0x12 + MessageTypeAMF3Data MessageType = 15 // 0x0f +) + +// The header of message. +type messageHeader struct { + // 3bytes. + // Three-byte field that contains a timestamp delta of the message. + // @remark, only used for decoding message from chunk stream. + timestampDelta uint32 + // 3bytes. + // Three-byte field that represents the size of the payload in bytes. + // It is set in big-endian format. + payloadLength uint32 + // 1byte. + // One byte field to represent the message type. A range of type IDs + // (1-7) are reserved for protocol control messages. + MessageType MessageType + // 4bytes. + // Four-byte field that identifies the stream of the message. These + // bytes are set in little-endian format. + streamID uint32 + + // The chunk stream id over which transport. + betterCid chunkID + + // Four-byte field that contains a timestamp of the message. + // The 4 bytes are packed in the big-endian order. + // @remark, we use 64bits for large time for jitter detect and for large tbn like HLS. + Timestamp uint64 +} + +// The RTMP message, transport over chunk stream in RTMP. +// Please read the cs id of @doc rtmp_specification_1.0.pdf, @page 30, @section 4.1. Message Header +type Message struct { + messageHeader + + // The payload which carries the RTMP packet. + Payload []byte +} + +func NewMessage() *Message { + return &Message{} +} + +func NewStreamMessage(streamID int) *Message { + v := NewMessage() + v.streamID = uint32(streamID) + v.betterCid = chunkIDOverStream + return v +} + +func (v *Message) generateC3Header() ([]byte, error) { + var c3h []byte + if v.Timestamp < extendedTimestamp { + c3h = make([]byte, 1) + } else { + c3h = make([]byte, 1+4) + } + + p := c3h + p[0] = 0xc0 | byte(v.betterCid&0x3f) + p = p[1:] + + // In RTMP protocol, there must not any timestamp in C3 header, + // but actually all products from adobe, such as FMS/AMS and Flash player and FMLE, + // always carry a extended timestamp in C3 header. + // @see: http://blog.csdn.net/win_lin/article/details/13363699 + if v.Timestamp >= extendedTimestamp { + p[0] = byte(v.Timestamp >> 24) + p[1] = byte(v.Timestamp >> 16) + p[2] = byte(v.Timestamp >> 8) + p[3] = byte(v.Timestamp) + } + + return c3h, nil +} + +func (v *Message) generateC0Header() ([]byte, error) { + var c0h []byte + if v.Timestamp < extendedTimestamp { + c0h = make([]byte, 1+3+3+1+4) + } else { + c0h = make([]byte, 1+3+3+1+4+4) + } + + p := c0h + p[0] = byte(v.betterCid) & 0x3f + p = p[1:] + + if v.Timestamp < extendedTimestamp { + p[0] = byte(v.Timestamp >> 16) + p[1] = byte(v.Timestamp >> 8) + p[2] = byte(v.Timestamp) + } else { + p[0] = 0xff + p[1] = 0xff + p[2] = 0xff + } + p = p[3:] + + p[0] = byte(v.payloadLength >> 16) + p[1] = byte(v.payloadLength >> 8) + p[2] = byte(v.payloadLength) + p = p[3:] + + p[0] = byte(v.MessageType) + p = p[1:] + + p[0] = byte(v.streamID) + p[1] = byte(v.streamID >> 8) + p[2] = byte(v.streamID >> 16) + p[3] = byte(v.streamID >> 24) + p = p[4:] + + if v.Timestamp >= extendedTimestamp { + p[0] = byte(v.Timestamp >> 24) + p[1] = byte(v.Timestamp >> 16) + p[2] = byte(v.Timestamp >> 8) + p[3] = byte(v.Timestamp) + } + + return c0h, nil +} + +// Please read the cs id of @doc rtmp_specification_1.0.pdf, @page 17, @section 6.1.1. Chunk Basic Header +type chunkID uint32 + +const ( + chunkIDProtocolControl chunkID = 0x02 + chunkIDOverConnection chunkID = 0x03 + chunkIDOverConnection2 chunkID = 0x04 + chunkIDOverStream chunkID = 0x05 + chunkIDOverStream2 chunkID = 0x06 + chunkIDVideo chunkID = 0x07 + chunkIDAudio chunkID = 0x08 +) + +// The Command Name of message. +const ( + commandConnect amf0String = amf0String("connect") + commandCreateStream amf0String = amf0String("createStream") + commandCloseStream amf0String = amf0String("closeStream") + commandPlay amf0String = amf0String("play") + commandPause amf0String = amf0String("pause") + commandOnBWDone amf0String = amf0String("onBWDone") + commandOnStatus amf0String = amf0String("onStatus") + commandResult amf0String = amf0String("_result") + commandError amf0String = amf0String("_error") + commandReleaseStream amf0String = amf0String("releaseStream") + commandFCPublish amf0String = amf0String("FCPublish") + commandFCUnpublish amf0String = amf0String("FCUnpublish") + commandPublish amf0String = amf0String("publish") + commandRtmpSampleAccess amf0String = amf0String("|RtmpSampleAccess") +) + +// The RTMP packet, transport as payload of RTMP message. +type Packet interface { + // Marshaler and unmarshaler + Size() int + encoding.BinaryUnmarshaler + encoding.BinaryMarshaler + + // RTMP protocol fields for each packet. + BetterCid() chunkID + Type() MessageType +} + +// A Call packet, both object and args are AMF0 objects. +type objectCallPacket struct { + CommandName amf0String + TransactionID amf0Number + CommandObject *amf0Object + Args *amf0Object +} + +func (v *objectCallPacket) BetterCid() chunkID { + return chunkIDOverConnection +} + +func (v *objectCallPacket) Type() MessageType { + return MessageTypeAMF0Command +} + +func (v *objectCallPacket) Size() int { + size := v.CommandName.Size() + v.TransactionID.Size() + v.CommandObject.Size() + if v.Args != nil { + size += v.Args.Size() + } + return size +} + +func (v *objectCallPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.CommandName.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal command name") + } + p = p[v.CommandName.Size():] + + if err = v.TransactionID.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal tid") + } + p = p[v.TransactionID.Size():] + + if err = v.CommandObject.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal command") + } + p = p[v.CommandObject.Size():] + + if len(p) == 0 { + return + } + + v.Args = NewAmf0Object() + if err = v.Args.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal args") + } + + return +} + +func (v *objectCallPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.CommandName.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal command name") + } + data = append(data, pb...) + + if pb, err = v.TransactionID.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal tid") + } + data = append(data, pb...) + + if pb, err = v.CommandObject.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal command object") + } + data = append(data, pb...) + + if v.Args != nil { + if pb, err = v.Args.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal args") + } + data = append(data, pb...) + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 45, @section 4.1.1. connect +// The client sends the connect command to the server to request +// connection to a server application instance. +type ConnectAppPacket struct { + objectCallPacket +} + +func NewConnectAppPacket() *ConnectAppPacket { + v := &ConnectAppPacket{} + v.CommandName = commandConnect + v.CommandObject = NewAmf0Object() + v.TransactionID = amf0Number(1.0) + return v +} + +func (v *ConnectAppPacket) UnmarshalBinary(data []byte) (err error) { + if err = v.objectCallPacket.UnmarshalBinary(data); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + + if v.CommandName != commandConnect { + return errors.Errorf("Invalid command name %v", string(v.CommandName)) + } + + if v.TransactionID != 1.0 { + return errors.Errorf("Invalid transaction ID %v", float64(v.TransactionID)) + } + + return +} + +func (v *ConnectAppPacket) TcUrl() string { + if v.CommandObject != nil { + if v, ok := v.CommandObject.Get("tcUrl").(*amf0String); ok { + return string(*v) + } + } + return "" +} + +// The response for ConnectAppPacket. +type ConnectAppResPacket struct { + objectCallPacket +} + +func NewConnectAppResPacket(tid amf0Number) *ConnectAppResPacket { + v := &ConnectAppResPacket{} + v.CommandName = commandResult + v.CommandObject = NewAmf0Object() + v.Args = NewAmf0Object() + v.TransactionID = tid + return v +} + +func (v *ConnectAppResPacket) SrsID() string { + if v.Args != nil { + if v, ok := v.Args.Get("data").(*amf0EcmaArray); ok { + if v, ok := v.Get("srs_id").(*amf0String); ok { + return string(*v) + } + } + } + return "" +} + +func (v *ConnectAppResPacket) UnmarshalBinary(data []byte) (err error) { + if err = v.objectCallPacket.UnmarshalBinary(data); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + + if v.CommandName != commandResult { + return errors.Errorf("Invalid command name %v", string(v.CommandName)) + } + + return +} + +// A Call object, command object is variant. +type variantCallPacket struct { + CommandName amf0String + TransactionID amf0Number + CommandObject amf0Any // object or null +} + +func (v *variantCallPacket) BetterCid() chunkID { + return chunkIDOverConnection +} + +func (v *variantCallPacket) Type() MessageType { + return MessageTypeAMF0Command +} + +func (v *variantCallPacket) Size() int { + size := v.CommandName.Size() + v.TransactionID.Size() + + if v.CommandObject != nil { + size += v.CommandObject.Size() + } + + return size +} + +func (v *variantCallPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.CommandName.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal command name") + } + p = p[v.CommandName.Size():] + + if err = v.TransactionID.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal tid") + } + p = p[v.TransactionID.Size():] + + if len(p) > 0 { + if v.CommandObject, err = Amf0Discovery(p); err != nil { + return errors.WithMessage(err, "discovery command object") + } + if err = v.CommandObject.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal command object") + } + p = p[v.CommandObject.Size():] + } + + return +} + +func (v *variantCallPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.CommandName.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal command name") + } + data = append(data, pb...) + + if pb, err = v.TransactionID.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal tid") + } + data = append(data, pb...) + + if v.CommandObject != nil { + if pb, err = v.CommandObject.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal command object") + } + data = append(data, pb...) + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 51, @section 4.1.2. Call +// The call method of the NetConnection object runs remote procedure +// calls (RPC) at the receiving end. The called RPC name is passed as a +// parameter to the call command. +// @remark onStatus packet is a call packet. +type CallPacket struct { + variantCallPacket + Args amf0Any // optional or object or null +} + +func NewCallPacket() *CallPacket { + return &CallPacket{} +} + +func (v *CallPacket) ArgsCode() string { + if v.Args != nil { + if v, ok := v.Args.(*amf0Object); ok { + if code, ok := v.Get("code").(*amf0String); ok { + return string(*code) + } + } + } + return "" +} + +func (v *CallPacket) Size() int { + size := v.variantCallPacket.Size() + + if v.Args != nil { + size += v.Args.Size() + } + + return size +} + +func (v *CallPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if len(p) > 0 { + if v.Args, err = Amf0Discovery(p); err != nil { + return errors.WithMessage(err, "discovery args") + } + if err = v.Args.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal args") + } + } + + return +} + +func (v *CallPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if v.Args != nil { + if pb, err = v.Args.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal args") + } + data = append(data, pb...) + } + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 52, @section 4.1.3. createStream +// The client sends this command to the server to create a logical +// channel for message communication The publishing of audio, video, and +// metadata is carried out over stream channel created using the +// createStream command. +type CreateStreamPacket struct { + variantCallPacket +} + +func NewCreateStreamPacket() *CreateStreamPacket { + v := &CreateStreamPacket{} + v.CommandName = commandCreateStream + v.TransactionID = amf0Number(2) + v.CommandObject = NewAmf0Null() + return v +} + +// The response for create stream +type CreateStreamResPacket struct { + variantCallPacket + StreamID amf0Number +} + +func NewCreateStreamResPacket(tid amf0Number) *CreateStreamResPacket { + v := &CreateStreamResPacket{} + v.CommandName = commandResult + v.TransactionID = tid + v.CommandObject = NewAmf0Null() + v.StreamID = 0 + return v +} + +func (v *CreateStreamResPacket) Size() int { + return v.variantCallPacket.Size() + v.StreamID.Size() +} + +func (v *CreateStreamResPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if err = v.StreamID.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal sid") + } + + return +} + +func (v *CreateStreamResPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if pb, err = v.StreamID.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal sid") + } + data = append(data, pb...) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 64, @section 4.2.6. Publish +type PublishPacket struct { + variantCallPacket + StreamName amf0String + StreamType amf0String +} + +func NewPublishPacket() *PublishPacket { + v := &PublishPacket{} + v.CommandName = commandPublish + v.CommandObject = NewAmf0Null() + v.StreamType = "live" + return v +} + +func (v *PublishPacket) Size() int { + return v.variantCallPacket.Size() + v.StreamName.Size() + v.StreamType.Size() +} + +func (v *PublishPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if err = v.StreamName.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal stream name") + } + p = p[v.StreamName.Size():] + + if err = v.StreamType.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal stream type") + } + + return +} + +func (v *PublishPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if pb, err = v.StreamName.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal stream name") + } + data = append(data, pb...) + + if pb, err = v.StreamType.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal stream type") + } + data = append(data, pb...) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 54, @section 4.2.1. play +type PlayPacket struct { + variantCallPacket + StreamName amf0String +} + +func NewPlayPacket() *PlayPacket { + v := &PlayPacket{} + v.CommandName = commandPlay + v.CommandObject = NewAmf0Null() + return v +} + +func (v *PlayPacket) Size() int { + return v.variantCallPacket.Size() + v.StreamName.Size() +} + +func (v *PlayPacket) UnmarshalBinary(data []byte) (err error) { + p := data + + if err = v.variantCallPacket.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal call") + } + p = p[v.variantCallPacket.Size():] + + if err = v.StreamName.UnmarshalBinary(p); err != nil { + return errors.WithMessage(err, "unmarshal stream name") + } + p = p[v.StreamName.Size():] + + return +} + +func (v *PlayPacket) MarshalBinary() (data []byte, err error) { + var pb []byte + if pb, err = v.variantCallPacket.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal call") + } + data = append(data, pb...) + + if pb, err = v.StreamName.MarshalBinary(); err != nil { + return nil, errors.WithMessage(err, "marshal stream name") + } + data = append(data, pb...) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 31, @section 5.1. Set Chunk Size +// Protocol control message 1, Set Chunk Size, is used to notify the +// peer about the new maximum chunk size. +type SetChunkSize struct { + ChunkSize uint32 +} + +func NewSetChunkSize() *SetChunkSize { + return &SetChunkSize{ + ChunkSize: defaultChunkSize, + } +} + +func (v *SetChunkSize) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *SetChunkSize) Type() MessageType { + return MessageTypeSetChunkSize +} + +func (v *SetChunkSize) Size() int { + return 4 +} + +func (v *SetChunkSize) UnmarshalBinary(data []byte) (err error) { + if len(data) < 4 { + return errors.Errorf("requires 4 only %v bytes, %x", len(data), data) + } + v.ChunkSize = binary.BigEndian.Uint32(data) + + return +} + +func (v *SetChunkSize) MarshalBinary() (data []byte, err error) { + data = make([]byte, 4) + binary.BigEndian.PutUint32(data, v.ChunkSize) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 33, @section 5.5. Window Acknowledgement Size (5) +// The client or the server sends this message to inform the peer which +// window size to use when sending acknowledgment. +type WindowAcknowledgementSize struct { + AckSize uint32 +} + +func NewWindowAcknowledgementSize() *WindowAcknowledgementSize { + return &WindowAcknowledgementSize{} +} + +func (v *WindowAcknowledgementSize) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *WindowAcknowledgementSize) Type() MessageType { + return MessageTypeWindowAcknowledgementSize +} + +func (v *WindowAcknowledgementSize) Size() int { + return 4 +} + +func (v *WindowAcknowledgementSize) UnmarshalBinary(data []byte) (err error) { + if len(data) < 4 { + return errors.Errorf("requires 4 only %v bytes, %x", len(data), data) + } + v.AckSize = binary.BigEndian.Uint32(data) + + return +} + +func (v *WindowAcknowledgementSize) MarshalBinary() (data []byte, err error) { + data = make([]byte, 4) + binary.BigEndian.PutUint32(data, v.AckSize) + + return +} + +// Please read @doc rtmp_specification_1.0.pdf, @page 33, @section 5.6. Set Peer Bandwidth (6) +// The sender can mark this message hard (0), soft (1), or dynamic (2) +// using the Limit type field. +type LimitType uint8 + +const ( + LimitTypeHard LimitType = iota + LimitTypeSoft + LimitTypeDynamic +) + +// Please read @doc rtmp_specification_1.0.pdf, @page 33, @section 5.6. Set Peer Bandwidth (6) +// The client or the server sends this message to update the output +// bandwidth of the peer. +type SetPeerBandwidth struct { + Bandwidth uint32 + LimitType LimitType +} + +func NewSetPeerBandwidth() *SetPeerBandwidth { + return &SetPeerBandwidth{} +} + +func (v *SetPeerBandwidth) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *SetPeerBandwidth) Type() MessageType { + return MessageTypeSetPeerBandwidth +} + +func (v *SetPeerBandwidth) Size() int { + return 4 + 1 +} + +func (v *SetPeerBandwidth) UnmarshalBinary(data []byte) (err error) { + if len(data) < 5 { + return errors.Errorf("requires 5 only %v bytes, %x", len(data), data) + } + v.Bandwidth = binary.BigEndian.Uint32(data) + v.LimitType = LimitType(data[4]) + + return +} + +func (v *SetPeerBandwidth) MarshalBinary() (data []byte, err error) { + data = make([]byte, 5) + binary.BigEndian.PutUint32(data, v.Bandwidth) + data[4] = byte(v.LimitType) + + return +} + +type EventType uint16 + +const ( + // Generally, 4bytes event-data + + // The server sends this event to notify the client + // that a stream has become functional and can be + // used for communication. By default, this event + // is sent on ID 0 after the application connect + // command is successfully received from the + // client. The event data is 4-byte and represents + // The stream ID of the stream that became + // Functional. + EventTypeStreamBegin = 0x00 + + // The server sends this event to notify the client + // that the playback of data is over as requested + // on this stream. No more data is sent without + // issuing additional commands. The client discards + // The messages received for the stream. The + // 4 bytes of event data represent the ID of the + // stream on which playback has ended. + EventTypeStreamEOF = 0x01 + + // The server sends this event to notify the client + // that there is no more data on the stream. If the + // server does not detect any message for a time + // period, it can notify the subscribed clients + // that the stream is dry. The 4 bytes of event + // data represent the stream ID of the dry stream. + EventTypeStreamDry = 0x02 + + // The client sends this event to inform the server + // of the buffer size (in milliseconds) that is + // used to buffer any data coming over a stream. + // This event is sent before the server starts + // processing the stream. The first 4 bytes of the + // event data represent the stream ID and the next + // 4 bytes represent the buffer length, in + // milliseconds. + EventTypeSetBufferLength = 0x03 // 8bytes event-data + + // The server sends this event to notify the client + // that the stream is a recorded stream. The + // 4 bytes event data represent the stream ID of + // The recorded stream. + EventTypeStreamIsRecorded = 0x04 + + // The server sends this event to test whether the + // client is reachable. Event data is a 4-byte + // timestamp, representing the local server time + // When the server dispatched the command. The + // client responds with kMsgPingResponse on + // receiving kMsgPingRequest. + EventTypePingRequest = 0x06 + + // The client sends this event to the server in + // Response to the ping request. The event data is + // a 4-byte timestamp, which was received with the + // kMsgPingRequest request. + EventTypePingResponse = 0x07 + + // For PCUC size=3, for example the payload is "00 1A 01", + // it's a FMS control event, where the event type is 0x001a and event data is 0x01, + // please notice that the event data is only 1 byte for this event. + EventTypeFmsEvent0 = 0x1a +) + +// Please read @doc rtmp_specification_1.0.pdf, @page 32, @5.4. User Control Message (4) +// The client or the server sends this message to notify the peer about the user control events. +// This message carries Event type and Event data. +type UserControl struct { + // Event type is followed by Event data. + // @see: SrcPCUCEventType + EventType EventType + // The event data generally in 4bytes. + // @remark for event type is 0x001a, only 1bytes. + // @see SrsPCUCFmsEvent0 + EventData int32 + // 4bytes if event_type is SetBufferLength; otherwise 0. + ExtraData int32 +} + +func NewUserControl() *UserControl { + return &UserControl{} +} + +func (v *UserControl) BetterCid() chunkID { + return chunkIDProtocolControl +} + +func (v *UserControl) Type() MessageType { + return MessageTypeUserControl +} + +func (v *UserControl) Size() int { + size := 2 + + if v.EventType == EventTypeFmsEvent0 { + size += 1 + } else { + size += 4 + } + + if v.EventType == EventTypeSetBufferLength { + size += 4 + } + + return size +} + +func (v *UserControl) UnmarshalBinary(data []byte) (err error) { + if len(data) < 3 { + return errors.Errorf("requires 5 only %v bytes, %x", len(data), data) + } + + v.EventType = EventType(binary.BigEndian.Uint16(data)) + if len(data) < v.Size() { + return errors.Errorf("requires %v only %v bytes, %x", v.Size(), len(data), data) + } + + if v.EventType == EventTypeFmsEvent0 { + v.EventData = int32(uint8(data[2])) + } else { + v.EventData = int32(binary.BigEndian.Uint32(data[2:])) + } + + if v.EventType == EventTypeSetBufferLength { + v.ExtraData = int32(binary.BigEndian.Uint32(data[6:])) + } + + return +} + +func (v *UserControl) MarshalBinary() (data []byte, err error) { + data = make([]byte, v.Size()) + binary.BigEndian.PutUint16(data, uint16(v.EventType)) + + if v.EventType == EventTypeFmsEvent0 { + data[2] = uint8(v.EventData) + } else { + binary.BigEndian.PutUint32(data[2:], uint32(v.EventData)) + } + + if v.EventType == EventTypeSetBufferLength { + binary.BigEndian.PutUint32(data[6:], uint32(v.ExtraData)) + } + + return +} diff --git a/proxy/signal.go b/proxy/signal.go new file mode 100644 index 0000000000..367543f4a7 --- /dev/null +++ b/proxy/signal.go @@ -0,0 +1,44 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "os" + "os/signal" + "syscall" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +func installSignals(ctx context.Context, cancel context.CancelFunc) { + sc := make(chan os.Signal, 1) + signal.Notify(sc, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) + + go func() { + for s := range sc { + logger.Df(ctx, "Got signal %v", s) + cancel() + } + }() +} + +func installForceQuit(ctx context.Context) error { + var forceTimeout time.Duration + if t, err := time.ParseDuration(envForceQuitTimeout()); err != nil { + return errors.Wrapf(err, "parse force timeout %v", envForceQuitTimeout()) + } else { + forceTimeout = t + } + + go func() { + <-ctx.Done() + time.Sleep(forceTimeout) + logger.Wf(ctx, "Force to exit by timeout") + os.Exit(1) + }() + return nil +} diff --git a/proxy/srs.go b/proxy/srs.go new file mode 100644 index 0000000000..d05a39c610 --- /dev/null +++ b/proxy/srs.go @@ -0,0 +1,553 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "encoding/json" + "fmt" + "math/rand" + "os" + "strconv" + "strings" + "time" + + // Use v8 because we use Go 1.16+, while v9 requires Go 1.18+ + "github.com/go-redis/redis/v8" + + "srs-proxy/errors" + "srs-proxy/logger" + "srs-proxy/sync" +) + +// If server heartbeat in this duration, it's alive. +const srsServerAliveDuration = 300 * time.Second + +// If HLS streaming update in this duration, it's alive. +const srsHLSAliveDuration = 120 * time.Second + +// If WebRTC streaming update in this duration, it's alive. +const srsRTCAliveDuration = 120 * time.Second + +type SRSServer struct { + // The server IP. + IP string `json:"ip,omitempty"` + // The server device ID, configured by user. + DeviceID string `json:"device_id,omitempty"` + // The server id of SRS, store in file, may not change, mandatory. + ServerID string `json:"server_id,omitempty"` + // The service id of SRS, always change when restarted, mandatory. + ServiceID string `json:"service_id,omitempty"` + // The process id of SRS, always change when restarted, mandatory. + PID string `json:"pid,omitempty"` + // The RTMP listen endpoints. + RTMP []string `json:"rtmp,omitempty"` + // The HTTP Stream listen endpoints. + HTTP []string `json:"http,omitempty"` + // The HTTP API listen endpoints. + API []string `json:"api,omitempty"` + // The SRT server listen endpoints. + SRT []string `json:"srt,omitempty"` + // The RTC server listen endpoints. + RTC []string `json:"rtc,omitempty"` + // Last update time. + UpdatedAt time.Time `json:"update_at,omitempty"` +} + +func (v *SRSServer) ID() string { + return fmt.Sprintf("%v-%v-%v", v.ServerID, v.ServiceID, v.PID) +} + +func (v *SRSServer) String() string { + return fmt.Sprintf("%v", v) +} + +func (v *SRSServer) Format(f fmt.State, c rune) { + switch c { + case 'v', 's': + if f.Flag('+') { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("pid=%v, server=%v, service=%v", v.PID, v.ServerID, v.ServiceID)) + if v.DeviceID != "" { + sb.WriteString(fmt.Sprintf(", device=%v", v.DeviceID)) + } + if len(v.RTMP) > 0 { + sb.WriteString(fmt.Sprintf(", rtmp=[%v]", strings.Join(v.RTMP, ","))) + } + if len(v.HTTP) > 0 { + sb.WriteString(fmt.Sprintf(", http=[%v]", strings.Join(v.HTTP, ","))) + } + if len(v.API) > 0 { + sb.WriteString(fmt.Sprintf(", api=[%v]", strings.Join(v.API, ","))) + } + if len(v.SRT) > 0 { + sb.WriteString(fmt.Sprintf(", srt=[%v]", strings.Join(v.SRT, ","))) + } + if len(v.RTC) > 0 { + sb.WriteString(fmt.Sprintf(", rtc=[%v]", strings.Join(v.RTC, ","))) + } + sb.WriteString(fmt.Sprintf(", update=%v", v.UpdatedAt.Format("2006-01-02 15:04:05.999"))) + fmt.Fprintf(f, "SRS ip=%v, id=%v, %v", v.IP, v.ID(), sb.String()) + } else { + fmt.Fprintf(f, "SRS ip=%v, id=%v", v.IP, v.ID()) + } + default: + fmt.Fprintf(f, "%v, fmt=%%%c", v, c) + } +} + +func NewSRSServer(opts ...func(*SRSServer)) *SRSServer { + v := &SRSServer{} + for _, opt := range opts { + opt(v) + } + return v +} + +// NewDefaultSRSForDebugging initialize the default SRS media server, for debugging only. +func NewDefaultSRSForDebugging() (*SRSServer, error) { + if envDefaultBackendEnabled() != "on" { + return nil, nil + } + + if envDefaultBackendIP() == "" { + return nil, fmt.Errorf("empty default backend ip") + } + if envDefaultBackendRTMP() == "" { + return nil, fmt.Errorf("empty default backend rtmp") + } + + server := NewSRSServer(func(srs *SRSServer) { + srs.IP = envDefaultBackendIP() + srs.RTMP = []string{envDefaultBackendRTMP()} + srs.ServerID = fmt.Sprintf("default-%v", logger.GenerateContextID()) + srs.ServiceID = logger.GenerateContextID() + srs.PID = fmt.Sprintf("%v", os.Getpid()) + srs.UpdatedAt = time.Now() + }) + + if envDefaultBackendHttp() != "" { + server.HTTP = []string{envDefaultBackendHttp()} + } + if envDefaultBackendAPI() != "" { + server.API = []string{envDefaultBackendAPI()} + } + if envDefaultBackendRTC() != "" { + server.RTC = []string{envDefaultBackendRTC()} + } + if envDefaultBackendSRT() != "" { + server.SRT = []string{envDefaultBackendSRT()} + } + return server, nil +} + +// SRSLoadBalancer is the interface to load balance the SRS servers. +type SRSLoadBalancer interface { + // Initialize the load balancer. + Initialize(ctx context.Context) error + // Update the backer server. + Update(ctx context.Context, server *SRSServer) error + // Pick a backend server for the specified stream URL. + Pick(ctx context.Context, streamURL string) (*SRSServer, error) + // Load or store the HLS streaming for the specified stream URL. + LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) + // Load the HLS streaming by SPBHID, the SRS Proxy Backend HLS ID. + LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) + // Store the WebRTC streaming for the specified stream URL. + StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error + // Load the WebRTC streaming by ufrag, the ICE username. + LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) +} + +// srsLoadBalancer is the global SRS load balancer. +var srsLoadBalancer SRSLoadBalancer + +// srsMemoryLoadBalancer stores state in memory. +type srsMemoryLoadBalancer struct { + // All available SRS servers, key is server ID. + servers sync.Map[string, *SRSServer] + // The picked server to servce client by specified stream URL, key is stream url. + picked sync.Map[string, *SRSServer] + // The HLS streaming, key is stream URL. + hlsStreamURL sync.Map[string, *HLSPlayStream] + // The HLS streaming, key is SPBHID. + hlsSPBHID sync.Map[string, *HLSPlayStream] + // The WebRTC streaming, key is stream URL. + rtcStreamURL sync.Map[string, *RTCConnection] + // The WebRTC streaming, key is ufrag. + rtcUfrag sync.Map[string, *RTCConnection] +} + +func NewMemoryLoadBalancer() SRSLoadBalancer { + return &srsMemoryLoadBalancer{} +} + +func (v *srsMemoryLoadBalancer) Initialize(ctx context.Context) error { + if server, err := NewDefaultSRSForDebugging(); err != nil { + return errors.Wrapf(err, "initialize default SRS") + } else if server != nil { + if err := v.Update(ctx, server); err != nil { + return errors.Wrapf(err, "update default SRS %+v", server) + } + + // Keep alive. + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(30 * time.Second): + if err := v.Update(ctx, server); err != nil { + logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err) + } + } + } + }() + logger.Df(ctx, "MemoryLB: Initialize default SRS media server, %+v", server) + } + return nil +} + +func (v *srsMemoryLoadBalancer) Update(ctx context.Context, server *SRSServer) error { + v.servers.Store(server.ID(), server) + return nil +} + +func (v *srsMemoryLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) { + // Always proxy to the same server for the same stream URL. + if server, ok := v.picked.Load(streamURL); ok { + return server, nil + } + + // Gather all servers that were alive within the last few seconds. + var servers []*SRSServer + v.servers.Range(func(key string, server *SRSServer) bool { + if time.Since(server.UpdatedAt) < srsServerAliveDuration { + servers = append(servers, server) + } + return true + }) + + // If no servers available, use all possible servers. + if len(servers) == 0 { + v.servers.Range(func(key string, server *SRSServer) bool { + servers = append(servers, server) + return true + }) + } + + // No server found, failed. + if len(servers) == 0 { + return nil, fmt.Errorf("no server available for %v", streamURL) + } + + // Pick a server randomly from servers. + server := servers[rand.Intn(len(servers))] + v.picked.Store(streamURL, server) + return server, nil +} + +func (v *srsMemoryLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) { + // Load the HLS streaming for the SPBHID, for TS files. + if actual, ok := v.hlsSPBHID.Load(spbhid); !ok { + return nil, errors.Errorf("no HLS streaming for SPBHID %v", spbhid) + } else { + return actual, nil + } +} + +func (v *srsMemoryLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) { + // Update the HLS streaming for the stream URL, for M3u8. + actual, _ := v.hlsStreamURL.LoadOrStore(streamURL, value) + if actual == nil { + return nil, errors.Errorf("load or store HLS streaming for %v failed", streamURL) + } + + // Update the HLS streaming for the SPBHID, for TS files. + v.hlsSPBHID.Store(value.SRSProxyBackendHLSID, actual) + + return actual, nil +} + +func (v *srsMemoryLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error { + // Update the WebRTC streaming for the stream URL. + v.rtcStreamURL.Store(streamURL, value) + + // Update the WebRTC streaming for the ufrag. + v.rtcUfrag.Store(value.Ufrag, value) + return nil +} + +func (v *srsMemoryLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) { + if actual, ok := v.rtcUfrag.Load(ufrag); !ok { + return nil, errors.Errorf("no WebRTC streaming for ufrag %v", ufrag) + } else { + return actual, nil + } +} + +type srsRedisLoadBalancer struct { + // The redis client sdk. + rdb *redis.Client +} + +func NewRedisLoadBalancer() SRSLoadBalancer { + return &srsRedisLoadBalancer{} +} + +func (v *srsRedisLoadBalancer) Initialize(ctx context.Context) error { + redisDatabase, err := strconv.Atoi(envRedisDB()) + if err != nil { + return errors.Wrapf(err, "invalid PROXY_REDIS_DB %v", envRedisDB()) + } + + rdb := redis.NewClient(&redis.Options{ + Addr: fmt.Sprintf("%v:%v", envRedisHost(), envRedisPort()), + Password: envRedisPassword(), + DB: redisDatabase, + }) + v.rdb = rdb + + if err := rdb.Ping(ctx).Err(); err != nil { + return errors.Wrapf(err, "unable to connect to redis %v", rdb.String()) + } + logger.Df(ctx, "RedisLB: connected to redis %v ok", rdb.String()) + + if server, err := NewDefaultSRSForDebugging(); err != nil { + return errors.Wrapf(err, "initialize default SRS") + } else if server != nil { + if err := v.Update(ctx, server); err != nil { + return errors.Wrapf(err, "update default SRS %+v", server) + } + + // Keep alive. + go func() { + for { + select { + case <-ctx.Done(): + return + case <-time.After(30 * time.Second): + if err := v.Update(ctx, server); err != nil { + logger.Wf(ctx, "update default SRS %+v failed, %+v", server, err) + } + } + } + }() + logger.Df(ctx, "RedisLB: Initialize default SRS media server, %+v", server) + } + return nil +} + +func (v *srsRedisLoadBalancer) Update(ctx context.Context, server *SRSServer) error { + b, err := json.Marshal(server) + if err != nil { + return errors.Wrapf(err, "marshal server %+v", server) + } + + key := v.redisKeyServer(server.ID()) + if err = v.rdb.Set(ctx, key, b, srsServerAliveDuration).Err(); err != nil { + return errors.Wrapf(err, "set key=%v server %+v", key, server) + } + + // Query all servers from redis, in json string. + var serverKeys []string + if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil { + if err := json.Unmarshal(b, &serverKeys); err != nil { + return errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b)) + } + } + + // Check each server expiration, if not exists in redis, remove from servers. + for i := len(serverKeys) - 1; i >= 0; i-- { + if _, err := v.rdb.Get(ctx, serverKeys[i]).Bytes(); err != nil { + serverKeys = append(serverKeys[:i], serverKeys[i+1:]...) + } + } + + // Add server to servers if not exists. + var found bool + for _, serverKey := range serverKeys { + if serverKey == key { + found = true + break + } + } + if !found { + serverKeys = append(serverKeys, key) + } + + // Update all servers to redis. + b, err = json.Marshal(serverKeys) + if err != nil { + return errors.Wrapf(err, "marshal servers %+v", serverKeys) + } + if err = v.rdb.Set(ctx, v.redisKeyServers(), b, 0).Err(); err != nil { + return errors.Wrapf(err, "set key=%v servers %+v", v.redisKeyServers(), serverKeys) + } + + return nil +} + +func (v *srsRedisLoadBalancer) Pick(ctx context.Context, streamURL string) (*SRSServer, error) { + key := fmt.Sprintf("srs-proxy-url:%v", streamURL) + + // Always proxy to the same server for the same stream URL. + if serverKey, err := v.rdb.Get(ctx, key).Result(); err == nil { + // If server not exists, ignore and pick another server for the stream URL. + if b, err := v.rdb.Get(ctx, serverKey).Bytes(); err == nil && len(b) > 0 { + var server SRSServer + if err := json.Unmarshal(b, &server); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v server %v", key, string(b)) + } + + // TODO: If server fail, we should migrate the streams to another server. + return &server, nil + } + } + + // Query all servers from redis, in json string. + var serverKeys []string + if b, err := v.rdb.Get(ctx, v.redisKeyServers()).Bytes(); err == nil { + if err := json.Unmarshal(b, &serverKeys); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v servers %v", v.redisKeyServers(), string(b)) + } + } + + // No server found, failed. + if len(serverKeys) == 0 { + return nil, fmt.Errorf("no server available for %v", streamURL) + } + + // All server should be alive, if not, should have been removed by redis. So we only + // random pick one that is always available. + var serverKey string + var server SRSServer + for i := 0; i < 3; i++ { + tryServerKey := serverKeys[rand.Intn(len(serverKeys))] + b, err := v.rdb.Get(ctx, tryServerKey).Bytes() + if err == nil && len(b) > 0 { + if err := json.Unmarshal(b, &server); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v server %v", serverKey, string(b)) + } + + serverKey = tryServerKey + break + } + } + if serverKey == "" { + return nil, errors.Errorf("no server available in %v for %v", serverKeys, streamURL) + } + + // Update the picked server for the stream URL. + if err := v.rdb.Set(ctx, key, []byte(serverKey), 0).Err(); err != nil { + return nil, errors.Wrapf(err, "set key=%v server %v", key, serverKey) + } + + return &server, nil +} + +func (v *srsRedisLoadBalancer) LoadHLSBySPBHID(ctx context.Context, spbhid string) (*HLSPlayStream, error) { + key := v.redisKeySPBHID(spbhid) + + b, err := v.rdb.Get(ctx, key).Bytes() + if err != nil { + return nil, errors.Wrapf(err, "get key=%v HLS", key) + } + + var actual HLSPlayStream + if err := json.Unmarshal(b, &actual); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b)) + } + + return &actual, nil +} + +func (v *srsRedisLoadBalancer) LoadOrStoreHLS(ctx context.Context, streamURL string, value *HLSPlayStream) (*HLSPlayStream, error) { + b, err := json.Marshal(value) + if err != nil { + return nil, errors.Wrapf(err, "marshal HLS %v", value) + } + + key := v.redisKeyHLS(streamURL) + if err = v.rdb.Set(ctx, key, b, srsHLSAliveDuration).Err(); err != nil { + return nil, errors.Wrapf(err, "set key=%v HLS %v", key, value) + } + + key2 := v.redisKeySPBHID(value.SRSProxyBackendHLSID) + if err := v.rdb.Set(ctx, key2, b, srsHLSAliveDuration).Err(); err != nil { + return nil, errors.Wrapf(err, "set key=%v HLS %v", key2, value) + } + + // Query the HLS streaming from redis. + b2, err := v.rdb.Get(ctx, key).Bytes() + if err != nil { + return nil, errors.Wrapf(err, "get key=%v HLS", key) + } + + var actual HLSPlayStream + if err := json.Unmarshal(b2, &actual); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v HLS %v", key, string(b2)) + } + + return &actual, nil +} + +func (v *srsRedisLoadBalancer) StoreWebRTC(ctx context.Context, streamURL string, value *RTCConnection) error { + b, err := json.Marshal(value) + if err != nil { + return errors.Wrapf(err, "marshal WebRTC %v", value) + } + + key := v.redisKeyRTC(streamURL) + if err = v.rdb.Set(ctx, key, b, srsRTCAliveDuration).Err(); err != nil { + return errors.Wrapf(err, "set key=%v WebRTC %v", key, value) + } + + key2 := v.redisKeyUfrag(value.Ufrag) + if err := v.rdb.Set(ctx, key2, b, srsRTCAliveDuration).Err(); err != nil { + return errors.Wrapf(err, "set key=%v WebRTC %v", key2, value) + } + + return nil +} + +func (v *srsRedisLoadBalancer) LoadWebRTCByUfrag(ctx context.Context, ufrag string) (*RTCConnection, error) { + key := v.redisKeyUfrag(ufrag) + + b, err := v.rdb.Get(ctx, key).Bytes() + if err != nil { + return nil, errors.Wrapf(err, "get key=%v WebRTC", key) + } + + var actual RTCConnection + if err := json.Unmarshal(b, &actual); err != nil { + return nil, errors.Wrapf(err, "unmarshal key=%v WebRTC %v", key, string(b)) + } + + return &actual, nil +} + +func (v *srsRedisLoadBalancer) redisKeyUfrag(ufrag string) string { + return fmt.Sprintf("srs-proxy-ufrag:%v", ufrag) +} + +func (v *srsRedisLoadBalancer) redisKeyRTC(streamURL string) string { + return fmt.Sprintf("srs-proxy-rtc:%v", streamURL) +} + +func (v *srsRedisLoadBalancer) redisKeySPBHID(spbhid string) string { + return fmt.Sprintf("srs-proxy-spbhid:%v", spbhid) +} + +func (v *srsRedisLoadBalancer) redisKeyHLS(streamURL string) string { + return fmt.Sprintf("srs-proxy-hls:%v", streamURL) +} + +func (v *srsRedisLoadBalancer) redisKeyServer(serverID string) string { + return fmt.Sprintf("srs-proxy-server:%v", serverID) +} + +func (v *srsRedisLoadBalancer) redisKeyServers() string { + return fmt.Sprintf("srs-proxy-all-servers") +} diff --git a/proxy/srt.go b/proxy/srt.go new file mode 100644 index 0000000000..e4c629af8d --- /dev/null +++ b/proxy/srt.go @@ -0,0 +1,574 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "bytes" + "context" + "encoding/binary" + "fmt" + "net" + "strings" + stdSync "sync" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" + "srs-proxy/sync" +) + +// srsSRTServer is the proxy for SRS server via SRT. It will figure out which backend server to +// proxy to. It only parses the SRT handshake messages, parses the stream id, and proxy to the +// backend server. +type srsSRTServer struct { + // The UDP listener for SRT server. + listener *net.UDPConn + + // The SRT connections, identify by the socket ID. + sockets sync.Map[uint32, *SRTConnection] + // The system start time. + start time.Time + + // The wait group for server. + wg stdSync.WaitGroup +} + +func NewSRSSRTServer(opts ...func(*srsSRTServer)) *srsSRTServer { + v := &srsSRTServer{ + start: time.Now(), + } + + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *srsSRTServer) Close() error { + if v.listener != nil { + v.listener.Close() + } + + v.wg.Wait() + return nil +} + +func (v *srsSRTServer) Run(ctx context.Context) error { + // Parse address to listen. + endpoint := envSRTServer() + if !strings.Contains(endpoint, ":") { + endpoint = ":" + endpoint + } + + saddr, err := net.ResolveUDPAddr("udp", endpoint) + if err != nil { + return errors.Wrapf(err, "resolve udp addr %v", endpoint) + } + + listener, err := net.ListenUDP("udp", saddr) + if err != nil { + return errors.Wrapf(err, "listen udp %v", saddr) + } + v.listener = listener + logger.Df(ctx, "SRT server listen at %v", saddr) + + // Consume all messages from UDP media transport. + v.wg.Add(1) + go func() { + defer v.wg.Done() + + for ctx.Err() == nil { + buf := make([]byte, 4096) + n, caddr, err := v.listener.ReadFromUDP(buf) + if err != nil { + // TODO: If SRT server closed unexpectedly, we should notice the main loop to quit. + logger.Wf(ctx, "read from udp failed, err=%+v", err) + continue + } + + if err := v.handleClientUDP(ctx, caddr, buf[:n]); err != nil { + logger.Wf(ctx, "handle udp %vB failed, addr=%v, err=%+v", n, caddr, err) + } + } + }() + + return nil +} + +func (v *srsSRTServer) handleClientUDP(ctx context.Context, addr *net.UDPAddr, data []byte) error { + socketID := srtParseSocketID(data) + + var pkt *SRTHandshakePacket + if srtIsHandshake(data) { + pkt = &SRTHandshakePacket{} + if err := pkt.UnmarshalBinary(data); err != nil { + return err + } + + if socketID == 0 { + socketID = pkt.SRTSocketID + } + } + + conn, ok := v.sockets.LoadOrStore(socketID, NewSRTConnection(func(c *SRTConnection) { + c.ctx = logger.WithContext(ctx) + c.listenerUDP, c.socketID = v.listener, socketID + c.start = v.start + })) + + ctx = conn.ctx + if !ok { + logger.Df(ctx, "Create new SRT connection skt=%v", socketID) + } + + if newSocketID, err := conn.HandlePacket(pkt, addr, data); err != nil { + return errors.Wrapf(err, "handle packet") + } else if newSocketID != 0 && newSocketID != socketID { + // The connection may use a new socket ID. + // TODO: FIXME: Should cleanup the dead SRT connection. + v.sockets.Store(newSocketID, conn) + } + + return nil +} + +// SRTConnection is an SRT connection proxy, for both caller and listener. It represents an SRT +// connection, identify by the socket ID. +// +// It's similar to RTMP or HTTP FLV/TS proxy connection, which are stateless and all state is in +// the client request. The SRTConnection is stateless, and no need to sync between proxy servers. +// +// Unlike the WebRTC connection, SRTConnection does not support address changes. This means the +// client should never switch to another network or port. If this occurs, the client may be served +// by a different proxy server and fail because the other proxy server cannot identify the client. +type SRTConnection struct { + // The stream context for SRT connection. + ctx context.Context + + // The current socket ID. + socketID uint32 + + // The UDP connection proxy to backend. + backendUDP *net.UDPConn + // The listener UDP connection, used to send messages to client. + listenerUDP *net.UDPConn + + // Listener start time. + start time.Time + + // Handshake packets with client. + handshake0 *SRTHandshakePacket + handshake1 *SRTHandshakePacket + handshake2 *SRTHandshakePacket + handshake3 *SRTHandshakePacket +} + +func NewSRTConnection(opts ...func(*SRTConnection)) *SRTConnection { + v := &SRTConnection{} + for _, opt := range opts { + opt(v) + } + return v +} + +func (v *SRTConnection) HandlePacket(pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) (uint32, error) { + ctx := v.ctx + + // If not handshake, try to proxy to backend directly. + if pkt == nil { + // Proxy client message to backend. + if v.backendUDP != nil { + if _, err := v.backendUDP.Write(data); err != nil { + return v.socketID, errors.Wrapf(err, "write to backend") + } + } + + return v.socketID, nil + } + + // Handle handshake messages. + if err := v.handleHandshake(ctx, pkt, addr, data); err != nil { + return v.socketID, errors.Wrapf(err, "handle handshake %v", pkt) + } + + return v.socketID, nil +} + +func (v *SRTConnection) handleHandshake(ctx context.Context, pkt *SRTHandshakePacket, addr *net.UDPAddr, data []byte) error { + // Handle handshake 0 and 1 messages. + if pkt.SynCookie == 0 { + // Save handshake 0 packet. + v.handshake0 = pkt + logger.Df(ctx, "SRT Handshake 0: %v", v.handshake0) + + // Response handshake 1. + v.handshake1 = &SRTHandshakePacket{ + ControlFlag: pkt.ControlFlag, + ControlType: 0, + SubType: 0, + AdditionalInfo: 0, + Timestamp: uint32(time.Since(v.start).Microseconds()), + SocketID: pkt.SRTSocketID, + Version: 5, + EncryptionField: 0, + ExtensionField: 0x4A17, + InitSequence: pkt.InitSequence, + MTU: pkt.MTU, + FlowWindow: pkt.FlowWindow, + HandshakeType: 1, + SRTSocketID: pkt.SRTSocketID, + SynCookie: 0x418d5e4e, + PeerIP: net.ParseIP("127.0.0.1"), + } + logger.Df(ctx, "SRT Handshake 1: %v", v.handshake1) + + if b, err := v.handshake1.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 1") + } else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil { + return errors.Wrapf(err, "write handshake 1") + } + + return nil + } + + // Handle handshake 2 and 3 messages. + // Parse stream id from packet. + streamID, err := pkt.StreamID() + if err != nil { + return errors.Wrapf(err, "parse stream id") + } + + // Save handshake packet. + v.handshake2 = pkt + logger.Df(ctx, "SRT Handshake 2: %v, sid=%v", v.handshake2, streamID) + + // Start the UDP proxy to backend. + if err := v.connectBackend(ctx, streamID); err != nil { + return errors.Wrapf(err, "connect backend for %v", streamID) + } + + // Proxy client message to backend. + if v.backendUDP == nil { + return errors.Errorf("no backend for %v", streamID) + } + + // Proxy handshake 0 to backend server. + if b, err := v.handshake0.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 0") + } else if _, err = v.backendUDP.Write(b); err != nil { + return errors.Wrapf(err, "write handshake 0") + } + logger.Df(ctx, "Proxy send handshake 0: %v", v.handshake0) + + // Read handshake 1 from backend server. + b := make([]byte, 4096) + handshake1p := &SRTHandshakePacket{} + if nn, err := v.backendUDP.Read(b); err != nil { + return errors.Wrapf(err, "read handshake 1") + } else if err := handshake1p.UnmarshalBinary(b[:nn]); err != nil { + return errors.Wrapf(err, "unmarshal handshake 1") + } + logger.Df(ctx, "Proxy got handshake 1: %v", handshake1p) + + // Proxy handshake 2 to backend server. + handshake2p := *v.handshake2 + handshake2p.SynCookie = handshake1p.SynCookie + if b, err := handshake2p.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 2") + } else if _, err = v.backendUDP.Write(b); err != nil { + return errors.Wrapf(err, "write handshake 2") + } + logger.Df(ctx, "Proxy send handshake 2: %v", handshake2p) + + // Read handshake 3 from backend server. + handshake3p := &SRTHandshakePacket{} + if nn, err := v.backendUDP.Read(b); err != nil { + return errors.Wrapf(err, "read handshake 3") + } else if err := handshake3p.UnmarshalBinary(b[:nn]); err != nil { + return errors.Wrapf(err, "unmarshal handshake 3") + } + logger.Df(ctx, "Proxy got handshake 3: %v", handshake3p) + + // Response handshake 3 to client. + v.handshake3 = &*handshake3p + v.handshake3.SynCookie = v.handshake1.SynCookie + v.socketID = handshake3p.SRTSocketID + logger.Df(ctx, "Handshake 3: %v", v.handshake3) + + if b, err := v.handshake3.MarshalBinary(); err != nil { + return errors.Wrapf(err, "marshal handshake 3") + } else if _, err = v.listenerUDP.WriteToUDP(b, addr); err != nil { + return errors.Wrapf(err, "write handshake 3") + } + + // Start a goroutine to proxy message from backend to client. + // TODO: FIXME: Support close the connection when timeout or client disconnected. + go func() { + for ctx.Err() == nil { + nn, err := v.backendUDP.Read(b) + if err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "read from backend failed, err=%v", err) + return + } + if _, err = v.listenerUDP.WriteToUDP(b[:nn], addr); err != nil { + // TODO: If backend server closed unexpectedly, we should notice the stream to quit. + logger.Wf(ctx, "write to client failed, err=%v", err) + return + } + } + }() + return nil +} + +func (v *SRTConnection) connectBackend(ctx context.Context, streamID string) error { + if v.backendUDP != nil { + return nil + } + + // Parse stream id to host and resource. + host, resource, err := parseSRTStreamID(streamID) + if err != nil { + return errors.Wrapf(err, "parse stream id %v", streamID) + } + + if host == "" { + host = "localhost" + } + + streamURL, err := buildStreamURL(fmt.Sprintf("srt://%v/%v", host, resource)) + if err != nil { + return errors.Wrapf(err, "build stream url %v", streamID) + } + + // Pick a backend SRS server to proxy the SRT stream. + backend, err := srsLoadBalancer.Pick(ctx, streamURL) + if err != nil { + return errors.Wrapf(err, "pick backend for %v", streamURL) + } + + // Parse UDP port from backend. + if len(backend.SRT) == 0 { + return errors.Errorf("no udp server %v for %v", backend, streamURL) + } + + _, _, udpPort, err := parseListenEndpoint(backend.SRT[0]) + if err != nil { + return errors.Wrapf(err, "parse udp port %v of %v for %v", backend.SRT[0], backend, streamURL) + } + + // Connect to backend SRS server via UDP client. + // TODO: FIXME: Support close the connection when timeout or client disconnected. + backendAddr := net.UDPAddr{IP: net.ParseIP(backend.IP), Port: int(udpPort)} + if backendUDP, err := net.DialUDP("udp", nil, &backendAddr); err != nil { + return errors.Wrapf(err, "dial udp to %v of %v for %v", backendAddr, backend, streamURL) + } else { + v.backendUDP = backendUDP + } + + return nil +} + +// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2 +// See https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-3.2.1 +type SRTHandshakePacket struct { + // F: 1 bit. Packet Type Flag. The control packet has this flag set to + // "1". The data packet has this flag set to "0". + ControlFlag uint8 + // Control Type: 15 bits. Control Packet Type. The use of these bits + // is determined by the control packet type definition. + // Handshake control packets (Control Type = 0x0000) are used to + // exchange peer configurations, to agree on connection parameters, and + // to establish a connection. + ControlType uint16 + // Subtype: 16 bits. This field specifies an additional subtype for + // specific packets. + SubType uint16 + // Type-specific Information: 32 bits. The use of this field depends on + // the particular control packet type. Handshake packets do not use + // this field. + AdditionalInfo uint32 + // Timestamp: 32 bits. + Timestamp uint32 + // Destination Socket ID: 32 bits. + SocketID uint32 + + // Version: 32 bits. A base protocol version number. Currently used + // values are 4 and 5. Values greater than 5 are reserved for future + // use. + Version uint32 + // Encryption Field: 16 bits. Block cipher family and key size. The + // values of this field are described in Table 2. The default value + // is AES-128. + // 0 | No Encryption Advertised + // 2 | AES-128 + // 3 | AES-192 + // 4 | AES-256 + EncryptionField uint16 + // Extension Field: 16 bits. This field is message specific extension + // related to Handshake Type field. The value MUST be set to 0 + // except for the following cases. (1) If the handshake control + // packet is the INDUCTION message, this field is sent back by the + // Listener. (2) In the case of a CONCLUSION message, this field + // value should contain a combination of Extension Type values. + // 0x00000001 | HSREQ + // 0x00000002 | KMREQ + // 0x00000004 | CONFIG + // 0x4A17 if HandshakeType is INDUCTION, see https://datatracker.ietf.org/doc/html/draft-sharabayko-srt-01#section-4.3.1.1 + ExtensionField uint16 + // Initial Packet Sequence Number: 32 bits. The sequence number of the + // very first data packet to be sent. + InitSequence uint32 + // Maximum Transmission Unit Size: 32 bits. This value is typically set + // to 1500, which is the default Maximum Transmission Unit (MTU) size + // for Ethernet, but can be less. + MTU uint32 + // Maximum Flow Window Size: 32 bits. The value of this field is the + // maximum number of data packets allowed to be "in flight" (i.e. the + // number of sent packets for which an ACK control packet has not yet + // been received). + FlowWindow uint32 + // Handshake Type: 32 bits. This field indicates the handshake packet + // type. + // 0xFFFFFFFD | DONE + // 0xFFFFFFFE | AGREEMENT + // 0xFFFFFFFF | CONCLUSION + // 0x00000000 | WAVEHAND + // 0x00000001 | INDUCTION + HandshakeType uint32 + // SRT Socket ID: 32 bits. This field holds the ID of the source SRT + // socket from which a handshake packet is issued. + SRTSocketID uint32 + // SYN Cookie: 32 bits. Randomized value for processing a handshake. + // The value of this field is specified by the handshake message + // type. + SynCookie uint32 + // Peer IP Address: 128 bits. IPv4 or IPv6 address of the packet's + // sender. The value consists of four 32-bit fields. + PeerIP net.IP + // Extensions. + // Extension Type: 16 bits. The value of this field is used to process + // an integrated handshake. Each extension can have a pair of + // request and response types. + // Extension Length: 16 bits. The length of the Extension Contents + // field in four-byte blocks. + // Extension Contents: variable length. The payload of the extension. + ExtraData []byte +} + +func (v *SRTHandshakePacket) IsData() bool { + return v.ControlFlag == 0x00 +} + +func (v *SRTHandshakePacket) IsControl() bool { + return v.ControlFlag == 0x80 +} + +func (v *SRTHandshakePacket) IsHandshake() bool { + return v.IsControl() && v.ControlType == 0x00 && v.SubType == 0x00 +} + +func (v *SRTHandshakePacket) StreamID() (string, error) { + p := v.ExtraData + for { + if len(p) < 2 { + return "", errors.Errorf("Require 2 bytes, actual=%v, extra=%v", len(p), len(v.ExtraData)) + } + + extType := binary.BigEndian.Uint16(p) + extSize := binary.BigEndian.Uint16(p[2:]) + p = p[4:] + + if len(p) < int(extSize*4) { + return "", errors.Errorf("Require %v bytes, actual=%v, extra=%v", extSize*4, len(p), len(v.ExtraData)) + } + + // Ignore other packets except stream id. + if extType != 0x05 { + p = p[extSize*4:] + continue + } + + // We must copy it, because we will decode the stream id. + data := append([]byte{}, p[:extSize*4]...) + + // Reverse the stream id encoded in little-endian to big-endian. + for i := 0; i < len(data); i += 4 { + value := binary.LittleEndian.Uint32(data[i:]) + binary.BigEndian.PutUint32(data[i:], value) + } + + // Trim the trailing zero bytes. + data = bytes.TrimRight(data, "\x00") + return string(data), nil + } +} + +func (v *SRTHandshakePacket) String() string { + return fmt.Sprintf("Control=%v, CType=%v, SType=%v, Timestamp=%v, SocketID=%v, Version=%v, Encrypt=%v, Extension=%v, InitSequence=%v, MTU=%v, FlowWnd=%v, HSType=%v, SRTSocketID=%v, Cookie=%v, Peer=%vB, Extra=%vB", + v.IsControl(), v.ControlType, v.SubType, v.Timestamp, v.SocketID, v.Version, v.EncryptionField, v.ExtensionField, v.InitSequence, v.MTU, v.FlowWindow, v.HandshakeType, v.SRTSocketID, v.SynCookie, len(v.PeerIP), len(v.ExtraData)) +} + +func (v *SRTHandshakePacket) UnmarshalBinary(b []byte) error { + if len(b) < 4 { + return errors.Errorf("Invalid packet length %v", len(b)) + } + v.ControlFlag = b[0] & 0x80 + v.ControlType = binary.BigEndian.Uint16(b[0:2]) & 0x7fff + v.SubType = binary.BigEndian.Uint16(b[2:4]) + + if len(b) < 64 { + return errors.Errorf("Invalid packet length %v", len(b)) + } + v.AdditionalInfo = binary.BigEndian.Uint32(b[4:]) + v.Timestamp = binary.BigEndian.Uint32(b[8:]) + v.SocketID = binary.BigEndian.Uint32(b[12:]) + v.Version = binary.BigEndian.Uint32(b[16:]) + v.EncryptionField = binary.BigEndian.Uint16(b[20:]) + v.ExtensionField = binary.BigEndian.Uint16(b[22:]) + v.InitSequence = binary.BigEndian.Uint32(b[24:]) + v.MTU = binary.BigEndian.Uint32(b[28:]) + v.FlowWindow = binary.BigEndian.Uint32(b[32:]) + v.HandshakeType = binary.BigEndian.Uint32(b[36:]) + v.SRTSocketID = binary.BigEndian.Uint32(b[40:]) + v.SynCookie = binary.BigEndian.Uint32(b[44:]) + + // Only support IPv4. + v.PeerIP = net.IPv4(b[51], b[50], b[49], b[48]) + + v.ExtraData = b[64:] + + return nil +} + +func (v *SRTHandshakePacket) MarshalBinary() ([]byte, error) { + b := make([]byte, 64+len(v.ExtraData)) + binary.BigEndian.PutUint16(b, uint16(v.ControlFlag)<<8|v.ControlType) + binary.BigEndian.PutUint16(b[2:], v.SubType) + binary.BigEndian.PutUint32(b[4:], v.AdditionalInfo) + binary.BigEndian.PutUint32(b[8:], v.Timestamp) + binary.BigEndian.PutUint32(b[12:], v.SocketID) + binary.BigEndian.PutUint32(b[16:], v.Version) + binary.BigEndian.PutUint16(b[20:], v.EncryptionField) + binary.BigEndian.PutUint16(b[22:], v.ExtensionField) + binary.BigEndian.PutUint32(b[24:], v.InitSequence) + binary.BigEndian.PutUint32(b[28:], v.MTU) + binary.BigEndian.PutUint32(b[32:], v.FlowWindow) + binary.BigEndian.PutUint32(b[36:], v.HandshakeType) + binary.BigEndian.PutUint32(b[40:], v.SRTSocketID) + binary.BigEndian.PutUint32(b[44:], v.SynCookie) + + // Only support IPv4. + ip := v.PeerIP.To4() + b[48] = ip[3] + b[49] = ip[2] + b[50] = ip[1] + b[51] = ip[0] + + if len(v.ExtraData) > 0 { + copy(b[64:], v.ExtraData) + } + + return b, nil +} diff --git a/proxy/sync/map.go b/proxy/sync/map.go new file mode 100644 index 0000000000..75db12f9a9 --- /dev/null +++ b/proxy/sync/map.go @@ -0,0 +1,45 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package sync + +import "sync" + +type Map[K comparable, V any] struct { + m sync.Map +} + +func (m *Map[K, V]) Delete(key K) { + m.m.Delete(key) +} + +func (m *Map[K, V]) Load(key K) (value V, ok bool) { + v, ok := m.m.Load(key) + if !ok { + return value, ok + } + return v.(V), ok +} + +func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) { + v, loaded := m.m.LoadAndDelete(key) + if !loaded { + return value, loaded + } + return v.(V), loaded +} + +func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) { + a, loaded := m.m.LoadOrStore(key, value) + return a.(V), loaded +} + +func (m *Map[K, V]) Range(f func(key K, value V) bool) { + m.m.Range(func(key, value any) bool { + return f(key.(K), value.(V)) + }) +} + +func (m *Map[K, V]) Store(key K, value V) { + m.m.Store(key, value) +} diff --git a/proxy/utils.go b/proxy/utils.go new file mode 100644 index 0000000000..f3c3930762 --- /dev/null +++ b/proxy/utils.go @@ -0,0 +1,276 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import ( + "context" + "encoding/binary" + "encoding/json" + stdErr "errors" + "fmt" + "io" + "io/ioutil" + "net" + "net/http" + "net/url" + "os" + "path" + "reflect" + "regexp" + "strconv" + "strings" + "syscall" + "time" + + "srs-proxy/errors" + "srs-proxy/logger" +) + +func apiResponse(ctx context.Context, w http.ResponseWriter, r *http.Request, data any) { + w.Header().Set("Server", fmt.Sprintf("%v/%v", Signature(), Version())) + + b, err := json.Marshal(data) + if err != nil { + apiError(ctx, w, r, errors.Wrapf(err, "marshal %v %v", reflect.TypeOf(data), data)) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(b) +} + +func apiError(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) { + logger.Wf(ctx, "HTTP API error %+v", err) + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintln(w, fmt.Sprintf("%v", err)) +} + +func apiCORS(ctx context.Context, w http.ResponseWriter, r *http.Request) bool { + // Always support CORS. Note that browser may send origin header for m3u8, but no origin header + // for ts. So we always response CORS header. + if true { + // SRS does not need cookie or credentials, so we disable CORS credentials, and use * for CORS origin, + // headers, expose headers and methods. + w.Header().Set("Access-Control-Allow-Origin", "*") + // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Headers + w.Header().Set("Access-Control-Allow-Headers", "*") + // See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Access-Control-Allow-Methods + w.Header().Set("Access-Control-Allow-Methods", "*") + } + + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusOK) + return true + } + + return false +} + +func parseGracefullyQuitTimeout() (time.Duration, error) { + if t, err := time.ParseDuration(envGraceQuitTimeout()); err != nil { + return 0, errors.Wrapf(err, "parse duration %v", envGraceQuitTimeout()) + } else { + return t, nil + } +} + +// ParseBody read the body from r, and unmarshal JSON to v. +func ParseBody(r io.ReadCloser, v interface{}) error { + b, err := ioutil.ReadAll(r) + if err != nil { + return errors.Wrapf(err, "read body") + } + defer r.Close() + + if len(b) == 0 { + return nil + } + + if err := json.Unmarshal(b, v); err != nil { + return errors.Wrapf(err, "json unmarshal %v", string(b)) + } + + return nil +} + +// buildStreamURL build as vhost/app/stream for stream URL r. +func buildStreamURL(r string) (string, error) { + u, err := url.Parse(r) + if err != nil { + return "", errors.Wrapf(err, "parse url %v", r) + } + + // If not domain or ip in hostname, it's __defaultVhost__. + defaultVhost := !strings.Contains(u.Hostname(), ".") + + // If hostname is actually an IP address, it's __defaultVhost__. + if ip := net.ParseIP(u.Hostname()); ip.To4() != nil { + defaultVhost = true + } + + if defaultVhost { + return fmt.Sprintf("__defaultVhost__%v", u.Path), nil + } + + // Ignore port, only use hostname as vhost. + return fmt.Sprintf("%v%v", u.Hostname(), u.Path), nil +} + +// isPeerClosedError indicates whether peer object closed the connection. +func isPeerClosedError(err error) bool { + causeErr := errors.Cause(err) + + if stdErr.Is(causeErr, io.EOF) { + return true + } + + if stdErr.Is(causeErr, syscall.EPIPE) { + return true + } + + if netErr, ok := causeErr.(*net.OpError); ok { + if sysErr, ok := netErr.Err.(*os.SyscallError); ok { + if stdErr.Is(sysErr.Err, syscall.ECONNRESET) { + return true + } + } + } + + return false +} + +// convertURLToStreamURL convert the URL in HTTP request to special URLs. The unifiedURL is the URL +// in unified, foramt as scheme://vhost/app/stream without extensions. While the fullURL is the unifiedURL +// with extension. +func convertURLToStreamURL(r *http.Request) (unifiedURL, fullURL string) { + scheme := "http" + if r.TLS != nil { + scheme = "https" + } + + hostname := "__defaultVhost__" + if strings.Contains(r.Host, ":") { + if v, _, err := net.SplitHostPort(r.Host); err == nil { + hostname = v + } + } + + var appStream, streamExt string + + // Parse app/stream from query string. + q := r.URL.Query() + if app := q.Get("app"); app != "" { + appStream = "/" + app + } + if stream := q.Get("stream"); stream != "" { + appStream = fmt.Sprintf("%v/%v", appStream, stream) + } + + // Parse app/stream from path. + if appStream == "" { + streamExt = path.Ext(r.URL.Path) + appStream = strings.TrimSuffix(r.URL.Path, streamExt) + } + + unifiedURL = fmt.Sprintf("%v://%v%v", scheme, hostname, appStream) + fullURL = fmt.Sprintf("%v%v", unifiedURL, streamExt) + return +} + +// rtcIsSTUN returns true if data of UDP payload is a STUN packet. +func rtcIsSTUN(data []byte) bool { + return len(data) > 0 && (data[0] == 0 || data[0] == 1) +} + +// rtcIsRTPOrRTCP returns true if data of UDP payload is a RTP or RTCP packet. +func rtcIsRTPOrRTCP(data []byte) bool { + return len(data) >= 12 && (data[0]&0xC0) == 0x80 +} + +// srtIsHandshake returns true if data of UDP payload is a SRT handshake packet. +func srtIsHandshake(data []byte) bool { + return len(data) >= 4 && binary.BigEndian.Uint32(data) == 0x80000000 +} + +// srtParseSocketID parse the socket id from the SRT packet. +func srtParseSocketID(data []byte) uint32 { + if len(data) >= 16 { + return binary.BigEndian.Uint32(data[12:]) + } + return 0 +} + +// parseIceUfragPwd parse the ice-ufrag and ice-pwd from the SDP. +func parseIceUfragPwd(sdp string) (ufrag, pwd string, err error) { + if true { + ufragRe := regexp.MustCompile(`a=ice-ufrag:([^\s]+)`) + ufragMatch := ufragRe.FindStringSubmatch(sdp) + if len(ufragMatch) <= 1 { + return "", "", errors.Errorf("no ice-ufrag in sdp %v", sdp) + } + ufrag = ufragMatch[1] + } + + if true { + pwdRe := regexp.MustCompile(`a=ice-pwd:([^\s]+)`) + pwdMatch := pwdRe.FindStringSubmatch(sdp) + if len(pwdMatch) <= 1 { + return "", "", errors.Errorf("no ice-pwd in sdp %v", sdp) + } + pwd = pwdMatch[1] + } + + return ufrag, pwd, nil +} + +// parseSRTStreamID parse the SRT stream id to host(optional) and resource(required). +// See https://ossrs.io/lts/en-us/docs/v7/doc/srt#srt-url +func parseSRTStreamID(sid string) (host, resource string, err error) { + if true { + hostRe := regexp.MustCompile(`h=([^,]+)`) + hostMatch := hostRe.FindStringSubmatch(sid) + if len(hostMatch) > 1 { + host = hostMatch[1] + } + } + + if true { + resourceRe := regexp.MustCompile(`r=([^,]+)`) + resourceMatch := resourceRe.FindStringSubmatch(sid) + if len(resourceMatch) <= 1 { + return "", "", errors.Errorf("no resource in sid %v", sid) + } + resource = resourceMatch[1] + } + + return host, resource, nil +} + +// parseListenEndpoint parse the listen endpoint as: +// port The tcp listen port, like 1935. +// protocol://ip:port The listen endpoint, like tcp://:1935 or tcp://0.0.0.0:1935 +func parseListenEndpoint(ep string) (protocol string, ip net.IP, port uint16, err error) { + // If no colon in ep, it's port in string. + if !strings.Contains(ep, ":") { + if p, err := strconv.Atoi(ep); err != nil { + return "", nil, 0, errors.Wrapf(err, "parse port %v", ep) + } else { + return "tcp", nil, uint16(p), nil + } + } + + // Must be protocol://ip:port schema. + parts := strings.Split(ep, ":") + if len(parts) != 3 { + return "", nil, 0, errors.Errorf("invalid endpoint %v", ep) + } + + if p, err := strconv.Atoi(parts[2]); err != nil { + return "", nil, 0, errors.Wrapf(err, "parse port %v", parts[2]) + } else { + return parts[0], net.ParseIP(parts[1]), uint16(p), nil + } +} diff --git a/proxy/version.go b/proxy/version.go new file mode 100644 index 0000000000..94f668f96e --- /dev/null +++ b/proxy/version.go @@ -0,0 +1,27 @@ +// Copyright (c) 2024 Winlin +// +// SPDX-License-Identifier: MIT +package main + +import "fmt" + +func VersionMajor() int { + return 1 +} + +// VersionMinor specifies the typical version of SRS we adapt to. +func VersionMinor() int { + return 5 +} + +func VersionRevision() int { + return 0 +} + +func Version() string { + return fmt.Sprintf("%v.%v.%v", VersionMajor(), VersionMinor(), VersionRevision()) +} + +func Signature() string { + return "SRSProxy" +} diff --git a/trunk/conf/origin1-for-proxy.conf b/trunk/conf/origin1-for-proxy.conf new file mode 100644 index 0000000000..baca5c9f40 --- /dev/null +++ b/trunk/conf/origin1-for-proxy.conf @@ -0,0 +1,57 @@ + +listen 19351; +max_connections 1000; +pid objs/origin1.pid; +daemon off; +srs_log_tank console; +http_server { + enabled on; + listen 8081; + dir ./objs/nginx/html; +} +http_api { + enabled on; + listen 19851; +} +rtc_server { + enabled on; + listen 8001; # UDP port + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate + candidate $CANDIDATE; +} +srt_server { + enabled on; + listen 10081; + tsbpdmode off; + tlpktdrop off; +} +heartbeat { + enabled on; + interval 9; + url http://127.0.0.1:12025/api/v1/srs/register; + device_id origin1; + ports on; +} +vhost __defaultVhost__ { + http_remux { + enabled on; + mount [vhost]/[app]/[stream].flv; + } + hls { + enabled on; + hls_path ./objs/nginx/html; + hls_fragment 10; + hls_window 60; + } + rtc { + enabled on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc + rtmp_to_rtc on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp + rtc_to_rtmp on; + } + srt { + enabled on; + srt_to_rtmp on; + } +} diff --git a/trunk/conf/origin2-for-proxy.conf b/trunk/conf/origin2-for-proxy.conf new file mode 100644 index 0000000000..48f6398930 --- /dev/null +++ b/trunk/conf/origin2-for-proxy.conf @@ -0,0 +1,57 @@ + +listen 19352; +max_connections 1000; +pid objs/origin2.pid; +daemon off; +srs_log_tank console; +http_server { + enabled on; + listen 8082; + dir ./objs/nginx/html; +} +http_api { + enabled on; + listen 19853; +} +rtc_server { + enabled on; + listen 8001; # UDP port + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate + candidate $CANDIDATE; +} +srt_server { + enabled on; + listen 10082; + tsbpdmode off; + tlpktdrop off; +} +heartbeat { + enabled on; + interval 9; + url http://127.0.0.1:12025/api/v1/srs/register; + device_id origin2; + ports on; +} +vhost __defaultVhost__ { + http_remux { + enabled on; + mount [vhost]/[app]/[stream].flv; + } + hls { + enabled on; + hls_path ./objs/nginx/html; + hls_fragment 10; + hls_window 60; + } + rtc { + enabled on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc + rtmp_to_rtc on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp + rtc_to_rtmp on; + } + srt { + enabled on; + srt_to_rtmp on; + } +} diff --git a/trunk/conf/origin3-for-proxy.conf b/trunk/conf/origin3-for-proxy.conf new file mode 100644 index 0000000000..95624fb773 --- /dev/null +++ b/trunk/conf/origin3-for-proxy.conf @@ -0,0 +1,57 @@ + +listen 19353; +max_connections 1000; +pid objs/origin3.pid; +daemon off; +srs_log_tank console; +http_server { + enabled on; + listen 8083; + dir ./objs/nginx/html; +} +http_api { + enabled on; + listen 19852; +} +rtc_server { + enabled on; + listen 8001; # UDP port + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#config-candidate + candidate $CANDIDATE; +} +srt_server { + enabled on; + listen 10083; + tsbpdmode off; + tlpktdrop off; +} +heartbeat { + enabled on; + interval 9; + url http://127.0.0.1:12025/api/v1/srs/register; + device_id origin3; + ports on; +} +vhost __defaultVhost__ { + http_remux { + enabled on; + mount [vhost]/[app]/[stream].flv; + } + hls { + enabled on; + hls_path ./objs/nginx/html; + hls_fragment 10; + hls_window 60; + } + rtc { + enabled on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtmp-to-rtc + rtmp_to_rtc on; + # @see https://ossrs.net/lts/zh-cn/docs/v4/doc/webrtc#rtc-to-rtmp + rtc_to_rtmp on; + } + srt { + enabled on; + srt_to_rtmp on; + } +} diff --git a/trunk/doc/CHANGELOG.md b/trunk/doc/CHANGELOG.md index 2772c0bf21..9e676930f2 100644 --- a/trunk/doc/CHANGELOG.md +++ b/trunk/doc/CHANGELOG.md @@ -7,6 +7,7 @@ The changelog for SRS. ## SRS 7.0 Changelog +* v7.0, 2024-09-09, Merge [#4158](https://github.com/ossrs/srs/pull/4158): Proxy: Support proxy server for SRS. v7.0.16 (#4158) * v7.0, 2024-09-09, Merge [#4171](https://github.com/ossrs/srs/pull/4171): Heartbeat: Report ports for proxy server. v7.0.15 (#4171) * v7.0, 2024-09-01, Merge [#4165](https://github.com/ossrs/srs/pull/4165): FLV: Refine source and http handler. v7.0.14 (#4165) * v7.0, 2024-09-01, Merge [#4166](https://github.com/ossrs/srs/pull/4166): Edge: Fix flv edge crash when http unmount. v7.0.13 (#4166) diff --git a/trunk/src/app/srs_app_st.cpp b/trunk/src/app/srs_app_st.cpp index 3e21e468cd..466cbe068f 100755 --- a/trunk/src/app/srs_app_st.cpp +++ b/trunk/src/app/srs_app_st.cpp @@ -342,7 +342,12 @@ SrsWaitGroup::SrsWaitGroup() SrsWaitGroup::~SrsWaitGroup() { - wait(); + // In the destructor, we should NOT wait for all coroutines to be done, because user should decide + // to wait or not. Similar to the Go's sync.WaitGroup, it also requires user to wait explicitly. For + // some special use scenarios, such as error handling, for example, if we started three servers with + // wait group, and one of them failed, user may want to return error and quit directly, without wait + // for other running servers to be done. If we wait in the destructor, it will continue to run without + // some servers, in unknown behaviors. srs_cond_destroy(done_); } diff --git a/trunk/src/core/srs_core_version7.hpp b/trunk/src/core/srs_core_version7.hpp index fed95c499b..458a6c3d84 100644 --- a/trunk/src/core/srs_core_version7.hpp +++ b/trunk/src/core/srs_core_version7.hpp @@ -9,6 +9,6 @@ #define VERSION_MAJOR 7 #define VERSION_MINOR 0 -#define VERSION_REVISION 15 +#define VERSION_REVISION 16 #endif \ No newline at end of file