Skip to content

Commit

Permalink
Ftr: enable filter and cluster when client consumer provider directly (
Browse files Browse the repository at this point in the history
…#1181)

* URL directly call add filter and cluster

* update

* update

* add mockFilter

Co-authored-by: kezhan <kezhan@shizhuang-inc.com>
Co-authored-by: Xin.Zh <dragoncharlie@foxmail.com>
  • Loading branch information
3 people authored May 15, 2021
1 parent 0cb4217 commit bda9a0b
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 6 deletions.
34 changes: 33 additions & 1 deletion config/reference_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"dubbo.apache.org/dubbo-go/v3/common/extension"
"dubbo.apache.org/dubbo-go/v3/common/proxy"
"dubbo.apache.org/dubbo-go/v3/protocol"
"dubbo.apache.org/dubbo-go/v3/protocol/protocolwrapper"
)

// ReferenceConfig is the configuration of service consumer
Expand Down Expand Up @@ -134,11 +135,42 @@ func (c *ReferenceConfig) Refer(_ interface{}) {

if len(c.urls) == 1 {
c.invoker = extension.GetProtocol(c.urls[0].Protocol).Refer(c.urls[0])
// c.URL != "" is direct call
if c.URL != "" {
//filter
c.invoker = protocolwrapper.BuildInvokerChain(c.invoker, constant.REFERENCE_FILTER_KEY)

// cluster
invokers := make([]protocol.Invoker, 0, len(c.urls))
invokers = append(invokers, c.invoker)
// TODO(decouple from directory, config should not depend on directory module)
var hitClu string
// not a registry url, must be direct invoke.
hitClu = constant.FAILOVER_CLUSTER_NAME
if len(invokers) > 0 {
u := invokers[0].GetURL()
if nil != &u {
hitClu = u.GetParam(constant.CLUSTER_KEY, constant.ZONEAWARE_CLUSTER_NAME)
}
}

cluster := extension.GetCluster(hitClu)
// If 'zone-aware' policy select, the invoker wrap sequence would be:
// ZoneAwareClusterInvoker(StaticDirectory) ->
// FailoverClusterInvoker(RegistryDirectory, routing happens here) -> Invoker
c.invoker = cluster.Join(directory.NewStaticDirectory(invokers))
}
} else {
invokers := make([]protocol.Invoker, 0, len(c.urls))
var regURL *common.URL
for _, u := range c.urls {
invokers = append(invokers, extension.GetProtocol(u.Protocol).Refer(u))
invoker := extension.GetProtocol(u.Protocol).Refer(u)
// c.URL != "" is direct call
if c.URL != "" {
//filter
invoker = protocolwrapper.BuildInvokerChain(invoker, constant.REFERENCE_FILTER_KEY)
}
invokers = append(invokers, invoker)
if u.Protocol == constant.REGISTRY_PROTOCOL {
regURL = u
}
Expand Down
31 changes: 29 additions & 2 deletions config/reference_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package config

import (
"context"
"sync"
"testing"
)
Expand All @@ -31,6 +32,7 @@ import (
"dubbo.apache.org/dubbo-go/v3/common"
"dubbo.apache.org/dubbo-go/v3/common/constant"
"dubbo.apache.org/dubbo-go/v3/common/extension"
"dubbo.apache.org/dubbo-go/v3/filter"
"dubbo.apache.org/dubbo-go/v3/protocol"
"dubbo.apache.org/dubbo-go/v3/registry"
)
Expand Down Expand Up @@ -193,7 +195,6 @@ func TestReferMultiReg(t *testing.T) {
doInitConsumer()
extension.SetProtocol("registry", GetProtocol)
extension.SetCluster(constant.ZONEAWARE_CLUSTER_NAME, cluster_impl.NewZoneAwareCluster)

for _, reference := range consumerConfig.References {
reference.Refer(nil)
assert.NotNil(t, reference.invoker)
Expand Down Expand Up @@ -234,6 +235,7 @@ func TestReferAsync(t *testing.T) {
func TestReferP2P(t *testing.T) {
doInitConsumer()
extension.SetProtocol("dubbo", GetProtocol)
mockFilter()
m := consumerConfig.References["MockService"]
m.URL = "dubbo://127.0.0.1:20000"

Expand All @@ -248,6 +250,7 @@ func TestReferP2P(t *testing.T) {
func TestReferMultiP2P(t *testing.T) {
doInitConsumer()
extension.SetProtocol("dubbo", GetProtocol)
mockFilter()
m := consumerConfig.References["MockService"]
m.URL = "dubbo://127.0.0.1:20000;dubbo://127.0.0.2:20000"

Expand All @@ -263,6 +266,7 @@ func TestReferMultiP2PWithReg(t *testing.T) {
doInitConsumer()
extension.SetProtocol("dubbo", GetProtocol)
extension.SetProtocol("registry", GetProtocol)
mockFilter()
m := consumerConfig.References["MockService"]
m.URL = "dubbo://127.0.0.1:20000;registry://127.0.0.2:20000"

Expand Down Expand Up @@ -291,6 +295,7 @@ func TestForking(t *testing.T) {
doInitConsumer()
extension.SetProtocol("dubbo", GetProtocol)
extension.SetProtocol("registry", GetProtocol)
mockFilter()
m := consumerConfig.References["MockService"]
m.URL = "dubbo://127.0.0.1:20000;registry://127.0.0.2:20000"

Expand All @@ -308,6 +313,7 @@ func TestSticky(t *testing.T) {
doInitConsumer()
extension.SetProtocol("dubbo", GetProtocol)
extension.SetProtocol("registry", GetProtocol)
mockFilter()
m := consumerConfig.References["MockService"]
m.URL = "dubbo://127.0.0.1:20000;registry://127.0.0.2:20000"

Expand All @@ -333,7 +339,8 @@ func newRegistryProtocol() protocol.Protocol {
return &mockRegistryProtocol{}
}

type mockRegistryProtocol struct{}
type mockRegistryProtocol struct {
}

func (*mockRegistryProtocol) Refer(url *common.URL) protocol.Invoker {
return protocol.NewBaseInvoker(url)
Expand Down Expand Up @@ -375,3 +382,23 @@ func getRegistryURL(invoker protocol.Invoker) *common.URL {
func (p *mockRegistryProtocol) GetRegistries() []registry.Registry {
return []registry.Registry{&mockServiceDiscoveryRegistry{}}
}

func mockFilter() {
consumerFiler := &mockShutdownFilter{}
extension.SetFilter(constant.CONSUMER_SHUTDOWN_FILTER, func() filter.Filter {
return consumerFiler
})
}

type mockShutdownFilter struct {
}

// Invoke adds the requests count and block the new requests if application is closing
func (gf *mockShutdownFilter) Invoke(ctx context.Context, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return invoker.Invoke(ctx, invocation)
}

// OnResponse reduces the number of active processes then return the process result
func (gf *mockShutdownFilter) OnResponse(ctx context.Context, result protocol.Result, invoker protocol.Invoker, invocation protocol.Invocation) protocol.Result {
return result
}
6 changes: 3 additions & 3 deletions protocol/protocolwrapper/protocol_filter_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (pfw *ProtocolFilterWrapper) Export(invoker protocol.Invoker) protocol.Expo
if pfw.protocol == nil {
pfw.protocol = extension.GetProtocol(invoker.GetURL().Protocol)
}
invoker = buildInvokerChain(invoker, constant.SERVICE_FILTER_KEY)
invoker = BuildInvokerChain(invoker, constant.SERVICE_FILTER_KEY)
return pfw.protocol.Export(invoker)
}

Expand All @@ -63,15 +63,15 @@ func (pfw *ProtocolFilterWrapper) Refer(url *common.URL) protocol.Invoker {
if invoker == nil {
return nil
}
return buildInvokerChain(invoker, constant.REFERENCE_FILTER_KEY)
return BuildInvokerChain(invoker, constant.REFERENCE_FILTER_KEY)
}

// Destroy will destroy all invoker and exporter.
func (pfw *ProtocolFilterWrapper) Destroy() {
pfw.protocol.Destroy()
}

func buildInvokerChain(invoker protocol.Invoker, key string) protocol.Invoker {
func BuildInvokerChain(invoker protocol.Invoker, key string) protocol.Invoker {
filterName := invoker.GetURL().GetParam(key, "")
if filterName == "" {
return invoker
Expand Down

0 comments on commit bda9a0b

Please sign in to comment.