From fcf6d6a02b201dd2660b9b97036f7d0ec8af1124 Mon Sep 17 00:00:00 2001 From: wineandchord Date: Thu, 14 Dec 2023 17:12:41 +0800 Subject: [PATCH 01/39] log: provide registration for custom format encoder (#146) Fixes #145 --- log/zaplogger.go | 26 ++++++++++++++++------- log/zaplogger_test.go | 49 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 7 deletions(-) diff --git a/log/zaplogger.go b/log/zaplogger.go index 63b48af..974618f 100644 --- a/log/zaplogger.go +++ b/log/zaplogger.go @@ -119,14 +119,26 @@ func newEncoder(c *OutputConfig) zapcore.Encoder { if c.EnableColor { encoderCfg.EncodeLevel = zapcore.CapitalColorLevelEncoder } - switch c.Formatter { - case "console": - return zapcore.NewConsoleEncoder(encoderCfg) - case "json": - return zapcore.NewJSONEncoder(encoderCfg) - default: - return zapcore.NewConsoleEncoder(encoderCfg) + if newFormatEncoder, ok := formatEncoders[c.Formatter]; ok { + return newFormatEncoder(encoderCfg) } + // Defaults to console encoder. + return zapcore.NewConsoleEncoder(encoderCfg) +} + +var formatEncoders = map[string]NewFormatEncoder{ + "console": zapcore.NewConsoleEncoder, + "json": zapcore.NewJSONEncoder, +} + +// NewFormatEncoder is the function type for creating a format encoder out of an encoder config. +type NewFormatEncoder func(zapcore.EncoderConfig) zapcore.Encoder + +// RegisterFormatEncoder registers a NewFormatEncoder with the specified formatName key. +// The existing formats include "console" and "json", but you can override these format encoders +// or provide a new custom one. +func RegisterFormatEncoder(formatName string, newFormatEncoder NewFormatEncoder) { + formatEncoders[formatName] = newFormatEncoder } // GetLogEncoderKey gets user defined log output name, uses defKey if empty. diff --git a/log/zaplogger_test.go b/log/zaplogger_test.go index 58eca32..5846fec 100644 --- a/log/zaplogger_test.go +++ b/log/zaplogger_test.go @@ -17,12 +17,14 @@ import ( "errors" "fmt" "runtime" + "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" + "go.uber.org/zap/buffer" "go.uber.org/zap/zapcore" "go.uber.org/zap/zaptest/observer" @@ -334,3 +336,50 @@ func TestLogEnableColor(t *testing.T) { l.Warn("hello") l.Error("hello") } + +func TestLogNewFormatEncoder(t *testing.T) { + const myFormatter = "myformatter" + log.RegisterFormatEncoder(myFormatter, func(ec zapcore.EncoderConfig) zapcore.Encoder { + return &consoleEncoder{ + Encoder: zapcore.NewJSONEncoder(zapcore.EncoderConfig{}), + pool: buffer.NewPool(), + cfg: ec, + } + }) + cfg := []log.OutputConfig{{Writer: "console", Level: "trace", Formatter: myFormatter}} + l := log.NewZapLog(cfg).With(log.Field{Key: "trace-id", Value: "xx"}) + l.Trace("hello") + l.Debug("hello") + l.Info("hello") + l.Warn("hello") + l.Error("hello") + // 2023/12/14 10:54:55 {"trace-id":"xx"} DEBUG hello + // 2023/12/14 10:54:55 {"trace-id":"xx"} DEBUG hello + // 2023/12/14 10:54:55 {"trace-id":"xx"} INFO hello + // 2023/12/14 10:54:55 {"trace-id":"xx"} WARN hello + // 2023/12/14 10:54:55 {"trace-id":"xx"} ERROR hello +} + +type consoleEncoder struct { + zapcore.Encoder + pool buffer.Pool + cfg zapcore.EncoderConfig +} + +func (c consoleEncoder) Clone() zapcore.Encoder { + return consoleEncoder{Encoder: c.Encoder.Clone(), pool: buffer.NewPool(), cfg: c.cfg} +} + +func (c consoleEncoder) EncodeEntry(entry zapcore.Entry, fields []zapcore.Field) (*buffer.Buffer, error) { + buf, err := c.Encoder.EncodeEntry(zapcore.Entry{}, nil) + if err != nil { + return nil, err + } + buffer := c.pool.Get() + buffer.AppendString(entry.Time.Format("2006/01/02 15:04:05")) + field := buf.String() + buffer.AppendString(" " + field[:len(field)-1] + " ") + buffer.AppendString(strings.ToUpper(entry.Level.String()) + " ") + buffer.AppendString(entry.Message + "\n") + return buffer, nil +} From 65a60b16ba4f1633a7b1d600d2c74b8741ce7507 Mon Sep 17 00:00:00 2001 From: wineandchord Date: Thu, 14 Dec 2023 17:13:28 +0800 Subject: [PATCH 02/39] docs: add note on listen to all addresses for a service (#144) Fixes #143 --- docs/user_guide/framework_conf.md | 1 + docs/user_guide/framework_conf.zh_CN.md | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/user_guide/framework_conf.md b/docs/user_guide/framework_conf.md index 0198b47..a934838 100644 --- a/docs/user_guide/framework_conf.md +++ b/docs/user_guide/framework_conf.md @@ -110,6 +110,7 @@ server: - # Optional, whether to prohibit inheriting the upstream timeout time, used to close the full link timeout mechanism, the default is false disable_request_timeout: Boolean # Optional, the IP address of the service monitors, if it is empty, it will try to get the network card IP, if it is still empty, use global.local_ip + # To listen on all addresses, please use "0.0.0.0" (IPv4) or "::" (IPv6). ip: String(ipv4 or ipv6) # Required, the service name, used for service discovery name: String diff --git a/docs/user_guide/framework_conf.zh_CN.md b/docs/user_guide/framework_conf.zh_CN.md index fe85470..4ab74e7 100644 --- a/docs/user_guide/framework_conf.zh_CN.md +++ b/docs/user_guide/framework_conf.zh_CN.md @@ -108,6 +108,7 @@ server: - # 选填,是否禁止继承上游的超时时间,用于关闭全链路超时机制,默认为 false disable_request_timeout: Boolean # 选填,service 监听的 IP 地址,如果为空,则会尝试获取网卡 IP,如果仍为空,则使用 global.local_ip + # 如果需要监听所有地址的话,请使用 "0.0.0.0" (ipv4) 或 "::" (ipv6) ip: String(ipv4 or ipv6) # 必填,服务名,用于服务发现 name: String From 940dc64e42cb707d6fe1534e9b50555718537a1f Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 15:34:38 +0800 Subject: [PATCH 03/39] don't wait entire service timeout, but close quickly on no active request. --- server/service.go | 30 ++++++----- server/service_test.go | 111 ++++++++++++++++++++++++++--------------- 2 files changed, 88 insertions(+), 53 deletions(-) diff --git a/server/service.go b/server/service.go index 631a54f..953de5a 100644 --- a/server/service.go +++ b/server/service.go @@ -534,21 +534,28 @@ func (s *service) Close(ch chan struct{}) error { } } } - s.waitBeforeClose() + if remains := s.waitBeforeClose(); remains > 0 { + log.Infof("process %d service %s remains %d requests before close", + os.Getpid(), s.opts.ServiceName, remains) + } // this will cancel all children ctx. s.cancel() + timeout := time.Millisecond * 300 if s.opts.Timeout > timeout { // use the larger one timeout = s.opts.Timeout } - time.Sleep(timeout) + if remains := s.waitInactive(timeout); remains > 0 { + log.Infof("process %d service %s remains %d requests after close", + os.Getpid(), s.opts.ServiceName, remains) + } log.Infof("process:%d, %s service:%s, closed", pid, s.opts.protocol, s.opts.ServiceName) ch <- struct{}{} return nil } -func (s *service) waitBeforeClose() { +func (s *service) waitBeforeClose() int64 { closeWaitTime := s.opts.CloseWaitTime if closeWaitTime > MaxCloseWaitTime { closeWaitTime = MaxCloseWaitTime @@ -562,18 +569,17 @@ func (s *service) waitBeforeClose() { os.Getpid(), s.opts.ServiceName, atomic.LoadInt64(&s.activeCount), closeWaitTime) time.Sleep(closeWaitTime) } + return s.waitInactive(s.opts.MaxCloseWaitTime - closeWaitTime) +} + +func (s *service) waitInactive(maxWaitTime time.Duration) int64 { const sleepTime = 100 * time.Millisecond - if s.opts.MaxCloseWaitTime > closeWaitTime { - spinCount := int((s.opts.MaxCloseWaitTime - closeWaitTime) / sleepTime) - for i := 0; i < spinCount; i++ { - if atomic.LoadInt64(&s.activeCount) <= 0 { - break - } - time.Sleep(sleepTime) + for start := time.Now(); time.Since(start) < maxWaitTime; time.Sleep(sleepTime) { + if atomic.LoadInt64(&s.activeCount) <= 0 { + return 0 } - log.Infof("process %d service %s remain %d requests when closing service", - os.Getpid(), s.opts.ServiceName, atomic.LoadInt64(&s.activeCount)) } + return atomic.LoadInt64(&s.activeCount) } func checkProcessStatus() (isGracefulRestart, isParentalProcess bool) { diff --git a/server/service_test.go b/server/service_test.go index 61b222a..70a28f3 100644 --- a/server/service_test.go +++ b/server/service_test.go @@ -306,55 +306,84 @@ func TestServiceUDP(t *testing.T) { require.Nil(t, err) } -func TestServiceCloseWait(t *testing.T) { - const waitChildTime = 300 * time.Millisecond - const schTime = 10 * time.Millisecond - cases := []struct { - closeWaitTime time.Duration - maxCloseWaitTime time.Duration - waitTime time.Duration - }{ - { - waitTime: waitChildTime, - }, - { - closeWaitTime: 50 * time.Millisecond, - waitTime: waitChildTime + 50*time.Millisecond, - }, - { - closeWaitTime: 50 * time.Millisecond, - maxCloseWaitTime: 30 * time.Millisecond, - waitTime: waitChildTime + 50*time.Millisecond, - }, - { - closeWaitTime: 50 * time.Millisecond, - maxCloseWaitTime: 100 * time.Millisecond, - waitTime: waitChildTime + 50*time.Millisecond, - }, +func TestCloseWaitTime(t *testing.T) { + startService := func(opts ...server.Option) (chan struct{}, func()) { + received, done := make(chan struct{}), make(chan struct{}) + addr, stop := startService(t, &Greeter{}, append([]server.Option{server.WithFilter( + func(ctx context.Context, req interface{}, next filter.ServerHandleFunc) (rsp interface{}, err error) { + received <- struct{}{} + <-done + return nil, errors.New("must fail") + })}, opts...)...) + go func() { + _, _ = pb.NewGreeterClientProxy(client.WithTarget("ip://"+addr)). + SayHello(context.Background(), &pb.HelloRequest{}) + }() + <-received + return done, stop } - for _, c := range cases { - service := server.New( - server.WithRegistry(&fakeRegistry{}), - server.WithCloseWaitTime(c.closeWaitTime), - server.WithMaxCloseWaitTime(c.maxCloseWaitTime), - ) + t.Run("active requests feature is not enabled on missing MaxCloseWaitTime", func(t *testing.T) { + done, stop := startService() + defer close(done) start := time.Now() - err := service.Close(nil) - assert.Nil(t, err) - cost := time.Since(start) - assert.GreaterOrEqual(t, cost, c.waitTime) - assert.LessOrEqual(t, cost, c.waitTime+schTime) - } + stop() + require.Less(t, time.Since(start), time.Millisecond*100) + }) + t.Run("total wait time should not significantly greater than MaxCloseWaitTime", func(t *testing.T) { + const closeWaitTime, maxCloseWaitTime = time.Millisecond * 500, time.Second + done, stop := startService( + server.WithMaxCloseWaitTime(maxCloseWaitTime), + server.WithCloseWaitTime(closeWaitTime)) + defer close(done) + start := time.Now() + stop() + require.WithinRange(t, time.Now(), + // 300ms comes from the internal implementation when close service + start.Add(maxCloseWaitTime).Add(time.Millisecond*300), + start.Add(maxCloseWaitTime).Add(time.Millisecond*500)) + }) + t.Run("total wait time is at least CloseWaitTime", func(t *testing.T) { + const closeWaitTime, maxCloseWaitTime = time.Millisecond * 500, time.Second + done, stop := startService( + server.WithMaxCloseWaitTime(maxCloseWaitTime), + server.WithCloseWaitTime(closeWaitTime)) + start := time.Now() + time.AfterFunc(closeWaitTime/2, func() { close(done) }) + stop() + require.WithinRange(t, time.Now(), start.Add(closeWaitTime), start.Add(closeWaitTime+time.Millisecond*100)) + }) + t.Run("no active request before MaxCloseWaitTime", func(t *testing.T) { + const closeWaitTime, maxCloseWaitTime = time.Millisecond * 500, time.Second + done, stop := startService( + server.WithMaxCloseWaitTime(maxCloseWaitTime), + server.WithCloseWaitTime(closeWaitTime)) + start := time.Now() + time.AfterFunc((closeWaitTime+maxCloseWaitTime)/2, func() { close(done) }) + stop() + require.WithinRange(t, time.Now(), start.Add(closeWaitTime), start.Add(maxCloseWaitTime)) + }) + t.Run("no active request before service timeout", func(t *testing.T) { + const closeWaitTime, maxCloseWaitTime, timeout = time.Millisecond * 500, time.Second, time.Second + done, stop := startService( + server.WithMaxCloseWaitTime(maxCloseWaitTime), + server.WithCloseWaitTime(closeWaitTime), + server.WithTimeout(timeout)) + start := time.Now() + time.AfterFunc(maxCloseWaitTime+time.Millisecond*100, func() { close(done) }) + stop() + require.WithinRange(t, time.Now(), start.Add(maxCloseWaitTime+time.Millisecond*100), start.Add(maxCloseWaitTime+timeout)) + }) } func startService(t *testing.T, gs GreeterServer, opts ...server.Option) (addr string, stop func()) { l, err := net.Listen("tcp", "0.0.0.0:0") require.Nil(t, err) - s := server.New(append(append([]server.Option{ - server.WithNetwork("tcp"), - server.WithProtocol("trpc"), - }, opts...), + s := server.New(append(append( + []server.Option{ + server.WithNetwork("tcp"), + server.WithProtocol("trpc"), + }, opts...), server.WithListener(l), )...) require.Nil(t, s.Register(&GreeterServerServiceDesc, gs)) From 2b0a1183f6453547efad499c8a55d8fb55fb0179 Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 15:41:10 +0800 Subject: [PATCH 04/39] support linux-386 --- server/service.go | 2 +- transport/client_transport_test.go | 29 ++++++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/server/service.go b/server/service.go index 953de5a..a407fef 100644 --- a/server/service.go +++ b/server/service.go @@ -102,13 +102,13 @@ type Stream interface { // service is an implementation of Service type service struct { + activeCount int64 // active requests count for graceful close if set MaxCloseWaitTime ctx context.Context // context of this service cancel context.CancelFunc // function that cancels this service opts *Options // options of this service handlers map[string]Handler // rpcname => handler streamHandlers map[string]StreamHandler streamInfo map[string]*StreamServerInfo - activeCount int64 // active requests count for graceful close if set MaxCloseWaitTime } // New creates a service. diff --git a/transport/client_transport_test.go b/transport/client_transport_test.go index 70db95f..1f94911 100644 --- a/transport/client_transport_test.go +++ b/transport/client_transport_test.go @@ -18,7 +18,9 @@ import ( "errors" "fmt" "io" + "math" "net" + "strings" "testing" "time" @@ -348,6 +350,24 @@ func TestClientTransport_RoundTrip(t *testing.T) { }() time.Sleep(20 * time.Millisecond) + t.Run("write: message too long", func(t *testing.T) { + c := mustListenUDP(t) + t.Cleanup(func() { + if err := c.Close(); err != nil { + t.Log(err) + } + }) + largeRequest := encodeLengthDelimited(strings.Repeat("1", math.MaxInt32/4)) + _, err := transport.RoundTrip(context.Background(), largeRequest, + transport.WithClientFramerBuilder(fb), + transport.WithDialNetwork("udp"), + transport.WithDialAddress(c.LocalAddr().String()), + transport.WithReqType(transport.SendAndRecv), + ) + require.Equal(t, errs.RetClientNetErr, errs.Code(err)) + require.Contains(t, errs.Msg(err), "udp client transport WriteTo") + }) + var err error _, err = transport.RoundTrip(context.Background(), encodeLengthDelimited("helloworld")) assert.NotNil(t, err) @@ -489,6 +509,14 @@ func TestClientTransport_RoundTrip(t *testing.T) { assert.Contains(t, err.Error(), remainingBytesError.Error()) } +func mustListenUDP(t *testing.T) net.PacketConn { + c, err := net.ListenPacket("udp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + return c +} + // Frame a stream of bytes based on a length prefix // +------------+--------------------------------+ // | len: uint8 | frame payload | @@ -609,7 +637,6 @@ func TestClientTransport_MultiplexedErr(t *testing.T) { } func TestClientTransport_RoundTrip_PreConnected(t *testing.T) { - go func() { err := transport.ListenAndServe( transport.WithListenNetwork("udp"), From f95e8a954fdf792fc27734251fd41f26a293f39a Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 15:54:42 +0800 Subject: [PATCH 05/39] package config expands env like trpc_go.yaml --- config.go | 6 +-- config/options.go | 10 +++++ config/trpc_config.go | 8 ++-- config/trpc_config_test.go | 40 +++++++++++++++++++ internal/expandenv/expand_env.go | 55 +++++++++++++++++++++++++++ internal/expandenv/expand_env_test.go | 47 +++++++++++++++++++++++ trpc_util.go | 53 -------------------------- 7 files changed, 159 insertions(+), 60 deletions(-) create mode 100644 internal/expandenv/expand_env.go create mode 100644 internal/expandenv/expand_env_test.go diff --git a/config.go b/config.go index 6ea24e6..f981c6c 100644 --- a/config.go +++ b/config.go @@ -25,6 +25,7 @@ import ( "time" yaml "gopkg.in/yaml.v3" + "trpc.group/trpc-go/trpc-go/internal/expandenv" trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" "trpc.group/trpc-go/trpc-go/client" @@ -608,11 +609,8 @@ func parseConfigFromFile(configPath string) (*Config, error) { if err != nil { return nil, err } - // expand environment variables - buf = []byte(expandEnv(string(buf))) - cfg := defaultConfig() - if err := yaml.Unmarshal(buf, cfg); err != nil { + if err := yaml.Unmarshal(expandenv.ExpandEnv(buf), cfg); err != nil { return nil, err } return cfg, nil diff --git a/config/options.go b/config/options.go index dc6d2f9..fed9b74 100644 --- a/config/options.go +++ b/config/options.go @@ -13,6 +13,8 @@ package config +import "trpc.group/trpc-go/trpc-go/internal/expandenv" + // WithCodec returns an option which sets the codec's name. func WithCodec(name string) LoadOption { return func(c *TrpcConfig) { @@ -27,6 +29,14 @@ func WithProvider(name string) LoadOption { } } +// WithExpandEnv replaces ${var} in raw bytes with environment value of var. +// Note, method TrpcConfig.Bytes will return the replaced bytes. +func WithExpandEnv() LoadOption { + return func(c *TrpcConfig) { + c.expandEnv = expandenv.ExpandEnv + } +} + // options is config option. type options struct{} diff --git a/config/trpc_config.go b/config/trpc_config.go index dea82c6..d780e3a 100644 --- a/config/trpc_config.go +++ b/config/trpc_config.go @@ -162,6 +162,7 @@ type TrpcConfig struct { path string decoder Codec rawData []byte + expandEnv func([]byte) []byte } func newTrpcConfig(path string) *TrpcConfig { @@ -170,6 +171,7 @@ func newTrpcConfig(path string) *TrpcConfig { unmarshalledData: make(map[string]interface{}), path: path, decoder: &YamlCodec{}, + expandEnv: func(bytes []byte) []byte { return bytes }, } } @@ -189,7 +191,7 @@ func (c *TrpcConfig) Load() error { return fmt.Errorf("trpc/config: failed to load %s: %s", c.path, err.Error()) } - c.rawData = data + c.rawData = c.expandEnv(data) if err := c.decoder.Unmarshal(c.rawData, &c.unmarshalledData); err != nil { return fmt.Errorf("trpc/config: failed to parse %s: %s", c.path, err.Error()) } @@ -208,8 +210,8 @@ func (c *TrpcConfig) Reload() { return } - c.rawData = data - if err := c.decoder.Unmarshal(data, &c.unmarshalledData); err != nil { + c.rawData = c.expandEnv(data) + if err := c.decoder.Unmarshal(c.rawData, &c.unmarshalledData); err != nil { log.Tracef("trpc/config: failed to parse %s: %v", c.path, err) return } diff --git a/config/trpc_config_test.go b/config/trpc_config_test.go index a83e60c..7e5e7d8 100644 --- a/config/trpc_config_test.go +++ b/config/trpc_config_test.go @@ -126,3 +126,43 @@ func TestYamlCodec_Unmarshal(t *testing.T) { require.NotNil(t, GetCodec("yaml").Unmarshal([]byte("[1, 2]"), &tt)) }) } + +func TestEnvExpanded(t *testing.T) { + RegisterProvider(NewEnvProvider(t.Name(), []byte(` +password: ${pwd} +`))) + + t.Setenv("pwd", t.Name()) + cfg, err := DefaultConfigLoader.Load( + t.Name(), + WithProvider(t.Name()), + WithExpandEnv()) + require.Nil(t, err) + + require.Equal(t, t.Name(), cfg.GetString("password", "")) + require.Contains(t, string(cfg.Bytes()), fmt.Sprintf("password: %s", t.Name())) +} + +func NewEnvProvider(name string, data []byte) *EnvProvider { + return &EnvProvider{ + name: name, + data: data, + } +} + +type EnvProvider struct { + name string + data []byte +} + +func (ep *EnvProvider) Name() string { + return ep.name +} + +func (ep *EnvProvider) Read(string) ([]byte, error) { + return ep.data, nil +} + +func (ep *EnvProvider) Watch(cb ProviderCallback) { + cb("", ep.data) +} diff --git a/internal/expandenv/expand_env.go b/internal/expandenv/expand_env.go new file mode 100644 index 0000000..ab1cc08 --- /dev/null +++ b/internal/expandenv/expand_env.go @@ -0,0 +1,55 @@ +// Package expandenv replaces ${key} in byte slices with the env value of key. +package expandenv + +import ( + "os" +) + +// ExpandEnv looks for ${var} in s and replaces them with value of the corresponding environment variable. +// It's not like os.ExpandEnv which handles both ${var} and $var. +// $var is considered invalid, since configurations like password for redis/mysql may contain $. +func ExpandEnv(s []byte) []byte { + var buf []byte + i := 0 + for j := 0; j < len(s); j++ { + if s[j] == '$' && j+2 < len(s) && s[j+1] == '{' { // only ${var} instead of $var is valid + if buf == nil { + buf = make([]byte, 0, 2*len(s)) + } + buf = append(buf, s[i:j]...) + name, w := getEnvName(s[j+1:]) + if name == nil && w > 0 { + // invalid matching, remove the $ + } else if name == nil { + buf = append(buf, s[j]) // keep the $ + } else { + buf = append(buf, os.Getenv(string(name))...) + } + j += w + i = j + 1 + } + } + if buf == nil { + return s + } + return append(buf, s[i:]...) +} + +// getEnvName gets env name, that is, var from ${var}. +// The env name and its len will be returned. +func getEnvName(s []byte) ([]byte, int) { + // look for right curly bracket '}' + // it's guaranteed that the first char is '{' and the string has at least two char + for i := 1; i < len(s); i++ { + if s[i] == ' ' || s[i] == '\n' || s[i] == '"' { // "xx${xxx" + return nil, 0 // encounter invalid char, keep the $ + } + if s[i] == '}' { + if i == 1 { // ${} + return nil, 2 // remove ${} + } + return s[1:i], i + 1 + } + } + return nil, 0 // no },keep the $ +} diff --git a/internal/expandenv/expand_env_test.go b/internal/expandenv/expand_env_test.go new file mode 100644 index 0000000..96439b5 --- /dev/null +++ b/internal/expandenv/expand_env_test.go @@ -0,0 +1,47 @@ +package expandenv_test + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/require" + . "trpc.group/trpc-go/trpc-go/internal/expandenv" +) + +func TestExpandEnv(t *testing.T) { + key := "env_key" + t.Run("no env", func(t *testing.T) { + require.Equal(t, []byte("abc"), ExpandEnv([]byte("abc"))) + }) + t.Run("${..} is expanded", func(t *testing.T) { + t.Setenv(key, t.Name()) + require.Equal(t, fmt.Sprintf("head_%s_tail", t.Name()), + string(ExpandEnv([]byte(fmt.Sprintf("head_${%s}_tail", key))))) + }) + t.Run("${ is not expanded", func(t *testing.T) { + require.Equal(t, "head_${_tail", + string(ExpandEnv([]byte(fmt.Sprintf("head_${_tail"))))) + }) + t.Run("${} is expanded as empty", func(t *testing.T) { + require.Equal(t, "head__tail", + string(ExpandEnv([]byte("head_${}_tail")))) + }) + t.Run("${..} is not expanded if .. contains any space", func(t *testing.T) { + t.Setenv("key key", t.Name()) + require.Equal(t, "head_${key key}_tail", + string(ExpandEnv([]byte("head_${key key}_tail")))) + }) + t.Run("${..} is not expanded if .. contains any new line", func(t *testing.T) { + t.Setenv("key\nkey", t.Name()) + require.Equal(t, t.Name(), os.Getenv("key\nkey")) + require.Equal(t, "head_${key\nkey}_tail", + string(ExpandEnv([]byte("head_${key\nkey}_tail")))) + }) + t.Run(`${..} is not expanded if .. contains any "`, func(t *testing.T) { + t.Setenv(`key"key`, t.Name()) + require.Equal(t, t.Name(), os.Getenv(`key"key`)) + require.Equal(t, `head_${key"key}_tail`, + string(ExpandEnv([]byte(`head_${key"key}_tail`)))) + }) +} diff --git a/trpc_util.go b/trpc_util.go index ae2e4d0..1aaf73e 100644 --- a/trpc_util.go +++ b/trpc_util.go @@ -16,7 +16,6 @@ package trpc import ( "context" "net" - "os" "runtime" "sync" "time" @@ -275,58 +274,6 @@ func Go(ctx context.Context, timeout time.Duration, handler func(context.Context return DefaultGoer.Go(ctx, timeout, handler) } -// expandEnv looks for ${var} in s and replaces them with value of the -// corresponding environment variable. -// $var is considered invalid. -// It's not like os.ExpandEnv which will handle both ${var} and $var. -// Since configurations like password for redis/mysql may contain $, this -// method is needed. -func expandEnv(s string) string { - var buf []byte - i := 0 - for j := 0; j < len(s); j++ { - if s[j] == '$' && j+2 < len(s) && s[j+1] == '{' { // only ${var} instead of $var is valid - if buf == nil { - buf = make([]byte, 0, 2*len(s)) - } - buf = append(buf, s[i:j]...) - name, w := getEnvName(s[j+1:]) - if name == "" && w > 0 { - // invalid matching, remove the $ - } else if name == "" { - buf = append(buf, s[j]) // keep the $ - } else { - buf = append(buf, os.Getenv(name)...) - } - j += w - i = j + 1 - } - } - if buf == nil { - return s - } - return string(buf) + s[i:] -} - -// getEnvName gets env name, that is, var from ${var}. -// And content of var and its len will be returned. -func getEnvName(s string) (string, int) { - // look for right curly bracket '}' - // it's guaranteed that the first char is '{' and the string has at least two char - for i := 1; i < len(s); i++ { - if s[i] == ' ' || s[i] == '\n' || s[i] == '"' { // "xx${xxx" - return "", 0 // encounter invalid char, keep the $ - } - if s[i] == '}' { - if i == 1 { // ${} - return "", 2 // remove ${} - } - return s[1:i], i + 1 - } - } - return "", 0 // no },keep the $ -} - // --------------- the following code is IP Config related -----------------// // nicIP defines the parameters used to record the ip address (ipv4 & ipv6) of the nic. From 4a2ef46499f51f4d7593a90eb28f306c57c2d2f5 Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 16:20:34 +0800 Subject: [PATCH 06/39] fix watch callback leak when call TrpcConfigLoader.Load multiple times --- config/README.md | 2 +- config/README.zh_CN.md | 2 +- config/config.go | 1 - config/options.go | 18 +- config/trpc_config.go | 339 ++++++++++++++++++------ config/trpc_config_test.go | 112 ++++++++ examples/features/config/client/main.go | 22 ++ examples/features/config/server/main.go | 110 +++++++- 8 files changed, 509 insertions(+), 97 deletions(-) diff --git a/config/README.md b/config/README.md index 7f4320e..5ac469b 100644 --- a/config/README.md +++ b/config/README.md @@ -31,7 +31,7 @@ For managing business configuration, we recommend the best practice of using a c ## What is Multiple Data Sources? -A data source is the source from which configuration is retrieved and where it is stored. Common data sources include: file, etcd, configmap, etc. The tRPC framework supports setting different data sources for different business configurations. The framework uses a plugin-based approach to extend support for more data sources. In the implementation principle section later, we will describe in detail how the framework supports multiple data sources. +A data source is the source from which configuration is retrieved and where it is stored. Common data sources include: file, etcd, configmap, env, etc. The tRPC framework supports setting different data sources for different business configurations. The framework uses a plugin-based approach to extend support for more data sources. In the implementation principle section later, we will describe in detail how the framework supports multiple data sources. ## What is Codec? diff --git a/config/README.zh_CN.md b/config/README.zh_CN.md index 7642b1f..57c6b34 100644 --- a/config/README.zh_CN.md +++ b/config/README.zh_CN.md @@ -26,7 +26,7 @@ 业务配置也支持本地文件。对于本地文件,大部分使用场景是客户端作为独立的工具使用,或者程序在开发调试阶段使用。好处在于不需要依赖外部系统就能工作。 ## 什么是多数据源 -数据源就获取配置的来源,配置存储的地方。常见的数据源包括:file,etcd,configmap 等。tRPC 框架支持对不同业务配置设定不同的数据源。框架采用插件化方式来扩展对更多数据源的支持。在后面的实现原理章节,我们会详细介绍框架是如何实现对多数据源的支持的。 +数据源就获取配置的来源,配置存储的地方。常见的数据源包括:file,etcd,configmap,env 等。tRPC 框架支持对不同业务配置设定不同的数据源。框架采用插件化方式来扩展对更多数据源的支持。在后面的实现原理章节,我们会详细介绍框架是如何实现对多数据源的支持的。 ## 什么是 Codec 业务配置中的 Codec 是指从配置源获取到的配置的格式,常见的配置文件格式为:yaml,json,toml 等。框架采用插件化方式来扩展对更多解码格式的支持。 diff --git a/config/config.go b/config/config.go index 3648305..8bf13bf 100644 --- a/config/config.go +++ b/config/config.go @@ -23,7 +23,6 @@ import ( "sync" "github.com/BurntSushi/toml" - yaml "gopkg.in/yaml.v3" ) diff --git a/config/options.go b/config/options.go index fed9b74..ef7144b 100644 --- a/config/options.go +++ b/config/options.go @@ -13,8 +13,6 @@ package config -import "trpc.group/trpc-go/trpc-go/internal/expandenv" - // WithCodec returns an option which sets the codec's name. func WithCodec(name string) LoadOption { return func(c *TrpcConfig) { @@ -33,7 +31,21 @@ func WithProvider(name string) LoadOption { // Note, method TrpcConfig.Bytes will return the replaced bytes. func WithExpandEnv() LoadOption { return func(c *TrpcConfig) { - c.expandEnv = expandenv.ExpandEnv + c.expandEnv = true + } +} + +// WithWatch returns an option to start watch model +func WithWatch() LoadOption { + return func(c *TrpcConfig) { + c.watch = true + } +} + +// WithWatchHook returns an option to set log func for config change logger +func WithWatchHook(f func(msg WatchMessage)) LoadOption { + return func(c *TrpcConfig) { + c.watchHook = f } } diff --git a/config/trpc_config.go b/config/trpc_config.go index d780e3a..e58470c 100644 --- a/config/trpc_config.go +++ b/config/trpc_config.go @@ -23,6 +23,7 @@ import ( "github.com/BurntSushi/toml" "github.com/spf13/cast" yaml "gopkg.in/yaml.v3" + "trpc.group/trpc-go/trpc-go/internal/expandenv" "trpc.group/trpc-go/trpc-go/log" ) @@ -49,68 +50,59 @@ type LoadOption func(*TrpcConfig) // TrpcConfigLoader is a config loader for trpc. type TrpcConfigLoader struct { - configMap map[string]Config - rwl sync.RWMutex + watchers sync.Map } // Load returns the config specified by input parameter. func (loader *TrpcConfigLoader) Load(path string, opts ...LoadOption) (Config, error) { - yc := newTrpcConfig(path) - for _, o := range opts { - o(yc) - } - if yc.decoder == nil { - return nil, ErrCodecNotExist - } - if yc.p == nil { - return nil, ErrProviderNotExist + c, err := newTrpcConfig(path, opts...) + if err != nil { + return nil, err } - key := fmt.Sprintf("%s.%s.%s", yc.decoder.Name(), yc.p.Name(), path) - loader.rwl.RLock() - if c, ok := loader.configMap[key]; ok { - loader.rwl.RUnlock() - return c, nil + w := &watcher{} + i, loaded := loader.watchers.LoadOrStore(c.p, w) + if !loaded { + c.p.Watch(w.watch) + } else { + w = i.(*watcher) } - loader.rwl.RUnlock() - if err := yc.Load(); err != nil { + c = w.getOrCreate(c.path).getOrStore(c) + if err = c.init(); err != nil { return nil, err } - - loader.rwl.Lock() - loader.configMap[key] = yc - loader.rwl.Unlock() - - yc.p.Watch(func(p string, data []byte) { - if p == path { - loader.rwl.Lock() - delete(loader.configMap, key) - loader.rwl.Unlock() - } - }) - return yc, nil + return c, nil } // Reload reloads config data. func (loader *TrpcConfigLoader) Reload(path string, opts ...LoadOption) error { - yc := newTrpcConfig(path) - for _, o := range opts { - o(yc) + c, err := newTrpcConfig(path, opts...) + if err != nil { + return err } - key := fmt.Sprintf("%s.%s.%s", yc.decoder.Name(), yc.p.Name(), path) - loader.rwl.RLock() - if config, ok := loader.configMap[key]; ok { - loader.rwl.RUnlock() - config.Reload() - return nil + + v, ok := loader.watchers.Load(c.p) + if !ok { + return ErrConfigNotExist + } + w := v.(*watcher) + + s := w.get(path) + if s == nil { + return ErrConfigNotExist } - loader.rwl.RUnlock() - return ErrConfigNotExist + + oc := s.get(c.id) + if oc == nil { + return ErrConfigNotExist + } + + return oc.Load() } func newTrpcConfigLoad() *TrpcConfigLoader { - return &TrpcConfigLoader{configMap: map[string]Config{}, rwl: sync.RWMutex{}} + return &TrpcConfigLoader{} } // DefaultConfigLoader is the default config loader. @@ -155,65 +147,249 @@ func (c *TomlCodec) Unmarshal(in []byte, out interface{}) error { return toml.Unmarshal(in, out) } +// watch manage one data provider +type watcher struct { + sets sync.Map // *set +} + +// get config item by path +func (w *watcher) get(path string) *set { + if i, ok := w.sets.Load(path); ok { + return i.(*set) + } + return nil +} + +// getOrCreate get config item by path if not exist and create and return +func (w *watcher) getOrCreate(path string) *set { + i, _ := w.sets.LoadOrStore(path, &set{}) + return i.(*set) +} + +// watch func +func (w *watcher) watch(path string, data []byte) { + if v := w.get(path); v != nil { + v.watch(data) + } +} + +// set manages configs with same provider and name with different type +// used config.id as unique identifier +type set struct { + path string + mutex sync.RWMutex + items []*TrpcConfig +} + +// get data +func (s *set) get(id string) *TrpcConfig { + s.mutex.RLock() + defer s.mutex.RUnlock() + for _, v := range s.items { + if v.id == id { + return v + } + } + return nil +} + +func (s *set) getOrStore(tc *TrpcConfig) *TrpcConfig { + if v := s.get(tc.id); v != nil { + return v + } + + s.mutex.Lock() + for _, item := range s.items { + if item.id == tc.id { + s.mutex.Unlock() + return item + } + } + // not found and add + s.items = append(s.items, tc) + s.mutex.Unlock() + return tc +} + +// watch data change, delete no watch model config and update watch model config and target notify +func (s *set) watch(data []byte) { + var items []*TrpcConfig + var del []*TrpcConfig + s.mutex.Lock() + for _, v := range s.items { + if v.watch { + items = append(items, v) + } else { + del = append(del, v) + } + } + s.items = items + s.mutex.Unlock() + + for _, item := range items { + err := item.doWatch(data) + item.notify(data, err) + } + + for _, item := range del { + item.notify(data, nil) + } +} + +// defaultNotifyChange default hook for notify config changed +var defaultWatchHook = func(message WatchMessage) {} + +// SetDefaultWatchHook set default hook notify when config changed +func SetDefaultWatchHook(f func(message WatchMessage)) { + defaultWatchHook = f +} + +// WatchMessage change message +type WatchMessage struct { + Provider string // provider name + Path string // config path + ExpandEnv bool // expend env status + Codec string // codec + Watch bool // status for start watch + Value []byte // config content diff ? + Error error // load error message, success is empty string +} + +var _ Config = (*TrpcConfig)(nil) + // TrpcConfig is used to parse yaml config file for trpc. type TrpcConfig struct { - p DataProvider - unmarshalledData interface{} - path string - decoder Codec - rawData []byte - expandEnv func([]byte) []byte + id string // config identity + msg WatchMessage // new to init message for notify only copy + + p DataProvider // config provider + path string // config name + decoder Codec // config codec + expandEnv bool // status for whether replace the variables in the configuration with environment variables + + // because function is not support comparable in singleton, so the following options work only for the first load + watch bool + watchHook func(message WatchMessage) + + mutex sync.RWMutex + value *entity // store config value } -func newTrpcConfig(path string) *TrpcConfig { - return &TrpcConfig{ - p: GetProvider("file"), - unmarshalledData: make(map[string]interface{}), - path: path, - decoder: &YamlCodec{}, - expandEnv: func(bytes []byte) []byte { return bytes }, +type entity struct { + raw []byte // current binary data + data interface{} // unmarshal type to use point type, save latest no error data +} + +func newTrpcConfig(path string, opts ...LoadOption) (*TrpcConfig, error) { + c := &TrpcConfig{ + path: path, + p: GetProvider("file"), + decoder: GetCodec("yaml"), + watchHook: func(message WatchMessage) { + defaultWatchHook(message) + }, + } + for _, o := range opts { + o(c) + } + if c.p == nil { + return nil, ErrProviderNotExist + } + if c.decoder == nil { + return nil, ErrCodecNotExist } + + c.msg.Provider = c.p.Name() + c.msg.Path = c.path + c.msg.Codec = c.decoder.Name() + c.msg.ExpandEnv = c.expandEnv + c.msg.Watch = c.watch + + // since reflect.String() cannot uniquely identify a type, this id is used as a preliminary judgment basis + const idFormat = "provider:%s path:%s codec:%s env:%t watch:%t" + c.id = fmt.Sprintf(idFormat, c.p.Name(), c.path, c.decoder.Name(), c.expandEnv, c.watch) + return c, nil } -// Unmarshal deserializes the config into input param. -func (c *TrpcConfig) Unmarshal(out interface{}) error { - return c.decoder.Unmarshal(c.rawData, out) +func (c *TrpcConfig) get() *entity { + c.mutex.RLock() + defer c.mutex.RUnlock() + if c.value != nil { + return c.value + } + return &entity{} } -// Load loads config. -func (c *TrpcConfig) Load() error { - if c.p == nil { - return ErrProviderNotExist +// init return config entity error when entity is empty and load run loads config once +func (c *TrpcConfig) init() error { + c.mutex.RLock() + if c.value != nil { + c.mutex.RUnlock() + return nil + } + c.mutex.RUnlock() + + c.mutex.Lock() + defer c.mutex.Unlock() + if c.value != nil { + return nil } data, err := c.p.Read(c.path) if err != nil { - return fmt.Errorf("trpc/config: failed to load %s: %s", c.path, err.Error()) + return fmt.Errorf("trpc/config failed to load error: %w config id: %s", err, c.id) + } + return c.set(data) +} +func (c *TrpcConfig) doWatch(data []byte) error { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.set(data) +} +func (c *TrpcConfig) set(data []byte) error { + if c.expandEnv { + data = expandenv.ExpandEnv(data) } - c.rawData = c.expandEnv(data) - if err := c.decoder.Unmarshal(c.rawData, &c.unmarshalledData); err != nil { - return fmt.Errorf("trpc/config: failed to parse %s: %s", c.path, err.Error()) + e := &entity{raw: data} + err := c.decoder.Unmarshal(data, &e.data) + if err != nil { + return fmt.Errorf("trpc/config: failed to parse:%w, id:%s", err, c.id) } + c.value = e return nil } +func (c *TrpcConfig) notify(data []byte, err error) { + m := c.msg -// Reload reloads config. -func (c *TrpcConfig) Reload() { + m.Value = data + if err != nil { + m.Error = err + } + + c.watchHook(m) +} + +// Load loads config. +func (c *TrpcConfig) Load() error { if c.p == nil { - return + return ErrProviderNotExist } + c.mutex.Lock() + defer c.mutex.Unlock() data, err := c.p.Read(c.path) if err != nil { - log.Tracef("trpc/config: failed to reload %s: %v", c.path, err) - return + return fmt.Errorf("trpc/config failed to load error: %w config id: %s", err, c.id) } - c.rawData = c.expandEnv(data) - if err := c.decoder.Unmarshal(c.rawData, &c.unmarshalledData); err != nil { - log.Tracef("trpc/config: failed to parse %s: %v", c.path, err) - return + return c.set(data) +} + +// Reload reloads config. +func (c *TrpcConfig) Reload() { + if err := c.Load(); err != nil { + log.Tracef("trpc/config: failed to reload %s: %v", c.id, err) } } @@ -225,9 +401,14 @@ func (c *TrpcConfig) Get(key string, defaultValue interface{}) interface{} { return defaultValue } +// Unmarshal deserializes the config into input param. +func (c *TrpcConfig) Unmarshal(out interface{}) error { + return c.decoder.Unmarshal(c.get().raw, out) +} + // Bytes returns original config data as bytes. func (c *TrpcConfig) Bytes() []byte { - return c.rawData + return c.get().raw } // GetInt returns int value by key, the second parameter @@ -335,7 +516,9 @@ func (c *TrpcConfig) findWithDefaultValue(key string, defaultValue interface{}) } func (c *TrpcConfig) search(key string) (interface{}, bool) { - unmarshalledData, ok := c.unmarshalledData.(map[string]interface{}) + e := c.get() + + unmarshalledData, ok := e.data.(map[string]interface{}) if !ok { return nil, false } diff --git a/config/trpc_config_test.go b/config/trpc_config_test.go index 7e5e7d8..de6e6fa 100644 --- a/config/trpc_config_test.go +++ b/config/trpc_config_test.go @@ -16,10 +16,14 @@ package config import ( "errors" "fmt" + "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "trpc.group/trpc-go/trpc-go/errs" + "trpc.group/trpc-go/trpc-go/log" ) func Test_search(t *testing.T) { @@ -115,6 +119,14 @@ func Test_search(t *testing.T) { } } +func TestTrpcConfig_Load(t *testing.T) { + t.Run("parse failed", func(t *testing.T) { + c, _ := newTrpcConfig("../testdata/trpc_go.yaml") + c.decoder = &TomlCodec{} + err := c.Load() + require.Contains(t, errs.Msg(err), "failed to parse") + }) +} func TestYamlCodec_Unmarshal(t *testing.T) { t.Run("interface", func(t *testing.T) { var tt interface{} @@ -166,3 +178,103 @@ func (ep *EnvProvider) Read(string) ([]byte, error) { func (ep *EnvProvider) Watch(cb ProviderCallback) { cb("", ep.data) } + +func TestWatch(t *testing.T) { + p := manualTriggerWatchProvider{} + var msgs = make(chan WatchMessage) + SetDefaultWatchHook(func(msg WatchMessage) { + if msg.Error != nil { + log.Errorf("config watch error: %+v", msg) + } else { + log.Infof("config watch error: %+v", msg) + } + msgs <- msg + }) + + RegisterProvider(&p) + p.Set("key", []byte(`key: value`)) + ops := []LoadOption{WithProvider(p.Name()), WithCodec("yaml"), WithWatch()} + c1, err := DefaultConfigLoader.Load("key", ops...) + require.Nilf(t, err, "first load config:%+v", c1) + require.True(t, c1.IsSet("key"), "first load config key exist") + require.Equal(t, c1.Get("key", "default"), "value", "first load config get key value") + + var c2 Config + c2, err = DefaultConfigLoader.Load("key", ops...) + require.Nil(t, err, "second load config:%+v", c2) + require.Equal(t, c1, c2, "first and second load config not equal") + require.True(t, c2.IsSet("key"), "second load config key exist") + require.Equal(t, c2.Get("key", "default"), "value", "second load config get key value") + + var gw sync.WaitGroup + gw.Add(1) + go func() { + defer gw.Done() + tt := time.NewTimer(time.Second) + select { + case <-msgs: + case <-tt.C: + t.Errorf("receive message timeout") + } + }() + + p.Set("key", []byte(`:key: value:`)) + gw.Wait() + + var c3 Config + c3, err = DefaultConfigLoader.Load("key", WithProvider(p.Name()), WithWatchHook(func(msg WatchMessage) { + msgs <- msg + })) + require.Contains(t, errs.Msg(err), "failed to parse") + require.Nil(t, c3, "update error") + + require.True(t, c2.IsSet("key"), "third load config key exist") + require.Equal(t, c2.Get("key", "default"), "value", "third load config get key value") + + gw.Add(1) + go func() { + defer gw.Done() + for i := 0; i < 2; i++ { + tt := time.NewTimer(time.Second) + select { + case <-msgs: + case <-tt.C: + t.Errorf("receive message timeout number%d ", i) + } + } + }() + p.Set("key", []byte(`key: value2`)) + gw.Wait() + + require.Truef(t, c2.IsSet("key"), "after update config and get key exist") + require.Equal(t, c2.Get("key", "default"), "value2", "after update config and config get value") +} + +var _ DataProvider = (*manualTriggerWatchProvider)(nil) + +type manualTriggerWatchProvider struct { + values sync.Map + callbacks []ProviderCallback +} + +func (m *manualTriggerWatchProvider) Name() string { + return "manual_trigger_watch_provider" +} + +func (m *manualTriggerWatchProvider) Read(s string) ([]byte, error) { + if v, ok := m.values.Load(s); ok { + return v.([]byte), nil + } + return nil, fmt.Errorf("not found config") +} + +func (m *manualTriggerWatchProvider) Watch(callback ProviderCallback) { + m.callbacks = append(m.callbacks, callback) +} + +func (m *manualTriggerWatchProvider) Set(key string, v []byte) { + m.values.Store(key, v) + for _, callback := range m.callbacks { + callback(key, v) + } +} diff --git a/examples/features/config/client/main.go b/examples/features/config/client/main.go index 55f68f9..23eeb78 100644 --- a/examples/features/config/client/main.go +++ b/examples/features/config/client/main.go @@ -41,4 +41,26 @@ func main() { return } fmt.Printf("Get msg: %s\n", rsp.GetMsg()) + // print + // + // Get msg: trpc-go-server response: Hello trpc-go-client + // load once config: number_1 + // start watch config:number_1 + + req = &pb.HelloRequest{ + Msg: "change config", // target config change + } + + // Send request. + rsp, err = clientProxy.SayHello(ctx, req) + if err != nil { + fmt.Println("Say hi err:%v", err) + return + } + fmt.Printf("Get msg: %s\n", rsp.GetMsg()) + // print + // + // Get msg: trpc-go-server response: Hello trpc-go-client + // load once config: number_1 + // start watch config:number_2 } diff --git a/examples/features/config/server/main.go b/examples/features/config/server/main.go index 2e36ff8..0351fbd 100644 --- a/examples/features/config/server/main.go +++ b/examples/features/config/server/main.go @@ -17,6 +17,7 @@ package main import ( "context" "fmt" + "sync" trpc "trpc.group/trpc-go/trpc-go" "trpc.group/trpc-go/trpc-go/config" @@ -27,33 +28,43 @@ import ( func main() { // Parse configuration files in yaml format. - conf, err := config.Load("server/custom.yaml", config.WithCodec("yaml"), config.WithProvider("file")) + // Load default codec is `yaml` and provider is `file` + c, err := config.Load("custom.yaml", config.WithCodec("yaml"), config.WithProvider("file")) if err != nil { fmt.Println(err) return } + fmt.Printf("test : %s \n", c.GetString("custom.test", "")) + fmt.Printf("key1 : %s \n", c.GetString("custom.test_obj.key1", "")) + fmt.Printf("key2 : %t \n", c.GetBool("custom.test_obj.key2", false)) + fmt.Printf("key2 : %d \n", c.GetInt32("custom.test_obj.key3", 0)) + + // print + // test : customConfigFromServer + // key1 : value1 + // key2 : true + // key3 : 1234 + // The format of the configuration file corresponds to custom struct. var custom customStruct - if err := conf.Unmarshal(&custom); err != nil { + if err := c.Unmarshal(&custom); err != nil { fmt.Println(err) } fmt.Printf("Get config - custom : %v \n", custom) - - fmt.Printf("test : %s \n", conf.GetString("custom.test", "")) - fmt.Printf("key1 : %s \n", conf.GetString("custom.test_obj.key1", "")) - fmt.Printf("key2 : %t \n", conf.GetBool("custom.test_obj.key2", false)) - fmt.Printf("key2 : %d \n", conf.GetInt32("custom.test_obj.key3", 0)) + // print: Get config - custom : {{customConfigFromServer {value1 true 1234}}} // Init server. s := trpc.NewServer() + config.RegisterProvider(p) // Register service. - greeterImpl := &greeterImpl{ - customConf: conf.GetString("custom.test", ""), - } - pb.RegisterGreeterService(s, greeterImpl) + imp := &greeterImpl{} + imp.once, _ = config.Load(p.Name(), config.WithProvider(p.Name())) + imp.watch, _ = config.Load(p.Name(), config.WithProvider(p.Name()), config.WithWatch()) + + pb.RegisterGreeterService(s, imp) // Serve and listen. if err := s.Serve(); err != nil { @@ -62,6 +73,58 @@ func main() { } +const cf = `custom : + test : number_%d + test_obj : + key1 : value_%d + key2 : %t + key3 : %d` + +var p = &provider{} + +// mock provider to trigger config change +type provider struct { + mu sync.Mutex + data []byte + num int + callbacks []config.ProviderCallback +} + +func (p *provider) Name() string { + return "test" +} + +func (p *provider) Read(s string) ([]byte, error) { + if s != p.Name() { + return nil, fmt.Errorf("not found config %s", s) + } + p.mu.Lock() + defer p.mu.Unlock() + + if p.data == nil { + p.num++ + p.data = []byte(fmt.Sprintf(cf, p.num, p.num, p.num%2 == 0, p.num)) + } + return p.data, nil +} + +func (p *provider) Watch(callback config.ProviderCallback) { + p.mu.Lock() + defer p.mu.Unlock() + p.callbacks = append(p.callbacks, callback) +} + +func (p *provider) update() { + p.mu.Lock() + p.num++ + p.data = []byte(fmt.Sprintf(cf, p.num, p.num, p.num%2 == 0, p.num)) + callbacks := p.callbacks + p.mu.Unlock() + for _, callback := range callbacks { + callback(p.Name(), p.data) + } +} + // customStruct it defines the struct of the custom configuration file read. type customStruct struct { Custom struct { @@ -78,16 +141,37 @@ type customStruct struct { type greeterImpl struct { common.GreeterServerImpl - customConf string + once config.Config + watch config.Config } // SayHello say hello request. Rewrite SayHello to inform server config. func (g *greeterImpl) SayHello(_ context.Context, req *pb.HelloRequest) (*pb.HelloReply, error) { fmt.Printf("trpc-go-server SayHello, req.msg:%s\n", req.Msg) + if req.Msg == "change config" { + p.update() + } + rsp := &pb.HelloReply{} - rsp.Msg = "trpc-go-server response: Hello " + req.Msg + ". Custom config from server: " + g.customConf + rsp.Msg = "trpc-go-server response: Hello " + req.Msg + + fmt.Sprintf("\nload once config: %s", g.once.GetString("custom.test", "")) + + fmt.Sprintf("\nstart watch config: %s", g.watch.GetString("custom.test", "")) + fmt.Printf("trpc-go-server SayHello, rsp.msg:%s\n", rsp.Msg) return rsp, nil } + +// first print +// +// trpc-go-server SayHello, rsp.msg:trpc-go-server response: Hello trpc-go-client +// load once config: number_1 +// start watch config:number_1 +// +// second print +// +// trpc-go-server SayHello, req.msg:change config +// trpc-go-server SayHello, rsp.msg:trpc-go-server response: Hello change config +// load once config: number_1 +// start watch config:number_2 From 93d46c3f09343a6d44a9e7108cbbd96dc4c6de10 Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 16:33:44 +0800 Subject: [PATCH 07/39] config: set empty ip to default 0.0.0.0 --- config.go | 17 +++++++--- examples/go.mod | 3 -- examples/go.sum | 4 --- test/graceful_restart_test.go | 34 +++++++++++++++++++ .../gracefulrestart/trpc/trpc_go_emptyip.yaml | 13 +++++++ 5 files changed, 60 insertions(+), 11 deletions(-) create mode 100644 test/gracefulrestart/trpc/trpc_go_emptyip.yaml diff --git a/config.go b/config.go index f981c6c..5d43149 100644 --- a/config.go +++ b/config.go @@ -714,11 +714,17 @@ func RepairConfig(cfg *Config) error { // repairServiceIPWithNic repairs the Config when service ip is empty according to the nic. func repairServiceIPWithNic(cfg *Config) error { + // Set empty ip to "0.0.0.0" to prevent malformed key matching + // for passed listeners during hot restart. + const defaultIP = "0.0.0.0" for index, item := range cfg.Server.Service { if item.IP == "" { ip := getIP(item.Nic) - if ip == "" && item.Nic != "" { - return fmt.Errorf("can't find service IP by the NIC: %s", item.Nic) + if ip == "" { + if item.Nic != "" { + return fmt.Errorf("can't find service IP by the NIC: %s", item.Nic) + } + ip = defaultIP } cfg.Server.Service[index].IP = ip } @@ -727,8 +733,11 @@ func repairServiceIPWithNic(cfg *Config) error { if cfg.Server.Admin.IP == "" { ip := getIP(cfg.Server.Admin.Nic) - if ip == "" && cfg.Server.Admin.Nic != "" { - return fmt.Errorf("can't find admin IP by the NIC: %s", cfg.Server.Admin.Nic) + if ip == "" { + if cfg.Server.Admin.Nic != "" { + return fmt.Errorf("can't find admin IP by the NIC: %s", cfg.Server.Admin.Nic) + } + ip = defaultIP } cfg.Server.Admin.IP = ip } diff --git a/examples/go.mod b/examples/go.mod index d6d8b8b..ab9f708 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -6,7 +6,6 @@ replace trpc.group/trpc-go/trpc-go => ../ require ( github.com/golang/protobuf v1.5.2 - github.com/stretchr/testify v1.8.0 google.golang.org/protobuf v1.30.0 trpc.group/trpc-go/trpc-go v0.0.0-00010101000000-000000000000 trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 @@ -15,7 +14,6 @@ require ( require ( github.com/BurntSushi/toml v0.3.1 // indirect github.com/andybalholm/brotli v1.0.4 // indirect - github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.4.9 // indirect github.com/go-playground/form/v4 v4.2.0 // indirect github.com/golang/mock v1.4.4 // indirect @@ -31,7 +29,6 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/panjf2000/ants/v2 v2.4.6 // indirect github.com/pkg/errors v0.9.1 // indirect - github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/cast v1.3.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.43.0 // indirect diff --git a/examples/go.sum b/examples/go.sum index f50aca0..e0d6081 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -59,13 +59,10 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.43.0 h1:Gy4sb32C98fbzVWZlTM1oTMdLWGyvxR03VhM6cBIU4g= @@ -128,7 +125,6 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/test/graceful_restart_test.go b/test/graceful_restart_test.go index 741cffb..dd4bf3c 100644 --- a/test/graceful_restart_test.go +++ b/test/graceful_restart_test.go @@ -42,6 +42,9 @@ func (s *TestSuite) TestServerGracefulRestart() { s.Run("SendNonGracefulRestartSignal", func() { s.testSendNonGracefulRestartSignal() }) + s.Run("GracefulRestartForEmptyIP", func() { + s.testGracefulRestartForEmptyIP() + }) } func (s *TestSuite) testServerGracefulRestartIsIdempotent() { @@ -250,6 +253,37 @@ func (s *TestSuite) testSendNonGracefulRestartSignal() { }) } +func (s *TestSuite) testGracefulRestartForEmptyIP() { + const ( + binaryFile = "./gracefulrestart/trpc/server.o" + sourceFile = "./gracefulrestart/trpc/server.go" + configFile = "./gracefulrestart/trpc/trpc_go_emptyip.yaml" + ) + + cmd, err := startServerFromBash( + sourceFile, + configFile, + binaryFile, + ) + require.Nil(s.T(), err) + defer func() { + require.Nil(s.T(), exec.Command("rm", binaryFile).Run()) + require.Nil(s.T(), cmd.Process.Kill()) + }() + + const target = "ip://127.0.0.1:17777" + sp, err := getServerProcessByEmptyCall(target) + require.Nil(s.T(), err) + pid := sp.Pid + require.Nil(s.T(), sp.Signal(server.DefaultServerGracefulSIG)) + time.Sleep(1 * time.Second) + sp, err = getServerProcessByEmptyCall(target) + require.Nil(s.T(), err) + require.NotEqual(s.T(), pid, sp.Pid) + pid = sp.Pid + require.Nil(s.T(), sp.Kill()) +} + func startServerFromBash(sourceFile, configFile, targetFile string) (*exec.Cmd, error) { cmd := exec.Command( "bash", diff --git a/test/gracefulrestart/trpc/trpc_go_emptyip.yaml b/test/gracefulrestart/trpc/trpc_go_emptyip.yaml new file mode 100644 index 0000000..28f6367 --- /dev/null +++ b/test/gracefulrestart/trpc/trpc_go_emptyip.yaml @@ -0,0 +1,13 @@ +global: + namespace: Development + env_name: test +server: + app: testing + server: end2end + admin: + port: 19999 + service: + - name: trpc.testing.end2end.TestTRPC + protocol: trpc + network: tcp + port: 17777 From c56a795ce41194d2c7deac3021af599bea04e728 Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 16:37:08 +0800 Subject: [PATCH 08/39] re-enable Config.Global.LocalIP --- config.go | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/config.go b/config.go index 5d43149..fc8184c 100644 --- a/config.go +++ b/config.go @@ -676,6 +676,17 @@ func RepairConfig(cfg *Config) error { } codec.SetReaderSize(*cfg.Global.ReadBufferSize) + // nic -> ip + if err := repairServiceIPWithNic(cfg); err != nil { + return err + } + + // Set empty ip to "0.0.0.0" to prevent malformed key matching + // for passed listeners during hot restart. + const defaultIP = "0.0.0.0" + setDefault(&cfg.Global.LocalIP, defaultIP) + setDefault(&cfg.Server.Admin.IP, cfg.Global.LocalIP) + // protocol network ip empty for _, serviceCfg := range cfg.Server.Service { setDefault(&serviceCfg.Protocol, cfg.Server.Protocol) @@ -714,17 +725,11 @@ func RepairConfig(cfg *Config) error { // repairServiceIPWithNic repairs the Config when service ip is empty according to the nic. func repairServiceIPWithNic(cfg *Config) error { - // Set empty ip to "0.0.0.0" to prevent malformed key matching - // for passed listeners during hot restart. - const defaultIP = "0.0.0.0" for index, item := range cfg.Server.Service { if item.IP == "" { ip := getIP(item.Nic) - if ip == "" { - if item.Nic != "" { - return fmt.Errorf("can't find service IP by the NIC: %s", item.Nic) - } - ip = defaultIP + if ip == "" && item.Nic != "" { + return fmt.Errorf("can't find service IP by the NIC: %s", item.Nic) } cfg.Server.Service[index].IP = ip } @@ -733,11 +738,8 @@ func repairServiceIPWithNic(cfg *Config) error { if cfg.Server.Admin.IP == "" { ip := getIP(cfg.Server.Admin.Nic) - if ip == "" { - if cfg.Server.Admin.Nic != "" { - return fmt.Errorf("can't find admin IP by the NIC: %s", cfg.Server.Admin.Nic) - } - ip = defaultIP + if ip == "" && cfg.Server.Admin.Nic != "" { + return fmt.Errorf("can't find admin IP by the NIC: %s", cfg.Server.Admin.Nic) } cfg.Server.Admin.IP = ip } From 97c15dd2797910458ef42cc9690da561707f145e Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 16:41:08 +0800 Subject: [PATCH 09/39] log/rollwriter: skip buffer and write directly when data packet exceeds expected size --- log/rollwriter/async_roll_writer.go | 11 +++++++++++ log/rollwriter/roll_writer_test.go | 10 ++++++++++ 2 files changed, 21 insertions(+) diff --git a/log/rollwriter/async_roll_writer.go b/log/rollwriter/async_roll_writer.go index 697ee42..1b84534 100644 --- a/log/rollwriter/async_roll_writer.go +++ b/log/rollwriter/async_roll_writer.go @@ -117,6 +117,17 @@ func (w *AsyncRollWriter) batchWriteLog() { buffer.Reset() } case data := <-w.logQueue: + if len(data) >= w.opts.WriteLogSize { + // If the length of the current data exceeds the expected maximum value, + // we directly write it to the underlying logger instead of placing it into the buffer. + // This prevents the buffer from being overwhelmed by excessively large data, + // which could lead to memory leaks. + // Prior to that, we need to write the existing data in the buffer to the underlying logger. + _, _ = w.logger.Write(buffer.Bytes()) + buffer.Reset() + _, _ = w.logger.Write(data) + continue + } buffer.Write(data) if buffer.Len() >= w.opts.WriteLogSize { _, err := w.logger.Write(buffer.Bytes()) diff --git a/log/rollwriter/roll_writer_test.go b/log/rollwriter/roll_writer_test.go index 79bc7a2..357c70a 100644 --- a/log/rollwriter/roll_writer_test.go +++ b/log/rollwriter/roll_writer_test.go @@ -376,6 +376,16 @@ func TestAsyncRollWriterSyncTwice(t *testing.T) { require.Nil(t, w.Close()) } +func TestAsyncRollWriterDirectWrite(t *testing.T) { + logSize := 1 + w := NewAsyncRollWriter(&noopWriteCloser{}, WithWriteLogSize(logSize)) + _, _ = w.Write([]byte("hello")) + time.Sleep(time.Millisecond) + require.Nil(t, w.Sync()) + require.Nil(t, w.Sync()) + require.Nil(t, w.Close()) +} + func TestRollWriterError(t *testing.T) { logDir := t.TempDir() t.Run("reopen file", func(t *testing.T) { From 84e143cf72a87efd58bdf10f29c560cd1da8a5bf Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 16:48:49 +0800 Subject: [PATCH 10/39] http: do not ignore server no response error --- http/transport.go | 6 ++++-- test/http_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) diff --git a/http/transport.go b/http/transport.go index 980c35a..4a19ed8 100644 --- a/http/transport.go +++ b/http/transport.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" "net" + "net/http" stdhttp "net/http" "net/http/httptrace" "net/url" @@ -139,8 +140,9 @@ func (t *ServerTransport) listenAndServeHTTP(ctx context.Context, opts *transpor if err != nil { span.SetAttribute(rpcz.TRPCAttributeError, err) log.Errorf("http server transport handle fail:%v", err) - if err == ErrEncodeMissingHeader { - w.WriteHeader(500) + if err == ErrEncodeMissingHeader || errors.Is(err, errs.ErrServerNoResponse) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(fmt.Sprintf("http server handle error: %+v", err))) } return } diff --git a/test/http_test.go b/test/http_test.go index c0cc32c..91cbcaa 100644 --- a/test/http_test.go +++ b/test/http_test.go @@ -151,6 +151,34 @@ func (s *TestSuite) testSendHTTPSRequestToHTTPServer(e *httpRPCEnv) { require.Contains(s.T(), errs.Msg(err), "codec empty") } +func (s *TestSuite) TestHandleErrServerNoResponse() { + for _, e := range allHTTPRPCEnvs { + if e.client.multiplexed { + continue + } + s.Run(e.String(), func() { s.testHandleErrServerNoResponse(e) }) + } +} +func (s *TestSuite) testHandleErrServerNoResponse(e *httpRPCEnv) { + s.startServer(&testHTTPService{TRPCService: TRPCService{UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { + return nil, errs.ErrServerNoResponse + }}}, server.WithServerAsync(e.server.async)) + + s.T().Cleanup(func() { s.closeServer(nil) }) + + bts, err := proto.Marshal(s.defaultSimpleRequest) + require.Nil(s.T(), err) + + c := thttp.NewStdHTTPClient("http-client") + rsp, err := c.Post(s.unaryCallCustomURL(), "application/pb", bytes.NewReader(bts)) + require.Nil(s.T(), err) + require.Equal(s.T(), http.StatusInternalServerError, rsp.StatusCode) + + bts, err = io.ReadAll(rsp.Body) + require.Nil(s.T(), err) + require.Containsf(s.T(), string(bts), "http server handle error: type:framework, code:0, msg:server no response", "full err: %+v", err) +} + func (s *TestSuite) TestStatusBadRequestDueToServerValidateFail() { for _, e := range allHTTPRPCEnvs { if e.client.multiplexed { From c8df21e20437e8d27357f747241cf7a607e9ddc2 Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 16:57:23 +0800 Subject: [PATCH 11/39] http: httptrace should not replace req ctx with transport ctx --- http/transport.go | 38 +++++++++++++++++++++++++++++- http/transport_test.go | 1 + internal/context/value_ctx.go | 29 +++++++++++++++++++++++ internal/context/value_ctx_test.go | 19 +++++++++++++++ 4 files changed, 86 insertions(+), 1 deletion(-) create mode 100644 internal/context/value_ctx.go create mode 100644 internal/context/value_ctx_test.go diff --git a/http/transport.go b/http/transport.go index 4a19ed8..a27bd7c 100644 --- a/http/transport.go +++ b/http/transport.go @@ -23,6 +23,7 @@ import ( "encoding/base64" "errors" "fmt" + "io" "net" "net/http" stdhttp "net/http" @@ -36,6 +37,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" + icontext "trpc.group/trpc-go/trpc-go/internal/context" "trpc.group/trpc-go/trpc-go/internal/reuseport" trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" @@ -476,7 +478,24 @@ func (ct *ClientTransport) RoundTrip( msg.WithRemoteAddr(tcpAddr) }, } - request := req.WithContext(httptrace.WithClientTrace(ctx, trace)) + reqCtx := ctx + cancel := context.CancelFunc(func() {}) + if rspHeader.ManualReadBody { + // In the scenario of Manual Read body, the lifecycle of rsp body is different + // from that of invoke ctx, and is independently controlled by body.Close(). + // Therefore, the timeout/cancel function in the original context needs to be replaced. + controlCtx := context.Background() + if deadline, ok := ctx.Deadline(); ok { + controlCtx, cancel = context.WithDeadline(context.Background(), deadline) + } + reqCtx = icontext.NewContextWithValues(controlCtx, ctx) + } + defer func() { + if err != nil { + cancel() + } + }() + request := req.WithContext(httptrace.WithClientTrace(reqCtx, trace)) client, err := ct.getStdHTTPClient(opts.CACertFile, opts.TLSCertFile, opts.TLSKeyFile, opts.TLSServerName) @@ -499,9 +518,26 @@ func (ct *ClientTransport) RoundTrip( return nil, errs.NewFrameError(errs.RetClientNetErr, "http client transport RoundTrip: "+err.Error()) } + rspHeader.Response.Body = &responseBodyWithCancel{body: rspHeader.Response.Body, cancel: cancel} return emptyBuf, nil } +// responseBodyWithCancel implements io.ReadCloser. +// It wraps response body and cancel function. +type responseBodyWithCancel struct { + body io.ReadCloser + cancel context.CancelFunc +} + +func (b *responseBodyWithCancel) Read(p []byte) (int, error) { + return b.body.Read(p) +} + +func (b *responseBodyWithCancel) Close() error { + b.cancel() + return b.body.Close() +} + func (ct *ClientTransport) getStdHTTPClient(caFile, certFile, keyFile, serverName string) (*stdhttp.Client, error) { if len(caFile) == 0 { // HTTP requests share one client. diff --git a/http/transport_test.go b/http/transport_test.go index ef15810..7a14527 100644 --- a/http/transport_test.go +++ b/http/transport_test.go @@ -1405,6 +1405,7 @@ func TestHTTPSendAndReceiveSSE(t *testing.T) { client.WithCurrentCompressType(codec.CompressTypeNoop), client.WithReqHead(reqHeader), client.WithRspHead(rspHead), + client.WithTimeout(time.Minute), )) body := rspHead.Response.Body // Do stream reads directly from rspHead.Response.Body. defer body.Close() // Do remember to close the body. diff --git a/internal/context/value_ctx.go b/internal/context/value_ctx.go new file mode 100644 index 0000000..eaf7341 --- /dev/null +++ b/internal/context/value_ctx.go @@ -0,0 +1,29 @@ +// Package context provides extensions to context.Context. +package context + +import ( + "context" +) + +// NewContextWithValues will use the valuesCtx's Value function. +// Effects of the returned context: +// +// Whether it has timed out or canceled: decided by ctx. +// Retrieve value using key: first use valuesCtx.Value, then ctx.Value. +func NewContextWithValues(ctx, valuesCtx context.Context) context.Context { + return &valueCtx{Context: ctx, values: valuesCtx} +} + +type valueCtx struct { + context.Context + values context.Context +} + +// Value re-implements context.Value, valueCtx.values.Value has the highest +// priority. +func (c *valueCtx) Value(key interface{}) interface{} { + if v := c.values.Value(key); v != nil { + return v + } + return c.Context.Value(key) +} diff --git a/internal/context/value_ctx_test.go b/internal/context/value_ctx_test.go new file mode 100644 index 0000000..994d65e --- /dev/null +++ b/internal/context/value_ctx_test.go @@ -0,0 +1,19 @@ +package context_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + icontext "trpc.group/trpc-go/trpc-go/internal/context" +) + +func TestWithValues(t *testing.T) { + type testKey struct{} + testValue := "value" + ctx := context.WithValue(context.TODO(), testKey{}, testValue) + ctx1 := icontext.NewContextWithValues(context.TODO(), ctx) + require.NotNil(t, ctx1.Value(testKey{})) + type notExist struct{} + require.Nil(t, ctx1.Value(notExist{})) +} From 8f59ab3b526bed8b1a7c7faf26be07fcf352725f Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 17:00:14 +0800 Subject: [PATCH 12/39] use GotConn for obtaining remote addr in connection reuse case --- http/transport.go | 5 ++--- http/transport_test.go | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/http/transport.go b/http/transport.go index a27bd7c..1f366d3 100644 --- a/http/transport.go +++ b/http/transport.go @@ -473,9 +473,8 @@ func (ct *ClientTransport) RoundTrip( return nil, err } trace := &httptrace.ClientTrace{ - ConnectStart: func(network, addr string) { - tcpAddr, _ := net.ResolveTCPAddr(network, addr) - msg.WithRemoteAddr(tcpAddr) + GotConn: func(info httptrace.GotConnInfo) { + msg.WithRemoteAddr(info.Conn.RemoteAddr()) }, } reqCtx := ctx diff --git a/http/transport_test.go b/http/transport_test.go index 7a14527..4659d05 100644 --- a/http/transport_test.go +++ b/http/transport_test.go @@ -1481,6 +1481,25 @@ func TestHTTPClientReqRspDifferentContentType(t *testing.T) { require.Equal(t, hello+t.Name(), rsp.Message) } +func TestHTTPGotConnectionRemoteAddr(t *testing.T) { + ctx := context.Background() + for i := 0; i < 3; i++ { + proxy := thttp.NewClientProxy(t.Name(), client.WithTarget("dns://new.qq.com/")) + rsp := &codec.Body{} + require.Nil(t, proxy.Get(ctx, "/", rsp, + client.WithSerializationType(codec.SerializationTypeNoop), + client.WithFilter( + func(ctx context.Context, req, rsp interface{}, next filter.ClientHandleFunc) error { + err := next(ctx, req, rsp) + msg := codec.Message(ctx) + addr := msg.RemoteAddr() + require.NotNil(t, addr, "expect to get remote addr from msg in connection reuse case") + t.Logf("addr = %+v\n", addr) + return err + }))) + } +} + type h struct{} func (*h) Handle(ctx context.Context, reqBuf []byte) (rsp []byte, err error) { From dc266ecd4c69eafbd4b34296d7e14c0a02d338b0 Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 17:34:15 +0800 Subject: [PATCH 13/39] fix the memory leak issue that occurs when stream.NewStream fails --- client/client_test.go | 23 ++-- client/stream.go | 26 +++-- client/stream_test.go | 236 ++++++++++++++++++++++++++++-------------- stream/client.go | 5 +- stream/client_test.go | 53 +++++++++- test/filter_test.go | 4 + 6 files changed, 242 insertions(+), 105 deletions(-) diff --git a/client/client_test.go b/client/client_test.go index 3734816..34aaf51 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -423,7 +423,11 @@ func (t *multiplexedTransport) RoundTrip( return t.fakeTransport.RoundTrip(ctx, req, opts...) } -type fakeTransport struct{} +type fakeTransport struct { + send func() error + recv func() ([]byte, error) + close func() +} func (c *fakeTransport) RoundTrip(ctx context.Context, req []byte, roundTripOpts ...transport.RoundTripOption) (rsp []byte, err error) { @@ -447,18 +451,15 @@ func (c *fakeTransport) RoundTrip(ctx context.Context, req []byte, } func (c *fakeTransport) Send(ctx context.Context, req []byte, opts ...transport.RoundTripOption) error { + if c.send != nil { + return c.send() + } return nil } func (c *fakeTransport) Recv(ctx context.Context, opts ...transport.RoundTripOption) ([]byte, error) { - body, ok := ctx.Value("recv-decode-error").(string) - if ok { - return []byte(body), nil - } - - err, ok := ctx.Value("recv-error").(string) - if ok { - return nil, errors.New(err) + if c.recv != nil { + return c.recv() } return []byte("body"), nil } @@ -467,7 +468,9 @@ func (c *fakeTransport) Init(ctx context.Context, opts ...transport.RoundTripOpt return nil } func (c *fakeTransport) Close(ctx context.Context) { - return + if c.close != nil { + c.close() + } } type fakeCodec struct { diff --git a/client/stream.go b/client/stream.go index a40ac8e..3a33159 100644 --- a/client/stream.go +++ b/client/stream.go @@ -66,11 +66,16 @@ type RecvControl interface { // It serializes the message and sends it to server through stream transport. // It's safe to call Recv and Send in different goroutines concurrently, but calling // Send in different goroutines concurrently is not thread-safe. -func (s *stream) Send(ctx context.Context, m interface{}) error { +func (s *stream) Send(ctx context.Context, m interface{}) (err error) { + defer func() { + if err != nil { + s.opts.StreamTransport.Close(ctx) + } + }() + msg := codec.Message(ctx) reqBodyBuf, err := serializeAndCompress(ctx, msg, m, s.opts) if err != nil { - s.opts.StreamTransport.Close(ctx) return err } @@ -87,7 +92,6 @@ func (s *stream) Send(ctx context.Context, m interface{}) error { } if err := s.opts.StreamTransport.Send(ctx, reqBuf); err != nil { - s.opts.StreamTransport.Close(ctx) return err } return nil @@ -97,18 +101,24 @@ func (s *stream) Send(ctx context.Context, m interface{}) error { // It decodes and decompresses the message and leaves serialization to upper layer. // It's safe to call Recv and Send in different goroutines concurrently, but calling // Send in different goroutines concurrently is not thread-safe. -func (s *stream) Recv(ctx context.Context) ([]byte, error) { +func (s *stream) Recv(ctx context.Context) (buf []byte, err error) { + defer func() { + if err != nil { + s.opts.StreamTransport.Close(ctx) + } + }() rspBuf, err := s.opts.StreamTransport.Recv(ctx) if err != nil { - s.opts.StreamTransport.Close(ctx) return nil, err } msg := codec.Message(ctx) rspBodyBuf, err := s.opts.Codec.Decode(msg, rspBuf) if err != nil { - s.opts.StreamTransport.Close(ctx) return nil, errs.NewFrameError(errs.RetClientDecodeFail, "client codec Decode: "+err.Error()) } + if err := msg.ClientRspErr(); err != nil { + return nil, err + } if len(rspBodyBuf) > 0 { compressType := msg.CompressType() if icodec.IsValidCompressType(s.opts.CurrentCompressType) { @@ -118,9 +128,7 @@ func (s *stream) Recv(ctx context.Context) ([]byte, error) { if icodec.IsValidCompressType(compressType) && compressType != codec.CompressTypeNoop { rspBodyBuf, err = codec.Decompress(compressType, rspBodyBuf) if err != nil { - s.opts.StreamTransport.Close(ctx) - return nil, - errs.NewFrameError(errs.RetClientDecodeFail, "client codec Decompress: "+err.Error()) + return nil, errs.NewFrameError(errs.RetClientDecodeFail, "client codec Decompress: "+err.Error()) } } } diff --git a/client/stream_test.go b/client/stream_test.go index 6b9d398..24a9bd7 100644 --- a/client/stream_test.go +++ b/client/stream_test.go @@ -15,6 +15,7 @@ package client_test import ( "context" + "errors" "testing" "time" @@ -37,85 +38,123 @@ func TestStream(t *testing.T) { // calling without error streamCli := client.NewStream() - require.NotNil(t, streamCli) - opts, err := streamCli.Init(ctx, client.WithTarget("ip://127.0.0.1:8000"), - client.WithTimeout(time.Second), client.WithSerializationType(codec.SerializationTypeNoop), - client.WithStreamTransport(&fakeTransport{}), client.WithProtocol("fake")) - require.Nil(t, err) - require.NotNil(t, opts) - err = streamCli.Invoke(ctx) - require.Nil(t, err) - err = streamCli.Send(ctx, reqBody) - require.Nil(t, err) - rsp, err := streamCli.Recv(ctx) - require.Nil(t, err) - require.Equal(t, []byte("body"), rsp) - err = streamCli.Close(ctx) - require.Nil(t, err) + t.Run("calling without error", func(t *testing.T) { + require.NotNil(t, streamCli) + opts, err := streamCli.Init(ctx, + client.WithTarget("ip://127.0.0.1:8000"), + client.WithTimeout(time.Second), + client.WithSerializationType(codec.SerializationTypeNoop), + client.WithStreamTransport(&fakeTransport{}), + client.WithProtocol("fake"), + ) + require.Nil(t, err) + require.NotNil(t, opts) + err = streamCli.Invoke(ctx) + require.Nil(t, err) + err = streamCli.Send(ctx, reqBody) + require.Nil(t, err) + rsp, err := streamCli.Recv(ctx) + require.Nil(t, err) + require.Equal(t, []byte("body"), rsp) + err = streamCli.Close(ctx) + require.Nil(t, err) + }) - // test nil Codec - opts, err = streamCli.Init(ctx, - client.WithTarget("ip://127.0.0.1:8080"), - client.WithTimeout(time.Second), - client.WithProtocol("fake-nil"), - client.WithSerializationType(codec.SerializationTypeNoop), - client.WithStreamTransport(&fakeTransport{})) - require.NotNil(t, err) - require.Nil(t, opts) - err = streamCli.Invoke(ctx) - require.Nil(t, err) + t.Run("test nil Codec", func(t *testing.T) { + opts, err := streamCli.Init(ctx, + client.WithTarget("ip://127.0.0.1:8080"), + client.WithTimeout(time.Second), + client.WithProtocol("fake-nil"), + client.WithSerializationType(codec.SerializationTypeNoop), + client.WithStreamTransport(&fakeTransport{})) + require.NotNil(t, err) + require.Nil(t, opts) + err = streamCli.Invoke(ctx) + require.Nil(t, err) + }) - // test selectNode with error - opts, err = streamCli.Init(ctx, client.WithTarget("ip/:/127.0.0.1:8080"), - client.WithProtocol("fake")) - require.NotNil(t, err) - require.Contains(t, err.Error(), "invalid") - require.Nil(t, opts) - - // test stream recv failure - ctx = context.WithValue(ctx, "recv-error", "recv failed") - opts, err = streamCli.Init(ctx, client.WithTarget("ip://127.0.0.1:8000"), - client.WithTimeout(time.Second), client.WithSerializationType(codec.SerializationTypeNoop), - client.WithStreamTransport(&fakeTransport{}), client.WithProtocol("fake")) - require.Nil(t, err) - require.NotNil(t, opts) - err = streamCli.Invoke(ctx) - require.Nil(t, err) - rsp, err = streamCli.Recv(ctx) - require.Nil(t, rsp) - require.NotNil(t, err) - - // test decode failure - ctx = context.WithValue(ctx, "recv-decode-error", "businessfail") - rsp, err = streamCli.Recv(ctx) - require.Nil(t, rsp) - require.NotNil(t, err) - - // test compress failure - ctx = context.Background() - opts, err = streamCli.Init(ctx, client.WithTarget("ip://127.0.0.1:8000"), - client.WithTimeout(time.Second), client.WithSerializationType(codec.SerializationTypeNoop), - client.WithStreamTransport(&fakeTransport{}), client.WithCurrentCompressType(codec.CompressTypeGzip), - client.WithProtocol("fake")) - require.Nil(t, err) - require.NotNil(t, opts) - err = streamCli.Invoke(ctx) - require.Nil(t, err) - _, err = streamCli.Recv(ctx) - require.NotNil(t, err) - - // test compress without error - opts, err = streamCli.Init(ctx, client.WithTarget("ip://127.0.0.1:8000"), - client.WithTimeout(time.Second), client.WithSerializationType(codec.SerializationTypeNoop), - client.WithStreamTransport(&fakeTransport{}), client.WithCurrentCompressType(codec.CompressTypeNoop), - client.WithProtocol("fake")) - require.Nil(t, err) - require.NotNil(t, opts) - err = streamCli.Invoke(ctx) - require.Nil(t, err) - rsp, err = streamCli.Recv(ctx) - require.Nil(t, err) - require.NotNil(t, rsp) + t.Run("test selectNode with error", func(t *testing.T) { + opts, err := streamCli.Init(ctx, + client.WithTarget("ip/:/127.0.0.1:8080"), + client.WithProtocol("fake"), + ) + require.NotNil(t, err) + require.Contains(t, err.Error(), "invalid") + require.Nil(t, opts) + }) + + t.Run("test stream recv failure", func(t *testing.T) { + opts, err := streamCli.Init(ctx, + client.WithTarget("ip://127.0.0.1:8000"), + client.WithTimeout(time.Second), + client.WithSerializationType(codec.SerializationTypeNoop), + client.WithStreamTransport(&fakeTransport{ + recv: func() ([]byte, error) { + return nil, errors.New("recv failed") + }, + }), + client.WithProtocol("fake"), + ) + require.Nil(t, err) + require.NotNil(t, opts) + err = streamCli.Invoke(ctx) + require.Nil(t, err) + rsp, err := streamCli.Recv(ctx) + require.Nil(t, rsp) + require.NotNil(t, err) + }) + + t.Run("test decode failure", func(t *testing.T) { + _, err := streamCli.Init(ctx, + client.WithTarget("ip://127.0.0.1:8000"), + client.WithTimeout(time.Second), + client.WithSerializationType(codec.SerializationTypeNoop), + client.WithStreamTransport(&fakeTransport{ + recv: func() ([]byte, error) { + return []byte("businessfail"), nil + }, + }), + client.WithProtocol("fake"), + ) + require.Nil(t, err) + rsp, err := streamCli.Recv(ctx) + require.Nil(t, rsp) + require.NotNil(t, err) + }) + + t.Run("test compress failure", func(t *testing.T) { + opts, err := streamCli.Init(context.Background(), + client.WithTarget("ip://127.0.0.1:8000"), + client.WithTimeout(time.Second), + client.WithSerializationType(codec.SerializationTypeNoop), + client.WithStreamTransport(&fakeTransport{}), + client.WithCurrentCompressType(codec.CompressTypeGzip), + client.WithProtocol("fake")) + require.Nil(t, err) + require.NotNil(t, opts) + err = streamCli.Invoke(ctx) + require.Nil(t, err) + _, err = streamCli.Recv(ctx) + require.NotNil(t, err) + }) + + t.Run("test compress without error", func(t *testing.T) { + opts, err := streamCli.Init(ctx, + client.WithTarget("ip://127.0.0.1:8000"), + client.WithTimeout(time.Second), + client.WithSerializationType(codec.SerializationTypeNoop), + client.WithStreamTransport(&fakeTransport{}), + client.WithCurrentCompressType(codec.CompressTypeNoop), + client.WithProtocol("fake"), + ) + require.Nil(t, err) + require.NotNil(t, opts) + err = streamCli.Invoke(ctx) + require.Nil(t, err) + rsp, err := streamCli.Recv(ctx) + require.Nil(t, err) + require.NotNil(t, rsp) + }) } func TestGetStreamFilter(t *testing.T) { @@ -151,3 +190,46 @@ func TestStreamGetAddress(t *testing.T) { require.NotNil(t, msg.RemoteAddr()) require.Equal(t, addr, msg.RemoteAddr().String()) } + +func TestStreamCloseTransport(t *testing.T) { + codec.Register("fake", nil, &fakeCodec{}) + t.Run("close transport when send fail", func(t *testing.T) { + var isClose bool + streamCli := client.NewStream() + _, err := streamCli.Init(context.Background(), + client.WithTarget("ip://127.0.0.1:8000"), + client.WithStreamTransport(&fakeTransport{ + send: func() error { + return errors.New("expected error") + }, + close: func() { + isClose = true + }, + }), + client.WithProtocol("fake"), + ) + require.Nil(t, err) + require.NotNil(t, streamCli.Send(context.Background(), nil)) + require.True(t, isClose) + }) + t.Run("close transport when recv fail", func(t *testing.T) { + var isClose bool + streamCli := client.NewStream() + _, err := streamCli.Init(context.Background(), + client.WithTarget("ip://127.0.0.1:8000"), + client.WithStreamTransport(&fakeTransport{ + recv: func() ([]byte, error) { + return nil, errors.New("expected error") + }, + close: func() { + isClose = true + }, + }), + client.WithProtocol("fake"), + ) + require.Nil(t, err) + _, err = streamCli.Recv(context.Background()) + require.NotNil(t, err) + require.True(t, isClose) + }) +} diff --git a/stream/client.go b/stream/client.go index 65a6f85..be7b754 100644 --- a/stream/client.go +++ b/stream/client.go @@ -252,9 +252,6 @@ func (cs *clientStream) invoke(ctx context.Context, _ *client.ClientStreamDesc) if _, err := cs.stream.Recv(newCtx); err != nil { return nil, err } - if newMsg.ClientRspErr() != nil { - return nil, newMsg.ClientRspErr() - } initWindowSize := defaultInitWindowSize if initRspMeta, ok := newMsg.StreamFrame().(*trpcpb.TrpcStreamInitMeta); ok { @@ -344,7 +341,7 @@ func (cs *clientStream) dispatch() { if err != nil { // return to client on error. cs.recvQueue.Put(&response{ - err: errs.Wrap(err, errs.RetClientStreamReadEnd, streamClosed), + err: errs.WrapFrameError(err, errs.RetClientStreamReadEnd, streamClosed), }) return } diff --git a/stream/client_test.go b/stream/client_test.go index 281770d..3fbc7a6 100644 --- a/stream/client_test.go +++ b/stream/client_test.go @@ -41,6 +41,8 @@ var ctx = context.Background() type fakeTransport struct { expectChan chan recvExpect + send func() error + close func() } // RoundTrip Mock RoundTrip method. @@ -51,9 +53,8 @@ func (c *fakeTransport) RoundTrip(ctx context.Context, req []byte, // Send Mock Send method. func (c *fakeTransport) Send(ctx context.Context, req []byte, opts ...transport.RoundTripOption) error { - err, ok := ctx.Value("send-error").(string) - if ok { - return errors.New(err) + if c.send != nil { + return c.send() } return nil } @@ -80,7 +81,9 @@ func (c *fakeTransport) Init(ctx context.Context, opts ...transport.RoundTripOpt // Close Mock Close method. func (c *fakeTransport) Close(ctx context.Context) { - return + if c.close != nil { + c.close() + } } type fakeCodec struct { @@ -687,7 +690,8 @@ func TestClientStreamReturn(t *testing.T) { rsp := getBytes(dataLen) err = clientStream.RecvMsg(rsp) - assert.EqualValues(t, int32(101), err.(*errs.Error).Code) + + assert.EqualValues(t, int32(101), errs.Code(err.(*errs.Error).Unwrap())) } // TestClientSendFailWhenServerUnavailable test when the client blocks @@ -746,3 +750,42 @@ func TestClientReceiveErrorWhenServerUnavailable(t *testing.T) { assert.NotEqual(t, io.EOF, err) assert.ErrorIs(t, err, io.EOF) } + +func TestClientNewStreamFail(t *testing.T) { + codec.Register("mock", nil, &fakeCodec{}) + t.Run("Close Transport when Send Fail", func(t *testing.T) { + var isClosed bool + tp := &fakeTransport{expectChan: make(chan recvExpect, 1)} + tp.send = func() error { + return errors.New("client error") + } + tp.close = func() { + isClosed = true + } + _, err := stream.NewStreamClient().NewStream(ctx, &client.ClientStreamDesc{}, "", + client.WithProtocol("mock"), + client.WithTarget("ip://127.0.0.1:8000"), + client.WithStreamTransport(tp), + ) + assert.NotNil(t, err) + assert.True(t, isClosed) + }) + t.Run("Close Transport when Recv Fail", func(t *testing.T) { + var isClosed bool + tp := &fakeTransport{expectChan: make(chan recvExpect, 1)} + tp.expectChan <- func(fh *trpc.FrameHead, m codec.Msg) ([]byte, error) { + m.WithClientRspErr(errors.New("server error")) + return nil, nil + } + tp.close = func() { + isClosed = true + } + _, err := stream.NewStreamClient().NewStream(ctx, &client.ClientStreamDesc{}, "", + client.WithProtocol("mock"), + client.WithTarget("ip://127.0.0.1:8000"), + client.WithStreamTransport(tp), + ) + assert.NotNil(t, err) + assert.True(t, isClosed) + }) +} diff --git a/test/filter_test.go b/test/filter_test.go index d534779..b23dec4 100644 --- a/test/filter_test.go +++ b/test/filter_test.go @@ -15,6 +15,7 @@ package test import ( "context" + "errors" "time" "github.com/stretchr/testify/require" @@ -189,6 +190,9 @@ func (s *TestSuite) TestStreamServerFilter() { s1.Send(&testpb.StreamingInputCallRequest{}) require.Nil(s.T(), err) _, err = s1.CloseAndRecv() + require.Equal(s.T(), errs.RetClientStreamReadEnd, errs.Code(err)) + + err = errors.Unwrap(err) require.Equal(s.T(), errs.Code(filterTestError), errs.Code(err)) require.Equal(s.T(), errs.Msg(filterTestError), errs.Msg(err)) From 1a65f3a30c2d34295bad7358f6040ee28982bcd6 Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 17:34:57 +0800 Subject: [PATCH 14/39] remove unused copyRspHead function --- codec.go | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/codec.go b/codec.go index 3c690c2..b01e780 100644 --- a/codec.go +++ b/codec.go @@ -665,19 +665,6 @@ func loadOrStoreDefaultUnaryFrameHead(msg codec.Msg) *FrameHead { return frameHead } -func copyRspHead(dst, src *trpcpb.ResponseProtocol) { - dst.Version = src.Version - dst.CallType = src.CallType - dst.RequestId = src.RequestId - dst.Ret = src.Ret - dst.FuncRet = src.FuncRet - dst.ErrorMsg = src.ErrorMsg - dst.MessageType = src.MessageType - dst.TransInfo = src.TransInfo - dst.ContentType = src.ContentType - dst.ContentEncoding = src.ContentEncoding -} - func updateMsg(msg codec.Msg, frameHead *FrameHead, rsp *trpcpb.ResponseProtocol, attm []byte) error { msg.WithFrameHead(frameHead) msg.WithCompressType(int(rsp.GetContentEncoding())) From 1c9fc8c6ec84d7da82bc0308eed8f93c9b97feac Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 18:00:49 +0800 Subject: [PATCH 15/39] promise that dst of codec.Unmarshal is always map[string]interface{} --- config/trpc_config.go | 11 +++++++++-- config/trpc_config_test.go | 33 +++++++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/config/trpc_config.go b/config/trpc_config.go index e58470c..4a88cdf 100644 --- a/config/trpc_config.go +++ b/config/trpc_config.go @@ -280,6 +280,12 @@ type entity struct { data interface{} // unmarshal type to use point type, save latest no error data } +func newEntity() *entity { + return &entity{ + data: make(map[string]interface{}), + } +} + func newTrpcConfig(path string, opts ...LoadOption) (*TrpcConfig, error) { c := &TrpcConfig{ path: path, @@ -317,7 +323,7 @@ func (c *TrpcConfig) get() *entity { if c.value != nil { return c.value } - return &entity{} + return newEntity() } // init return config entity error when entity is empty and load run loads config once @@ -351,7 +357,8 @@ func (c *TrpcConfig) set(data []byte) error { data = expandenv.ExpandEnv(data) } - e := &entity{raw: data} + e := newEntity() + e.raw = data err := c.decoder.Unmarshal(data, &e.data) if err != nil { return fmt.Errorf("trpc/config: failed to parse:%w, id:%s", err, c.id) diff --git a/config/trpc_config_test.go b/config/trpc_config_test.go index de6e6fa..35d5a6c 100644 --- a/config/trpc_config_test.go +++ b/config/trpc_config_test.go @@ -16,6 +16,8 @@ package config import ( "errors" "fmt" + "os" + "reflect" "sync" "testing" "time" @@ -121,9 +123,10 @@ func Test_search(t *testing.T) { func TestTrpcConfig_Load(t *testing.T) { t.Run("parse failed", func(t *testing.T) { - c, _ := newTrpcConfig("../testdata/trpc_go.yaml") + c, err := newTrpcConfig("../testdata/trpc_go.yaml") + require.Nil(t, err) c.decoder = &TomlCodec{} - err := c.Load() + err = c.Load() require.Contains(t, errs.Msg(err), "failed to parse") }) } @@ -155,6 +158,14 @@ password: ${pwd} require.Contains(t, string(cfg.Bytes()), fmt.Sprintf("password: %s", t.Name())) } +func TestCodecUnmarshalDstMustBeMap(t *testing.T) { + filePath := t.TempDir() + "/conf.map" + require.Nil(t, os.WriteFile(filePath, []byte{}, 0644)) + RegisterCodec(dstMustBeMapCodec{}) + _, err := DefaultConfigLoader.Load(filePath, WithCodec(dstMustBeMapCodec{}.Name())) + require.Nil(t, err) +} + func NewEnvProvider(name string, data []byte) *EnvProvider { return &EnvProvider{ name: name, @@ -278,3 +289,21 @@ func (m *manualTriggerWatchProvider) Set(key string, v []byte) { callback(key, v) } } + +type dstMustBeMapCodec struct{} + +func (c dstMustBeMapCodec) Name() string { + return "map" +} + +func (c dstMustBeMapCodec) Unmarshal(bts []byte, dst interface{}) error { + rv := reflect.ValueOf(dst) + if rv.Kind() != reflect.Ptr || + rv.Elem().Kind() != reflect.Interface || + rv.Elem().Elem().Kind() != reflect.Map || + rv.Elem().Elem().Type().Key().Kind() != reflect.String || + rv.Elem().Elem().Type().Elem().Kind() != reflect.Interface { + return errors.New("the dst of codec.Unmarshal must be a map") + } + return nil +} From 5215ace90f066e40118add14e43d11d4567ea12f Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 18:52:55 +0800 Subject: [PATCH 16/39] server: do not close old listener immediately after hot restart --- http/restful_server_transport.go | 12 +++++-- http/transport.go | 4 +++ server/serve_unix.go | 3 -- server/server_unix_test.go | 12 +------ server/service.go | 1 + transport/server_listenserve_options.go | 10 ++++++ transport/server_transport.go | 45 ++++++++++++++++--------- transport/server_transport_tcp.go | 10 ------ transport/server_transport_test.go | 21 ++++++++++++ transport/tnet/server_transport_tcp.go | 4 +++ 10 files changed, 80 insertions(+), 42 deletions(-) diff --git a/http/restful_server_transport.go b/http/restful_server_transport.go index 499b787..a44c42b 100644 --- a/http/restful_server_transport.go +++ b/http/restful_server_transport.go @@ -160,12 +160,20 @@ func (st *RESTServerTransport) ListenAndServe(ctx context.Context, opt ...transp ln = tls.NewListener(ln, tlsConf) } + go func() { + <-opts.StopListening + ln.Close() + }() + return st.serve(ctx, ln, opts) } // serve starts service. -func (st *RESTServerTransport) serve(ctx context.Context, ln net.Listener, - opts *transport.ListenServeOptions) error { +func (st *RESTServerTransport) serve( + ctx context.Context, + ln net.Listener, + opts *transport.ListenServeOptions, +) error { // Get router. router := restful.GetRouter(opts.ServiceName) if router == nil { diff --git a/http/transport.go b/http/transport.go index 1f366d3..46e8293 100644 --- a/http/transport.go +++ b/http/transport.go @@ -198,6 +198,10 @@ func (t *ServerTransport) serve(ctx context.Context, s *stdhttp.Server, opts *tr _ = s.Shutdown(context.TODO()) }() } + go func() { + <-opts.StopListening + ln.Close() + }() return nil } diff --git a/server/serve_unix.go b/server/serve_unix.go index 16f4902..102f8a6 100644 --- a/server/serve_unix.go +++ b/server/serve_unix.go @@ -115,9 +115,6 @@ func (s *Server) StartNewProcess(args ...string) (uintptr, error) { return 0, err } - for _, f := range listenersFds { - f.OriginalListenCloser.Close() - } return uintptr(childPID), nil } diff --git a/server/server_unix_test.go b/server/server_unix_test.go index 8447e17..4873e06 100644 --- a/server/server_unix_test.go +++ b/server/server_unix_test.go @@ -59,20 +59,10 @@ func TestStartNewProcess(t *testing.T) { s.AddService("trpc.test.helloworld.Greeter1", service) err := s.Register(nil, nil) - assert.NotNil(t, err) impl := &GreeterServerImpl{} err = s.Register(&GreeterServerServiceDesc, impl) - assert.Nil(t, err) - go func() { - var netOpError *net.OpError - assert.ErrorAs( - t, - s.Serve(), - &netOpError, - `it is normal to have "use of closed network connection" error during hot restart`, - ) - }() + go s.Serve() time.Sleep(time.Second * 1) log.Info(os.Environ()) diff --git a/server/service.go b/server/service.go index a407fef..25980fe 100644 --- a/server/service.go +++ b/server/service.go @@ -109,6 +109,7 @@ type service struct { handlers map[string]Handler // rpcname => handler streamHandlers map[string]StreamHandler streamInfo map[string]*StreamServerInfo + stopListening chan<- struct{} } // New creates a service. diff --git a/transport/server_listenserve_options.go b/transport/server_listenserve_options.go index f8b6c8f..dac0397 100644 --- a/transport/server_listenserve_options.go +++ b/transport/server_listenserve_options.go @@ -43,6 +43,9 @@ type ListenServeOptions struct { // This used for rpc transport layer like http, it's unrelated to // the TCP keep-alives. DisableKeepAlives bool + + // StopListening is used to instruct the server transport to stop listening. + StopListening <-chan struct{} } // ListenServeOption modifies the ListenServeOptions. @@ -149,3 +152,10 @@ func WithServerIdleTimeout(timeout time.Duration) ListenServeOption { options.IdleTimeout = timeout } } + +// WithStopListening returns a ListenServeOption which notifies the transport to stop listening. +func WithStopListening(ch <-chan struct{}) ListenServeOption { + return func(options *ListenServeOptions) { + options.StopListening = ch + } +} diff --git a/transport/server_transport.go b/transport/server_transport.go index 6ca98eb..36c86c7 100644 --- a/transport/server_transport.go +++ b/transport/server_transport.go @@ -18,7 +18,6 @@ import ( "crypto/tls" "errors" "fmt" - "io" "net" "os" "runtime" @@ -242,6 +241,23 @@ func mayLiftToTLSListener(ln net.Listener, opts *ListenServeOptions) (net.Listen } func (s *serverTransport) serveStream(ctx context.Context, ln net.Listener, opts *ListenServeOptions) error { + var once sync.Once + closeListener := func() { ln.Close() } + defer once.Do(closeListener) + // Create a goroutine to watch ctx.Done() channel. + // Once Server.Close(), TCP listener should be closed immediately and won't accept any new connection. + go func() { + select { + case <-ctx.Done(): + // ctx.Done will perform the following two actions: + // 1. Stop listening. + // 2. Cancel all currently established connections. + // Whereas opts.StopListening will only stop listening. + case <-opts.StopListening: + } + log.Tracef("recv server close event") + once.Do(closeListener) + }() return s.serveTCP(ctx, ln, opts) } @@ -386,11 +402,10 @@ func getPassedListener(network, address string) (interface{}, error) { // ListenFd is the listener fd. type ListenFd struct { - OriginalListenCloser io.Closer - Fd uintptr - Name string - Network string - Address string + Fd uintptr + Name string + Network string + Address string } // inheritListeners stores the listener according to start listenfd and number of listenfd passed @@ -460,11 +475,10 @@ func getPacketConnFd(c net.PacketConn) (*ListenFd, error) { return nil, fmt.Errorf("getPacketConnFd getRawFd err: %w", err) } return &ListenFd{ - OriginalListenCloser: c, - Fd: lnFd, - Name: "a udp listener fd", - Network: c.LocalAddr().Network(), - Address: c.LocalAddr().String(), + Fd: lnFd, + Name: "a udp listener fd", + Network: c.LocalAddr().Network(), + Address: c.LocalAddr().String(), }, nil } @@ -478,11 +492,10 @@ func getListenerFd(ln net.Listener) (*ListenFd, error) { return nil, fmt.Errorf("getListenerFd getRawFd err: %w", err) } return &ListenFd{ - OriginalListenCloser: ln, - Fd: fd, - Name: "a tcp listener fd", - Network: ln.Addr().Network(), - Address: ln.Addr().String(), + Fd: fd, + Name: "a tcp listener fd", + Network: ln.Addr().Network(), + Address: ln.Addr().String(), }, nil } diff --git a/transport/server_transport_tcp.go b/transport/server_transport_tcp.go index ec91715..b930ee6 100644 --- a/transport/server_transport_tcp.go +++ b/transport/server_transport_tcp.go @@ -79,16 +79,6 @@ func createRoutinePool(size int) *ants.PoolWithFunc { } func (s *serverTransport) serveTCP(ctx context.Context, ln net.Listener, opts *ListenServeOptions) error { - var once sync.Once - closeListener := func() { ln.Close() } - defer once.Do(closeListener) - // Create a goroutine to watch ctx.Done() channel. - // Once Server.Close(), TCP listener should be closed immediately and won't accept any new connection. - go func() { - <-ctx.Done() - log.Tracef("recv server close event") - once.Do(closeListener) - }() // Create a goroutine pool if ServerAsync enabled. var pool *ants.PoolWithFunc if opts.ServerAsync { diff --git a/transport/server_transport_test.go b/transport/server_transport_test.go index 7596237..fd5ad20 100644 --- a/transport/server_transport_test.go +++ b/transport/server_transport_test.go @@ -1024,3 +1024,24 @@ func TestListenAndServeTLSFail(t *testing.T) { transport.WithListener(ln), )) } + +func TestListenAndServeWithStopListener(t *testing.T) { + s := transport.NewServerTransport() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.Nil(t, err) + ch := make(chan struct{}) + require.Nil(t, s.ListenAndServe(ctx, + transport.WithListenNetwork("tcp"), + transport.WithServerFramerBuilder(&framerBuilder{}), + transport.WithListener(ln), + transport.WithStopListening(ch), + )) + _, err = net.Dial("tcp", ln.Addr().String()) + require.Nil(t, err) + close(ch) + time.Sleep(time.Millisecond) + _, err = net.Dial("tcp", ln.Addr().String()) + require.NotNil(t, err) +} diff --git a/transport/tnet/server_transport_tcp.go b/transport/tnet/server_transport_tcp.go index ec5ab4f..442ce85 100644 --- a/transport/tnet/server_transport_tcp.go +++ b/transport/tnet/server_transport_tcp.go @@ -154,6 +154,10 @@ func (s *serverTransport) startService( pool *ants.PoolWithFunc, opts *transport.ListenServeOptions, ) error { + go func() { + <-opts.StopListening + listener.Close() + }() tnetOpts := []tnet.Option{ tnet.WithOnTCPOpened(func(conn tnet.Conn) error { tc := s.onConnOpened(conn, pool, opts) From 0d85c39ca38a7d1c32bc73b46b65b13b85ff01e7 Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 18:56:41 +0800 Subject: [PATCH 17/39] **http:** expose possible io.Writer interface for http response body --- http/transport.go | 40 +++++++++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/http/transport.go b/http/transport.go index 46e8293..31e410c 100644 --- a/http/transport.go +++ b/http/transport.go @@ -521,24 +521,50 @@ func (ct *ClientTransport) RoundTrip( return nil, errs.NewFrameError(errs.RetClientNetErr, "http client transport RoundTrip: "+err.Error()) } - rspHeader.Response.Body = &responseBodyWithCancel{body: rspHeader.Response.Body, cancel: cancel} + decorateWithCancel(rspHeader, cancel) return emptyBuf, nil } -// responseBodyWithCancel implements io.ReadCloser. +func decorateWithCancel(rspHeader *ClientRspHeader, cancel context.CancelFunc) { + // Quoted from: https://github.com/golang/go/blob/go1.21.4/src/net/http/response.go#L69 + // + // "As of Go 1.12, the Body will also implement io.Writer on a successful "101 Switching Protocols" response, + // as used by WebSockets and HTTP/2's "h2c" mode." + // + // Therefore, we require an extra check to ensure io.Writer's conformity, + // which will then expose the corresponding method. + // + // It's important to note that an embedded body may not be capable of exposing all the attached interfaces. + // Consequently, we perform an explicit interface assertion here. + if body, ok := rspHeader.Response.Body.(io.ReadWriteCloser); ok { + rspHeader.Response.Body = &writableResponseBodyWithCancel{ReadWriteCloser: body, cancel: cancel} + } else { + rspHeader.Response.Body = &responseBodyWithCancel{ReadCloser: rspHeader.Response.Body, cancel: cancel} + } +} + +// writableResponseBodyWithCancel implements io.ReadWriteCloser. // It wraps response body and cancel function. -type responseBodyWithCancel struct { - body io.ReadCloser +type writableResponseBodyWithCancel struct { + io.ReadWriteCloser cancel context.CancelFunc } -func (b *responseBodyWithCancel) Read(p []byte) (int, error) { - return b.body.Read(p) +func (b *writableResponseBodyWithCancel) Close() error { + b.cancel() + return b.ReadWriteCloser.Close() +} + +// responseBodyWithCancel implements io.ReadCloser. +// It wraps response body and cancel function. +type responseBodyWithCancel struct { + io.ReadCloser + cancel context.CancelFunc } func (b *responseBodyWithCancel) Close() error { b.cancel() - return b.body.Close() + return b.ReadCloser.Close() } func (ct *ClientTransport) getStdHTTPClient(caFile, certFile, From a4cc84869ae744bfe6387b0d041d7b3690675666 Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 18:59:40 +0800 Subject: [PATCH 18/39] **http:** check type of url.Values for form serialization --- http/serialization_form.go | 10 +++++++++- http/serialization_form_test.go | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/http/serialization_form.go b/http/serialization_form.go index 8de5736..a273b99 100644 --- a/http/serialization_form.go +++ b/http/serialization_form.go @@ -61,7 +61,8 @@ func (j *FormSerialization) Unmarshal(in []byte, body interface{}) error { } switch body.(type) { // go-playground/form does not support map structure. - case map[string]interface{}, *map[string]interface{}, map[string]string, *map[string]string: + case map[string]interface{}, *map[string]interface{}, map[string]string, *map[string]string, + url.Values, *url.Values: // Essentially, the underlying type of 'url.Values' is also a map. return unmarshalValues(j.tagname, values, body) default: } @@ -80,6 +81,13 @@ func (j *FormSerialization) Unmarshal(in []byte, body interface{}) error { // unmarshalValues parses the corresponding fields in values according to tagname. func unmarshalValues(tagname string, values url.Values, body interface{}) error { + // To handle the scenario where the underlying type of 'body' is 'url.Values'. + if b, ok := body.(url.Values); ok && b != nil { + for k, v := range values { + b[k] = v + } + return nil + } params := map[string]interface{}{} for k, v := range values { if len(v) == 1 { diff --git a/http/serialization_form_test.go b/http/serialization_form_test.go index a8836e1..da3711f 100644 --- a/http/serialization_form_test.go +++ b/http/serialization_form_test.go @@ -106,7 +106,27 @@ func TestFormSerializer(t *testing.T) { buf, _ := formSerializer.Marshal(&query) require.Equal(string(buf), expectedQueries[i], "x should be equal") } +} +func TestFromSerializerURLValues(t *testing.T) { + in := make(url.Values) + const ( + key = "key" + val = "val" + ) + in.Add(key, val) + bs, err := codec.Marshal(codec.SerializationTypeForm, in) + require.Nil(t, err) + out := make(url.Values) + require.Nil(t, codec.Unmarshal(codec.SerializationTypeForm, bs, out)) + require.Equal(t, val, out.Get(key)) + + out2 := make(url.Values) + require.Nil(t, codec.Unmarshal(codec.SerializationTypeForm, bs, &out2)) + require.Equal(t, val, out2.Get(key)) + + var out3 url.Values + require.NotNil(t, codec.Unmarshal(codec.SerializationTypeForm, bs, out3)) } func TestUnmarshal(t *testing.T) { From f0396588a9442986c41daaaa1fc7fce5fb468b9f Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 19:03:22 +0800 Subject: [PATCH 19/39] **codec:**: get serializer should also be able to unmarshal nested structure --- http/serialization_get.go | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/http/serialization_get.go b/http/serialization_get.go index 59421ab..bce7d7f 100644 --- a/http/serialization_get.go +++ b/http/serialization_get.go @@ -15,7 +15,6 @@ package http import ( "errors" - "net/url" "trpc.group/trpc-go/trpc-go/codec" ) @@ -26,24 +25,20 @@ func init() { // NewGetSerialization initializes the get serialized object. func NewGetSerialization(tag string) codec.Serializer { - + formSerializer := NewFormSerialization(tag) return &GetSerialization{ - tagname: tag, + formSerializer: formSerializer.(*FormSerialization), } } // GetSerialization packages kv structure of the http get request. type GetSerialization struct { - tagname string + formSerializer *FormSerialization } // Unmarshal unpacks kv structure. func (s *GetSerialization) Unmarshal(in []byte, body interface{}) error { - values, err := url.ParseQuery(string(in)) - if err != nil { - return err - } - return unmarshalValues(s.tagname, values, body) + return s.formSerializer.Unmarshal(in, body) } // Marshal packages kv structure. From 64d222b6d112ad5d2a3246c635098cb36158292d Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 19:36:48 +0800 Subject: [PATCH 20/39] **stream:** ensure server returns an error when connection is closed --- stream/server.go | 14 +++++++---- stream/server_test.go | 54 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/stream/server.go b/stream/server.go index 3e13185..f87053b 100644 --- a/stream/server.go +++ b/stream/server.go @@ -20,6 +20,7 @@ import ( "net" "sync" + "go.uber.org/atomic" trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" trpc "trpc.group/trpc-go/trpc-go" @@ -40,7 +41,7 @@ type serverStream struct { opts *server.Options recvQueue *queue.Queue[*response] done chan struct{} - err error // Carry the server tcp failure information. + err atomic.Error // Carry the server tcp failure information. once sync.Once rControl *receiveControl // Receiver flow control. sControl *sendControl // Sender flow control. @@ -48,6 +49,9 @@ type serverStream struct { // SendMsg is the API that users use to send streaming messages. func (s *serverStream) SendMsg(m interface{}) error { + if err := s.err.Load(); err != nil { + return errs.WrapFrameError(err, errs.Code(err), "stream sending error") + } msg := codec.Message(s.ctx) ctx, newMsg := codec.WithCloneContextAndMessage(s.ctx) defer codec.PutBackMessage(newMsg) @@ -119,8 +123,8 @@ func (s *serverStream) serializationAndCompressType(msg codec.Msg) (int, int) { func (s *serverStream) RecvMsg(m interface{}) error { resp, ok := s.recvQueue.Get() if !ok { - if s.err != nil { - return s.err + if err := s.err.Load(); err != nil { + return err } return errs.NewFrameError(errs.RetServerSystemErr, streamClosed) } @@ -320,7 +324,7 @@ func (sd *streamDispatcher) startStreamHandler(addr net.Addr, streamID uint32, err = ss.CloseSend(int32(trpcpb.TrpcStreamCloseType_TRPC_STREAM_CLOSE), 0, "") } if err != nil { - ss.err = err + ss.err.Store(err) log.Trace(closeSendFail, err) } } @@ -435,7 +439,7 @@ func (sd *streamDispatcher) handleError(msg codec.Msg) ([]byte, error) { return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchAddr) } for streamID, ss := range addrToStream { - ss.err = msg.ServerRspErr() + ss.err.Store(msg.ServerRspErr()) ss.once.Do(func() { close(ss.done) }) delete(addrToStream, streamID) } diff --git a/stream/server_test.go b/stream/server_test.go index 7a5cb80..3a4edfd 100644 --- a/stream/server_test.go +++ b/stream/server_test.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "math/rand" + "net" "sync" "testing" "time" @@ -911,3 +912,56 @@ func serverFilterAdd2(ss server.Stream, si *server.StreamServerInfo, err := handler(newWrappedServerStream(ss)) return err } + +// TestServerStreamAllFailWhenConnectionClosedAndReconnect tests when a connection +// is closed and then reconnected (with the same client IP and port), both SendMsg +// and RecvMsg on the server side result in errors. +func TestServerStreamAllFailWhenConnectionClosedAndReconnect(t *testing.T) { + ch := make(chan struct{}) + addr := "127.0.0.1:30211" + svrOpts := []server.Option{ + server.WithAddress(addr), + } + handle := func(s server.Stream) error { + <-ch + err := s.SendMsg(getBytes(100)) + assert.Equal(t, errs.Code(err), errs.RetServerSystemErr) + err = s.RecvMsg(getBytes(100)) + assert.Equal(t, errs.Code(err), errs.RetServerSystemErr) + ch <- struct{}{} + return nil + } + svr := startStreamServer(handle, svrOpts) + defer closeStreamServer(svr) + + // Init a stream + dialer := net.Dialer{ + LocalAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 20001}, + } + conn, err := dialer.Dial("tcp", addr) + assert.Nil(t, err) + _, msg := codec.WithNewMessage(context.Background()) + msg.WithFrameHead(&trpc.FrameHead{ + FrameType: uint8(trpcpb.TrpcDataFrameType_TRPC_STREAM_FRAME), + StreamFrameType: uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT), + }) + msg.WithClientRPCName("/trpc.test.stream.Greeter/StreamSayHello") + initReq, err := trpc.DefaultClientCodec.Encode(msg, nil) + assert.Nil(t, err) + _, err = conn.Write(initReq) + assert.Nil(t, err) + + // Close the connection + conn.Close() + + // Dial another connection using the same client ip:port + time.Sleep(time.Millisecond * 200) + _, err = dialer.Dial("tcp", addr) + assert.Nil(t, err) + + // Notify server to send and receive + ch <- struct{}{} + + // Wait server sending and receiving result assertion + <-ch +} From 7cc8c858b476a31e4bef3347304e63b94dff7c78 Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 20:09:29 +0800 Subject: [PATCH 21/39] **log:** log.Info("a","b") print "a b" instead of "ab" --- log/log.go | 24 ++++++++++++------------ log/logger.go | 12 ++++++------ log/zaplogger.go | 15 ++++++++------- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/log/log.go b/log/log.go index 311dc95..6460660 100644 --- a/log/log.go +++ b/log/log.go @@ -96,7 +96,7 @@ func RedirectStdLogAt(logger Logger, level zapcore.Level) (func(), error) { return nil, errors.New("log: only supports redirecting std logs to trpc zap logger") } -// Trace logs to TRACE log. Arguments are handled in the manner of fmt.Print. +// Trace logs to TRACE log. Arguments are handled in the manner of fmt.Println. func Trace(args ...interface{}) { if traceEnabled { GetDefaultLogger().Trace(args...) @@ -110,7 +110,7 @@ func Tracef(format string, args ...interface{}) { } } -// TraceContext logs to TRACE log. Arguments are handled in the manner of fmt.Print. +// TraceContext logs to TRACE log. Arguments are handled in the manner of fmt.Println. func TraceContext(ctx context.Context, args ...interface{}) { if !traceEnabled { return @@ -134,7 +134,7 @@ func TraceContextf(ctx context.Context, format string, args ...interface{}) { GetDefaultLogger().Tracef(format, args...) } -// Debug logs to DEBUG log. Arguments are handled in the manner of fmt.Print. +// Debug logs to DEBUG log. Arguments are handled in the manner of fmt.Println. func Debug(args ...interface{}) { GetDefaultLogger().Debug(args...) } @@ -144,7 +144,7 @@ func Debugf(format string, args ...interface{}) { GetDefaultLogger().Debugf(format, args...) } -// Info logs to INFO log. Arguments are handled in the manner of fmt.Print. +// Info logs to INFO log. Arguments are handled in the manner of fmt.Println. func Info(args ...interface{}) { GetDefaultLogger().Info(args...) } @@ -154,7 +154,7 @@ func Infof(format string, args ...interface{}) { GetDefaultLogger().Infof(format, args...) } -// Warn logs to WARNING log. Arguments are handled in the manner of fmt.Print. +// Warn logs to WARNING log. Arguments are handled in the manner of fmt.Println. func Warn(args ...interface{}) { GetDefaultLogger().Warn(args...) } @@ -164,7 +164,7 @@ func Warnf(format string, args ...interface{}) { GetDefaultLogger().Warnf(format, args...) } -// Error logs to ERROR log. Arguments are handled in the manner of fmt.Print. +// Error logs to ERROR log. Arguments are handled in the manner of fmt.Println. func Error(args ...interface{}) { GetDefaultLogger().Error(args...) } @@ -174,7 +174,7 @@ func Errorf(format string, args ...interface{}) { GetDefaultLogger().Errorf(format, args...) } -// Fatal logs to ERROR log. Arguments are handled in the manner of fmt.Print. +// Fatal logs to ERROR log. Arguments are handled in the manner of fmt.Println. // All Fatal logs will exit by calling os.Exit(1). // Implementations may also call os.Exit() with a non-zero exit code. func Fatal(args ...interface{}) { @@ -211,7 +211,7 @@ func WithContextFields(ctx context.Context, fields ...string) context.Context { return ctx } -// DebugContext logs to DEBUG log. Arguments are handled in the manner of fmt.Print. +// DebugContext logs to DEBUG log. Arguments are handled in the manner of fmt.Println. func DebugContext(ctx context.Context, args ...interface{}) { if l, ok := codec.Message(ctx).Logger().(Logger); ok { l.Debug(args...) @@ -229,7 +229,7 @@ func DebugContextf(ctx context.Context, format string, args ...interface{}) { GetDefaultLogger().Debugf(format, args...) } -// InfoContext logs to INFO log. Arguments are handled in the manner of fmt.Print. +// InfoContext logs to INFO log. Arguments are handled in the manner of fmt.Println. func InfoContext(ctx context.Context, args ...interface{}) { if l, ok := codec.Message(ctx).Logger().(Logger); ok { l.Info(args...) @@ -247,7 +247,7 @@ func InfoContextf(ctx context.Context, format string, args ...interface{}) { GetDefaultLogger().Infof(format, args...) } -// WarnContext logs to WARNING log. Arguments are handled in the manner of fmt.Print. +// WarnContext logs to WARNING log. Arguments are handled in the manner of fmt.Println. func WarnContext(ctx context.Context, args ...interface{}) { if l, ok := codec.Message(ctx).Logger().(Logger); ok { l.Warn(args...) @@ -266,7 +266,7 @@ func WarnContextf(ctx context.Context, format string, args ...interface{}) { } -// ErrorContext logs to ERROR log. Arguments are handled in the manner of fmt.Print. +// ErrorContext logs to ERROR log. Arguments are handled in the manner of fmt.Println. func ErrorContext(ctx context.Context, args ...interface{}) { if l, ok := codec.Message(ctx).Logger().(Logger); ok { l.Error(args...) @@ -284,7 +284,7 @@ func ErrorContextf(ctx context.Context, format string, args ...interface{}) { GetDefaultLogger().Errorf(format, args...) } -// FatalContext logs to ERROR log. Arguments are handled in the manner of fmt.Print. +// FatalContext logs to ERROR log. Arguments are handled in the manner of fmt.Println. // All Fatal logs will exit by calling os.Exit(1). // Implementations may also call os.Exit() with a non-zero exit code. func FatalContext(ctx context.Context, args ...interface{}) { diff --git a/log/logger.go b/log/logger.go index ff1a208..9c8c31f 100644 --- a/log/logger.go +++ b/log/logger.go @@ -74,27 +74,27 @@ type Field struct { // Logger is the underlying logging work for tRPC framework. type Logger interface { - // Trace logs to TRACE log. Arguments are handled in the manner of fmt.Print. + // Trace logs to TRACE log. Arguments are handled in the manner of fmt.Println. Trace(args ...interface{}) // Tracef logs to TRACE log. Arguments are handled in the manner of fmt.Printf. Tracef(format string, args ...interface{}) - // Debug logs to DEBUG log. Arguments are handled in the manner of fmt.Print. + // Debug logs to DEBUG log. Arguments are handled in the manner of fmt.Println. Debug(args ...interface{}) // Debugf logs to DEBUG log. Arguments are handled in the manner of fmt.Printf. Debugf(format string, args ...interface{}) - // Info logs to INFO log. Arguments are handled in the manner of fmt.Print. + // Info logs to INFO log. Arguments are handled in the manner of fmt.Println. Info(args ...interface{}) // Infof logs to INFO log. Arguments are handled in the manner of fmt.Printf. Infof(format string, args ...interface{}) - // Warn logs to WARNING log. Arguments are handled in the manner of fmt.Print. + // Warn logs to WARNING log. Arguments are handled in the manner of fmt.Println. Warn(args ...interface{}) // Warnf logs to WARNING log. Arguments are handled in the manner of fmt.Printf. Warnf(format string, args ...interface{}) - // Error logs to ERROR log. Arguments are handled in the manner of fmt.Print. + // Error logs to ERROR log. Arguments are handled in the manner of fmt.Println. Error(args ...interface{}) // Errorf logs to ERROR log. Arguments are handled in the manner of fmt.Printf. Errorf(format string, args ...interface{}) - // Fatal logs to ERROR log. Arguments are handled in the manner of fmt.Print. + // Fatal logs to ERROR log. Arguments are handled in the manner of fmt.Println. // All Fatal logs will exit by calling os.Exit(1). // Implementations may also call os.Exit() with a non-zero exit code. Fatal(args ...interface{}) diff --git a/log/zaplogger.go b/log/zaplogger.go index 974618f..4b61c5e 100644 --- a/log/zaplogger.go +++ b/log/zaplogger.go @@ -282,7 +282,8 @@ func (l *zapLog) With(fields ...Field) Logger { } func getLogMsg(args ...interface{}) string { - msg := fmt.Sprint(args...) + msg := fmt.Sprintln(args...) + msg = msg[:len(msg)-1] report.LogWriteSize.IncrBy(float64(len(msg))) return msg } @@ -293,7 +294,7 @@ func getLogMsgf(format string, args ...interface{}) string { return msg } -// Trace logs to TRACE log. Arguments are handled in the manner of fmt.Print. +// Trace logs to TRACE log. Arguments are handled in the manner of fmt.Println. func (l *zapLog) Trace(args ...interface{}) { if l.logger.Core().Enabled(zapcore.DebugLevel) { l.logger.Debug(getLogMsg(args...)) @@ -307,7 +308,7 @@ func (l *zapLog) Tracef(format string, args ...interface{}) { } } -// Debug logs to DEBUG log. Arguments are handled in the manner of fmt.Print. +// Debug logs to DEBUG log. Arguments are handled in the manner of fmt.Println. func (l *zapLog) Debug(args ...interface{}) { if l.logger.Core().Enabled(zapcore.DebugLevel) { l.logger.Debug(getLogMsg(args...)) @@ -321,7 +322,7 @@ func (l *zapLog) Debugf(format string, args ...interface{}) { } } -// Info logs to INFO log. Arguments are handled in the manner of fmt.Print. +// Info logs to INFO log. Arguments are handled in the manner of fmt.Println. func (l *zapLog) Info(args ...interface{}) { if l.logger.Core().Enabled(zapcore.InfoLevel) { l.logger.Info(getLogMsg(args...)) @@ -335,7 +336,7 @@ func (l *zapLog) Infof(format string, args ...interface{}) { } } -// Warn logs to WARNING log. Arguments are handled in the manner of fmt.Print. +// Warn logs to WARNING log. Arguments are handled in the manner of fmt.Println. func (l *zapLog) Warn(args ...interface{}) { if l.logger.Core().Enabled(zapcore.WarnLevel) { l.logger.Warn(getLogMsg(args...)) @@ -349,7 +350,7 @@ func (l *zapLog) Warnf(format string, args ...interface{}) { } } -// Error logs to ERROR log. Arguments are handled in the manner of fmt.Print. +// Error logs to ERROR log. Arguments are handled in the manner of fmt.Println. func (l *zapLog) Error(args ...interface{}) { if l.logger.Core().Enabled(zapcore.ErrorLevel) { l.logger.Error(getLogMsg(args...)) @@ -363,7 +364,7 @@ func (l *zapLog) Errorf(format string, args ...interface{}) { } } -// Fatal logs to FATAL log. Arguments are handled in the manner of fmt.Print. +// Fatal logs to FATAL log. Arguments are handled in the manner of fmt.Println. func (l *zapLog) Fatal(args ...interface{}) { if l.logger.Core().Enabled(zapcore.FatalLevel) { l.logger.Fatal(getLogMsg(args...)) From 1120e8e6c3f9109934e6da925970c4bd58daffb9 Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 20:14:15 +0800 Subject: [PATCH 22/39] **stream:**: return an error when receiving an unexpected frame type --- stream/client.go | 12 +++++++----- stream/client_test.go | 12 ++++++++++++ stream/flow_control.go | 9 --------- stream/flow_control_test.go | 8 ++------ 4 files changed, 21 insertions(+), 20 deletions(-) diff --git a/stream/client.go b/stream/client.go index be7b754..c2de5f3 100644 --- a/stream/client.go +++ b/stream/client.go @@ -17,6 +17,7 @@ package stream import ( "context" "errors" + "fmt" "io" "sync" "sync/atomic" @@ -252,12 +253,13 @@ func (cs *clientStream) invoke(ctx context.Context, _ *client.ClientStreamDesc) if _, err := cs.stream.Recv(newCtx); err != nil { return nil, err } - - initWindowSize := defaultInitWindowSize - if initRspMeta, ok := newMsg.StreamFrame().(*trpcpb.TrpcStreamInitMeta); ok { - // If the server has feedback, use the server's window, if not, use the default window. - initWindowSize = initRspMeta.GetInitWindowSize() + initRspMeta, ok := newMsg.StreamFrame().(*trpcpb.TrpcStreamInitMeta) + if !ok { + return nil, fmt.Errorf("client stream (method = %s, streamID = %d) recv "+ + "unexpected frame type: %T, expected: %T", + cs.method, cs.streamID, newMsg.StreamFrame(), (*trpcpb.TrpcStreamInitMeta)(nil)) } + initWindowSize := initRspMeta.GetInitWindowSize() cs.configSendControl(initWindowSize) // Start the dispatch goroutine loop to send packets. diff --git a/stream/client_test.go b/stream/client_test.go index 3fbc7a6..430fd0f 100644 --- a/stream/client_test.go +++ b/stream/client_test.go @@ -333,6 +333,18 @@ func TestClientError(t *testing.T) { assert.Nil(t, cs) assert.NotNil(t, err) + // receive unexpected stream frame type + f = func(fh *trpc.FrameHead, msg codec.Msg) ([]byte, error) { + msg.WithStreamFrame(int(1)) + return nil, nil + } + ft.expectChan <- f + cs, err = cli.NewStream(ctx, bidiDesc, "/trpc.test.helloworld.Greeter/SayHello", + client.WithTarget("ip://127.0.0.1:8000"), + client.WithProtocol("fake"), client.WithSerializationType(codec.SerializationTypeNoop), + client.WithStreamTransport(ft), client.WithClientStreamQueueSize(100000)) + assert.Nil(t, cs) + assert.Contains(t, err.Error(), "unexpected frame type") } // TestClientContext tests the case of streaming client context cancel and timeout. diff --git a/stream/flow_control.go b/stream/flow_control.go index 1395345..a379897 100644 --- a/stream/flow_control.go +++ b/stream/flow_control.go @@ -85,7 +85,6 @@ func checkUpdate(updatedWindow, increment int64) bool { type receiveControl struct { buffer uint32 // upper limit. unUpdated uint32 // Consumed, no window update sent. - left uint32 // remaining available buffer. fb feedback // function for feedback. } @@ -93,7 +92,6 @@ func newReceiveControl(buffer uint32, fb feedback) *receiveControl { return &receiveControl{ buffer: buffer, fb: fb, - left: buffer, } } @@ -103,17 +101,10 @@ func (r *receiveControl) OnRecv(n uint32) error { if r.unUpdated >= r.buffer/4 { increment := r.unUpdated r.unUpdated = 0 - r.updateLeft() if r.fb != nil { return r.fb(increment) } return nil } - r.updateLeft() return nil } - -// updateLeft updates the remaining available buffers. -func (r *receiveControl) updateLeft() { - atomic.StoreUint32(&r.left, r.buffer-r.unUpdated) -} diff --git a/stream/flow_control_test.go b/stream/flow_control_test.go index 4ffa4e4..e0b4935 100644 --- a/stream/flow_control_test.go +++ b/stream/flow_control_test.go @@ -15,7 +15,6 @@ package stream import ( "errors" - "sync/atomic" "testing" "time" @@ -41,8 +40,8 @@ func TestSendControl(t *testing.T) { }() err = sc.GetWindow(200) assert.Nil(t, err) - t2 := int64(time.Now().Sub(t1)) - assert.GreaterOrEqual(t, t2, int64(500*time.Millisecond)) + t2 := time.Since(t1) + assert.GreaterOrEqual(t, t2, 500*time.Millisecond) } // TestReceiveControl test. @@ -54,9 +53,6 @@ func TestReceiveControl(t *testing.T) { err := rc.OnRecv(100) assert.Nil(t, err) - n := atomic.LoadUint32(&rc.left) - assert.Equal(t, defaultInitWindowSize-uint32(100), n) - // need to send updates. err = rc.OnRecv(defaultInitWindowSize / 4) assert.Nil(t, err) From 86096b64a7672adcc8ae16f7d0c6ead6132024e2 Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 20:18:48 +0800 Subject: [PATCH 23/39] **stream:** fix client's compression type setting not working --- codec/message_impl.go | 3 +++ stream/client.go | 3 +++ stream/client_test.go | 37 +++++++++++++++++++++++++++++++++++++ stream/server.go | 3 ++- 4 files changed, 45 insertions(+), 1 deletion(-) diff --git a/codec/message_impl.go b/codec/message_impl.go index e4d78c0..00ee2c0 100644 --- a/codec/message_impl.go +++ b/codec/message_impl.go @@ -616,6 +616,9 @@ func WithCloneContextAndMessage(ctx context.Context) (context.Context, Msg) { // copyCommonMessage copy common data of message. func copyCommonMessage(m *msg, newMsg *msg) { + // Do not copy compress type here, as it will cause subsequence RPC calls to inherit the upstream + // compress type which is not the expected behavior. Compress type should not be propagated along + // the entire RPC invocation chain. newMsg.frameHead = m.frameHead newMsg.requestTimeout = m.requestTimeout newMsg.serializationType = m.serializationType diff --git a/stream/client.go b/stream/client.go index c2de5f3..839d6de 100644 --- a/stream/client.go +++ b/stream/client.go @@ -189,6 +189,7 @@ func (cs *clientStream) SendMsg(m interface{}) error { msg.WithStreamID(cs.streamID) msg.WithClientRPCName(cs.method) msg.WithCalleeMethod(icodec.MethodFromRPCName(cs.method)) + msg.WithCompressType(codec.Message(cs.ctx).CompressType()) return cs.stream.Send(ctx, m) } @@ -240,6 +241,7 @@ func (cs *clientStream) invoke(ctx context.Context, _ *client.ClientStreamDesc) newMsg.WithClientRPCName(cs.method) newMsg.WithCalleeMethod(icodec.MethodFromRPCName(cs.method)) newMsg.WithStreamID(cs.streamID) + newMsg.WithCompressType(codec.Message(cs.ctx).CompressType()) newMsg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{ RequestMeta: &trpcpb.TrpcStreamInitRequestMeta{}, InitWindowSize: w, @@ -338,6 +340,7 @@ func (cs *clientStream) dispatch() { }() for { ctx, msg := codec.WithCloneContextAndMessage(cs.ctx) + msg.WithCompressType(codec.Message(cs.ctx).CompressType()) msg.WithStreamID(cs.streamID) respData, err := cs.stream.Recv(ctx) if err != nil { diff --git a/stream/client_test.go b/stream/client_test.go index 430fd0f..cded911 100644 --- a/stream/client_test.go +++ b/stream/client_test.go @@ -801,3 +801,40 @@ func TestClientNewStreamFail(t *testing.T) { assert.True(t, isClosed) }) } + +func TestClientServerCompress(t *testing.T) { + var ( + dataLen = 1024 + compressType = codec.CompressTypeSnappy + ) + svrOpts := []server.Option{ + server.WithAddress("127.0.0.1:30211"), + } + handle := func(s server.Stream) error { + assert.Equal(t, compressType, codec.Message(s.Context()).CompressType()) + req := getBytes(dataLen) + s.RecvMsg(req) + rsp := req + s.SendMsg(rsp) + return nil + } + svr := startStreamServer(handle, svrOpts) + defer closeStreamServer(svr) + + cliOpts := []client.Option{ + client.WithTarget("ip://127.0.0.1:30211"), + client.WithCompressType(compressType), + } + + clientStream, err := getClientStream(context.Background(), clientDesc, cliOpts) + assert.Nil(t, err) + req := getBytes(dataLen) + rand.Read(req.Data) + err = clientStream.SendMsg(req) + assert.Nil(t, err) + + rsp := getBytes(dataLen) + err = clientStream.RecvMsg(rsp) + assert.Equal(t, rsp.Data, req.Data) + assert.Nil(t, err) +} diff --git a/stream/server.go b/stream/server.go index f87053b..31c3ed5 100644 --- a/stream/server.go +++ b/stream/server.go @@ -57,6 +57,7 @@ func (s *serverStream) SendMsg(m interface{}) error { defer codec.PutBackMessage(newMsg) newMsg.WithLocalAddr(msg.LocalAddr()) newMsg.WithRemoteAddr(msg.RemoteAddr()) + newMsg.WithCompressType(msg.CompressType()) newMsg.WithStreamID(s.streamID) // Refer to the pb code generated by trpc.proto, common to each language, automatically generated code. newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA, s.streamID)) @@ -65,7 +66,7 @@ func (s *serverStream) SendMsg(m interface{}) error { err error reqBodyBuffer []byte ) - serializationType, compressType := s.serializationAndCompressType(msg) + serializationType, compressType := s.serializationAndCompressType(newMsg) if icodec.IsValidSerializationType(serializationType) { reqBodyBuffer, err = codec.Marshal(serializationType, m) if err != nil { From 562fb9db4ff713d638551064e1798f9391201ac9 Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 22 Jan 2024 20:31:33 +0800 Subject: [PATCH 24/39] stream: fix server stream being overwritten --- stream/server.go | 51 +++++++++++++------------- stream/server_test.go | 84 +++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 108 insertions(+), 27 deletions(-) diff --git a/stream/server.go b/stream/server.go index 31c3ed5..0a669f0 100644 --- a/stream/server.go +++ b/stream/server.go @@ -17,10 +17,10 @@ import ( "context" "errors" "io" - "net" "sync" "go.uber.org/atomic" + "trpc.group/trpc-go/trpc-go/internal/addrutil" trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" trpc "trpc.group/trpc-go/trpc-go" @@ -222,9 +222,12 @@ func (s *serverStream) Context() context.Context { // The structure of streamDispatcher is used to distribute streaming data. type streamDispatcher struct { - m sync.RWMutex - streamIDToServerStream map[net.Addr]map[uint32]*serverStream - opts *server.Options + m sync.RWMutex + // local address + remote address + network + // => stream ID + // => serverStream + addrToServerStream map[string]map[uint32]*serverStream + opts *server.Options } // DefaultStreamDispatcher is the default implementation of the trpc dispatcher, @@ -234,45 +237,45 @@ var DefaultStreamDispatcher = NewStreamDispatcher() // NewStreamDispatcher returns a new dispatcher. func NewStreamDispatcher() server.StreamHandle { return &streamDispatcher{ - streamIDToServerStream: make(map[net.Addr]map[uint32]*serverStream), + addrToServerStream: make(map[string]map[uint32]*serverStream), } } // storeServerStream msg contains the socket address of the client connection, // there are multiple streams under each socket address, and map it to serverStream // again according to the id of the stream. -func (sd *streamDispatcher) storeServerStream(addr net.Addr, streamID uint32, ss *serverStream) { +func (sd *streamDispatcher) storeServerStream(addr string, streamID uint32, ss *serverStream) { sd.m.Lock() defer sd.m.Unlock() - if addrToStreamID, ok := sd.streamIDToServerStream[addr]; !ok { + if addrToStreamID, ok := sd.addrToServerStream[addr]; !ok { // Does not exist, indicating that a new connection is coming, re-create the structure. - sd.streamIDToServerStream[addr] = map[uint32]*serverStream{streamID: ss} + sd.addrToServerStream[addr] = map[uint32]*serverStream{streamID: ss} } else { addrToStreamID[streamID] = ss } } // deleteServerStream deletes the serverStream from cache. -func (sd *streamDispatcher) deleteServerStream(addr net.Addr, streamID uint32) { +func (sd *streamDispatcher) deleteServerStream(addr string, streamID uint32) { sd.m.Lock() defer sd.m.Unlock() - if addrToStreamID, ok := sd.streamIDToServerStream[addr]; ok { + if addrToStreamID, ok := sd.addrToServerStream[addr]; ok { if _, ok = addrToStreamID[streamID]; ok { delete(addrToStreamID, streamID) } if len(addrToStreamID) == 0 { - delete(sd.streamIDToServerStream, addr) + delete(sd.addrToServerStream, addr) } } } // loadServerStream loads the stored serverStream through the socket address // of the client connection and the id of the stream. -func (sd *streamDispatcher) loadServerStream(addr net.Addr, streamID uint32) (*serverStream, error) { +func (sd *streamDispatcher) loadServerStream(addr string, streamID uint32) (*serverStream, error) { sd.m.RLock() defer sd.m.RUnlock() - addrToStream, ok := sd.streamIDToServerStream[addr] - if !ok || addr == nil { + addrToStream, ok := sd.addrToServerStream[addr] + if !ok { return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchAddr) } @@ -298,7 +301,7 @@ func (sd *streamDispatcher) Init(opts *server.Options) error { // startStreamHandler is used to start the goroutine, execute streamHandler, // streamHandler is implemented for the specific streaming server. -func (sd *streamDispatcher) startStreamHandler(addr net.Addr, streamID uint32, +func (sd *streamDispatcher) startStreamHandler(addr string, streamID uint32, ss *serverStream, si *server.StreamServerInfo, sh server.StreamHandler) { defer func() { sd.deleteServerStream(addr, streamID) @@ -362,7 +365,7 @@ func (sd *streamDispatcher) handleInit(ctx context.Context, ss := newServerStream(ctx, streamID, sd.opts) w := getWindowSize(sd.opts.MaxWindowSize) ss.rControl = newReceiveControl(w, ss.feedback) - sd.storeServerStream(msg.RemoteAddr(), streamID, ss) + sd.storeServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), streamID, ss) cw, err := ss.setSendControl(msg) if err != nil { @@ -395,13 +398,13 @@ func (sd *streamDispatcher) handleInit(ctx context.Context, } // Initiate a goroutine to execute specific business logic. - go sd.startStreamHandler(msg.RemoteAddr(), streamID, ss, si, sh) + go sd.startStreamHandler(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), streamID, ss, si, sh) return nil, errs.ErrServerNoResponse } // handleData handles data messages. func (sd *streamDispatcher) handleData(msg codec.Msg, req []byte) ([]byte, error) { - ss, err := sd.loadServerStream(msg.RemoteAddr(), msg.StreamID()) + ss, err := sd.loadServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), msg.StreamID()) if err != nil { return nil, err } @@ -411,7 +414,7 @@ func (sd *streamDispatcher) handleData(msg codec.Msg, req []byte) ([]byte, error // handleClose handles the Close message. func (sd *streamDispatcher) handleClose(msg codec.Msg) ([]byte, error) { - ss, err := sd.loadServerStream(msg.RemoteAddr(), msg.StreamID()) + ss, err := sd.loadServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), msg.StreamID()) if err != nil { // The server has sent the Close frame. // Since the timing of the Close frame is unpredictable, when the server receives the Close frame from the client, @@ -434,9 +437,9 @@ func (sd *streamDispatcher) handleError(msg codec.Msg) ([]byte, error) { sd.m.Lock() defer sd.m.Unlock() - addr := msg.RemoteAddr() - addrToStream, ok := sd.streamIDToServerStream[addr] - if !ok || addr == nil { + addr := addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()) + addrToStream, ok := sd.addrToServerStream[addr] + if !ok { return nil, errs.NewFrameError(errs.RetServerSystemErr, noSuchAddr) } for streamID, ss := range addrToStream { @@ -444,7 +447,7 @@ func (sd *streamDispatcher) handleError(msg codec.Msg) ([]byte, error) { ss.once.Do(func() { close(ss.done) }) delete(addrToStream, streamID) } - delete(sd.streamIDToServerStream, addr) + delete(sd.addrToServerStream, addr) return nil, errs.ErrServerNoResponse } @@ -467,7 +470,7 @@ func (sd *streamDispatcher) StreamHandleFunc(ctx context.Context, // handleFeedback handles the feedback frame. func (sd *streamDispatcher) handleFeedback(msg codec.Msg) ([]byte, error) { - ss, err := sd.loadServerStream(msg.RemoteAddr(), msg.StreamID()) + ss, err := sd.loadServerStream(addrutil.AddrToKey(msg.LocalAddr(), msg.RemoteAddr()), msg.StreamID()) if err != nil { return nil, err } diff --git a/stream/server_test.go b/stream/server_test.go index 3a4edfd..4fd79ae 100644 --- a/stream/server_test.go +++ b/stream/server_test.go @@ -148,6 +148,7 @@ func TestStreamDispatcherHandleInit(t *testing.T) { msg.WithFrameHead(fh) msg.WithStreamID(uint32(100)) msg.WithRemoteAddr(&fakeAddr{}) + msg.WithLocalAddr(&fakeAddr{}) rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init")) assert.Nil(t, rsp) assert.Equal(t, err, errs.ErrServerNoResponse) @@ -208,6 +209,7 @@ func TestStreamDispatcherHandleData(t *testing.T) { msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) addr := &fakeAddr{} msg.WithRemoteAddr(addr) + msg.WithLocalAddr(addr) rsp, err := dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init")) assert.Nil(t, rsp) assert.Equal(t, err, errs.ErrServerNoResponse) @@ -220,7 +222,8 @@ func TestStreamDispatcherHandleData(t *testing.T) { assert.Equal(t, err, errs.ErrServerNoResponse) // handleData error no such addr - msg.WithRemoteAddr(nil) + raddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:1") + msg.WithRemoteAddr(raddr) fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) msg.WithFrameHead(fh) rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("data")) @@ -265,6 +268,7 @@ func TestStreamDispatcherHandleClose(t *testing.T) { addr := &fakeAddr{} msg.WithRemoteAddr(addr) + msg.WithLocalAddr(addr) msg.WithFrameHead(fh) rsp, err := dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("init")) assert.Nil(t, rsp) @@ -279,7 +283,8 @@ func TestStreamDispatcherHandleClose(t *testing.T) { // handle close no such addr msg.WithFrameHead(fh) - msg.WithRemoteAddr(nil) + raddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:1") + msg.WithRemoteAddr(raddr) rsp, err = dispatcher.StreamHandleFunc(ctx, streamHandler, si, []byte("close")) assert.Nil(t, rsp) assert.Equal(t, errs.ErrServerNoResponse, err) @@ -331,6 +336,7 @@ func TestServerStreamSendMsg(t *testing.T) { msg.WithFrameHead(fh) msg.WithStreamID(uint32(100)) msg.WithRemoteAddr(&fakeAddr{}) + msg.WithLocalAddr(&fakeAddr{}) msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) opts.CurrentCompressType = codec.CompressTypeNoop @@ -408,6 +414,7 @@ func TestServerStreamRecvMsg(t *testing.T) { msg.WithFrameHead(fh) msg.WithStreamID(uint32(100)) msg.WithRemoteAddr(&fakeAddr{}) + msg.WithLocalAddr(&fakeAddr{}) msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) opts.CurrentCompressType = codec.CompressTypeNoop opts.CurrentSerializationType = codec.SerializationTypeNoop @@ -468,6 +475,7 @@ func TestServerStreamRecvMsgFail(t *testing.T) { msg.WithFrameHead(fh) msg.WithStreamID(uint32(100)) msg.WithRemoteAddr(&fakeAddr{}) + msg.WithLocalAddr(&fakeAddr{}) msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) opts.CurrentCompressType = codec.CompressTypeGzip @@ -521,6 +529,7 @@ func TestHandleError(t *testing.T) { msg.WithFrameHead(fh) msg.WithStreamID(uint32(100)) msg.WithRemoteAddr(&fakeAddr{}) + msg.WithLocalAddr(&fakeAddr{}) msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) opts.CurrentCompressType = codec.CompressTypeGzip @@ -586,12 +595,15 @@ func TestStreamDispatcherHandleFeedback(t *testing.T) { addr := &fakeAddr{} msg.WithRemoteAddr(addr) + msg.WithLocalAddr(addr) rsp, err := dispatcher.StreamHandleFunc(ctx, sh, si, []byte("init")) assert.Nil(t, rsp) assert.Equal(t, err, errs.ErrServerNoResponse) // handle feedback get server stream fail - msg.WithRemoteAddr(nil) + raddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:1") + msg.WithRemoteAddr(raddr) + msg.WithLocalAddr(raddr) fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK) msg.WithFrameHead(fh) rsp, err = dispatcher.StreamHandleFunc(ctx, nil, si, []byte("feedback")) @@ -600,6 +612,7 @@ func TestStreamDispatcherHandleFeedback(t *testing.T) { // handle feedback invalid stream msg.WithRemoteAddr(addr) + msg.WithLocalAddr(addr) fh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK) msg.WithFrameHead(fh) rsp, err = dispatcher.StreamHandleFunc(ctx, nil, si, []byte("feedback")) @@ -636,6 +649,7 @@ func TestServerFlowControl(t *testing.T) { msg.WithStreamID(uint32(100)) addr := &fakeAddr{} msg.WithRemoteAddr(addr) + msg.WithLocalAddr(addr) msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{InitWindowSize: 65535}) opts.CurrentCompressType = codec.CompressTypeNoop opts.CurrentSerializationType = codec.SerializationTypeNoop @@ -662,6 +676,7 @@ func TestServerFlowControl(t *testing.T) { newCtx, newMsg := codec.WithNewMessage(newCtx) newMsg.WithStreamID(uint32(100)) newMsg.WithRemoteAddr(addr) + newMsg.WithLocalAddr(addr) newFh := &trpc.FrameHead{} newFh.StreamID = uint32(100) newFh.StreamFrameType = uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA) @@ -965,3 +980,66 @@ func TestServerStreamAllFailWhenConnectionClosedAndReconnect(t *testing.T) { // Wait server sending and receiving result assertion <-ch } + +func TestSameClientAddrDiffServerAddr(t *testing.T) { + dp := stream.NewStreamDispatcher() + dp.Init(&server.Options{ + Transport: &fakeServerTransport{}, + Codec: &fakeServerCodec{}, + CurrentSerializationType: codec.SerializationTypeNoop}) + wg := sync.WaitGroup{} + + initFrame := func(localAddr, remoteAddr net.Addr) { + ctx, msg := codec.WithNewMessage(context.Background()) + msg.WithFrameHead(&trpc.FrameHead{StreamFrameType: uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT)}) + msg.WithStreamID(200) + msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) + msg.WithRemoteAddr(remoteAddr) + msg.WithLocalAddr(localAddr) + msg.WithSerializationType(codec.SerializationTypeNoop) + wg.Add(1) + rsp, err := dp.StreamHandleFunc( + ctx, + func(s server.Stream) error { + err := s.RecvMsg(&codec.Body{}) + assert.Nil(t, err) + wg.Done() + return nil + }, + &server.StreamServerInfo{}, + []byte("init")) + assert.Nil(t, rsp) + assert.Equal(t, errs.ErrServerNoResponse, err) + } + + dataFrame := func(localAddr, remoteAddr net.Addr) { + ctx, msg := codec.WithNewMessage(context.Background()) + msg.WithFrameHead(&trpc.FrameHead{StreamFrameType: uint8(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA)}) + msg.WithStreamID(200) + msg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{}) + msg.WithRemoteAddr(remoteAddr) + msg.WithLocalAddr(localAddr) + rsp, err := dp.StreamHandleFunc(ctx, nil, &server.StreamServerInfo{}, []byte("data")) + assert.Nil(t, rsp) + assert.Equal(t, errs.ErrServerNoResponse, err) + } + + clientAddr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:9000") + serverAddr1, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:10001") + serverAddr2, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:10002") + initFrame(serverAddr1, clientAddr) + initFrame(serverAddr2, clientAddr) + dataFrame(serverAddr1, clientAddr) + dataFrame(serverAddr2, clientAddr) + + c := make(chan struct{}) + go func() { + defer close(c) + wg.Wait() + }() + select { + case <-c: // completed normally + case <-time.After(time.Millisecond * 500): // timed out + assert.FailNow(t, "server did not receive data frame") + } +} From b58f1e0bc1c99b6ff823b6aff267e3bbbe1faec3 Mon Sep 17 00:00:00 2001 From: goodliu Date: Tue, 23 Jan 2024 10:10:23 +0800 Subject: [PATCH 25/39] **http:** fix form serialization panicking for compatibility --- http/serialization_form.go | 35 +++++++++++++++++++++++++++------ http/serialization_form_test.go | 14 +++++++++++++ 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/http/serialization_form.go b/http/serialization_form.go index a273b99..916b067 100644 --- a/http/serialization_form.go +++ b/http/serialization_form.go @@ -41,16 +41,16 @@ func NewFormSerialization(tag string) codec.Serializer { decoder.SetTagName(tag) return &FormSerialization{ tagname: tag, - encoder: encoder, - decoder: decoder, + encode: encoder.Encode, + decode: wrapDecodeWithRecovery(decoder.Decode), } } // FormSerialization packages the kv structure of http get request. type FormSerialization struct { tagname string - encoder *form.Encoder - decoder *form.Decoder + encode func(interface{}) (url.Values, error) + decode func(interface{}, url.Values) error } // Unmarshal unpacks kv structure. @@ -68,7 +68,7 @@ func (j *FormSerialization) Unmarshal(in []byte, body interface{}) error { } // First try using go-playground/form, it can handle nested struct. // But it cannot handle Chinese characters in byte slice. - err = j.decoder.Decode(body, values) + err = j.decode(body, values) if err == nil { return nil } @@ -79,6 +79,29 @@ func (j *FormSerialization) Unmarshal(in []byte, body interface{}) error { return nil } +// wrapDecodeWithRecovery wraps the decode function, adding panic recovery to handle +// panics as errors. This function is designed to prevent malformed query parameters +// from causing a panic, which is the default behavior of the go-playground/form decoder +// implementation. This is because, in certain cases, it's more acceptable to receive +// a degraded result rather than experiencing a direct server crash. +// Besides, the behavior of not panicking also ensures backward compatibility ( Date: Tue, 23 Jan 2024 10:16:26 +0800 Subject: [PATCH 26/39] **client:** fix client wildcard match for config --- client/config.go | 5 +++++ client/config_test.go | 22 ++++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/client/config.go b/client/config.go index e1747ab..64c12ea 100644 --- a/client/config.go +++ b/client/config.go @@ -374,6 +374,11 @@ func RegisterConfig(conf map[string]*BackendConfig) error { // RegisterClientConfig is called to replace backend config of single callee service by name. func RegisterClientConfig(callee string, conf *BackendConfig) error { + if callee == "*" { + // Reset the callee and service name to enable wildcard matching. + conf.Callee = "" + conf.ServiceName = "" + } opts, err := conf.genOptions() if err != nil { return err diff --git a/client/config_test.go b/client/config_test.go index 9de0ef9..639877a 100644 --- a/client/config_test.go +++ b/client/config_test.go @@ -312,3 +312,25 @@ func TestConfig(t *testing.T) { } require.Nil(t, client.RegisterClientConfig("trpc.test.helloworld3", backconfig)) } + +func TestRegisterWildcardClient(t *testing.T) { + cfg := client.Config("*") + t.Cleanup(func() { + client.RegisterClientConfig("*", cfg) + }) + client.RegisterClientConfig("*", &client.BackendConfig{ + DisableServiceRouter: true, + }) + + ch := make(chan *client.Options, 1) + c := client.New() + ctx, _ := codec.EnsureMessage(context.Background()) + require.Nil(t, c.Invoke(ctx, nil, nil, client.WithFilter( + func(ctx context.Context, _, _ interface{}, _ filter.ClientHandleFunc) error { + ch <- client.OptionsFromContext(ctx) + // Skip next. + return nil + }))) + opts := <-ch + require.True(t, opts.DisableServiceRouter) +} From 944802a62f43096016e6b4d570438b286058cc97 Mon Sep 17 00:00:00 2001 From: goodliu Date: Tue, 23 Jan 2024 10:26:04 +0800 Subject: [PATCH 27/39] **codec:** revert "optimize performance of extracting method name out of rpc name" --- codec.go | 5 +-- codec/message_impl.go | 49 +++++++++++++++++++++++ codec/message_internal_test.go | 67 ++++++++++++++++++++++++++++++++ codec/message_test.go | 41 +++++++++++++++++++ codec_stream.go | 9 +---- http/codec.go | 2 +- http/restful_server_transport.go | 1 - internal/codec/method.go | 8 ---- restful/router.go | 1 - stream/client.go | 4 -- 10 files changed, 161 insertions(+), 26 deletions(-) create mode 100644 codec/message_internal_test.go delete mode 100644 internal/codec/method.go diff --git a/codec.go b/codec.go index b01e780..2fa0d0d 100644 --- a/codec.go +++ b/codec.go @@ -27,7 +27,6 @@ import ( "trpc.group/trpc-go/trpc-go/codec" "trpc.group/trpc-go/trpc-go/errs" "trpc.group/trpc-go/trpc-go/internal/attachment" - icodec "trpc.group/trpc-go/trpc-go/internal/codec" "trpc.group/trpc-go/trpc-go/transport" "google.golang.org/protobuf/proto" @@ -314,9 +313,7 @@ func msgWithRequestProtocol(msg codec.Msg, req *trpcpb.RequestProtocol, attm []b msg.WithCallerServiceName(string(req.GetCaller())) msg.WithCalleeServiceName(string(req.GetCallee())) // set server handler method name - rpcName := string(req.GetFunc()) - msg.WithServerRPCName(rpcName) - msg.WithCalleeMethod(icodec.MethodFromRPCName(rpcName)) + msg.WithServerRPCName(string(req.GetFunc())) // set body serialization type msg.WithSerializationType(int(req.GetContentType())) // set body compression type diff --git a/codec/message_impl.go b/codec/message_impl.go index 00ee2c0..1cc89e0 100644 --- a/codec/message_impl.go +++ b/codec/message_impl.go @@ -259,6 +259,10 @@ func (m *msg) WithClientRPCName(s string) { } func (m *msg) updateMethodNameUsingRPCName(s string) { + if rpcNameIsTRPCForm(s) { + m.WithCalleeMethod(methodFromRPCName(s)) + return + } if m.CalleeMethod() == "" { m.WithCalleeMethod(s) } @@ -732,3 +736,48 @@ func getAppServerService(s string) (app, server, service string) { service = s[j:] return } + +// methodFromRPCName returns the method parsed from rpc string. +func methodFromRPCName(s string) string { + return s[strings.LastIndex(s, "/")+1:] +} + +// rpcNameIsTRPCForm checks whether the given string is of trpc form. +// It is equivalent to: +// +// var r = regexp.MustCompile(`^/[^/.]+\.[^/]+/[^/.]+$`) +// +// func rpcNameIsTRPCForm(s string) bool { +// return r.MatchString(s) +// } +// +// But regexp is much slower than the current version. +// Refer to BenchmarkRPCNameIsTRPCForm in message_bench_test.go. +func rpcNameIsTRPCForm(s string) bool { + if len(s) == 0 { + return false + } + if s[0] != '/' { // ^/ + return false + } + const start = 1 + firstDot := strings.Index(s[start:], ".") + if firstDot == -1 || firstDot == 0 { // [^.]+\. + return false + } + if strings.Contains(s[start:start+firstDot], "/") { // [^/]+\. + return false + } + secondSlash := strings.Index(s[start+firstDot:], "/") + if secondSlash == -1 || secondSlash == 1 { // [^/]+/ + return false + } + if start+firstDot+secondSlash == len(s)-1 { // The second slash should not be the last character. + return false + } + const offset = 1 + if strings.ContainsAny(s[start+firstDot+secondSlash+offset:], "/.") { // [^/.]+$ + return false + } + return true +} diff --git a/codec/message_internal_test.go b/codec/message_internal_test.go new file mode 100644 index 0000000..d9a546b --- /dev/null +++ b/codec/message_internal_test.go @@ -0,0 +1,67 @@ +package codec + +import ( + "regexp" + "testing" + + "github.com/stretchr/testify/require" +) + +func BenchmarkRPCNameIsTRPCForm(b *testing.B) { + rpcNames := []string{ + "/trpc.app.server.service/method", + "/sdadfasd/xadfasdf/zxcasd/asdfasd/v2", + "trpc.app.server.service", + "/trpc.app.server.service", + "/trpc.app.", + "/trpc/asdf/asdf", + "/trpc.asdfasdf/asdfasdf/sdfasdfa/", + "/trpc.app/method/", + "/trpc.app/method/hhhhh", + } + b.Run("bench regexp", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for j := range rpcNames { + rpcNameIsTRPCFormRegExp(rpcNames[j]) + } + } + }) + b.Run("bench vanilla", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for j := range rpcNames { + rpcNameIsTRPCForm(rpcNames[j]) + } + } + }) +} + +func TestEnsureEqualSemacticOfTRPCFormChecking(t *testing.T) { + rpcNames := []string{ + "/trpc.app.server.service/method", + "/trpc.app.server.service/", + "/trpc", + "//", + "/./", + "/xx/.", + "/x./method", + "/.x/method", + "/sdadfasd/xadfasdf/zxcasd/asdfasd/v2", + "trpc.app.server.service", + "/trpc.app.server.service", + "/trpc.app.", + "/trpc/asdf/asdf", + "/trpc.asdfasdf/asdfasdf/sdfasdfa/", + "/trpc.app/method/", + "/trpc.app/method/hhhhh", + } + for _, s := range rpcNames { + v1, v2 := rpcNameIsTRPCFormRegExp(s), rpcNameIsTRPCForm(s) + require.True(t, v1 == v2, "%s %v %v", s, v1, v2) + } +} + +var r = regexp.MustCompile(`^/[^/.]+\.[^/]+/[^/.]+$`) + +func rpcNameIsTRPCFormRegExp(s string) bool { + return r.MatchString(s) +} diff --git a/codec/message_test.go b/codec/message_test.go index 31f9308..799e4b5 100644 --- a/codec/message_test.go +++ b/codec/message_test.go @@ -477,3 +477,44 @@ func TestEnsureMessage(t *testing.T) { require.Equal(t, ctx, newCtx) require.Equal(t, msg, newMsg) } + +func TestSetMethodNameUsingRPCName(t *testing.T) { + msg := codec.Message(context.Background()) + testSetMethodNameUsingRPCName(t, msg, msg.WithServerRPCName) + testSetMethodNameUsingRPCName(t, msg, msg.WithClientRPCName) +} + +func testSetMethodNameUsingRPCName(t *testing.T, msg codec.Msg, msgWithRPCName func(string)) { + var cases = []struct { + name string + originalMethod string + rpcName string + expectMethod string + }{ + {"normal trpc rpc name", "", "/trpc.app.server.service/method", "method"}, + {"normal http url path", "", "/v1/subject/info/get", "/v1/subject/info/get"}, + {"invalid trpc rpc name (method name is empty)", "", "trpc.app.server.service", "trpc.app.server.service"}, + {"invalid trpc rpc name (method name is not mepty)", "/v1/subject/info/get", "trpc.app.server.service", "/v1/subject/info/get"}, + {"valid trpc rpc name will override existing method name", "/v1/subject/info/get", "/trpc.app.server.service/method", "method"}, + {"invalid trpc rpc will not override exising method name", "/v1/subject/info/get", "/trpc.app.server.service", "/v1/subject/info/get"}, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + resetMsgRPCNameAndMethodName(msg) + msg.WithCalleeMethod(tt.originalMethod) + msgWithRPCName(tt.rpcName) + method := msg.CalleeMethod() + if method != tt.expectMethod { + t.Errorf("given original method %s and rpc name %s, expect new method name %s, got %s", + tt.originalMethod, tt.rpcName, tt.expectMethod, method) + } + }) + } +} + +func resetMsgRPCNameAndMethodName(msg codec.Msg) { + msg.WithCalleeMethod("") + msg.WithClientRPCName("") + msg.WithServerRPCName("") +} diff --git a/codec_stream.go b/codec_stream.go index a7e6bad..d6ee81d 100644 --- a/codec_stream.go +++ b/codec_stream.go @@ -23,7 +23,6 @@ import ( "trpc.group/trpc-go/trpc-go/codec" "trpc.group/trpc-go/trpc-go/errs" "trpc.group/trpc-go/trpc-go/internal/addrutil" - icodec "trpc.group/trpc-go/trpc-go/internal/codec" trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" "google.golang.org/protobuf/proto" @@ -377,9 +376,7 @@ func (s *ServerStreamCodec) setInitMeta(msg codec.Msg) error { defer s.m.RUnlock() if streamIDToInitMeta, ok := s.initMetas[addr]; ok { if initMeta, ok := streamIDToInitMeta[streamID]; ok { - rpcName := string(initMeta.GetRequestMeta().GetFunc()) - msg.WithServerRPCName(rpcName) - msg.WithCalleeMethod(icodec.MethodFromRPCName(rpcName)) + msg.WithServerRPCName(string(initMeta.GetRequestMeta().GetFunc())) return nil } } @@ -468,9 +465,7 @@ func (s *ServerStreamCodec) updateMsg(msg codec.Msg, initMeta *trpcpb.TrpcStream msg.WithCallerServiceName(string(req.GetCaller())) msg.WithCalleeServiceName(string(req.GetCallee())) // set server handler method name - rpcName := string(req.GetFunc()) - msg.WithServerRPCName(rpcName) - msg.WithCalleeMethod(icodec.MethodFromRPCName(rpcName)) + msg.WithServerRPCName(string(req.GetFunc())) // set body serialization type msg.WithSerializationType(int(initMeta.GetContentType())) // set body compression type diff --git a/http/codec.go b/http/codec.go index 6f58a28..0ea39f7 100644 --- a/http/codec.go +++ b/http/codec.go @@ -361,8 +361,8 @@ func unmarshalTransInfo(msg codec.Msg, v string) (map[string][]byte, error) { // getReqbody gets the body of request. func (sc *ServerCodec) getReqbody(head *Header, msg codec.Msg) ([]byte, error) { - msg.WithServerRPCName(head.Request.URL.Path) msg.WithCalleeMethod(head.Request.URL.Path) + msg.WithServerRPCName(head.Request.URL.Path) if !sc.AutoReadBody { return nil, nil diff --git a/http/restful_server_transport.go b/http/restful_server_transport.go index a44c42b..3a4a40e 100644 --- a/http/restful_server_transport.go +++ b/http/restful_server_transport.go @@ -83,7 +83,6 @@ func putRESTMsgInCtx( ctx, msg := codec.WithNewMessage(ctx) msg.WithCalleeServiceName(service) msg.WithServerRPCName(method) - msg.WithCalleeMethod(method) msg.WithSerializationType(codec.SerializationTypePB) if v := headerGetter(TrpcTimeout); v != "" { i, _ := strconv.Atoi(v) diff --git a/internal/codec/method.go b/internal/codec/method.go deleted file mode 100644 index 5edc239..0000000 --- a/internal/codec/method.go +++ /dev/null @@ -1,8 +0,0 @@ -package codec - -import "strings" - -// MethodFromRPCName returns the method parsed from rpc string. -func MethodFromRPCName(s string) string { - return s[strings.LastIndex(s, "/")+1:] -} diff --git a/restful/router.go b/restful/router.go index 2a7df82..2683e9f 100644 --- a/restful/router.go +++ b/restful/router.go @@ -199,7 +199,6 @@ var DefaultHeaderMatcher = func( func withNewMessage(ctx context.Context, serviceName, methodName string) context.Context { ctx, msg := codec.WithNewMessage(ctx) msg.WithServerRPCName(methodName) - msg.WithCalleeMethod(methodName) msg.WithCalleeServiceName(serviceName) msg.WithSerializationType(codec.SerializationTypePB) return ctx diff --git a/stream/client.go b/stream/client.go index 839d6de..e68ba6e 100644 --- a/stream/client.go +++ b/stream/client.go @@ -188,7 +188,6 @@ func (cs *clientStream) SendMsg(m interface{}) error { msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_DATA, cs.streamID)) msg.WithStreamID(cs.streamID) msg.WithClientRPCName(cs.method) - msg.WithCalleeMethod(icodec.MethodFromRPCName(cs.method)) msg.WithCompressType(codec.Message(cs.ctx).CompressType()) return cs.stream.Send(ctx, m) } @@ -217,7 +216,6 @@ func (cs *clientStream) CloseSend() error { func (cs *clientStream) prepare(opt ...client.Option) error { msg := codec.Message(cs.ctx) msg.WithClientRPCName(cs.method) - msg.WithCalleeMethod(icodec.MethodFromRPCName(cs.method)) msg.WithStreamID(cs.streamID) opt = append([]client.Option{client.WithStreamTransport(transport.DefaultClientStreamTransport)}, opt...) @@ -239,7 +237,6 @@ func (cs *clientStream) invoke(ctx context.Context, _ *client.ClientStreamDesc) copyMetaData(newMsg, codec.Message(cs.ctx)) newMsg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_INIT, cs.streamID)) newMsg.WithClientRPCName(cs.method) - newMsg.WithCalleeMethod(icodec.MethodFromRPCName(cs.method)) newMsg.WithStreamID(cs.streamID) newMsg.WithCompressType(codec.Message(cs.ctx).CompressType()) newMsg.WithStreamFrame(&trpcpb.TrpcStreamInitMeta{ @@ -287,7 +284,6 @@ func (cs *clientStream) feedback(i uint32) error { msg.WithFrameHead(newFrameHead(trpcpb.TrpcStreamFrameType_TRPC_STREAM_FRAME_FEEDBACK, cs.streamID)) msg.WithStreamID(cs.streamID) msg.WithClientRPCName(cs.method) - msg.WithCalleeMethod(icodec.MethodFromRPCName(cs.method)) msg.WithStreamFrame(&trpcpb.TrpcStreamFeedBackMeta{WindowSizeIncrement: i}) return cs.stream.Send(ctx, nil) } From 9fd8bb2372a70a7bfb24c8525cbb55ed07ac6f7c Mon Sep 17 00:00:00 2001 From: goodliu Date: Tue, 23 Jan 2024 10:55:48 +0800 Subject: [PATCH 28/39] fix lint and typo --- codec/message_test.go | 2 +- config/trpc_config_test.go | 2 +- stream/client_test.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/codec/message_test.go b/codec/message_test.go index 799e4b5..2f49427 100644 --- a/codec/message_test.go +++ b/codec/message_test.go @@ -496,7 +496,7 @@ func testSetMethodNameUsingRPCName(t *testing.T, msg codec.Msg, msgWithRPCName f {"invalid trpc rpc name (method name is empty)", "", "trpc.app.server.service", "trpc.app.server.service"}, {"invalid trpc rpc name (method name is not mepty)", "/v1/subject/info/get", "trpc.app.server.service", "/v1/subject/info/get"}, {"valid trpc rpc name will override existing method name", "/v1/subject/info/get", "/trpc.app.server.service/method", "method"}, - {"invalid trpc rpc will not override exising method name", "/v1/subject/info/get", "/trpc.app.server.service", "/v1/subject/info/get"}, + {"invalid trpc rpc will not override existing method name", "/v1/subject/info/get", "/trpc.app.server.service", "/v1/subject/info/get"}, } for _, tt := range cases { diff --git a/config/trpc_config_test.go b/config/trpc_config_test.go index 35d5a6c..11eb852 100644 --- a/config/trpc_config_test.go +++ b/config/trpc_config_test.go @@ -160,7 +160,7 @@ password: ${pwd} func TestCodecUnmarshalDstMustBeMap(t *testing.T) { filePath := t.TempDir() + "/conf.map" - require.Nil(t, os.WriteFile(filePath, []byte{}, 0644)) + require.Nil(t, os.WriteFile(filePath, []byte{}, 0600)) RegisterCodec(dstMustBeMapCodec{}) _, err := DefaultConfigLoader.Load(filePath, WithCodec(dstMustBeMapCodec{}.Name())) require.Nil(t, err) diff --git a/stream/client_test.go b/stream/client_test.go index cded911..3e28f2d 100644 --- a/stream/client_test.go +++ b/stream/client_test.go @@ -16,11 +16,11 @@ package stream_test import ( "context" + "crypto/rand" "encoding/binary" "errors" "fmt" "io" - "math/rand" "testing" "time" From 5dad29ff72faf4a1ccbadb067b2ace0a940cb2fe Mon Sep 17 00:00:00 2001 From: Andrew Chang Date: Thu, 25 Jan 2024 10:07:18 +0800 Subject: [PATCH 29/39] docs: add comments informing that global variable JSONAPI should not be modified (#157) (#160) docs: add comments informing that global variable JSONAPI should not be modified (#157) Discussions took place in #157 --- codec/serialization_json.go | 11 +++++++++-- codec/serialization_jsonpb.go | 4 +++- restful/serialize_jsonpb.go | 11 +++++++++++ 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/codec/serialization_json.go b/codec/serialization_json.go index f7fad9e..10dfc66 100644 --- a/codec/serialization_json.go +++ b/codec/serialization_json.go @@ -17,8 +17,15 @@ import ( jsoniter "github.com/json-iterator/go" ) -// JSONAPI is json packing and unpacking object, users can change -// the internal parameter. +// JSONAPI is used by tRPC JSON serialization when the object does +// not conform to protobuf proto.Message interface. +// +// Deprecated: This global variable is exportable due to backward comparability issue but +// should not be modified. If users want to change the default behavior of +// internal JSON serialization, please use register your customized serializer +// function like: +// +// codec.RegisterSerializer(codec.SerializationTypeJSON, yourOwnJSONSerializer) var JSONAPI = jsoniter.ConfigCompatibleWithStandardLibrary // JSONSerialization provides json serialization mode. diff --git a/codec/serialization_jsonpb.go b/codec/serialization_jsonpb.go index cc50d45..79e245c 100644 --- a/codec/serialization_jsonpb.go +++ b/codec/serialization_jsonpb.go @@ -29,7 +29,9 @@ var Marshaler = protojson.MarshalOptions{EmitUnpopulated: true, UseProtoNames: t var Unmarshaler = protojson.UnmarshalOptions{DiscardUnknown: false} // JSONPBSerialization provides jsonpb serialization mode. It is based on -// protobuf/jsonpb. +// protobuf/jsonpb. This serializer will firstly try jsonpb's serialization. If +// object does not conform to protobuf proto.Message interface, json-iterator +// will be used. type JSONPBSerialization struct{} // Unmarshal deserialize the in bytes into body. diff --git a/restful/serialize_jsonpb.go b/restful/serialize_jsonpb.go index 0e6d1ea..67297f4 100644 --- a/restful/serialize_jsonpb.go +++ b/restful/serialize_jsonpb.go @@ -31,12 +31,23 @@ func init() { // JSONPBSerializer is used for content-Type: application/json. // It's based on google.golang.org/protobuf/encoding/protojson. +// +// This serializer will firstly try jsonpb's serialization. If object does not +// conform to protobuf proto.Message interface, the serialization will switch to +// json-iterator. type JSONPBSerializer struct { AllowUnmarshalNil bool // allow unmarshalling nil body } // JSONAPI is a copy of jsoniter.ConfigCompatibleWithStandardLibrary. // github.com/json-iterator/go is faster than Go's standard json library. +// +// Deprecated: This global variable is exportable due to backward comparability issue but +// should not be modified. If users want to change the default behavior of +// internal JSON serialization, please use register your customized serializer +// function like: +// +// restful.RegisterSerializer(yourOwnJSONSerializer) var JSONAPI = jsoniter.ConfigCompatibleWithStandardLibrary // Marshaller is a configurable protojson marshaler. From c8cb302d333995ab252efee57241cd357b7dbf5d Mon Sep 17 00:00:00 2001 From: Ash Liu Date: Tue, 30 Jan 2024 14:27:44 +0800 Subject: [PATCH 30/39] docs: correct the spelling error (#163) docs: correct the spelling error Signed-off-by: iutx --- docs/basics_tutorial.md | 2 +- docs/basics_tutorial.zh_CN.md | 2 +- docs/user_guide/server/overview.md | 2 +- docs/user_guide/server/overview.zh_CN.md | 2 +- stream/README.zh_CN.md | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/basics_tutorial.md b/docs/basics_tutorial.md index c53395f..71295e7 100644 --- a/docs/basics_tutorial.md +++ b/docs/basics_tutorial.md @@ -39,7 +39,7 @@ Note that `Method` has a `{}` at the end, which can also have content. We will s ### Write Client and Server Code -What protobuf gives is a language-independent service definition, and we need to use [trpc command line tool](https://github.com/trpc-group/trpc-cmdline) to translate it into a corresponding language stub code. You can see the various options it supports with `$ tprc create -h`. You can refer to the quick start [helloworld](/examples/helloworld/pb/Makefile) project to quickly create your own stub code. +What protobuf gives is a language-independent service definition, and we need to use [trpc command line tool](https://github.com/trpc-group/trpc-cmdline) to translate it into a corresponding language stub code. You can see the various options it supports with `$ trpc create -h`. You can refer to the quick start [helloworld](/examples/helloworld/pb/Makefile) project to quickly create your own stub code. The stub code is mainly divided into two parts: client and server. Below is part of the generated client code. In [Quick Start](./quick_start.md), we use `NewGreeterClientProxy` to create a client instance and call its `Hello` method: diff --git a/docs/basics_tutorial.zh_CN.md b/docs/basics_tutorial.zh_CN.md index 35bb22f..cc178a9 100644 --- a/docs/basics_tutorial.zh_CN.md +++ b/docs/basics_tutorial.zh_CN.md @@ -39,7 +39,7 @@ message HelloRsp { ### 编写客户端和服务端代码 -protobuf 给出的是一个语言无关的服务定义,我们还要用 [trpc 命令行工具](https://github.com/trpc-group/trpc-cmdline)将它翻译成对应语言的桩代码。你可以通过 `$ tprc create -h` 查看它支持的各种选项。你可以参考快速开始的 [helloworld](/examples/helloworld/pb/Makefile) 项目来快速创建你自己的桩代码。 +protobuf 给出的是一个语言无关的服务定义,我们还要用 [trpc 命令行工具](https://github.com/trpc-group/trpc-cmdline)将它翻译成对应语言的桩代码。你可以通过 `$ trpc create -h` 查看它支持的各种选项。你可以参考快速开始的 [helloworld](/examples/helloworld/pb/Makefile) 项目来快速创建你自己的桩代码。 桩代码主要分为 client 和 server 两部分。 下面是生成的部分 client 代码。在[快速开始](./quick_start.zh_CN.md)中,我们通过 `NewGreeterClientProxy` 来创建一个 client 实例,并调用了它的 `Hello` 方法: diff --git a/docs/user_guide/server/overview.md b/docs/user_guide/server/overview.md index d2bc2a5..5dd4ffe 100644 --- a/docs/user_guide/server/overview.md +++ b/docs/user_guide/server/overview.md @@ -288,4 +288,4 @@ tRPC-Go allows businesses to define and register serialization and deserializati ## Setting the maximum number of service coroutines -tRPC-Go supports service-level synchronous/asynchronous packet processing modes. For asynchronous mode, a coroutine pool is used to improve coroutine usage efficiency and performance. Users can set the maximum number of service coroutines through framework configuration and Option configuration. For details, please refer to the service configuration in the [tPRC-Go Framework Configuration](/docs/user_guide/framework_conf.md) section. +tRPC-Go supports service-level synchronous/asynchronous packet processing modes. For asynchronous mode, a coroutine pool is used to improve coroutine usage efficiency and performance. Users can set the maximum number of service coroutines through framework configuration and Option configuration. For details, please refer to the service configuration in the [tRPC-Go Framework Configuration](/docs/user_guide/framework_conf.md) section. diff --git a/docs/user_guide/server/overview.zh_CN.md b/docs/user_guide/server/overview.zh_CN.md index 72ee65b..640697b 100644 --- a/docs/user_guide/server/overview.zh_CN.md +++ b/docs/user_guide/server/overview.zh_CN.md @@ -288,4 +288,4 @@ tRPC-Go 自定义 RPC 消息体的序列化、反序列化方式,业务可以 ## 设置服务最大协程数 -tRPC-Go 支持服务级别的同/异步包处理模式,对于异步模式采用协程池来提升协程使用效率和性能。用户可以通过框架配置和 Option 配置两种方式来设置服务的最大协程数,具体请参考 [tPRC-Go 框架配置](/docs/user_guide/framework_conf.zh_CN.md) 章节的 service 配置。 +tRPC-Go 支持服务级别的同/异步包处理模式,对于异步模式采用协程池来提升协程使用效率和性能。用户可以通过框架配置和 Option 配置两种方式来设置服务的最大协程数,具体请参考 [tRPC-Go 框架配置](/docs/user_guide/framework_conf.zh_CN.md) 章节的 service 配置。 diff --git a/stream/README.zh_CN.md b/stream/README.zh_CN.md index 892d870..0bedb1f 100644 --- a/stream/README.zh_CN.md +++ b/stream/README.zh_CN.md @@ -322,7 +322,7 @@ func main() { - 接收端每消费 1/4 的初始窗口大小进行 feedback,发送一个 feedback 帧,携带增量的 window size,发送端接收到这个增量 window size 之后加到本地可发送的 window 大小 - 帧分优先级,对于 feedback 的帧不做流控,优先级高于 Data 帧,防止因为优先级问题导致 feedback 帧发生阻塞 -tPRC-Go 默认启用流控,目前默认窗口大小为 65535,如果连续发送超过 65535 大小的数据(序列化和压缩后),接收方没调用 Recv,则发送方会 block +tRPC-Go 默认启用流控,目前默认窗口大小为 65535,如果连续发送超过 65535 大小的数据(序列化和压缩后),接收方没调用 Recv,则发送方会 block 如果要设置客户端接收窗口大小,使用 client option `WithMaxWindowSize` ```go From 6087fd094d357f81d8825ea3a1460bc1a7f0e0f2 Mon Sep 17 00:00:00 2001 From: wineandchord Date: Tue, 16 Apr 2024 10:26:04 +0800 Subject: [PATCH 31/39] docs: add notes on service idle timeout (#166) docs: add notes on service idle timeout RELEASE NOTES: NONE --- docs/user_guide/server/overview.md | 13 +++++++++++++ docs/user_guide/server/overview.zh_CN.md | 15 ++++++++++++++- http/transport_test.go | 3 +++ pool/connpool/checker_unix_test.go | 2 +- 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/docs/user_guide/server/overview.md b/docs/user_guide/server/overview.md index 5dd4ffe..eac2b36 100644 --- a/docs/user_guide/server/overview.md +++ b/docs/user_guide/server/overview.md @@ -268,6 +268,19 @@ tRPC-Go provides three timeout mechanisms for RPC calls: link timeout, message t This feature requires protocol support (the protocol needs to carry timeout metadata downstream). The tRPC protocol, generic HTTP RPC protocol all support timeout control. +## Idle Timeout + +The server has a default idle timeout of 60 seconds to prevent excessive idle connections from consuming server-side resources. This value can be modified through the `idletimeout` setting in the framework configuration: + +```yaml +server: + service: + - name: trpc.server.service.Method + network: tcp + protocol: trpc + idletime: 60000 # The unit is milliseconds. Setting it to -1 means there is no idle timeout (setting it to 0 will still default to the 60s by the framework) +``` + ## Link transmission The tRPC-Go framework provides a mechanism for passing fields between the client and server and passing them down the entire call chain. For the mechanism and usage of link transmission, please refer to [tRPC-Go Link Transmission](/docs/user_guide/metadata_transmission.md). diff --git a/docs/user_guide/server/overview.zh_CN.md b/docs/user_guide/server/overview.zh_CN.md index 640697b..7c74efd 100644 --- a/docs/user_guide/server/overview.zh_CN.md +++ b/docs/user_guide/server/overview.zh_CN.md @@ -266,7 +266,20 @@ tRPC-Go 从设计之初就考虑了框架的易测性,在通过 pb 生成桩 tRPC-Go 为 RPC 调用提供了 3 种超时机制控制:链路超时,消息超时和调用超时。关于这 3 种超时机制的原理介绍和相关配置,请参考 [tRPC-Go 超时控制](/docs/user_guide/timeout_control.zh_CN.md)。 -此功能需要协议的支持(协议需要携带 timeout 元数据到下游),tRPC 协议,泛 HTTP RPC 协议均支持超时控制功能。其 +此功能需要协议的支持(协议需要携带 timeout 元数据到下游),tRPC 协议,泛 HTTP RPC 协议均支持超时控制功能。 + +## 空闲超时 + +服务默认存在一个 60s 的空闲超时时间,以防止过多空闲连接消耗服务侧的资源,这个值可以通过框架配置中的 `idletimeout` 来进行修改: + +```yaml +server: + service: + - name: trpc.server.service.Method + network: tcp + protocol: trpc + idletime: 60000 # 单位是毫秒, 设置为 -1 的时候表示没有空闲超时(这里设置为 0 时框架仍会自动转为默认的 60s) +``` ## 链路透传 diff --git a/http/transport_test.go b/http/transport_test.go index 4659d05..3c51c00 100644 --- a/http/transport_test.go +++ b/http/transport_test.go @@ -809,6 +809,9 @@ func TestCheckRedirect(t *testing.T) { return nil } thttp.DefaultClientTransport.(*thttp.ClientTransport).CheckRedirect = checkRedirect + defer func() { + thttp.DefaultClientTransport.(*thttp.ClientTransport).CheckRedirect = nil + }() proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter", client.WithTarget("ip://"+ln.Addr().String()), client.WithSerializationType(codec.SerializationTypeNoop), diff --git a/pool/connpool/checker_unix_test.go b/pool/connpool/checker_unix_test.go index 6f303a6..5ec677c 100644 --- a/pool/connpool/checker_unix_test.go +++ b/pool/connpool/checker_unix_test.go @@ -53,7 +53,7 @@ func TestRemoteEOF(t *testing.T) { require.Nil(t, pc.Close()) } -func TestUnexceptedRead(t *testing.T) { +func TestUnexpectedRead(t *testing.T) { var s server require.Nil(t, s.init()) From e17e291f01db81367fdce9ddd94f4a0799678ce1 Mon Sep 17 00:00:00 2001 From: wineandchord Date: Tue, 16 Apr 2024 10:27:06 +0800 Subject: [PATCH 32/39] pool: fix typos (#167) RELEASE NOTES: NONE From b82b6a79f4fcc29e15157ee314a78e5a20cce90f Mon Sep 17 00:00:00 2001 From: wineandchord Date: Tue, 16 Apr 2024 10:27:36 +0800 Subject: [PATCH 33/39] docs: add notes on client connpool idletimeout (#170) RELEASE NOTES: NONE --- docs/user_guide/client/connection_mode.md | 42 +++++++++++++++++++ .../client/connection_mode.zh_CN.md | 42 ++++++++++++++++++- 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/docs/user_guide/client/connection_mode.md b/docs/user_guide/client/connection_mode.md index 2a00bb6..5040874 100644 --- a/docs/user_guide/client/connection_mode.md +++ b/docs/user_guide/client/connection_mode.md @@ -128,6 +128,48 @@ if err != nil { log.Info("req:%v, rsp:%v, err:%v", req, rsp, err) ``` + +#### Setting Idle Connection Timeout + +For the client's connection pool mode, the framework sets a default idle timeout of 50 seconds. + +* For `go-net`, the connection pool maintains a list of idle connections. The idle timeout only affects the connections in this idle list and is only triggered when the connection is retrieved next time, causing idle connections to be closed due to the idle timeout. +* For `tnet`, the idle timeout is implemented by maintaining a timer on each connection. Even if a connection is being used for a client's call, if the downstream does not return a result within the idle timeout period, the connection will still be triggered by the idle timeout and forcibly closed. + +The methods to change the idle timeout are as follows: + +* `go-net` + +```go +import "trpc.group/trpc-go/trpc-go/pool/connpool" + +func init() { + connpool.DefaultConnectionPool = connpool.NewConnectionPool( + connpool.WithIdleTimeout(0), // Setting to 0 disables it. + ) +} +``` + +tnet + +```go +import ( + "trpc.group/trpc-go/trpc-go/pool/connpool" + tnettrans "trpc.group/trpc-go/trpc-go/transport/tnet" +) + +func init() { + tnettrans.DefaultConnPool = connpool.NewConnectionPool( + connpool.WithDialFunc(tnettrans.Dial), + connpool.WithIdleTimeout(0), // Setting to 0 disables it. + connpool.WithHealthChecker(tnettrans.HealthChecker), + ) +} +``` + +**Note**: The server also has a default idle timeout, which is 60 seconds. This time is designed to be longer than the 50 seconds, so that under default conditions, it is the client that triggers the idle connection timeout to actively close the connection, rather than the server triggering a forced cleanup. For methods to change the server's idle timeout, see the server usage documentation. + + ### I/O multiplexing ```go diff --git a/docs/user_guide/client/connection_mode.zh_CN.md b/docs/user_guide/client/connection_mode.zh_CN.md index dbe025b..469b2f1 100644 --- a/docs/user_guide/client/connection_mode.zh_CN.md +++ b/docs/user_guide/client/connection_mode.zh_CN.md @@ -126,7 +126,47 @@ if err != nil { log.Info("req:%v, rsp:%v, err:%v", req, rsp, err) ``` -###连接多路复用 +#### 设置空闲连接超时 + +对于客户端的连接池模式来说,框架会设置一个默认的 50s 的空闲超时时间。 + +* 对于 `go-net` 而言,连接池中会维持一个空闲连接列表,空闲超时时间只会对空闲连接列表中的连接生效,并且只会在下次获取的时候触发空闲连接触发空闲超时的关闭 +* 对于 `tnet` 而言,空闲超时通过在每个连接上维护定时器来实现,即使该连接被用于客户端发起调用,假如下游未在空闲连接超时时间内返回结果的话,该连接仍然会被触发空闲超时并强制被关闭 + +更改空闲超时时间的方式如下: + +* `go-net` + +```go +import "trpc.group/trpc-go/trpc-go/pool/connpool" + +func init() { + connpool.DefaultConnectionPool = connpool.NewConnectionPool( + connpool.WithIdleTimeout(0), // 设置为 0 是禁用 + ) +} +``` + +* `tnet` + +```go +import ( + "trpc.group/trpc-go/trpc-go/pool/connpool" + tnettrans "trpc.group/trpc-go/trpc-go/transport/tnet" +) + +func init() { + tnettrans.DefaultConnPool = connpool.NewConnectionPool( + connpool.WithDialFunc(tnettrans.Dial), + connpool.WithIdleTimeout(0), // 设置为 0 是禁用 + connpool.WithHealthChecker(tnettrans.HealthChecker), + ) +} +``` + +**注**:服务端默认也有一个空闲超时时间,为 60s,该时间设计得比 50s 打,从而在默认情况下是客户端主动触发空闲连接超时以主动关闭连接,而非服务端触发强制清理。服务端空闲超时的更改方法见服务端使用文档。 + +### 连接多路复用 ```go opts := []client.Option{ From 3ce75697d0483b678940efe63f8318f33a8290e1 Mon Sep 17 00:00:00 2001 From: wineandchord Date: Tue, 16 Apr 2024 10:28:39 +0800 Subject: [PATCH 34/39] go.mod: upgrade tnet version to handle negative idle timeout (#169) Fixes #165 --- examples/go.mod | 2 +- examples/go.sum | 4 ++-- go.mod | 2 +- go.sum | 4 ++-- test/go.mod | 2 +- test/go.sum | 4 ++-- transport/tnet/client_transport_test.go | 13 ------------- 7 files changed, 9 insertions(+), 22 deletions(-) diff --git a/examples/go.mod b/examples/go.mod index ab9f708..18b4a5f 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -41,5 +41,5 @@ require ( golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect - trpc.group/trpc-go/tnet v1.0.0 // indirect + trpc.group/trpc-go/tnet v1.0.1 // indirect ) diff --git a/examples/go.sum b/examples/go.sum index e0d6081..21e4cf7 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -128,7 +128,7 @@ gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -trpc.group/trpc-go/tnet v1.0.0 h1:XsdA82/sOHLa4TFAlCZbb3xi4+Q92NNuxEMTj0UfFZ0= -trpc.group/trpc-go/tnet v1.0.0/go.mod h1:s/webUFYWEFBHErKyFmj7LYC7XfC2LTLCcwfSnJ04M0= +trpc.group/trpc-go/tnet v1.0.1 h1:Yzqyrgyfm+W742FzGr39c4+OeQmLi7PWotJxrOBtV9o= +trpc.group/trpc-go/tnet v1.0.1/go.mod h1:s/webUFYWEFBHErKyFmj7LYC7XfC2LTLCcwfSnJ04M0= trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 h1:rMtHYzI0ElMJRxHtT5cD99SigFE6XzKK4PFtjcwokI0= trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0/go.mod h1:K+a1K/Gnlcg9BFHWx30vLBIEDhxODhl25gi1JjA54CQ= diff --git a/go.mod b/go.mod index c98e1c8..a610973 100644 --- a/go.mod +++ b/go.mod @@ -28,7 +28,7 @@ require ( golang.org/x/sys v0.13.0 google.golang.org/protobuf v1.30.0 gopkg.in/yaml.v3 v3.0.1 - trpc.group/trpc-go/tnet v1.0.0 + trpc.group/trpc-go/tnet v1.0.1 trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 ) diff --git a/go.sum b/go.sum index 258b5d9..d251db4 100644 --- a/go.sum +++ b/go.sum @@ -137,7 +137,7 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -trpc.group/trpc-go/tnet v1.0.0 h1:XsdA82/sOHLa4TFAlCZbb3xi4+Q92NNuxEMTj0UfFZ0= -trpc.group/trpc-go/tnet v1.0.0/go.mod h1:s/webUFYWEFBHErKyFmj7LYC7XfC2LTLCcwfSnJ04M0= +trpc.group/trpc-go/tnet v1.0.1 h1:Yzqyrgyfm+W742FzGr39c4+OeQmLi7PWotJxrOBtV9o= +trpc.group/trpc-go/tnet v1.0.1/go.mod h1:s/webUFYWEFBHErKyFmj7LYC7XfC2LTLCcwfSnJ04M0= trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 h1:rMtHYzI0ElMJRxHtT5cD99SigFE6XzKK4PFtjcwokI0= trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0/go.mod h1:K+a1K/Gnlcg9BFHWx30vLBIEDhxODhl25gi1JjA54CQ= diff --git a/test/go.mod b/test/go.mod index 632f2bb..f1075bb 100644 --- a/test/go.mod +++ b/test/go.mod @@ -43,5 +43,5 @@ require ( go.uber.org/multierr v1.10.0 // indirect golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect - trpc.group/trpc-go/tnet v1.0.0 // indirect + trpc.group/trpc-go/tnet v1.0.1 // indirect ) diff --git a/test/go.sum b/test/go.sum index 2821134..4a0ce4c 100644 --- a/test/go.sum +++ b/test/go.sum @@ -125,7 +125,7 @@ gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= -trpc.group/trpc-go/tnet v1.0.0 h1:XsdA82/sOHLa4TFAlCZbb3xi4+Q92NNuxEMTj0UfFZ0= -trpc.group/trpc-go/tnet v1.0.0/go.mod h1:s/webUFYWEFBHErKyFmj7LYC7XfC2LTLCcwfSnJ04M0= +trpc.group/trpc-go/tnet v1.0.1 h1:Yzqyrgyfm+W742FzGr39c4+OeQmLi7PWotJxrOBtV9o= +trpc.group/trpc-go/tnet v1.0.1/go.mod h1:s/webUFYWEFBHErKyFmj7LYC7XfC2LTLCcwfSnJ04M0= trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 h1:rMtHYzI0ElMJRxHtT5cD99SigFE6XzKK4PFtjcwokI0= trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0/go.mod h1:K+a1K/Gnlcg9BFHWx30vLBIEDhxODhl25gi1JjA54CQ= diff --git a/transport/tnet/client_transport_test.go b/transport/tnet/client_transport_test.go index 503e3f9..e3b807b 100644 --- a/transport/tnet/client_transport_test.go +++ b/transport/tnet/client_transport_test.go @@ -50,19 +50,6 @@ func TestDial(t *testing.T) { return assert.Contains(t, err.Error(), "unknown network") }, }, - { - name: "invalid idle timeout", - opts: &connpool.DialOptions{ - CACertFile: "", - Network: "tcp", - Address: l.Addr().String(), - IdleTimeout: -1, - }, - want: nil, - wantErr: func(t assert.TestingT, err error, msg ...interface{}) bool { - return assert.Contains(t, err.Error(), "delay time is too short") - }, - }, { name: "wrong CACertFile and TLSServerName ", opts: &connpool.DialOptions{ From 82f4e4a027281b70ffcc2b605c5b2f5553e828bd Mon Sep 17 00:00:00 2001 From: wineandchord Date: Tue, 16 Apr 2024 10:29:37 +0800 Subject: [PATCH 35/39] http: improve stability of http test (#168) RELEASE NOTES: NONE From ccdde66e9010fa54e074877587918e88ffba26ef Mon Sep 17 00:00:00 2001 From: goodliu Date: Tue, 16 Apr 2024 11:01:29 +0800 Subject: [PATCH 36/39] {trpc, examples, test}: upgrade google.golang.org/protobuf v1.30.0 => v1.33.0 (#171) Fixes - https://github.com/trpc-group/trpc-go/security/dependabot/13 - https://github.com/trpc-group/trpc-go/security/dependabot/14 - https://github.com/trpc-group/trpc-go/security/dependabot/15 --- examples/go.mod | 2 +- examples/go.sum | 3 ++- go.mod | 2 +- go.sum | 2 ++ test/go.mod | 2 +- test/go.sum | 4 ++-- 6 files changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/go.mod b/examples/go.mod index 18b4a5f..08de7b7 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -6,7 +6,7 @@ replace trpc.group/trpc-go/trpc-go => ../ require ( github.com/golang/protobuf v1.5.2 - google.golang.org/protobuf v1.30.0 + google.golang.org/protobuf v1.33.0 trpc.group/trpc-go/trpc-go v0.0.0-00010101000000-000000000000 trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 ) diff --git a/examples/go.sum b/examples/go.sum index 21e4cf7..53365b8 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -117,8 +117,9 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/go.mod b/go.mod index a610973..ad60b03 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( golang.org/x/net v0.17.0 golang.org/x/sync v0.1.0 golang.org/x/sys v0.13.0 - google.golang.org/protobuf v1.30.0 + google.golang.org/protobuf v1.33.0 gopkg.in/yaml.v3 v3.0.1 trpc.group/trpc-go/tnet v1.0.1 trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 diff --git a/go.sum b/go.sum index d251db4..237009b 100644 --- a/go.sum +++ b/go.sum @@ -127,6 +127,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/test/go.mod b/test/go.mod index f1075bb..a92e774 100644 --- a/test/go.mod +++ b/test/go.mod @@ -10,7 +10,7 @@ require ( go.uber.org/zap v1.26.0 golang.org/x/net v0.17.0 golang.org/x/sync v0.4.0 - google.golang.org/protobuf v1.31.0 + google.golang.org/protobuf v1.33.0 gopkg.in/yaml.v3 v3.0.1 trpc.group/trpc-go/trpc-go v0.0.0-00010101000000-000000000000 trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 diff --git a/test/go.sum b/test/go.sum index 4a0ce4c..c0689bf 100644 --- a/test/go.sum +++ b/test/go.sum @@ -114,8 +114,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= -google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= From 7826836aa11f4c1b826b7a10cf2a4f8bd6f54221 Mon Sep 17 00:00:00 2001 From: goodliu Date: Tue, 23 Apr 2024 16:31:51 +0800 Subject: [PATCH 37/39] github-action: allow dependabot to contribute and bump cla to v2.3.2 (#174) Fixes #173 --- .github/workflows/cla.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml index de11c47..795467b 100644 --- a/.github/workflows/cla.yml +++ b/.github/workflows/cla.yml @@ -18,7 +18,7 @@ jobs: steps: - name: "CLA Assistant" if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target' - uses: contributor-assistant/github-action@v2.3.1 + uses: contributor-assistant/github-action@v2.3.2 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} PERSONAL_ACCESS_TOKEN: ${{ secrets.CLA_DATABASE_ACCESS_TOKEN }} @@ -28,4 +28,5 @@ jobs: path-to-signatures: 'signatures/${{ github.event.repository.name }}-${{ github.repository_id }}/cla.json' path-to-document: 'https://github.com/trpc-group/cla-database/blob/main/Tencent-Contributor-License-Agreement.md' # branch should not be protected - branch: 'main' \ No newline at end of file + branch: 'main' + allowlist: dependabot \ No newline at end of file From f7276020ae28159f8d66e2d25d3c21ad23a56b04 Mon Sep 17 00:00:00 2001 From: wineandchord Date: Tue, 23 Apr 2024 16:32:14 +0800 Subject: [PATCH 38/39] docs: refine idle timeout settings documentation for clarity (#172) --- docs/user_guide/client/connection_mode.zh_CN.md | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/user_guide/client/connection_mode.zh_CN.md b/docs/user_guide/client/connection_mode.zh_CN.md index 469b2f1..de1b4d2 100644 --- a/docs/user_guide/client/connection_mode.zh_CN.md +++ b/docs/user_guide/client/connection_mode.zh_CN.md @@ -128,12 +128,12 @@ log.Info("req:%v, rsp:%v, err:%v", req, rsp, err) #### 设置空闲连接超时 -对于客户端的连接池模式来说,框架会设置一个默认的 50s 的空闲超时时间。 +在客户端的连接池模式中,框架默认会设置一个 50 秒的空闲超时时间。 -* 对于 `go-net` 而言,连接池中会维持一个空闲连接列表,空闲超时时间只会对空闲连接列表中的连接生效,并且只会在下次获取的时候触发空闲连接触发空闲超时的关闭 -* 对于 `tnet` 而言,空闲超时通过在每个连接上维护定时器来实现,即使该连接被用于客户端发起调用,假如下游未在空闲连接超时时间内返回结果的话,该连接仍然会被触发空闲超时并强制被关闭 +* 对于 `go-net` 来说,连接池会维护一个空闲连接列表。空闲超时时间仅对列表中的空闲连接有效,并且只有在下一次尝试获取连接时,才会触发检查并关闭超时的空闲连接。 +* 对于 `tnet`,则是通过在每个连接上设置定时器来实现空闲超时。即便连接正在被用于客户端的调用,如果下游服务在空闲超时时间内没有返回结果,该连接仍然会因为空闲超时而被强制关闭。 -更改空闲超时时间的方式如下: +可以按照以下方式更改空闲超时时间: * `go-net` @@ -142,7 +142,7 @@ import "trpc.group/trpc-go/trpc-go/pool/connpool" func init() { connpool.DefaultConnectionPool = connpool.NewConnectionPool( - connpool.WithIdleTimeout(0), // 设置为 0 是禁用 + connpool.WithIdleTimeout(0), // 设置为 0 以禁用空闲超时 ) } ``` @@ -158,14 +158,13 @@ import ( func init() { tnettrans.DefaultConnPool = connpool.NewConnectionPool( connpool.WithDialFunc(tnettrans.Dial), - connpool.WithIdleTimeout(0), // 设置为 0 是禁用 + connpool.WithIdleTimeout(0), // 设置为 0 以禁用空闲超时 connpool.WithHealthChecker(tnettrans.HealthChecker), ) } ``` -**注**:服务端默认也有一个空闲超时时间,为 60s,该时间设计得比 50s 打,从而在默认情况下是客户端主动触发空闲连接超时以主动关闭连接,而非服务端触发强制清理。服务端空闲超时的更改方法见服务端使用文档。 - +**注**:服务端默认也设置了一个空闲超时时间,为 60 秒。这个时间比客户端的默认时间长,以确保在大多数情况下,是客户端主动触发空闲超时并关闭连接,而不是服务端强制进行清理。服务端空闲超时时间的修改方法,请参见服务端使用文档。 ### 连接多路复用 ```go From 82ee6e836cccbc3a881222c7a3ad1145df661072 Mon Sep 17 00:00:00 2001 From: YoungFr <43751910+YoungFr@users.noreply.github.com> Date: Wed, 15 May 2024 14:49:53 +0800 Subject: [PATCH 39/39] {client, naming}: allow selector to define its own net.Addr parser (#176) This is used to avoid unnecessary addr parse which is commonly used in trpc-database. To properly update the DSN library, we need to introduce this feature into the open-source tRPC-Go. --- client/client.go | 15 ++++++++-- client/client_test.go | 63 +++++++++++++++++++++++++++++++++++++++++ client/stream.go | 2 +- naming/registry/node.go | 4 +++ 4 files changed, 80 insertions(+), 4 deletions(-) diff --git a/client/client.go b/client/client.go index 8860c91..946d58b 100644 --- a/client/client.go +++ b/client/client.go @@ -393,7 +393,7 @@ func selectorFilter(ctx context.Context, req interface{}, rsp interface{}, next if err != nil { return OptionsFromContext(ctx).fixTimeout(err) } - ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address) + ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address, node.ParseAddr) // Start to process the next filter and report. begin := time.Now() @@ -471,11 +471,21 @@ func getNode(opts *Options) (*registry.Node, error) { return node, nil } -func ensureMsgRemoteAddr(msg codec.Msg, network string, address string) { +func ensureMsgRemoteAddr( + msg codec.Msg, + network, address string, + parseAddr func(network, address string) net.Addr, +) { // If RemoteAddr has already been set, just return. if msg.RemoteAddr() != nil { return } + + if parseAddr != nil { + msg.WithRemoteAddr(parseAddr(network, address)) + return + } + switch network { case "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": // Check if address can be parsed as an ip. @@ -484,7 +494,6 @@ func ensureMsgRemoteAddr(msg codec.Msg, network string, address string) { return } } - var addr net.Addr switch network { case "tcp", "tcp4", "tcp6": diff --git a/client/client_test.go b/client/client_test.go index 34aaf51..22fb5be 100644 --- a/client/client_test.go +++ b/client/client_test.go @@ -16,6 +16,8 @@ package client_test import ( "context" "errors" + "fmt" + "net" "testing" "time" @@ -409,6 +411,31 @@ func TestFixTimeout(t *testing.T) { }) } +func TestSelectorRemoteAddrUseUserProvidedParser(t *testing.T) { + selector.Register(t.Name(), &fSelector{ + selectNode: func(s string, option ...selector.Option) (*registry.Node, error) { + return ®istry.Node{ + Network: t.Name(), + Address: t.Name(), + ParseAddr: func(network, address string) net.Addr { + return newUnresolvedAddr(network, address) + }}, nil + }, + report: func(node *registry.Node, duration time.Duration, err error) error { return nil }, + }) + fake := "fake" + codec.Register(fake, nil, &fakeCodec{}) + ctx := trpc.BackgroundContext() + require.NotNil(t, client.New().Invoke(ctx, "failbody", nil, + client.WithServiceName(t.Name()), + client.WithProtocol(fake), + client.WithTarget(fmt.Sprintf("%s://xxx", t.Name())))) + addr := trpc.Message(ctx).RemoteAddr() + require.NotNil(t, addr) + require.Equal(t, t.Name(), addr.Network()) + require.Equal(t, t.Name(), addr.String()) +} + type multiplexedTransport struct { require func(context.Context, []byte, ...transport.RoundTripOption) fakeTransport @@ -527,3 +554,39 @@ func (c *fakeSelector) Select(serviceName string, opt ...selector.Option) (*regi func (c *fakeSelector) Report(node *registry.Node, cost time.Duration, err error) error { return nil } + +type fSelector struct { + selectNode func(string, ...selector.Option) (*registry.Node, error) + report func(*registry.Node, time.Duration, error) error +} + +func (s *fSelector) Select(serviceName string, opts ...selector.Option) (*registry.Node, error) { + return s.selectNode(serviceName, opts...) +} + +func (s *fSelector) Report(node *registry.Node, cost time.Duration, err error) error { + return s.report(node, cost, err) +} + +// newUnresolvedAddr returns a new unresolvedAddr. +func newUnresolvedAddr(network, address string) *unresolvedAddr { + return &unresolvedAddr{network: network, address: address} +} + +var _ net.Addr = (*unresolvedAddr)(nil) + +// unresolvedAddr is a net.Addr which returns the original network or address. +type unresolvedAddr struct { + network string + address string +} + +// Network returns the unresolved original network. +func (a *unresolvedAddr) Network() string { + return a.network +} + +// String returns the unresolved original address. +func (a *unresolvedAddr) String() string { + return a.address +} diff --git a/client/stream.go b/client/stream.go index 3a33159..fbfd136 100644 --- a/client/stream.go +++ b/client/stream.go @@ -162,7 +162,7 @@ func (s *stream) Init(ctx context.Context, opt ...Option) (*Options, error) { report.SelectNodeFail.Incr() return nil, err } - ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address) + ensureMsgRemoteAddr(msg, findFirstNonEmpty(node.Network, opts.Network), node.Address, node.ParseAddr) const invalidCost = -1 opts.Node.set(node, node.Address, invalidCost) if opts.Codec == nil { diff --git a/naming/registry/node.go b/naming/registry/node.go index d4f4d19..c5f0a34 100644 --- a/naming/registry/node.go +++ b/naming/registry/node.go @@ -15,6 +15,7 @@ package registry import ( "fmt" + "net" "time" ) @@ -30,6 +31,9 @@ type Node struct { CostTime time.Duration // 当次请求耗时 EnvKey string // 透传的环境信息 Metadata map[string]interface{} + // ParseAddr should be used to convert Node to net.Addr if it's not nil. + // See test case TestSelectorRemoteAddrUseUserProvidedParser in client package. + ParseAddr func(network, address string) net.Addr } // String returns an abbreviation information of node.