From 795e995798c91a6b6cc481ae7aedec6252e1d557 Mon Sep 17 00:00:00 2001 From: goodliu Date: Fri, 22 Dec 2023 10:37:59 +0800 Subject: [PATCH 1/3] Merge remote-tracking branch 'upstream/main' into r1.0 to release kafka-v1.1.0 (#24) * clickhouse: fix go reference API doc (#18) https://pkg.go.dev/trpc.group/trpc-go/trpc-database/clickhouse * kafka: update sarama dependence (#21) * kafka: update sarama dependence * fix unit test --------- Co-authored-by: Leo --- clickhouse/client.go | 1 - clickhouse/client_test.go | 6 -- clickhouse/codec_test.go | 1 - clickhouse/dsn_test.go | 6 -- clickhouse/mockclickhouse/clickhouse_mock.go | 2 +- clickhouse/transport_test.go | 1 - kafka/README.md | 6 +- kafka/README.zh_CN.md | 6 +- kafka/client.go | 2 +- kafka/client_test.go | 2 +- kafka/client_transport.go | 2 +- kafka/client_transport_test.go | 2 +- kafka/config.go | 2 +- kafka/config_parser.go | 2 +- kafka/config_test.go | 2 +- kafka/examples/batchconsumer/main.go | 2 +- kafka/examples/consumer/main.go | 2 +- .../consumer_with_mulit_service/main.go | 2 +- kafka/go.mod | 19 ++-- kafka/go.sum | 68 +++++-------- kafka/kafka.go | 4 +- kafka/kafka_test.go | 98 ++++++++++++++++++- kafka/mockkafka/kafka_mock.go | 2 +- kafka/plugin.go | 2 +- kafka/scram_auth.go | 2 +- kafka/scram_auth_test.go | 2 +- kafka/server_transport.go | 2 +- kafka/server_transport_test.go | 2 +- kafka/service_desc.go | 2 +- kafka/service_desc_test.go | 2 +- 30 files changed, 157 insertions(+), 97 deletions(-) diff --git a/clickhouse/client.go b/clickhouse/client.go index f5ba27f..539053e 100644 --- a/clickhouse/client.go +++ b/clickhouse/client.go @@ -1,4 +1,3 @@ -// Package clickhouse 封装标准库clickhouse package clickhouse import ( diff --git a/clickhouse/client_test.go b/clickhouse/client_test.go index 2566ca6..7d23871 100644 --- a/clickhouse/client_test.go +++ b/clickhouse/client_test.go @@ -1,9 +1,3 @@ -/** - * @author aceyugong - * @date 2022/7/4 - */ - -// Package clickhouse packages standard library clickhouse. package clickhouse import ( diff --git a/clickhouse/codec_test.go b/clickhouse/codec_test.go index 8e16683..a073dc4 100644 --- a/clickhouse/codec_test.go +++ b/clickhouse/codec_test.go @@ -1,4 +1,3 @@ -// Package clickhouse packages standard library clickhouse. package clickhouse import ( diff --git a/clickhouse/dsn_test.go b/clickhouse/dsn_test.go index b6af6c3..c1d2555 100644 --- a/clickhouse/dsn_test.go +++ b/clickhouse/dsn_test.go @@ -1,9 +1,3 @@ -/** - * @author aceyugong - * @date 2022/8/19 - */ - -// Package clickhouse packages standard library clickhouse. package clickhouse import ( diff --git a/clickhouse/mockclickhouse/clickhouse_mock.go b/clickhouse/mockclickhouse/clickhouse_mock.go index b4974f3..acd5633 100644 --- a/clickhouse/mockclickhouse/clickhouse_mock.go +++ b/clickhouse/mockclickhouse/clickhouse_mock.go @@ -9,8 +9,8 @@ import ( sql "database/sql" reflect "reflect" - clickhouse "trpc.group/trpc-go/trpc-database/clickhouse" gomock "github.com/golang/mock/gomock" + clickhouse "trpc.group/trpc-go/trpc-database/clickhouse" ) // MockClient is a mock of Client interface. diff --git a/clickhouse/transport_test.go b/clickhouse/transport_test.go index 34aa8f5..800152a 100644 --- a/clickhouse/transport_test.go +++ b/clickhouse/transport_test.go @@ -1,4 +1,3 @@ -// Package clickhouse packages standard library clickhouse. package clickhouse import ( diff --git a/kafka/README.md b/kafka/README.md index db8e6ad..15b9054 100644 --- a/kafka/README.md +++ b/kafka/README.md @@ -7,7 +7,7 @@ English | [中文](README.zh_CN.md) [![Tests](https://github.com/trpc-ecosystem/go-database/actions/workflows/kafka.yml/badge.svg)](https://github.com/trpc-ecosystem/go-database/actions/workflows/kafka.yml) [![Coverage](https://codecov.io/gh/trpc-ecosystem/go-database/branch/main/graph/badge.svg?flag=kafka&precision=2)](https://app.codecov.io/gh/trpc-ecosystem/go-database/tree/main/kafka) -wrapping community [sarama](https://github.com/Shopify/sarama), used with trpc. +wrapping community [sarama](https://github.com/IBM/sarama), used with trpc. ## producer client @@ -68,7 +68,7 @@ import ( "trpc.group/trpc-go/trpc-database/kafka" trpc "trpc.group/trpc-go/trpc-go" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" ) func main() { @@ -99,7 +99,7 @@ import ( "trpc.group/trpc-go/trpc-database/kafka" trpc "trpc.group/trpc-go/trpc-go" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" ) func main() { diff --git a/kafka/README.zh_CN.md b/kafka/README.zh_CN.md index 423074c..7dbe504 100644 --- a/kafka/README.zh_CN.md +++ b/kafka/README.zh_CN.md @@ -7,7 +7,7 @@ [![Tests](https://github.com/trpc-ecosystem/go-database/actions/workflows/kafka.yml/badge.svg)](https://github.com/trpc-ecosystem/go-database/actions/workflows/kafka.yml) [![Coverage](https://codecov.io/gh/trpc-ecosystem/go-database/branch/main/graph/badge.svg?flag=kafka&precision=2)](https://app.codecov.io/gh/trpc-ecosystem/go-database/tree/main/kafka) -封装社区的 [sarama](https://github.com/Shopify/sarama) ,配合 trpc 使用。 +封装社区的 [sarama](https://github.com/IBM/sarama) ,配合 trpc 使用。 ## producer client @@ -67,7 +67,7 @@ import ( "trpc.group/trpc-go/trpc-database/kafka" trpc "trpc.group/trpc-go/trpc-go" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" ) func main() { @@ -97,7 +97,7 @@ import ( "trpc.group/trpc-go/trpc-database/kafka" trpc "trpc.group/trpc-go/trpc-go" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" ) func main() { diff --git a/kafka/client.go b/kafka/client.go index 35b4135..3aba891 100644 --- a/kafka/client.go +++ b/kafka/client.go @@ -4,7 +4,7 @@ import ( "context" "fmt" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "trpc.group/trpc-go/trpc-go/client" "trpc.group/trpc-go/trpc-go/codec" ) diff --git a/kafka/client_test.go b/kafka/client_test.go index 19b007a..e43fde5 100644 --- a/kafka/client_test.go +++ b/kafka/client_test.go @@ -5,7 +5,7 @@ import ( "flag" "testing" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "github.com/stretchr/testify/assert" "trpc.group/trpc-go/trpc-go/client" ) diff --git a/kafka/client_transport.go b/kafka/client_transport.go index 058bff9..e0d4206 100644 --- a/kafka/client_transport.go +++ b/kafka/client_transport.go @@ -5,7 +5,7 @@ import ( "sync" "time" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "trpc.group/trpc-go/trpc-go/codec" "trpc.group/trpc-go/trpc-go/errs" "trpc.group/trpc-go/trpc-go/log" diff --git a/kafka/client_transport_test.go b/kafka/client_transport_test.go index 9038ff6..4f3a967 100644 --- a/kafka/client_transport_test.go +++ b/kafka/client_transport_test.go @@ -8,7 +8,7 @@ import ( "testing" "time" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "github.com/stretchr/testify/assert" "trpc.group/trpc-go/trpc-go" "trpc.group/trpc-go/trpc-go/codec" diff --git a/kafka/config.go b/kafka/config.go index 250d7b3..5e9c69c 100644 --- a/kafka/config.go +++ b/kafka/config.go @@ -7,7 +7,7 @@ import ( "sync" "time" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "trpc.group/trpc-go/trpc-go/naming/discovery" "trpc.group/trpc-go/trpc-go/naming/servicerouter" ) diff --git a/kafka/config_parser.go b/kafka/config_parser.go index cfde88f..dc7230d 100644 --- a/kafka/config_parser.go +++ b/kafka/config_parser.go @@ -9,7 +9,7 @@ import ( "strings" "time" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" ) // configParseFunc set UserConfig property. diff --git a/kafka/config_test.go b/kafka/config_test.go index f913221..3890d41 100644 --- a/kafka/config_test.go +++ b/kafka/config_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "github.com/stretchr/testify/assert" "trpc.group/trpc-go/trpc-go/naming/discovery" "trpc.group/trpc-go/trpc-go/naming/registry" diff --git a/kafka/examples/batchconsumer/main.go b/kafka/examples/batchconsumer/main.go index 8e5662d..8276420 100644 --- a/kafka/examples/batchconsumer/main.go +++ b/kafka/examples/batchconsumer/main.go @@ -6,7 +6,7 @@ import ( "context" "fmt" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "trpc.group/trpc-go/trpc-database/kafka" "trpc.group/trpc-go/trpc-go" "trpc.group/trpc-go/trpc-go/log" diff --git a/kafka/examples/consumer/main.go b/kafka/examples/consumer/main.go index 183b766..06d426f 100644 --- a/kafka/examples/consumer/main.go +++ b/kafka/examples/consumer/main.go @@ -5,7 +5,7 @@ package main import ( "context" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "trpc.group/trpc-go/trpc-database/kafka" "trpc.group/trpc-go/trpc-go" "trpc.group/trpc-go/trpc-go/log" diff --git a/kafka/examples/consumer_with_mulit_service/main.go b/kafka/examples/consumer_with_mulit_service/main.go index e8c5f93..de22544 100644 --- a/kafka/examples/consumer_with_mulit_service/main.go +++ b/kafka/examples/consumer_with_mulit_service/main.go @@ -4,7 +4,7 @@ package main import ( "context" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "trpc.group/trpc-go/trpc-database/kafka" "trpc.group/trpc-go/trpc-go" "trpc.group/trpc-go/trpc-go/log" diff --git a/kafka/go.mod b/kafka/go.mod index 7302855..b22c7e1 100644 --- a/kafka/go.mod +++ b/kafka/go.mod @@ -3,10 +3,10 @@ module trpc.group/trpc-go/trpc-database/kafka go 1.20 require ( - github.com/Shopify/sarama v1.29.1 + github.com/IBM/sarama v1.40.1 github.com/golang/mock v1.4.4 github.com/smartystreets/goconvey v1.8.0 - github.com/stretchr/testify v1.8.1 + github.com/stretchr/testify v1.8.4 github.com/xdg-go/scram v1.1.2 golang.org/x/time v0.3.0 gopkg.in/yaml.v3 v3.0.1 @@ -35,15 +35,16 @@ require ( github.com/jcmturner/rpc/v2 v2.0.3 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/jtolds/gls v4.20.0+incompatible // indirect - github.com/klauspost/compress v1.15.14 // indirect + github.com/klauspost/compress v1.16.6 // indirect github.com/lestrrat-go/strftime v1.0.6 // indirect github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/panjf2000/ants/v2 v2.4.6 // indirect - github.com/pierrec/lz4 v2.6.0+incompatible // indirect + github.com/pierrec/lz4/v4 v4.1.17 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect + github.com/rogpeppe/go-internal v1.11.0 // indirect github.com/smartystreets/assertions v1.13.1 // indirect github.com/spf13/cast v1.3.1 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect @@ -54,11 +55,11 @@ require ( go.uber.org/automaxprocs v1.3.0 // indirect go.uber.org/multierr v1.6.0 // indirect go.uber.org/zap v1.24.0 // indirect - golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa // indirect - golang.org/x/net v0.5.0 // indirect - golang.org/x/sync v0.1.0 // indirect - golang.org/x/sys v0.6.0 // indirect - golang.org/x/text v0.6.0 // indirect + golang.org/x/crypto v0.11.0 // indirect + golang.org/x/net v0.12.0 // indirect + golang.org/x/sync v0.3.0 // indirect + golang.org/x/sys v0.10.0 // indirect + golang.org/x/text v0.11.0 // indirect google.golang.org/protobuf v1.30.0 // indirect trpc.group/trpc-go/tnet v0.0.0-20230810071536-9d05338021cf // indirect trpc.group/trpc/trpc-protocol/pb/go/trpc v0.0.0-20230803031059-de4168eb5952 // indirect diff --git a/kafka/go.sum b/kafka/go.sum index 33c77c9..c7aa4b5 100644 --- a/kafka/go.sum +++ b/kafka/go.sum @@ -1,40 +1,31 @@ github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/Shopify/sarama v1.29.1 h1:wBAacXbYVLmWieEA/0X/JagDdCZ8NVFOfS6l6+2u5S0= -github.com/Shopify/sarama v1.29.1/go.mod h1:mdtqvCSg8JOxk8PmpTNGyo6wzd4BMm4QXSfDnTXmgkE= -github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc= -github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= +github.com/IBM/sarama v1.40.1 h1:lL01NNg/iBeigUbT+wpPysuTYW6roHo6kc1QrffRf0k= +github.com/IBM/sarama v1.40.1/go.mod h1:+5OFwA5Du9I6QrznhaMHsuwWdWZNMjaBSIxEWEgKOYE= +github.com/Shopify/toxiproxy/v2 v2.5.0 h1:i4LPT+qrSlKNtQf5QliVjdP08GyAH8+BUIc9gT0eahc= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/eapache/go-resiliency v1.2.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-resiliency v1.3.0 h1:RRL0nge+cWGlxXbUzJ7yMcq6w2XBEr19dCN6HECGaT0= github.com/eapache/go-resiliency v1.3.0/go.mod h1:5yPzW0MIvSe0JDsv0v+DvcjEv2FyD6iZYSs1ZI+iQho= -github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/go-xerial-snappy v0.0.0-20230111030713-bf00bc1b83b6 h1:8yY/I9ndfrgrXUbOGObLHKBR4Fl3nZXwM2c7OYTT8hM= github.com/eapache/go-xerial-snappy v0.0.0-20230111030713-bf00bc1b83b6/go.mod h1:YvSRo5mw33fLEx1+DlK6L2VV43tJt5Eyel9n9XBcR+0= github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= github.com/fortytw2/leaktest v1.3.0 h1:u8491cBMTQ8ft8aeV+adlcytMZylmA5nnwwkRZjI8vw= -github.com/fortytw2/leaktest v1.3.0/go.mod h1:jDsjWgpAGjm2CA7WthBh/CdZYEPF31XHquHwclZch5g= -github.com/frankban/quicktest v1.11.3 h1:8sXhOn0uLys67V8EsXLc6eszDs8VXWxL3iRvebPhedY= -github.com/frankban/quicktest v1.11.3/go.mod h1:wRf/ReqHper53s+kmmSZizM8NamnL3IM0I9ntUbOk+k= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/go-playground/form/v4 v4.2.0 h1:N1wh+Goz61e6w66vo8vJkQt+uwZSoLz50kZPJWR8eic= github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/flatbuffers v2.0.0+incompatible h1:dicJ2oXwypfwUGnB2/TYWYEKiuk9eYQlQO/AnOHl5mI= github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= -github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -54,12 +45,10 @@ github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFK github.com/jcmturner/aescts/v2 v2.0.0/go.mod h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs= github.com/jcmturner/dnsutils/v2 v2.0.0 h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo= github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM= -github.com/jcmturner/gofork v1.0.0/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o= github.com/jcmturner/gofork v1.7.6 h1:QH0l3hzAU1tfT3rZCnW5zXl+orbkNMMRGJfdJjHVETg= github.com/jcmturner/gofork v1.7.6/go.mod h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo= github.com/jcmturner/goidentity/v6 v6.0.1 h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o= github.com/jcmturner/goidentity/v6 v6.0.1/go.mod h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg= -github.com/jcmturner/gokrb5/v8 v8.4.2/go.mod h1:sb+Xq/fTY5yktf/VxLsE3wlfPqQjp0aWNYyvBVK62bc= github.com/jcmturner/gokrb5/v8 v8.4.3 h1:iTonLeSJOn7MVUtyMT+arAn5AKAPrkilzhGw8wE/Tq8= github.com/jcmturner/gokrb5/v8 v8.4.3/go.mod h1:dqRwJGXznQrzw6cWmyo6kH+E7jksEQG/CyVWsJEsJO0= github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY= @@ -69,17 +58,14 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.12.2/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= -github.com/klauspost/compress v1.15.14 h1:i7WCKDToww0wA+9qrUZ1xOjp218vfFo3nTU6UHp+gOc= -github.com/klauspost/compress v1.15.14/go.mod h1:QPwzmACJjUTFsnSHH934V6woptycfrDDJnH7hvFVbGM= +github.com/klauspost/compress v1.16.6 h1:91SKEy4K37vkp255cJ8QesJhjyRO0hn9i9G0GoUwLsk= +github.com/klauspost/compress v1.16.6/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/lestrrat-go/envload v0.0.0-20180220234015-a3eb8ddeffcc h1:RKf14vYWi2ttpEmkA4aQ3j4u9dStX2t4M8UM6qqNsG8= github.com/lestrrat-go/envload v0.0.0-20180220234015-a3eb8ddeffcc/go.mod h1:kopuH9ugFRkIXf3YoqHKyrJ9YfUFsckUU9S7B+XP+is= github.com/lestrrat-go/strftime v1.0.6 h1:CFGsDEt1pOpFNU+TJB0nhz9jl+K0hZSLE205AhTIGQQ= @@ -91,8 +77,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/panjf2000/ants/v2 v2.4.6 h1:drmj9mcygn2gawZ155dRbo+NfXEfAssjZNU1qoIb4gQ= github.com/panjf2000/ants/v2 v2.4.6/go.mod h1:f6F0NZVFsGCp5A7QW/Zj/m92atWwOkY0OIhFxRNFr4A= -github.com/pierrec/lz4 v2.6.0+incompatible h1:Ix9yFKn1nSPBLFl/yZknTp8TU5G4Ps0JDmguYK6iH1A= -github.com/pierrec/lz4 v2.6.0+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/pierrec/lz4/v4 v4.1.17 h1:kV4Ip+/hUBC+8T6+2EgburRtkE9ef4nbY3f4dFhGjMc= +github.com/pierrec/lz4/v4 v4.1.17/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -100,6 +86,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 h1:N/ElC8H3+5XpJzTSTfLsJV/mx9Q9g7kxmchpfZyxgzM= github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/smartystreets/assertions v1.13.1 h1:Ef7KhSmjZcK6AVf9YbJdvPYG9avaF0ZxudX+ThRdWfU= github.com/smartystreets/assertions v1.13.1/go.mod h1:cXr/IwVfSo/RbCSPhoAPv73p3hlSdrBH/b3SdnW/LMY= github.com/smartystreets/goconvey v1.8.0 h1:Oi49ha/2MURE0WexF052Z0m+BNSGirfjg5RL+JXWq3w= @@ -108,16 +96,13 @@ github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= -github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.43.0 h1:Gy4sb32C98fbzVWZlTM1oTMdLWGyvxR03VhM6cBIU4g= @@ -129,8 +114,6 @@ github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= -github.com/xdg/scram v1.0.3/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I= -github.com/xdg/stringprep v1.0.3/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= @@ -144,12 +127,11 @@ go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20201112155050-0c6587e931a9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa h1:zuSxTR4o9y82ebqCUJYNGJbGPo6sKVl54f/TVDObg1c= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.11.0 h1:6Ewdq3tDic1mg5xRO4milcWCfMVQhI4NkqWWvqejpuA= +golang.org/x/crypto v0.11.0/go.mod h1:xgJhtzW8F9jGdVFWZESrid1U1bjeNy4zgy5cRr/CIio= golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= @@ -158,17 +140,16 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200114155413-6afb5195e5aa/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220725212005-46097bf591d3/go.mod h1:AaygXjzTFtRAg2ttMY5RMuhpJ3cNnI0XpyFJD1iQRSM= golang.org/x/net v0.0.0-20220906165146-f3363e06e74c/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= -golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw= -golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= +golang.org/x/net v0.12.0 h1:cfawfvKITfUsFCeJIHJrbSxpeu/E81khclypR0GVT50= +golang.org/x/net v0.12.0/go.mod h1:zEVYFnQC7m/vmpQFELhcD1EWkZlX69l4oqgmer6hfKA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.3.0 h1:ftCYgMx6zT/asHUrPw8BLLscYtGznsLAnjq5RH9P66E= +golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -178,8 +159,8 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0 h1:SqMFp9UcQJZa+pmYuAKjd9xq1f0j5rLcDIk0mj4qAsA= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -187,8 +168,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= -golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.11.0 h1:LAntKIrcmeSKERyiOh0XMV39LXS8IE9UL2yP7+f5ij4= +golang.org/x/text v0.11.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -199,19 +180,16 @@ golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/kafka/kafka.go b/kafka/kafka.go index f793376..7b875ea 100644 --- a/kafka/kafka.go +++ b/kafka/kafka.go @@ -1,5 +1,5 @@ /* -Package kafka encapsulated from github.com/Shopify/sarama +Package kafka encapsulated from github.com/IBM/sarama Producer sending through trpc.Client Implement Consumer logic through trpc.Service */ @@ -12,7 +12,7 @@ import ( "strings" "time" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "golang.org/x/time/rate" "trpc.group/trpc-go/trpc-go" "trpc.group/trpc-go/trpc-go/codec" diff --git a/kafka/kafka_test.go b/kafka/kafka_test.go index 6f0da8a..3984fa0 100644 --- a/kafka/kafka_test.go +++ b/kafka/kafka_test.go @@ -7,7 +7,7 @@ import ( "testing" "time" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "github.com/stretchr/testify/assert" "trpc.group/trpc-go/trpc-go/codec" "trpc.group/trpc-go/trpc-go/transport" @@ -68,19 +68,34 @@ func (tgc *testGroupClaim) Messages() <-chan *sarama.ConsumerMessage { type testGroup struct{} +// Consume implements sarama.ConsumerGroupHandler. func (tg testGroup) Consume(ctx context.Context, topics []string, handler sarama.ConsumerGroupHandler) error { return nil } +// Errors implements sarama.ConsumerGroupHandler. func (tg testGroup) Errors() <-chan error { ec := make(chan error, 1) return ec } +// Close implements sarama.ConsumerGroupHandler. func (tg testGroup) Close() error { return nil } +// Pause implements sarama.ConsumerGroupHandler. +func (tg testGroup) Pause(partitions map[string][]int32) {} + +// Resume implements sarama.ConsumerGroupHandler. +func (tg testGroup) Resume(partitions map[string][]int32) {} + +// PauseAll implements sarama.ConsumerGroupHandler. +func (tg testGroup) PauseAll() {} + +// ResumeAll implements sarama.ConsumerGroupHandler. +func (tg testGroup) ResumeAll() {} + type testGroupSession struct { ctx context.Context cancel context.CancelFunc @@ -127,6 +142,33 @@ func (tap testAsyncProducer) Close() error { return nil } +// TxnStatus return current producer transaction status. +func (tap testAsyncProducer) TxnStatus() sarama.ProducerTxnStatusFlag { + return sarama.ProducerTxnFlagReady +} + +// IsTransactional return true when current producer is transactional. +func (tap testAsyncProducer) IsTransactional() bool { return false } + +// BeginTxn mark current transaction as ready. +func (tap testAsyncProducer) BeginTxn() error { return nil } + +// CommitTxn commit current transaction. +func (tap testAsyncProducer) CommitTxn() error { return nil } + +// AbortTxn abort current transaction. +func (tap testAsyncProducer) AbortTxn() error { return nil } + +// AddOffsetsToTxn add associated offsets to current transaction. +func (tap testAsyncProducer) AddOffsetsToTxn(offsets map[string][]*sarama.PartitionOffsetMetadata, groupId string) error { + return nil +} + +// AddMessageToTxn add message offsets to current transaction. +func (tap testAsyncProducer) AddMessageToTxn(msg *sarama.ConsumerMessage, groupId string, metadata *string) error { + return nil +} + func (tap testAsyncProducer) Input() chan<- *sarama.ProducerMessage { ret := make(chan *sarama.ProducerMessage, 2) return ret @@ -150,6 +192,33 @@ func (t testBlockedAsyncProducer) Close() error { return nil } +// TxnStatus return current producer transaction status. +func (t testBlockedAsyncProducer) TxnStatus() sarama.ProducerTxnStatusFlag { + return sarama.ProducerTxnFlagReady +} + +// IsTransactional return true when current producer is transactional. +func (t testBlockedAsyncProducer) IsTransactional() bool { return false } + +// BeginTxn mark current transaction as ready. +func (t testBlockedAsyncProducer) BeginTxn() error { return nil } + +// CommitTxn commit current transaction. +func (t testBlockedAsyncProducer) CommitTxn() error { return nil } + +// AbortTxn abort current transaction. +func (t testBlockedAsyncProducer) AbortTxn() error { return nil } + +// AddOffsetsToTxn add associated offsets to current transaction. +func (t testBlockedAsyncProducer) AddOffsetsToTxn(offsets map[string][]*sarama.PartitionOffsetMetadata, groupId string) error { + return nil +} + +// AddMessageToTxn add message offsets to current transaction. +func (t testBlockedAsyncProducer) AddMessageToTxn(msg *sarama.ConsumerMessage, groupId string, metadata *string) error { + return nil +} + func (t testBlockedAsyncProducer) Input() chan<- *sarama.ProducerMessage { return make(chan *sarama.ProducerMessage) // block forever } @@ -176,6 +245,33 @@ func (tsp testSyncProducer) Close() error { return nil } +// TxnStatus return current producer transaction status. +func (tsp testSyncProducer) TxnStatus() sarama.ProducerTxnStatusFlag { + return sarama.ProducerTxnFlagReady +} + +// IsTransactional return true when current producer is transactional. +func (tsp testSyncProducer) IsTransactional() bool { return false } + +// BeginTxn mark current transaction as ready. +func (tsp testSyncProducer) BeginTxn() error { return nil } + +// CommitTxn commit current transaction. +func (tsp testSyncProducer) CommitTxn() error { return nil } + +// AbortTxn abort current transaction. +func (tsp testSyncProducer) AbortTxn() error { return nil } + +// AddOffsetsToTxn add associated offsets to current transaction. +func (tsp testSyncProducer) AddOffsetsToTxn(offsets map[string][]*sarama.PartitionOffsetMetadata, groupId string) error { + return nil +} + +// AddMessageToTxn add message offsets to current transaction. +func (tsp testSyncProducer) AddMessageToTxn(msg *sarama.ConsumerMessage, groupId string, metadata *string) error { + return nil +} + func newMsg(offset int64) *sarama.ConsumerMessage { return &sarama.ConsumerMessage{ Timestamp: time.Now(), diff --git a/kafka/mockkafka/kafka_mock.go b/kafka/mockkafka/kafka_mock.go index 95852e8..9779efb 100644 --- a/kafka/mockkafka/kafka_mock.go +++ b/kafka/mockkafka/kafka_mock.go @@ -6,7 +6,7 @@ package mockkafka import ( context "context" - sarama "github.com/Shopify/sarama" + sarama "github.com/IBM/sarama" gomock "github.com/golang/mock/gomock" reflect "reflect" ) diff --git a/kafka/plugin.go b/kafka/plugin.go index cf628ee..d458c52 100644 --- a/kafka/plugin.go +++ b/kafka/plugin.go @@ -3,7 +3,7 @@ package kafka import ( "fmt" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "trpc.group/trpc-go/trpc-go/log" "trpc.group/trpc-go/trpc-go/plugin" ) diff --git a/kafka/scram_auth.go b/kafka/scram_auth.go index 5595ea1..d2d293a 100644 --- a/kafka/scram_auth.go +++ b/kafka/scram_auth.go @@ -6,7 +6,7 @@ import ( "crypto/sha512" "hash" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "github.com/xdg-go/scram" "trpc.group/trpc-go/trpc-go/errs" ) diff --git a/kafka/scram_auth_test.go b/kafka/scram_auth_test.go index 25b6545..8bdbb64 100644 --- a/kafka/scram_auth_test.go +++ b/kafka/scram_auth_test.go @@ -3,7 +3,7 @@ package kafka import ( "testing" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "github.com/golang/mock/gomock" "github.com/xdg-go/scram" "trpc.group/trpc-go/trpc-database/kafka/mockkafka" diff --git a/kafka/server_transport.go b/kafka/server_transport.go index c3fa807..fd3f36c 100644 --- a/kafka/server_transport.go +++ b/kafka/server_transport.go @@ -3,7 +3,7 @@ package kafka import ( "context" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "golang.org/x/time/rate" "trpc.group/trpc-go/trpc-go/log" "trpc.group/trpc-go/trpc-go/transport" diff --git a/kafka/server_transport_test.go b/kafka/server_transport_test.go index b2c8a8c..ab4e1ae 100644 --- a/kafka/server_transport_test.go +++ b/kafka/server_transport_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "github.com/stretchr/testify/assert" "trpc.group/trpc-go/trpc-go/transport" ) diff --git a/kafka/service_desc.go b/kafka/service_desc.go index 46689f4..6eca0e1 100644 --- a/kafka/service_desc.go +++ b/kafka/service_desc.go @@ -3,7 +3,7 @@ package kafka import ( "context" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "trpc.group/trpc-go/trpc-go/codec" "trpc.group/trpc-go/trpc-go/errs" "trpc.group/trpc-go/trpc-go/server" diff --git a/kafka/service_desc_test.go b/kafka/service_desc_test.go index 6d79946..0ed53ae 100644 --- a/kafka/service_desc_test.go +++ b/kafka/service_desc_test.go @@ -5,7 +5,7 @@ import ( "errors" "testing" - "github.com/Shopify/sarama" + "github.com/IBM/sarama" "github.com/stretchr/testify/assert" "trpc.group/trpc-go/trpc-go" "trpc.group/trpc-go/trpc-go/filter" From 9e865e3588e11ff94209976dad7059c90a75707d Mon Sep 17 00:00:00 2001 From: goodliu Date: Mon, 20 May 2024 14:35:18 +0800 Subject: [PATCH 2/3] sync main to r1.0 to release localcache v1.0.0 (#27) * clickhouse: fix go reference API doc (#18) https://pkg.go.dev/trpc.group/trpc-go/trpc-database/clickhouse * kafka: update sarama dependence (#21) * kafka: update sarama dependence * fix unit test * kafka: release v1.1.0 (#22) * workflows: add cla.yaml (#26) * add localcache (#25) * feat: add localcache plugin * chore: update LICENSE * test: add localcache workflow * chore: yaml version * test: flaky test --------- Co-authored-by: Leo Co-authored-by: Flash-LHR <47357603+Flash-LHR@users.noreply.github.com> --- .github/workflows/cla.yml | 32 + .github/workflows/localcache.yml | 33 + LICENSE | 6 + kafka/CHANGELOG.md | 12 +- localcache/CHANGELOG.md | 1 + localcache/README.md | 278 +++++ localcache/README.zh_CN.md | 271 +++++ localcache/cache.go | 700 +++++++++++ localcache/cache_test.go | 1089 ++++++++++++++++++ localcache/examples/custom_callback/main.go | 50 + localcache/examples/custom_load/main.go | 27 + localcache/examples/delayed_deletion/main.go | 28 + localcache/examples/with_expiration/main.go | 31 + localcache/func.go | 37 + localcache/func_test.go | 87 ++ localcache/go.mod | 18 + localcache/go.sum | 35 + localcache/lru.go | 95 ++ localcache/lru_test.go | 102 ++ localcache/mocklocalcache/localcache_mock.go | 204 ++++ localcache/mocklocalcache/localcache_test.go | 112 ++ localcache/policy.go | 23 + localcache/ring.go | 63 + localcache/ring_test.go | 81 ++ localcache/store.go | 113 ++ localcache/store_test.go | 118 ++ localcache/timer.go | 98 ++ localcache/timer_test.go | 163 +++ 28 files changed, 3906 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/cla.yml create mode 100644 .github/workflows/localcache.yml create mode 100644 localcache/CHANGELOG.md create mode 100644 localcache/README.md create mode 100644 localcache/README.zh_CN.md create mode 100644 localcache/cache.go create mode 100644 localcache/cache_test.go create mode 100644 localcache/examples/custom_callback/main.go create mode 100644 localcache/examples/custom_load/main.go create mode 100644 localcache/examples/delayed_deletion/main.go create mode 100644 localcache/examples/with_expiration/main.go create mode 100644 localcache/func.go create mode 100644 localcache/func_test.go create mode 100644 localcache/go.mod create mode 100644 localcache/go.sum create mode 100644 localcache/lru.go create mode 100644 localcache/lru_test.go create mode 100644 localcache/mocklocalcache/localcache_mock.go create mode 100644 localcache/mocklocalcache/localcache_test.go create mode 100644 localcache/policy.go create mode 100644 localcache/ring.go create mode 100644 localcache/ring_test.go create mode 100644 localcache/store.go create mode 100644 localcache/store_test.go create mode 100644 localcache/timer.go create mode 100644 localcache/timer_test.go diff --git a/.github/workflows/cla.yml b/.github/workflows/cla.yml new file mode 100644 index 0000000..395c5d8 --- /dev/null +++ b/.github/workflows/cla.yml @@ -0,0 +1,32 @@ +name: "CLA Assistant" +on: + issue_comment: + types: [created] + pull_request_target: + types: [opened, closed, synchronize, reopened] + +# explicitly configure permissions, in case your GITHUB_TOKEN workflow permissions are set to read-only in repository settings +permissions: + actions: write + contents: write + pull-requests: write + statuses: write + +jobs: + CLAAssistant: + runs-on: ubuntu-latest + steps: + - name: "CLA Assistant" + if: (github.event.comment.body == 'recheck' || github.event.comment.body == 'I have read the CLA Document and I hereby sign the CLA') || github.event_name == 'pull_request_target' + uses: contributor-assistant/github-action@v2.4.0 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PERSONAL_ACCESS_TOKEN: ${{ secrets.CLA_DATABASE_ACCESS_TOKEN }} + with: + remote-organization-name: trpc-group + remote-repository-name: cla-database + path-to-signatures: 'signatures/${{ github.event.repository.name }}-${{ github.repository_id }}/cla.json' + path-to-document: 'https://github.com/trpc-group/cla-database/blob/main/Tencent-Contributor-License-Agreement.md' + # branch should not be protected + branch: 'main' + allowlist: dependabot \ No newline at end of file diff --git a/.github/workflows/localcache.yml b/.github/workflows/localcache.yml new file mode 100644 index 0000000..b3d38d1 --- /dev/null +++ b/.github/workflows/localcache.yml @@ -0,0 +1,33 @@ +name: Localcache Pull Request Check +on: + pull_request: + paths: + - 'localcache/**' + - '.github/workflows/localcache.yml' + push: + paths: + - 'localcache/**' + - '.github/workflows/localcache.yml' + workflow_dispatch: +permissions: + contents: read +jobs: + build: + name: build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-go@v4 + with: + go-version: 1.19 + - name: Build + run: cd localcache && go build -v ./... + - name: Test + run: cd localcache && go test -v -coverprofile=coverage.out ./... + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + with: + files: ./localcache/coverage.out + flags: localcache + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/LICENSE b/LICENSE index 627dbd0..a6f6d0e 100644 --- a/LICENSE +++ b/LICENSE @@ -170,6 +170,12 @@ Copyright (c) 2013 Shopify 10. go-sqlite3 Copyright (c) 2014 Yasuhiro Matsumoto +11. timingwheel +Copyright (c) 2022 Luo Peng + +12. xxhash +Copyright (c) 2016 Caleb Spare + Terms of the MIT License: -------------------------------------------------------------------- diff --git a/kafka/CHANGELOG.md b/kafka/CHANGELOG.md index fa4d35e..34a9e5f 100644 --- a/kafka/CHANGELOG.md +++ b/kafka/CHANGELOG.md @@ -1 +1,11 @@ -# Change Log \ No newline at end of file +# Change Log + +## [1.1.0](https://github.com/trpc-ecosystem/go-database/releases/tag/kafka%2Fv1.1.0) (2023-12-22) + +### Breaking Changes + +- update sarama dependence from to github.com/Shopify/sarama v1.29.1 to github.com/IBM/sarama v1.40.1 (#21) + +### Bug Fixes + +- fix unit test (#21) \ No newline at end of file diff --git a/localcache/CHANGELOG.md b/localcache/CHANGELOG.md new file mode 100644 index 0000000..420e6f2 --- /dev/null +++ b/localcache/CHANGELOG.md @@ -0,0 +1 @@ +# Change Log diff --git a/localcache/README.md b/localcache/README.md new file mode 100644 index 0000000..f7aacde --- /dev/null +++ b/localcache/README.md @@ -0,0 +1,278 @@ +English | [中文](README.zh_CN.md) + +# tRPC-Go localcache plugin + +[![Go Reference](https://pkg.go.dev/badge/trpc.group/trpc-go/trpc-database/localcache.svg)](https://pkg.go.dev/trpc.group/trpc-go/trpc-database/localcache) +[![Go Report Card](https://goreportcard.com/badge/trpc.group/trpc-go/trpc-database/localcache)](https://goreportcard.com/report/trpc.group/trpc-go/trpc-database/localcache) +[![Tests](https://github.com/trpc-ecosystem/go-database/actions/workflows/localcache.yml/badge.svg)](https://github.com/trpc-ecosystem/go-database/actions/workflows/localcache.yml) +[![Coverage](https://codecov.io/gh/trpc-ecosystem/go-database/branch/main/graph/badge.svg?flag=localcache&precision=2)](https://app.codecov.io/gh/trpc-ecosystem/go-database/tree/main/localcache) + +localcache is a standalone local K-V cache component that allows multiple goroutines to access it concurrently and supports LRU and expiration time based elimination policy. +After the capacity of localcache reaches the upper limit, it will carry out data elimination based on LRU, and the deletion of expired key-value is realized based on time wheel. + +**applies to readcache scenarios, not to writecache scenarios.** + + +## Quick Start +Use the functions directly from the localcache package. +```go +package main + +import ( + "trpc.group/trpc-go/trpc-database/localcache" +) + +func LoadData(ctx context.Context, key string) (interface{}, error) { + return "cat", nil +} + +func main() { + // Cache the key-value, and set the expiration time to 5 seconds. + localcache.Set("foo", "bar", 5) + + // Get the value corresponding to the key + value, found := localcache.Get("foo") + + // Get the value corresponding to the key. + // If the key does not exist in the cache, it is loaded from the data source using the custom LoadData function + // and cached in the cache. + // And Set an expiration time of 5 seconds. + value, err := localcache.GetWithLoad(context.TODO(), "tom", LoadData, 5) + + // Delete key + localcache.Del("foo") + + // Clear cache + localcache.Clear() +} +``` + +## Configuring Usage +New() generates a Cache instance and calls the functional functions of that instance +### Optional parameters +#### **WithCapacity(capacity int)** +Sets the maximum size of the cache, with a minimum value of 1 and a maximum value of 1e30. When the cache is full, the last element of the queue is eliminated based on LRU. Default value is 1e30. +#### **WithExpiration(ttl int64)** +Sets the expiration time of the element in seconds. The default value is 60 seconds. +#### **WithLoad(f LoadFunc)** +```go +type LoadFunc func(ctx context.Context, key string) (interface{}, error) +``` +Set data load function, when the key does not exist in cache, use this function to load the corresponding value, and cache it in cache. Use with GetWithLoad(). +#### **WithMLoad(f MLoadFunc)** +```go +type MLoadFunc func(ctx context.Context, keys []string) (map[string]interface{}, error) +``` +Set **bulk** data loading function, use this function to bulk load keys that don't exist in cache. Use with MGetWithLoad(). +#### **WithDelay(duration int64)** +Sets the interval in seconds for delayed deletion of expired keys in cache. The default value is 0, the key is deleted immediately after it expires. + +Usage Scenario: When the key expires, at the same time the data downstream service exception of the business, I hope to be able to get the expired value from the cache in the business as the backing data. + +#### **WithOnDel(delCallBack ItemCallBackFunc)** + +```go +type ItemCallBackFunc func(item *Item) +``` + +Setting the callback function when an element is deleted: expired deletions, active deletions, and LRU deletions all trigger this callback function. + +#### **WithOnExpire(expireCallback ItemCallbackFunc)** + +```go +type ItemCallBackFunc func(item *Item) +``` + +Set the callback function when an element expires. Two callback functions are triggered when an element expires: the expiration callback and the deletion callback. + +#### Cache Interface + +```go +type Cache interface { + // Get returns the value of the key, bool returns true. bool returns false if the key does not exist or is expired. + Get(key string) (interface{}, bool) + + // GetWithLoad returns the value corresponding to the key, if the key does not exist, use the customized load function + // to get the data and cache it. + GetWithLoad(ctx context.Context, key string) (interface{}, error) + + // MGetWithLoad returns values corresponding to multiple keys. when some keys don't exist, use a customized bulk load + // function to fetch the data and cache it. + // For a key that does not exist in the cache, and does not exist in the result of the mLoad function call, the return + // result of MGetWithLoad will contain the key, and the corresponding value will be nil. + MGetWithLoad(ctx context.Context, keys []string) (map[string]interface{}, error) + + // GetWithCustomLoad returns the value corresponding to the key. If the key does not exist, it is loaded using the load + // function passed in and caches the ttl time. + // load function does not exist will return err, if you do not need to pass in the load function every time you get + // please use the option to set load when new cache, and use the GetWithLoad method to get the cache value. + GetWithCustomLoad(ctx context.Context, key string, customLoad LoadFunc, ttl int64) (interface{}, error) + + // MGetWithCustomLoad returns values corresponding to multiple keys. when some keys don't exist, they are loaded using the + // load function passed in and cached ttl time. For a key that does not exist in the cache, and does not exist in the result + // of the mLoad function call, the result of MGetWithLoad contains the key, and the corresponding value is nil. oad function + // does not exist will return err, if you do not need to pass in the load function every time you get please use the option + // to set load when new cache, and use the MGetWithLoad method to get the cache value. + MGetWithCustomLoad(ctx context.Context, keys []string, customLoad MLoadFunc, ttl int64) (map[string]interface{}, error) + + // Set cache key-value + Set(key string, value interface{}) bool + + // SetWithExpire caches key-values, and sets a specific ttl (expiration time in seconds) for a key. + SetWithExpire(key string, value interface{}, ttl int64) bool + + // Delete key + Del(key string) + + // Clear all queues and caches + Clear() + + // Close cache + Close() +} +``` + +## Example +#### Setting capacity and expiration time + +```go +func main() { + var lc localcache.Cache + + // Create a cache with a size of 100 and an element expiration time of 5 seconds. + lc = localcache.New(localcache.WithCapacity(100), localcache.WithExpiration(5)) + + // Set key-value, expiration time 5 seconds + lc.Set("foo", "bar") + + // Set a specific expiration time (10 seconds) for the key, without using the expiration parameter in the New() method + lc.SetWithExpire("tom", "cat", 10) + + // Short wait for asynchronous processing to complete from cache + time.Sleep(time.Millisecond) + + // Get value + val, found := lc.Get("foo") + + // Delete key: "foo" + lc.Del("foo") +} +``` + +### Delayed deletion +```go +func main() { + // Set the delayed deletion interval to 3 seconds + lc := localcache.New(localcache.WithDelay(3)) + // Set key expiration time to 1 second: key expires after 1 second and is removed from cache after 4 seconds + lc.SetWithExpire("tom", "cat", 1) + // sleep 2s + time.Sleep(time.Second * 2) + + value, ok := lc.Get("tom") + if !ok { + // key has expired after 2 seconds of sleep. + fmt.Printf("key:%s is expired or empty\n", "tom") + } + + if s, ok := value.(string); ok && s == "cat" { + // Expired values are still returned, and the business side can decide whether or not to use them. + fmt.Printf("get expired value: %s\n", "cat") + } +} +``` + +#### Custom Load Functions +Set up a custom data loading function and use GetWithLoad(key) to get the value +```go +func main() { + loadFunc := func(ctx context.Context, key string) (interface{}, error) { + return "cat", nil + } + + lc := localcache.New(localcache.WithLoad(loadFunc), localcache.WithExpiration(5)) + + // err is the error message returned directly by the loadFunc function. + val, err := lc.GetWithLoad(context.TODO(), "tom") + + // Or you can pass in the load function at get time + otherLoadFunc := func(ctx context.Context, key string) (interface{}, error) { + return "dog", nil + } + + val,err := lc.GetWithCustomLoad(context.TODO(),"tom",otherLoadFunc,10) +} +``` +Set the bulk data load function and use MGetWithLoad(keys) to get values +```go +func main() { + mLoadFunc := func(ctx context.Context, keys []string) (map[string]interface{}, error) { + return map[string]interface{} { + "foo": "bar", + "tom": "cat", + }, nil + } + + lc := localcache.New(localcache.WithMLoad(mLoadFunc), localcache.WithExpiration(5)) + + // err is the error message returned directly by the mLoadFunc function. + val, err := lc.MGetWithLoad(context.TODO(), []string{"foo", "tom"}) + + // Or you can pass in the load function at get time + val,err := lc.MGetWithCustomLoad(context.TODO(),"tom",mLoadFunc,10) +} +``` + +#### Customizing the callback on expiration/deletion of deleted elements + +```go +func main() { + delCount := map[string]int{"A": 0, "B": 0, "": 0} + expireCount := map[string]int{"A": 0, "B": 0, "": 0} + c := localcache.New( + localcache.WithCapacity(4), + localcache.WithOnExpire(func(item *localcache.Item) { + if item.Flag != localcache.ItemDelete { + return + } + if _, ok := expireCount[item.Key]; ok { + expireCount[item.Key]++ + } + }), + localcache.WithOnDel(func(item *localcache.Item) { + if item.Flag != localcache.ItemDelete { + return + } + if _, ok := delCount[item.Key]; ok { + delCount[item.Key]++ + } + })) + + defer c.Close() + + elems := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 1}, + {"B", "b", 300}, + {"C", "c", 300}, + } + + for _, elem := range elems { + c.SetWithExpire(elem.key, elem.val, elem.ttl) + } + time.Sleep(1001 * time.Millisecond) + fmt.Printf("del info:%v\n", delCount) + fmt.Printf("expire info:%v\n", expireCount) +} +``` + + + +### TODO + +1. Add Metrics statistics +2. increase the control of memory usage +3. Introduce an upgraded version of LRU: W-tinyLRU, which controls key writing and elimination more efficiently. \ No newline at end of file diff --git a/localcache/README.zh_CN.md b/localcache/README.zh_CN.md new file mode 100644 index 0000000..c458a95 --- /dev/null +++ b/localcache/README.zh_CN.md @@ -0,0 +1,271 @@ +[English](README.md) | 中文 + +# tRPC-Go localcache 插件 + +[![Go Reference](https://pkg.go.dev/badge/trpc.group/trpc-go/trpc-database/localcache.svg)](https://pkg.go.dev/trpc.group/trpc-go/trpc-database/localcache) +[![Go Report Card](https://goreportcard.com/badge/trpc.group/trpc-go/trpc-database/localcache)](https://goreportcard.com/report/trpc.group/trpc-go/trpc-database/localcache) +[![Tests](https://github.com/trpc-ecosystem/go-database/actions/workflows/localcache.yml/badge.svg)](https://github.com/trpc-ecosystem/go-database/actions/workflows/localcache.yml) +[![Coverage](https://codecov.io/gh/trpc-ecosystem/go-database/branch/main/graph/badge.svg?flag=localcache&precision=2)](https://app.codecov.io/gh/trpc-ecosystem/go-database/tree/main/localcache) + +localcache是一个单机的本地K-V缓存组件,允许多个goroutine并发访问,支持基于LRU和过期时间的淘汰策略。 +localcache容量达到上限后,会基于LRU进行数据淘汰,过期key-value的删除基于time wheel实现。 + +**适用于读cache场景,而不适用于写cache场景。** + + +## 快速使用 +直接使用localcache包下的功能函数 +```go +package main + +import ( + "trpc.group/trpc-go/trpc-database/localcache" +) + +func LoadData(ctx context.Context, key string) (interface{}, error) { + return "cat", nil +} + +func main() { + // 缓存key-value, 并设置5秒的过期时间 + localcache.Set("foo", "bar", 5) + + // 获取key对应的value + value, found := localcache.Get("foo") + + // 获取key对应的value + // 如果key在cache中不存在,则使用自定义的LoadData函数从数据源加载,并缓存在cache中 + // 同时设置5秒的过期时间 + value, err := localcache.GetWithLoad(context.TODO(), "tom", LoadData, 5) + + // 删除key + localcache.Del("foo") + + // 清空缓存 + localcache.Clear() +} +``` + +## 配置使用 +New()生成Cache实例,调用该实例的功能函数 +### 可选参数 +#### **WithCapacity(capacity int)** +设置cache的最大容量, 最小值1,最大值1e30。缓存满后,则基于LRU淘汰队尾元素。默认值1e30 +#### **WithExpiration(ttl int64)** +设置元素的过期时间,单位秒。默认值60秒。 +#### **WithLoad(f LoadFunc)** +```go +type LoadFunc func(ctx context.Context, key string) (interface{}, error) +``` +设置数据加载函数, key在cache中不存在时,使用该函数加载对应的value, 并缓存在cache中。和GetWithLoad()搭配使用。 +#### **WithMLoad(f MLoadFunc)** +```go +type MLoadFunc func(ctx context.Context, keys []string) (map[string]interface{}, error) +``` +设置**批量**数据加载函数, 在cache中不存在的keys,使用该函数进行批量加载。和MGetWithLoad()搭配使用。 +#### **WithDelay(duration int64)** +设置cache中过期key的延迟删除的时间间隔,单位秒。默认值0,key过期后立即删除。 + +使用场景:当key过期时,同时业务的数据下游服务异常,希望可以从cache中拿到过期的value在业务上做为兜底数据。 + +#### **WithOnDel(delCallBack ItemCallBackFunc)** + +```go +type ItemCallBackFunc func(item *Item) +``` + +设置元素删除时的回调函数:过期删除、主动删除、LRU 删除都会触发该回调函数 + +#### **WithOnExpire(expireCallback ItemCallbackFunc)** + +```go +type ItemCallBackFunc func(item *Item) +``` + +设置元素过期时的回调函数,元素过期时会触发两个回调函数:过期回调,删除回调 + +#### Cache 接口 + +```go +type Cache interface { + // Get 返回key对应的value值,bool返回true。如果key不存在或过期,bool返回false + Get(key string) (interface{}, bool) + + // GetWithLoad 返回key对应的value, 如果key不存在,使用自定义的加载函数获取数据,并缓存 + GetWithLoad(ctx context.Context, key string) (interface{}, error) + + // MGetWithLoad 返回多个key对应的values。当某些keys不存在时,使用自定义的批量加载函数获取数据,并缓存 + // 对于cache中不存在, 且在mLoad函数的调用结果中不存在的key,MGetWithLoad的返回结果中,会包含该key,且对应的value为nil + MGetWithLoad(ctx context.Context, keys []string) (map[string]interface{}, error) + + // GetWithCustomLoad 返回key对应的value, 如果key不存在,则使用传入的load函数加载并缓存ttl时间 + // load函数不存在会返回err,如果不需要每次get时都传入load函数请在new cache时使用option方式设置load,并使用GetWithLoad方法获取缓存值 + GetWithCustomLoad(ctx context.Context, key string, customLoad LoadFunc, ttl int64) (interface{}, error) + + // MGetWithCustomLoad 返回多个key对应的values。当某些keys不存在时,则使用传入的load函数加载并缓存ttl时间 + // 对于cache中不存在, 且在mLoad函数的调用结果中不存在的key,MGetWithLoad的返回结果中,包含该key,且对应的value为nil + // load函数不存在会返回err,如果不需要每次get时都传入load函数请在new cache时使用option方式设置load,并使用MGetWithLoad方法获取缓存值 + MGetWithCustomLoad(ctx context.Context, keys []string, customLoad MLoadFunc, ttl int64) (map[string]interface{}, error) + + // Set 缓存key-value + Set(key string, value interface{}) bool + + // SetWithExpire 缓存key-value, 并为某个key设置特定的ttl(过期时间,单位秒) + SetWithExpire(key string, value interface{}, ttl int64) bool + + // Del 删除key + Del(key string) + + // Clear 清空所有队列和缓存 + Clear() + + // Close 关闭cache + Close() +} +``` + +## 使用示例 +#### 设置容量和过期时间 + +```go +func main() { + var lc localcache.Cache + + // 创建一个容量大小100, 元素过期时间5秒的缓存 + lc = localcache.New(localcache.WithCapacity(100), localcache.WithExpiration(5)) + + // 设置key-value, 过期时间5秒 + lc.Set("foo", "bar") + + // 为key设置特定的过期时间(10秒),不使用New()方法中的过期参数 + lc.SetWithExpire("tom", "cat", 10) + + // 短暂地等待,从缓存中异步处理完成 + time.Sleep(time.Millisecond) + + // 获取value + val, found := lc.Get("foo") + fmt.Println(val, found) + + // 删除key: "foo" + lc.Del("foo") +} +``` + +### 延迟删除 +```go +func main() { + // 设置延迟删除的时间间隔为3秒 + lc := localcache.New(localcache.WithDelay(3)) + // 设置key的过期时间为1秒: key在1秒后过期,4秒后从cache中删除 + lc.SetWithExpire("tom", "cat", 1) + // sleep 2秒 + time.Sleep(time.Second * 2) + + value, ok := lc.Get("tom") + if !ok { + // key 在sleep 2 秒后已过期 + fmt.Printf("key:%s is expired or empty", "tom") + } + + if s, ok := value.(string); ok && s == "cat" { + // 过期的value依旧返回了,业务侧可以决定是否使用过期的value + fmt.Printf("get expired value: %s\n", "cat") + } +} +``` + +#### 自定义加载函数 +设置自定义数据加载函数,并使用GetWithLoad(key)获取value +```go +func main() { + loadFunc := func(ctx context.Context, key string) (interface{}, error) { + return "cat", nil + } + + lc := localcache.New(localcache.WithLoad(loadFunc), localcache.WithExpiration(5)) + + // err 为loadFunc函数直接返回的error信息 + val, err := lc.GetWithLoad(context.TODO(), "tom") + + // 或者可以在get时传入load函数 + otherLoadFunc := func(ctx context.Context, key string) (interface{}, error) { + return "dog", nil + } + + val,err := lc.GetWithCustomLoad(context.TODO(),"tom",otherLoadFunc,10) +} +``` +设置批量数据加载函数,并使用MGetWithLoad(keys)获取values +```go +func main() { + mLoadFunc := func(ctx context.Context, keys []string) (map[string]interface{}, error) { + return map[string]interface{} { + "foo": "bar", + "tom": "cat", + }, nil + } + + lc := localcache.New(localcache.WithMLoad(mLoadFunc), localcache.WithExpiration(5)) + + // err 为mLoadFunc函数直接返回的error信息 + val, err := lc.MGetWithLoad(context.TODO(), []string{"foo", "tom"}) + + // 或者可以在get时传入load函数 + val,err := lc.MGetWithCustomLoad(context.TODO(),"tom",mLoadFunc,10) +} +``` + +#### 自定义删除元素过期/删除时的回调 + +```go +func main() { + delCount := map[string]int{"A": 0, "B": 0, "": 0} + expireCount := map[string]int{"A": 0, "B": 0, "": 0} + c := localcache.New( + localcache.WithCapacity(4), + localcache.WithOnExpire(func(item *localcache.Item) { + if item.Flag != localcache.ItemDelete { + return + } + if _, ok := expireCount[item.Key]; ok { + expireCount[item.Key]++ + } + }), + localcache.WithOnDel(func(item *localcache.Item) { + if item.Flag != localcache.ItemDelete { + return + } + if _, ok := delCount[item.Key]; ok { + delCount[item.Key]++ + } + })) + + defer c.Close() + + elems := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 1}, + {"B", "b", 300}, + {"C", "c", 300}, + } + + for _, elem := range elems { + c.SetWithExpire(elem.key, elem.val, elem.ttl) + } + time.Sleep(1001 * time.Millisecond) + fmt.Printf("del info:%v\n", delCount) + fmt.Printf("expire info:%v\n", expireCount) +} +``` + + + +### TODO + +1. 增加Metrics数据统计 +2. 增加对内存使用量的控制 +3. 引入升级版的LRU:W-tinyLRU,更高效地控制key的写入和淘汰 \ No newline at end of file diff --git a/localcache/cache.go b/localcache/cache.go new file mode 100644 index 0000000..1ab3d0c --- /dev/null +++ b/localcache/cache.go @@ -0,0 +1,700 @@ +// Package localcache is a stand-alone local K-V cache component that allows concurrent access by multiple goroutines. +// The cache eliminates elements based on LRU and expiration time strategies. +package localcache + +import ( + "container/list" + "context" + "errors" + "fmt" + "sync" + "time" +) + +//go:generate mockgen -source=cache.go -destination=./mocklocalcache/localcache_mock.go -package=mocklocalcache + +const ( + // The maximum size of the cache + maxCapacity = 1 << 30 + + ringBufSize = 16 + + // Asynchronous metadata for cache operations + setBufSize = 1 << 15 + updateBufSize = 1 << 15 + delBufSize = 1 << 15 + expireBufSize = 1 << 15 + + // The default expiration time is 60s + ttl = 60 +) + +// CachedStatus cache status +type CachedStatus int + +const ( + // CacheNotExist cache data does not exist + CacheNotExist CachedStatus = 1 + // CacheExist cache data exists + CacheExist CachedStatus = 2 + // CacheExpire cache data exists but has expired. You can choose whether to use it + CacheExpire CachedStatus = 3 +) + +// ErrCacheExpire the cache exists but has expired +var ErrCacheExpire = errors.New("cache exist, but expired") + +// currentTime is an alias of time.Now, which is convenient for specifying the current time during testing +var currentTime = time.Now + +// Cache is a local K-V memory store that supports expiration time +type Cache interface { + // Get returns the value corresponding to key, bool returns true. + // If the key does not exist or expires, bool returns false + Get(key string) (interface{}, bool) + // GetWithStatus returns the value corresponding to the key and returns the cache status + GetWithStatus(key string) (interface{}, CachedStatus) + // GetWithLoad returns the value corresponding to the key. + // If the key does not exist, use the user-defined loading function to obtain the data and cache it. + GetWithLoad(ctx context.Context, key string) (interface{}, error) + // MGetWithLoad returns values corresponding to multiple keys. + // When some keys do not exist, use a custom batch loading function to obtain data and cache it + // For a key that does not exist in the cache and does not exist in the calling result of the mLoad function, + // the return result of MGetWithLoad includes the key and the corresponding value is nil. + MGetWithLoad(ctx context.Context, keys []string) (map[string]interface{}, error) + // GetWithCustomLoad returns the value corresponding to the key. + // If the key does not exist, the passed in load function is used to load and cache the ttl time. + // If the load function does not exist, err will be returned. + // If you do not need to pass in the load function every time you get it, please use the option method + // to set the load in the new cache, and use the GetWithLoad method to obtain the cache value. + GetWithCustomLoad(ctx context.Context, key string, customLoad LoadFunc, ttl int64) (interface{}, error) + // MGetWithCustomLoad returns values corresponding to multiple keys. + // When some keys do not exist, the passed in load function is used to load and cache the ttl time. + // For a key that does not exist in the cache and does not exist in the calling result of the mLoad function, + // the return result of MGetWithLoad includes the key and the corresponding value is nil. + // If the load function does not exist, err will be returned. + // If you do not need to pass in the load function every time you get it, please use the option method to set + // the load in the new cache, and use the MGetWithLoad method to obtain the cache value. + MGetWithCustomLoad(ctx context.Context, keys []string, customLoad MLoadFunc, ttl int64) (map[string]interface{}, error) + // Set key and value + Set(key string, value interface{}) bool + // SetWithExpire sets key, value, and sets different ttl (expiration time in seconds) for different keys + SetWithExpire(key string, value interface{}, ttl int64) bool + // Del delete key + Del(key string) + // Len key quantity + Len() int + // Clear clears all queues and caches + Clear() + // Close Close cache + Close() +} + +type eleWithFinish struct { + ele *list.Element + finish func() +} + +type entWithFinish struct { + ent *entry + finish func() +} + +type keyWithFinish struct { + key string + finish func() +} + +// cache K-V memory storage +type cache struct { + capacity int + store store + + // key entry and elimination strategies + policy policy + + getBuf *ringBuffer + elementsCh chan []*list.Element + + setBuf chan *entWithFinish + updateBuf chan *eleWithFinish + delBuf chan *keyWithFinish + expireBuf chan string + + g group + load LoadFunc + mLoad MLoadFunc + + // Element expiration time (seconds) + ttl int64 + // Delay time for deleting expired keys + delay int64 + // Delete the task queue of expired key + expireQueue *expireQueue + + stop chan struct{} + + // Triggered when deleted: deletion triggered by element expiration, active deletion, deletion triggered by lru + onDel ItemCallBackFunc + // Triggered on expiration + onExpire ItemCallBackFunc + + // syncUpdateFlag data setting and updating method. + // When it is true, the synchronous method is used to set or update the data. + // Otherwise, the asynchronous method is used to set the data by default. + syncUpdateFlag bool + // settingTimeout timeout for synchronizing setting data or updating data + settingTimeout time.Duration + // syncDelFlag is the data deletion method. + // If it is true, the data will be deleted synchronously. + // Otherwise, the data will be deleted asynchronously by default. + syncDelFlag bool +} + +// LoadFunc loads the value data corresponding to the key and is used to fill the cache +type LoadFunc func(ctx context.Context, key string) (interface{}, error) + +// MLoadFunc loads the value data of multiple keys in batches to fill the cache +type MLoadFunc func(ctx context.Context, keys []string) (map[string]interface{}, error) + +// ItemCallBackFunc callback function triggered when the element expires/deletes +type ItemCallBackFunc func(*Item) + +// ItemFlag The event type that triggers the callback +type ItemFlag int + +const ( + // ItemDelete Triggered when actively deleted/expired + ItemDelete ItemFlag = iota + // ItemLruDel LRU triggered deletion + ItemLruDel +) + +// Item The element that triggered the callback event +type Item struct { + Flag ItemFlag + Key string + Value interface{} +} + +// call refers to the implementation of singleflight, used for loading user-defined data +type call struct { + wg sync.WaitGroup + val interface{} + err error +} + +// group refers to the implementation of singleflight and is used for user-defined data loading +type group struct { + mu sync.Mutex // protects m + m map[string]*call // lazily initialized +} + +// Option parameter tool function +type Option func(*cache) + +// WithCapacity sets the maximum number of keys +func WithCapacity(capacity int) Option { + if capacity > maxCapacity { + capacity = maxCapacity + } + if capacity <= 0 { + capacity = 1 + } + return func(c *cache) { + c.capacity = capacity + } +} + +// WithExpiration sets the expiration time of the element (in seconds). +// If the expiration time is not set, the default is 60 seconds. +func WithExpiration(ttl int64) Option { + if ttl <= 0 { + ttl = 1 + } + return func(c *cache) { + c.ttl = ttl + } +} + +// WithLoad sets a custom data loading function +func WithLoad(f LoadFunc) Option { + return func(c *cache) { + c.load = f + } +} + +// WithMLoad sets a custom data batch loading function +func WithMLoad(f MLoadFunc) Option { + return func(c *cache) { + c.mLoad = f + } +} + +// WithDelay delay deletion interval of expired keys (unit seconds) +func WithDelay(duration int64) Option { + return func(c *cache) { + c.delay = duration + } +} + +// WithOnDel sets the callback function when the element is deleted +func WithOnDel(delCallBack ItemCallBackFunc) Option { + return func(c *cache) { + c.onDel = delCallBack + } +} + +// WithOnExpire sets the callback function triggered when the element expires +func WithOnExpire(expireCallback ItemCallBackFunc) Option { + return func(c *cache) { + c.onExpire = expireCallback + } +} + +// WithSettingTimeout causes elements to be written to the store synchronously, and a timeout needs to be set. +func WithSettingTimeout(t time.Duration) Option { + return func(c *cache) { + c.syncUpdateFlag = true + c.settingTimeout = t + } +} + +// WithSyncDelFlag sets the synchronous method to delete elements in the store when deleting elements. +func WithSyncDelFlag(flag bool) Option { + return func(c *cache) { + c.syncDelFlag = flag + } +} + +// New generate cache object +func New(opts ...Option) Cache { + // Initialize cache with default values + cache := &cache{ + capacity: maxCapacity, + store: newStore(), + + elementsCh: make(chan []*list.Element, 3), + setBuf: make(chan *entWithFinish, setBufSize), + updateBuf: make(chan *eleWithFinish, updateBufSize), + delBuf: make(chan *keyWithFinish, delBufSize), + expireBuf: make(chan string, expireBufSize), + + ttl: ttl, + expireQueue: newExpireQueue(time.Second, 60), + stop: make(chan struct{}), + } + // Set cache using passed parameters + for _, opt := range opts { + opt(cache) + } + + cache.policy = newPolicy(cache.capacity, cache.store) + cache.getBuf = newRingBuffer(cache, ringBufSize) + + go cache.processEntries() + + return cache +} + +// processEntries asynchronously processes cache operations +func (c *cache) processEntries() { + for { + select { + case elements := <-c.elementsCh: + c.access(elements) + case buf := <-c.setBuf: + c.add(buf.ent) + if buf.finish != nil { + buf.finish() + } + case buf := <-c.updateBuf: + c.update(buf.ele) + if buf.finish != nil { + buf.finish() + } + case buf := <-c.delBuf: + c.del(buf.key) + if buf.finish != nil { + buf.finish() + } + case key := <-c.expireBuf: + c.expire(key) + case <-c.stop: + return + } + } +} + +// Get returns the value corresponding to key, bool returns true. +// If the key does not exist or expires, bool returns false +func (c *cache) Get(key string) (interface{}, bool) { + value, status := c.GetWithStatus(key) + if status == CacheExist { + return value, true + } + return value, false +} + +// GetWithStatus returns the value corresponding to the key. +// Since the user may cache nil and cannot distinguish the data, CachedStatus is used to represent the return status. +func (c *cache) GetWithStatus(key string) (interface{}, CachedStatus) { + if c == nil { + return nil, CacheNotExist + } + + value, hit := c.store.get(key) + if hit { + ele, _ := value.(*list.Element) + ent := getEntry(ele) + ent.mux.RLock() + defer ent.mux.RUnlock() + if ent.expireTime.Before(currentTime()) { + return ent.value, CacheExpire + } + c.getBuf.push(ele) + return ent.value, CacheExist + } + + return nil, CacheNotExist +} + +// GetWithLoad returns the value corresponding to the key. +// If the key does not exist, use the user-defined filling function to load the data and return it, and cache it. +func (c *cache) GetWithLoad(ctx context.Context, key string) (interface{}, error) { + return c.GetWithCustomLoad(ctx, key, c.load, c.ttl) +} + +// MGetWithLoad returns values corresponding to multiple keys. +// When some keys do not exist, use a custom batch loading function to obtain data and cache it +// For a key that does not exist in the cache and does not exist in the calling result of the mLoad function, +// the return result of MGetWithLoad includes the key and the corresponding value is nil. +func (c *cache) MGetWithLoad(ctx context.Context, keys []string) (map[string]interface{}, error) { + return c.MGetWithCustomLoad(ctx, keys, c.mLoad, c.ttl) +} + +// GetWithCustomLoad returns the value corresponding to the key. +// If the key does not exist, the passed in load function is used to load and cache the ttl time. +// If the load function does not exist, err will be returned. +// If you do not need to pass in the load function every time you get it, please use the option method to set +// the load in the new cache, and use the GetWithLoad method to obtain the cache value. +func (c *cache) GetWithCustomLoad(ctx context.Context, key string, customLoad LoadFunc, ttl int64) ( + interface{}, error) { + if customLoad == nil { + return nil, errors.New("undefined LoadFunc in cache") + } + + val, status := c.GetWithStatus(key) + if status == CacheExist { + return val, nil + } + latest, err := c.loadData(ctx, key, customLoad, ttl) + if err != nil { + if status == CacheExpire { + return val, fmt.Errorf("load key %s err %v, %w", key, err, ErrCacheExpire) + } + return nil, err + } + return latest, nil +} + +// MGetWithCustomLoad returns values corresponding to multiple keys. +// When some keys do not exist, the passed in load function is used to load and cache the ttl time. +// For a key that does not exist in the cache and does not exist in the calling result of the mLoad function, +// the return result of MGetWithLoad includes the key and the corresponding value is nil. +// If the load function does not exist, err will be returned. +// If you do not need to pass in the load function every time you get it, please use the option method to set the +// load in the new cache, and use the MGetWithLoad method to obtain the cache value. +func (c *cache) MGetWithCustomLoad(ctx context.Context, keys []string, customMLoad MLoadFunc, ttl int64) ( + map[string]interface{}, error) { + if customMLoad == nil { + return nil, errors.New("undefined MLoadFunc in cache") + } + values := make(map[string]interface{}, len(keys)) + var noCacheKeys []string + for _, key := range keys { + value, ok := c.Get(key) + if !ok { + noCacheKeys = append(noCacheKeys, key) + } + values[key] = value + } + if len(noCacheKeys) == 0 { + return values, nil + } + return c.loadNoCacheKeys(ctx, customMLoad, noCacheKeys, values, ttl) +} + +func (c *cache) loadNoCacheKeys(ctx context.Context, + mLoad MLoadFunc, + noCacheKeys []string, + values map[string]interface{}, + ttl int64) (map[string]interface{}, error) { + latest, err := mLoad(ctx, noCacheKeys) + if err != nil { + return values, err + } + for key, value := range latest { + values[key] = value + c.SetWithExpire(key, value, ttl) + } + + return values, nil +} + +// Set key, value +func (c *cache) Set(key string, value interface{}) bool { + return c.SetWithExpire(key, value, c.ttl) +} + +// SetWithExpire sets key, value, time to live (seconds), and sets different expiration times for different elements +func (c *cache) SetWithExpire(key string, value interface{}, ttl int64) bool { + if c == nil { + return false + } + expireTime := currentTime().Add(time.Second * time.Duration(ttl)) + + val, hit := c.store.get(key) + if hit { + // If the key exists, immediately update the latest value in the storage to prevent Get from obtaining dirty data. + ele, _ := val.(*list.Element) + oldEnt := getEntry(ele) + + oldEnt.mux.Lock() + oldEnt.value = value + oldEnt.expireTime = expireTime + oldEnt.mux.Unlock() + if c.syncUpdateFlag { + waitFinish := make(chan struct{}, 1) + select { + case c.updateBuf <- &eleWithFinish{ + ele, + func() { + close(waitFinish) + }, + }: + <-waitFinish + return true + case <-time.After(c.settingTimeout): + return false + } + } else { + select { + case c.updateBuf <- &eleWithFinish{ele, nil}: + default: + } + } + c.expireQueue.update(key, expireTime.Add(time.Duration(c.delay)*time.Second), c.afterExpire(key)) + return true + } + + ent := &entry{ + key: key, + value: value, + expireTime: expireTime, + } + // Add new key and value. In the extreme case where syncSet is false, the Set operation is not guaranteed to + // be successful. + // Because of asynchronous processing and heavy load, there may be a ms-level delay in the Set result. + if c.syncUpdateFlag { + waitFinish := make(chan struct{}, 1) + select { + case c.setBuf <- &entWithFinish{ + ent, + func() { + close(waitFinish) + }, + }: + <-waitFinish + return true + case <-time.After(c.settingTimeout): + return false + } + } + select { + case c.setBuf <- &entWithFinish{ent, nil}: + return true + default: + return false + } +} + +// Del deletes key, supports synchronous deletion and asynchronous deletion +func (c *cache) Del(key string) { + if c == nil { + return + } + + // Enable synchronous deletion, block and wait for deletion to complete before returning + if c.syncDelFlag { + waitFinish := make(chan struct{}, 1) + c.delBuf <- &keyWithFinish{ + key: key, + finish: func() { + close(waitFinish) + }, + } + <-waitFinish + return + } + c.delBuf <- &keyWithFinish{key: key} +} + +// Clear clears all queues and caches. +// It is a non-atomic operation and should be called after there are no Get and Set operations. +func (c *cache) Clear() { + // Block until processEntries goroutine ends + c.stop <- struct{}{} + + c.elementsCh = make(chan []*list.Element, 3) + c.setBuf = make(chan *entWithFinish, setBufSize) + c.updateBuf = make(chan *eleWithFinish, updateBufSize) + c.delBuf = make(chan *keyWithFinish, delBufSize) + c.expireBuf = make(chan string, expireBufSize) + + c.store.clear() + c.policy.clear() + c.expireQueue.clear() + + // Restart processEntries goroutine + go c.processEntries() +} + +// Close cache +func (c *cache) Close() { + // Block until processEntries goroutine ends + c.stop <- struct{}{} + close(c.stop) + + close(c.elementsCh) + close(c.setBuf) + close(c.updateBuf) + close(c.delBuf) + close(c.expireBuf) + + c.expireQueue.stop() +} + +// access is called asynchronously to handle access operations +func (c *cache) access(elements []*list.Element) { + c.policy.push(elements) +} + +// add is called asynchronously to handle new operations +func (c *cache) add(ent *entry) { + // Store new key-value. + // After reaching the upper limit of capacity, return the eliminated entry. + key := ent.key + victimEnt := c.policy.add(ent) + + expireTime := ent.expireTime.Add(time.Second * time.Duration(c.delay)) + c.expireQueue.add(key, expireTime, c.afterExpire(key)) + + // Remove eliminated entries from the expiration queue + if victimEnt != nil { + c.expireQueue.remove(victimEnt.key) + if c.onDel != nil { + c.onDel(&Item{ItemLruDel, victimEnt.key, victimEnt.value}) + } + } +} + +// update is called asynchronously to handle update operations +func (c *cache) update(ele *list.Element) { + c.policy.hit(ele) +} + +// del is called asynchronously to handle the deletion operation +func (c *cache) del(key string) { + delEnt := c.policy.del(key) + c.expireQueue.remove(key) + if delEnt != nil && c.onDel != nil { + c.onDel(&Item{ItemDelete, delEnt.key, delEnt.value}) + } +} + +// expire is called asynchronously to process expired data +func (c *cache) expire(key string) { + delEnt := c.policy.del(key) + + if delEnt != nil && c.onExpire != nil { + c.onExpire(&Item{ItemDelete, delEnt.key, delEnt.value}) + } + + if delEnt != nil && c.onDel != nil { + c.onDel(&Item{ItemDelete, delEnt.key, delEnt.value}) + } +} + +// loadData uses user-defined functions to load and cache data +func (c *cache) loadData(ctx context.Context, key string, load LoadFunc, ttl int64) (interface{}, error) { + // Refer to singleflight implementation to prevent cache breakdown + c.g.mu.Lock() + if c.g.m == nil { + c.g.m = make(map[string]*call) + } + if call, ok := c.g.m[key]; ok { + c.g.mu.Unlock() + call.wg.Wait() + return call.val, call.err + } + call := new(call) + call.wg.Add(1) + c.g.m[key] = call + c.g.mu.Unlock() + if c.syncUpdateFlag { + // Try to read the cache and return if the cache is read. Otherwise load load + value, hit := c.store.get(key) + if hit { + ele, _ := value.(*list.Element) + ent := getEntry(ele) + ent.mux.RLock() + if ent.expireTime.After(currentTime()) { + ent.mux.RUnlock() + return ent.value, nil + } + ent.mux.RUnlock() + } + } + call.val, call.err = load(ctx, key) + if call.err == nil { + if ok := c.SetWithExpire(key, call.val, ttl); !ok { + call.err = fmt.Errorf("set key [%s] fail", key) + } + } + + c.g.mu.Lock() + call.wg.Done() + delete(c.g.m, key) + c.g.mu.Unlock() + + return call.val, call.err +} + +// afterExpire returns the callback task after the key expires +func (c *cache) afterExpire(key string) func() { + return func() { + select { + case c.expireBuf <- key: + default: + } + } +} + +// push writes read requests into the channel in batches +func (c *cache) push(elements []*list.Element) bool { + if len(elements) == 0 { + return true + } + select { + case c.elementsCh <- elements: + return true + default: + return false + } +} + +// Len key quantity +func (c *cache) Len() int { + return c.store.len() +} diff --git a/localcache/cache_test.go b/localcache/cache_test.go new file mode 100644 index 0000000..dc97d30 --- /dev/null +++ b/localcache/cache_test.go @@ -0,0 +1,1089 @@ +package localcache + +import ( + "context" + "errors" + "fmt" + "math/rand" + "reflect" + "strings" + "sync" + "testing" + "time" +) + +var wait = 10 * time.Millisecond + +// TestCacheSetRace tests concurrency competition of Cache Set +func TestCacheSetRace(t *testing.T) { + cache := New(WithExpiration(100)) + n := 8128 + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + cache.Set("foo", "bar") + cache.Get("foo") + wg.Done() + }() + } + wg.Wait() +} + +// TestCacheSetGet tests Cache Set first and then Get +func TestCacheSetGet(t *testing.T) { + mocks := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 1000}, + {"B", "b", 1000}, + {"", "null", 1000}, + } + C := New() + c := C.(*cache) + defer c.Close() + + for _, mock := range mocks { + c.SetWithExpire(mock.key, mock.val, mock.ttl) + } + + time.Sleep(wait) + for _, mock := range mocks { + val, found := c.Get(mock.key) + if !found || val.(string) != mock.val { + t.Fatalf("Unexpected value: %v (%v) to key: %v", val, found, mock.key) + } + } + + // update + c.SetWithExpire(mocks[0].key, mocks[0].key+"foobar", 1000) + val, found := c.Get(mocks[0].key) + if !found || val.(string) == mocks[0].key { + t.Fatalf("Unexpected value: %v (%v) to key: %v, want: %v", val, found, mocks[0].key, mocks[0].val) + } + + // set struct + type Foo struct { + Name string + Age int + } + valStruct := Foo{ + "Bob", + 18, + } + valPtr := &valStruct + c.SetWithExpire("foo", valStruct, 1000) + c.SetWithExpire("bar", valPtr, 1000) + time.Sleep(wait) + if val, found := c.Get("foo"); !found || val.(Foo) != valStruct { + t.Fatalf("Unexpected value: %v (%v) to key: %v, want: %v", val, found, "foo", valStruct) + } + if val, found := c.Get("bar"); !found || val.(*Foo) != valPtr { + t.Fatalf("Unexpected value: %v (%v) to key: %v, want: %v", val, found, "foo", valPtr) + } +} + +// TestCacheMaxCap tests whether MaxCap meets expectations after Cache WithCapacity +func TestCacheMaxCap(t *testing.T) { + C := New(WithCapacity(2)) + c := C.(*cache) + defer c.Close() + + mocks := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 1000}, + {"B", "b", 1000}, + {"", "null", 1000}, + } + + for _, mock := range mocks { + c.SetWithExpire(mock.key, mock.val, mock.ttl) + } + + time.Sleep(wait) + if c.store.len() != 2 { + t.Fatalf("unexpected Length:%d, want:%d", c.store.len(), 2) + } +} + +// TestCacheExpire tests the expiration function of Cache +func TestCacheExpire(t *testing.T) { + C := New(WithCapacity(2), WithExpiration(1)) + c := C.(*cache) + defer c.Close() + + c.Set("A", "a") + c.Set("B", "b") + + time.Sleep(1 * time.Second) + if c.store.len() != 0 { + t.Fatalf("unexpected Length:%d, want:%d", c.store.len(), 0) + } + if val, found := c.Get("A"); found { + t.Fatalf("unexpected expired value: %v to key: %v", val, "A") + } + if val, found := c.Get("B"); found { + t.Fatalf("unexpected expired value: %v to key: %v", val, "B") + } +} + +// TestCacheSpecExpire tests the expiration function of Cache for the specified key +func TestCacheSpecExpire(t *testing.T) { + C := New(WithCapacity(2)) + c := C.(*cache) + defer c.Close() + + mocks := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 1}, + {"B", "b", 1}, + {"", "null", 1}, + } + + for _, mock := range mocks { + c.SetWithExpire(mock.key, mock.val, mock.ttl) + } + + time.Sleep(time.Second) + if c.store.len() != 0 { + t.Fatalf("unexpected Length:%d, want:%d", c.store.len(), 0) + } + for _, mock := range mocks { + val, found := c.Get(mock.key) + if found { + t.Fatalf("unexpected expired value: %v to key: %v", val, mock.key) + } + } +} + +// TestCacheGetWithDelay tests the delayed deletion function of expired keys +func TestCacheGetWithDelay(t *testing.T) { + C := New(WithDelay(2)) + defer C.Close() + + C.SetWithExpire("A", "B", 1) + time.Sleep(time.Second) + value, ok := C.Get("A") + if ok { + t.Fatalf("unexpected found: %v", ok) + } + if s, ok := value.(string); !ok || s != "B" { + t.Fatalf("unexpected value: %v, want: %s", s, "B") + } +} + +// TestCacheGetWithStatus tests the delayed deletion status of expired keys +func TestCacheGetWithStatus(t *testing.T) { + C := New(WithDelay(2)) + defer C.Close() + + C.SetWithExpire("A", "B", 1) + time.Sleep(time.Second) + value, status := C.GetWithStatus("A") + if status != CacheExpire { + t.Fatalf("unexpected status: %v", status) + } + if s, ok := value.(string); !ok || s != "B" { + t.Fatalf("unexpected value: %v, want: %s", s, "B") + } +} + +// TestCacheMGetWithLoad tests the WithMLoad function of Cache +func TestCacheMGetWithLoad(t *testing.T) { + m := map[string]interface{}{ + "A": "a", + "B": "b", + "": "null", + } + mLoadFunc := func(ctx context.Context, keys []string) (map[string]interface{}, error) { + ret := make(map[string]interface{}, 0) + for _, key := range keys { + if v, ok := m[key]; ok { + ret[key] = v + } + } + return ret, nil + } + + C := New(WithMLoad(mLoadFunc)) + c := C.(*cache) + defer c.Close() + + type args struct { + ctx context.Context + keys []string + } + tests := []struct { + name string + args args + want map[string]interface{} + wantErr bool + }{ + {"A", args{context.Background(), []string{"A", "B"}}, map[string]interface{}{"A": "a", "B": "b"}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := c.MGetWithLoad(tt.args.ctx, tt.args.keys) + time.Sleep(wait) + if (err != nil) != tt.wantErr { + t.Errorf("cache.GetWithLoad() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("cache.GetWithLoad() = %v, want %v", got, tt.want) + } + // Get + for _, key := range tt.args.keys { + value, found := c.Get(key) + if !found || !reflect.DeepEqual(value, tt.want[key]) { + t.Errorf("cache.Get() = %v, want %v", got, tt.want) + } + } + }) + } +} + +// TestCacheMGetWithLoadError Load error when testing batch acquisition of Cache +func TestCacheMGetWithLoadError(t *testing.T) { + C := New() + defer C.Close() + if _, err := C.MGetWithLoad(context.Background(), []string{"a"}); !strings.Contains(err.Error(), "undefined MLoadFunc in cache") { + t.Errorf("got unexpected:%s, want contains [undefined MLoadFunc]", err.Error()) + } + + mLoadFunc := func(ctx context.Context, keys []string) (map[string]interface{}, error) { + return nil, errors.New("unknown keys") + } + C = New(WithMLoad(mLoadFunc)) + if _, err := C.MGetWithLoad(context.Background(), []string{"a"}); !strings.Contains(err.Error(), "unknown keys") { + t.Errorf("got unexpected:%s, want contains [unknown keys]", err.Error()) + } +} + +// TestCacheGetWithLoad tests the WithLoad function of Cache +func TestCacheGetWithLoad(t *testing.T) { + m := map[string]interface{}{ + "A": "a", + "B": "b", + "C": nil, + "": "null", + } + loadFunc := func(ctx context.Context, key string) (interface{}, error) { + if v, exist := m[key]; exist { + return v, nil + } + return nil, errors.New("key not exist") + } + + C := New(WithLoad(loadFunc)) + c := C.(*cache) + defer c.Close() + + type args struct { + ctx context.Context + key string + } + tests := []struct { + name string + args args + want interface{} + wantErr bool + }{ + {"A", args{context.Background(), "A"}, "a", false}, + {"B", args{context.Background(), "B"}, "b", false}, + {"C", args{context.Background(), "C"}, nil, false}, + {"", args{context.Background(), ""}, "null", false}, + {"unrecognized-key", args{context.Background(), "unregonizedKey"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := c.GetWithLoad(tt.args.ctx, tt.args.key) + time.Sleep(wait) + got2, _ := c.Get(tt.args.key) + if (err != nil) != tt.wantErr { + t.Errorf("cache.GetWithLoad() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("cache.GetWithLoad() = %v, want %v", got, tt.want) + } + // Get + if !reflect.DeepEqual(got2, tt.want) { + t.Errorf("cache.Get() = %v, want %v", got, tt.want) + } + }) + } +} + +// TestCacheGetWithLoadError tests Cache Load error +func TestCacheGetWithLoadError(t *testing.T) { + C := New() + defer C.Close() + if _, err := C.GetWithLoad(context.Background(), "A"); !strings.Contains(err.Error(), "undefined LoadFunc in cache") { + t.Errorf("got unexpected:%s, want contains [undefined LoadFunc]", err.Error()) + } + + loadFunc := func(ctx context.Context, keys string) (interface{}, error) { + return nil, errors.New("load fail") + } + C = New(WithLoad(loadFunc), WithDelay(2)) + C.SetWithExpire("A", "B", 1) + time.Sleep(time.Second) + if value, err := C.GetWithLoad(context.Background(), "A"); !errors.Is(err, ErrCacheExpire) { + t.Errorf("got unexpected:%v, want contains [ErrCacheExpire]", err) + } else { + if s, ok := value.(string); !ok || s != "B" { + t.Fatalf("unexpected value: %v, want: %s", s, "B") + } + } + time.Sleep(2 * time.Second) + if _, err := C.GetWithLoad(context.Background(), "A"); !strings.Contains(err.Error(), "load fail") { + t.Errorf("got unexpected:%s, want contains [load fail]", err.Error()) + } +} + +// TestCacheDel tests the Cache deletion function +func TestCacheDel(t *testing.T) { + C := New(WithCapacity(3)) + c := C.(*cache) + defer c.Close() + + mocks := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 10}, + {"B", "b", 10}, + {"", "null", 10}, + } + + for _, mock := range mocks { + c.SetWithExpire(mock.key, mock.val, mock.ttl) + } + time.Sleep(wait) + c.Del(mocks[0].key) + time.Sleep(wait) + + if val, found := c.Get(mocks[0].key); found { + t.Fatalf("unexpected deleted value: %v to key: %v", val, mocks[0].key) + } + + for i := 1; i < len(mocks); i++ { + if val, found := c.Get(mocks[i].key); !found || val.(string) != mocks[i].val { + t.Fatalf("unexpected deleted value: %v (%v) to key: %v, want: %v", val, found, mocks[i].key, mocks[i].val) + } + } +} + +// TestCacheClear tests the Cache clearing function +func TestCacheClear(t *testing.T) { + C := New() + c := C.(*cache) + defer c.Close() + for i := 0; i < 10; i++ { + k := fmt.Sprint(i) + v := fmt.Sprint(i) + c.SetWithExpire(k, v, 10) + } + time.Sleep(wait) + + c.Clear() + + for i := 0; i < 10; i++ { + k := fmt.Sprint(i) + if _, found := c.Get(k); found { + t.Fatalf("Shouldn't found value from clear cache") + } + } + if c.store.len() != 0 { + t.Fatalf("Length(%d) is not equal to 0, after clear", c.store.len()) + } +} + +// BenchmarkCacheGet Get function of Benchmark Cache +func BenchmarkCacheGet(b *testing.B) { + k := "A" + v := "a" + + C := New() + c := C.(*cache) + c.SetWithExpire(k, v, 100000) + + b.SetBytes(1) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + c.Get(k) + } + }) +} + +// BenchmarkCacheGet Set function of Benchmark Cache +func BenchmarkCacheSet(b *testing.B) { + C := New() + c := C.(*cache) + rand.New(rand.NewSource(currentTime().Unix())) + + b.SetBytes(1) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + k := fmt.Sprint(rand.Int()) + v := k + c.SetWithExpire(k, v, 10000) + } + }) +} + +func TestCacheGetWithCustomLoad(t *testing.T) { + m := map[string]interface{}{ + "A": "a", + "B": "b", + "": "null", + } + loadFunc := func(ctx context.Context, key string) (interface{}, error) { + if v, exist := m[key]; exist { + return v, nil + } + return nil, errors.New("key not exist") + } + + C := New() + c := C.(*cache) + defer c.Close() + + type args struct { + ctx context.Context + key string + } + tests := []struct { + name string + args args + want interface{} + wantErr bool + }{ + {"A", args{context.Background(), "A"}, "a", false}, + {"B", args{context.Background(), "B"}, "b", false}, + {"", args{context.Background(), ""}, "null", false}, + {"unrecognized-key", args{context.Background(), "unregonizedKey"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := c.GetWithCustomLoad(tt.args.ctx, tt.args.key, loadFunc, 10) + if (err != nil) != tt.wantErr { + t.Errorf("cache.GetWithCustomLoad() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("cache.GetWithCustomLoad() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestCacheMGetWithCustomLoad(t *testing.T) { + m := map[string]interface{}{ + "A": "a", + "B": "b", + "": "null", + } + mLoadFunc := func(ctx context.Context, keys []string) (map[string]interface{}, error) { + ret := make(map[string]interface{}, 4) + for _, key := range keys { + if v, ok := m[key]; ok { + ret[key] = v + } + } + return ret, nil + } + + C := New(WithExpiration(10)) + c := C.(*cache) + defer c.Close() + + type args struct { + ctx context.Context + keys []string + ttl int64 + sleep time.Duration + found map[string]bool + setup func() + after func() + } + tests := []struct { + name string + args args + want1 map[string]interface{} + want2 map[string]interface{} + wantErr bool + }{ + { + name: "A", + args: args{ + ctx: context.Background(), + keys: []string{"A", "B"}, + ttl: 60, + sleep: time.Millisecond, + found: map[string]bool{ + "A": true, "B": true, + }, + setup: func() { + c.Clear() + }, + after: func() { + c.Clear() + }, + }, + want1: map[string]interface{}{"A": "a", "B": "b"}, + want2: map[string]interface{}{"A": "a", "B": "b"}, + wantErr: false, + }, + { + name: "B", + args: args{ + ctx: context.Background(), + keys: []string{"A", "B"}, + ttl: 1, + sleep: time.Second * 2, + found: map[string]bool{ + "A": false, "B": false, + }, + setup: func() { + c.Clear() + }, + after: func() { + c.Clear() + }, + }, + want1: map[string]interface{}{"A": "a", "B": "b"}, + want2: map[string]interface{}{"A": "", "B": ""}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.args.setup() + got, err := c.MGetWithCustomLoad(tt.args.ctx, tt.args.keys, mLoadFunc, tt.args.ttl) + if (err != nil) != tt.wantErr { + t.Errorf("cache.GetWithLoad() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want1) { + t.Errorf("cache.GetWithLoad() = %v, want1 %v", got, tt.want1) + } + time.Sleep(tt.args.sleep) + // Get + for _, key := range tt.args.keys { + value, found := c.Get(key) + if found != tt.args.found[key] { + t.Errorf("found miss match, cache.Get() = %v, want2 %v", got, tt.want2) + } + if found && !reflect.DeepEqual(value, tt.want2[key]) { + t.Errorf("value not equal, cache.Get() = %v, want2 %v", got, tt.want2) + } + } + tt.args.after() + }) + } +} + +func TestKeysExceedCapacity(t *testing.T) { + rand.New(rand.NewSource(time.Now().UnixNano())) + var testKeys []string + for i := 0; i < 10; i++ { + testKeys = append(testKeys, fmt.Sprintf("test%d", i)) + } + + type args struct { + capacity int + ttl int64 + concurrent int + count int + } + tests := []struct { + name string + args args + }{ + { + name: "cap=1, ttl=30, concurrent=1, count=100000", + args: args{ + capacity: 1, + ttl: 30, + concurrent: 1, + count: 100000, + }, + }, + { + name: "cap=9, ttl=30, concurrent=10, count=100000", + args: args{ + capacity: 9, + ttl: 30, + concurrent: 10, + count: 100000, + }, + }, + } + type testObj struct { + key string + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := New(WithCapacity(tt.args.capacity), WithExpiration(tt.args.ttl)) + var wg sync.WaitGroup + for i := 0; i < tt.args.concurrent; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < tt.args.count; i++ { + key := testKeys[rand.Intn(len(testKeys))] + val, ok := c.Get(key) + if !ok { + c.Set(key, &testObj{key: key}) + continue + } + if val.(*testObj).key != key { + t.Errorf("cache.Get() = %v, want %v", val, key) + } + } + }() + } + wg.Wait() + }) + } +} + +// TestCacheLen tests the length query function of Cache +func TestCacheLen(t *testing.T) { + C := New(WithCapacity(4)) + c := C.(*cache) + defer c.Close() + mocks := []*struct { + key string + val string + ttl int64 + }{ + {"A", "a", 1}, + {"B", "b", 2}, + {"C", "c", 3}, + {"D", "d", 3}, + {"E", "e", 3}, + } + for _, mock := range mocks { + c.SetWithExpire(mock.key, mock.val, mock.ttl) + } + testSheet := []*struct { + expire time.Duration + len int + }{ + { + expire: 1 * time.Millisecond, + len: 4, + }, + { + expire: 1010 * time.Millisecond, + len: 4, + }, + { + expire: 1010 * time.Millisecond, + len: 3, + }, + { + expire: 1010 * time.Millisecond, + len: 0, + }, + } + for i, v := range testSheet { + time.Sleep(v.expire) + if s := c.Len(); s != v.len { + t.Errorf("#%d cache.Len() = %v, want %v", i, s, v.len) + } + } +} + +// BenchmarkCacheLen Len function of Benchmark Cache +func BenchmarkCacheLen(b *testing.B) { + C := New() + c := C.(*cache) + defer c.Close() + mocks := []*struct { + key string + val string + ttl int64 + }{ + {"A", "a", 1}, + {"B", "b", 2}, + {"C", "c", 3}, + {"D", "d", 4}, + {"E", "e", 5}, + } + for _, mock := range mocks { + c.SetWithExpire(mock.key, mock.val, mock.ttl) + } + b.SetBytes(1) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + c.Len() + } + }) +} + +// ====================== Callback function test ========================= + +// TestCacheOnExpire tests callback when cache expires +func TestCacheOnExpire(t *testing.T) { + delCount := map[string]int{"A": 0, "B": 0, "": 0} + expireCount := map[string]int{"A": 0, "B": 0, "": 0} + C := New( + WithCapacity(4), + WithOnExpire(func(item *Item) { + if item.Flag != ItemDelete { + return + } + if _, ok := expireCount[item.Key]; ok { + expireCount[item.Key]++ + } + }), + WithOnDel(func(item *Item) { + if item.Flag != ItemDelete { + return + } + if _, ok := delCount[item.Key]; ok { + delCount[item.Key]++ + } + })) + + c := C.(*cache) + defer c.Close() + + mocks := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 1}, + {"B", "b", 300}, + {"", "null", 300}, + } + + for _, mock := range mocks { + c.SetWithExpire(mock.key, mock.val, mock.ttl) + } + time.Sleep(1001 * wait) + + for _, mock := range mocks { + key := mock.key + if key == "A" && delCount[key] != 1 { + t.Fatalf("unexpected delCount value: %d to key: %v, want: %d", delCount[key], key, 1) + } + if key != "A" && delCount[key] != 0 { + t.Fatalf("unexpected delCount value: %d to key: %v, want: %d", delCount[key], key, 0) + } + if key == "A" && expireCount[key] != 1 { + t.Fatalf("unexpected expireCount value: %d to key: %v, want: %d", expireCount[key], key, 1) + continue + } + if key != "A" && expireCount[key] != 0 { + t.Fatalf("unexpected expireCount value: %d to key: %v, want: %d", expireCount[key], key, 0) + } + } +} + +// TestCacheOnLruDel tests callback when cache LRU is deleted +func TestCacheOnLruDel(t *testing.T) { + delCount := map[string]int{"A": 0, "B": 0, "": 0} + C := New(WithCapacity(2), WithOnDel(func(item *Item) { + if item.Flag != ItemLruDel { + return + } + if _, ok := delCount[item.Key]; ok { + delCount[item.Key]++ + } + })) + c := C.(*cache) + defer c.Close() + + mocks := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 300}, + {"B", "b", 300}, + {"", "null", 300}, + } + + for _, mock := range mocks { + c.SetWithExpire(mock.key, mock.val, mock.ttl) + } + time.Sleep(2 * wait) + + for key, delCount := range delCount { + if key == mocks[0].key && delCount != 1 { + t.Fatalf("unexpected delCount value: %d to key: %v, want: %d", delCount, key, 1) + } + if key != mocks[0].key && delCount != 0 { + t.Fatalf("unexpected delCount value: %d to key: %v, want: %d", delCount, key, 0) + } + } +} + +// TestCacheOnDel tests callback when cache is deleted +func TestCacheOnDel(t *testing.T) { + mocks := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 300}, + {"B", "b", 300}, + {"", "null", 300}, + } + delCount := map[string]int{"A": 0, "B": 0, "": 0} + C := New(WithCapacity(4), WithOnDel(func(item *Item) { + if item.Flag != ItemDelete { + return + } + if _, ok := delCount[item.Key]; ok { + delCount[item.Key]++ + } + })) + c := C.(*cache) + defer c.Close() + + for _, mock := range mocks { + c.SetWithExpire(mock.key, mock.val, mock.ttl) + } + + time.Sleep(2 * wait) + c.Del(mocks[0].key) + time.Sleep(2 * wait) + + for key, delCount := range delCount { + if key == mocks[0].key && delCount != 1 { + t.Fatalf("unexpected delCount value: %d to key: %v, want: %d", delCount, key, 1) + } + if key != mocks[0].key && delCount != 0 { + t.Fatalf("unexpected delCount value: %d to key: %v, want: %d", delCount, key, 0) + } + } +} + +// generateFunction is executed only once +func generateFunction() func(ctx context.Context, key string) (interface{}, error) { + alreadyExecuted := false + + return func(ctx context.Context, key string) (interface{}, error) { + if alreadyExecuted { + return "", errors.New("function can only be executed once") + } + + alreadyExecuted = true + return "Function executed successfully", nil + } +} + +// Test_GetAndSetMultipleKey concurrently reads and writes, observe whether the results are affected +func Test_GetAndSetMultipleKey(t *testing.T) { + mocks := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 300}, + {"B", "b", 300}, + {"", "null", 300}, + } + loadFunc := generateFunction() + // Asynchronous scenarios will have multiple loads + dataCache := New(WithCapacity(10000001), + WithExpiration(int64(1)), WithLoad(loadFunc), WithSettingTimeout(wait)) + dataCache.SetWithExpire(mocks[0].key, mocks[0].val, 1) + time.Sleep(time.Second) + // Whether the reading data observation value can be loaded multiple times + if v, err := dataCache.GetWithLoad(context.TODO(), mocks[0].key); err == nil { + if v != "Function executed successfully" { + t.Fatalf("unexpected value: %v to key: %v, want: %v", v, mocks[0].key, "Function executed successfully") + } + } else { + t.Fatal("error multiple load 1") + } + if v, err := dataCache.GetWithLoad(context.TODO(), mocks[0].key); err == nil { + if v != "Function executed successfully" { + t.Fatalf("unexpected value: %v to key: %v, want: %v", v, mocks[0].key, "Function executed successfully") + } + } else { + t.Fatal("error multiple load 2") + } + if v, err := dataCache.GetWithLoad(context.TODO(), mocks[0].key); err == nil { + if v != "Function executed successfully" { + t.Fatalf("unexpected value: %v to key: %v, want: %v", v, mocks[0].key, "Function executed successfully") + } + } else { + t.Fatal("error multiple load 3") + } +} + +// Test_GetAndSetMultipleKeyAsync concurrent asynchronous reading and writing, observe whether the results are affected +func Test_GetAndSetMultipleKeyAsync(t *testing.T) { + mocks := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 300}, + {"B", "b", 300}, + {"", "null", 300}, + } + loadFunc := generateFunction() + // Asynchronous scenarios will have multiple loads + dataCache := New(WithCapacity(10000001), + WithExpiration(int64(1)), WithLoad(loadFunc)) + dataCache.SetWithExpire(mocks[0].key, mocks[0].val, 1) + time.Sleep(time.Second) + // Read the data and observe how many times it is loaded + // From the second time, err will be expected + if v, err := dataCache.GetWithLoad(context.TODO(), mocks[0].key); err == nil { + if v != "Function executed successfully" { + t.Fatalf("unexpected value: %v to key: %v, want: %v", v, mocks[0].key, "Function executed successfully") + } + } else { + // Can be loaded successfully for the first time + t.Fatal("error multiple load 1") + } + if _, err := dataCache.GetWithLoad(context.TODO(), mocks[0].key); err == nil { + // Unable to load successfully the second time + t.Fatal("error multiple load 2") + } +} + +// Test_UpdateKeySync concurrent asynchronous reading and writing, writing is synchronous, +// observe whether the results are affected +func Test_UpdateKeySync(t *testing.T) { + mocks := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 300}, + {"B", "b", 300}, + {"", "null", 300}, + } + loadFunc := func(ctx context.Context, key string) (interface{}, error) { + return "success", nil + } + // Asynchronous scenarios will have multiple loads + c := New(WithCapacity(4), WithLoad(loadFunc), WithSettingTimeout(time.Duration(0*time.Microsecond)), WithExpiration(1)) + for _, mock := range mocks { + c.SetWithExpire(mock.key, mock.val, 0) + } + n := 1000000 + // Test data update + for i := 0; i < n; i++ { + go func() { + time.Sleep(wait * 2) + for _, mock := range mocks { + // At this time, you will go to update to check whether the update is successful or times out + if val, err := c.GetWithLoad(context.TODO(), mock.key); err == nil { + // success + if val != "success" { + t.Error("update fail") + return + } + } else { + // timeout + if err.Error() != fmt.Sprintf("set key [%s] fail", mock.key) { + t.Error("unexpected error") + return + } + } + + } + }() + } +} + +func Test_DelKeyAsync(t *testing.T) { + mocks := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 300}, + {"B", "b", 300}, + {"", "null", 300}, + } + delCount := map[string]int{"A": 0, "B": 0, "": 0} + C := New(WithCapacity(4), WithOnDel(func(item *Item) { + if item.Flag != ItemDelete { + return + } + if _, ok := delCount[item.Key]; ok { + delCount[item.Key]++ + } + }), WithSettingTimeout(1)) + c := C.(*cache) + defer c.Close() + + for _, mock := range mocks { + c.SetWithExpire(mock.key, mock.val, mock.ttl) + } + + time.Sleep(2 * wait) + c.Del(mocks[0].key) + time.Sleep(2 * wait) + + for key, delCount := range delCount { + if key == mocks[0].key && delCount != 1 { + t.Fatalf("unexpected delCount value: %d to key: %v, want: %d", delCount, key, 1) + } + if key != mocks[0].key && delCount != 0 { + t.Fatalf("unexpected delCount value: %d to key: %v, want: %d", delCount, key, 0) + } + } +} + +func Test_DelKeySync(t *testing.T) { + mocks := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 300}, + {"B", "b", 300}, + {"", "null", 300}, + } + delCount := map[string]int{"A": 0, "B": 0, "": 0} + C := New(WithCapacity(4), WithOnDel(func(item *Item) { + if item.Flag != ItemDelete { + return + } + if _, ok := delCount[item.Key]; ok { + delCount[item.Key]++ + } + }), WithSettingTimeout(1), WithSyncDelFlag(true)) + c := C.(*cache) + defer c.Close() + + for _, mock := range mocks { + c.SetWithExpire(mock.key, mock.val, mock.ttl) + } + + c.Del(mocks[0].key) + + for key, delCount := range delCount { + if key == mocks[0].key && delCount != 1 { + t.Fatalf("unexpected delCount value: %d to key: %v, want: %d", delCount, key, 1) + } + if key != mocks[0].key && delCount != 0 { + t.Fatalf("unexpected delCount value: %d to key: %v, want: %d", delCount, key, 0) + } + } +} diff --git a/localcache/examples/custom_callback/main.go b/localcache/examples/custom_callback/main.go new file mode 100644 index 0000000..1986b90 --- /dev/null +++ b/localcache/examples/custom_callback/main.go @@ -0,0 +1,50 @@ +package main + +import ( + "fmt" + "time" + + "trpc.group/trpc-go/trpc-database/localcache" +) + +func main() { + delCount := map[string]int{"A": 0, "B": 0, "": 0} + expireCount := map[string]int{"A": 0, "B": 0, "": 0} + c := localcache.New( + localcache.WithCapacity(4), + localcache.WithOnExpire(func(item *localcache.Item) { + if item.Flag != localcache.ItemDelete { + return + } + if _, ok := expireCount[item.Key]; ok { + expireCount[item.Key]++ + } + }), + localcache.WithOnDel(func(item *localcache.Item) { + if item.Flag != localcache.ItemDelete { + return + } + if _, ok := delCount[item.Key]; ok { + delCount[item.Key]++ + } + })) + + defer c.Close() + + elems := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 1}, + {"B", "b", 300}, + {"C", "c", 300}, + } + + for _, elem := range elems { + c.SetWithExpire(elem.key, elem.val, elem.ttl) + } + time.Sleep(1001 * time.Millisecond) + fmt.Printf("del info:%v\n", delCount) + fmt.Printf("expire info:%v\n", expireCount) +} diff --git a/localcache/examples/custom_load/main.go b/localcache/examples/custom_load/main.go new file mode 100644 index 0000000..abb99ec --- /dev/null +++ b/localcache/examples/custom_load/main.go @@ -0,0 +1,27 @@ +package main + +import ( + "context" + "fmt" + + "trpc.group/trpc-go/trpc-database/localcache" +) + +func main() { + mLoadFunc := func(ctx context.Context, keys []string) (map[string]interface{}, error) { + return map[string]interface{}{ + "foo": "bar", + "tom": "cat", + }, nil + } + + lc := localcache.New(localcache.WithMLoad(mLoadFunc), localcache.WithExpiration(5)) + + // err is the error message returned directly by the mLoadFunc function. + val, err := lc.MGetWithLoad(context.TODO(), []string{"foo", "tom"}) + fmt.Println(val, err) + + // Or you can pass in the load function at get time + val, err = lc.MGetWithCustomLoad(context.TODO(), []string{"foo"}, mLoadFunc, 10) + fmt.Println(val, err) +} diff --git a/localcache/examples/delayed_deletion/main.go b/localcache/examples/delayed_deletion/main.go new file mode 100644 index 0000000..2261c61 --- /dev/null +++ b/localcache/examples/delayed_deletion/main.go @@ -0,0 +1,28 @@ +package main + +import ( + "fmt" + "time" + + "trpc.group/trpc-go/trpc-database/localcache" +) + +func main() { + // Set the delayed deletion interval to 3 seconds + lc := localcache.New(localcache.WithDelay(3)) + // Set key expiration time to 1 second: key expires after 1 second and is removed from cache after 4 seconds + lc.SetWithExpire("tom", "cat", 1) + // sleep 2s + time.Sleep(time.Second * 2) + + value, ok := lc.Get("tom") + if !ok { + // key has expired after 2 seconds of sleep. + fmt.Printf("key:%s is expired or empty\n", "tom") + } + + if s, ok := value.(string); ok && s == "cat" { + // Expired values are still returned, and the business side can decide whether or not to use them. + fmt.Printf("get expired value: %s\n", "cat") + } +} diff --git a/localcache/examples/with_expiration/main.go b/localcache/examples/with_expiration/main.go new file mode 100644 index 0000000..7863886 --- /dev/null +++ b/localcache/examples/with_expiration/main.go @@ -0,0 +1,31 @@ +package main + +import ( + "fmt" + "time" + + "trpc.group/trpc-go/trpc-database/localcache" +) + +func main() { + var lc localcache.Cache + + // Create a cache with a size of 100 and an element expiration time of 5 seconds. + lc = localcache.New(localcache.WithCapacity(100), localcache.WithExpiration(5)) + + // Set key-value, expiration time 5 seconds + lc.Set("foo", "bar") + + // Set a specific expiration time (10 seconds) for the key, without using the expiration parameter in the New() method + lc.SetWithExpire("tom", "cat", 10) + + // Short wait for asynchronous processing to complete from cache + time.Sleep(time.Millisecond) + + // Get value + val, found := lc.Get("foo") + fmt.Println(val, found) + + // Delete key: "foo" + lc.Del("foo") +} diff --git a/localcache/func.go b/localcache/func.go new file mode 100644 index 0000000..dcba558 --- /dev/null +++ b/localcache/func.go @@ -0,0 +1,37 @@ +package localcache + +import "context" + +var defaultLocalCache = New() + +// Get returns the value corresponding to key, bool returns true. +// If the key does not exist or expires, bool returns false +func Get(key string) (interface{}, bool) { + return defaultLocalCache.Get(key) +} + +// GetWithLoad returns the value corresponding to the key. +// If the key does not exist, use the load function to load the return and cache it for ttl seconds. +func GetWithLoad(ctx context.Context, key string, load LoadFunc, ttl int64) (interface{}, error) { + value, found := defaultLocalCache.Get(key) + if found { + return value, nil + } + c, _ := defaultLocalCache.(*cache) + return c.loadData(ctx, key, load, ttl) +} + +// Set key, value, expiration time ttl (seconds) +func Set(key string, value interface{}, ttl int64) bool { + return defaultLocalCache.SetWithExpire(key, value, ttl) +} + +// Del Delete key +func Del(key string) { + defaultLocalCache.Del(key) +} + +// Clear all queues and caches +func Clear() { + defaultLocalCache.Clear() +} diff --git a/localcache/func_test.go b/localcache/func_test.go new file mode 100644 index 0000000..5b508f7 --- /dev/null +++ b/localcache/func_test.go @@ -0,0 +1,87 @@ +package localcache + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// TestFuncGetAndSet tests Get and Set +func TestFuncGetAndSet(t *testing.T) { + mocks := []struct { + key string + val string + ttl int64 + }{ + {"A", "a", 1000}, + {"B", "b", 1000}, + {"", "null", 1000}, + } + + for _, mock := range mocks { + Set(mock.key, mock.val, mock.ttl) + } + + time.Sleep(wait) + + for _, mock := range mocks { + val, found := Get(mock.key) + if !found || val.(string) != mock.val { + t.Fatalf("Unexpected value: %v (%v) to key: %v", val, found, mock.key) + } + } +} + +// TestFuncExpireSet tests Set with Expire +func TestFuncExpireSet(t *testing.T) { + Set("Foo", "Bar", 1) + time.Sleep(1 * time.Second) + if val, found := Get("Foo"); found { + t.Fatalf("unexpected expired value: %v to key: %v", val, "Foo") + } +} + +// TestFuncGetWithLoad tests Get in the case of WithLoad +func TestFuncGetWithLoad(t *testing.T) { + m := map[string]interface{}{ + "A1": "a", + "B": "b", + "": "null", + } + loadFunc := func(ctx context.Context, key string) (interface{}, error) { + if v, exist := m[key]; exist { + return v, nil + } + return nil, errors.New("key not exist") + } + value, err := GetWithLoad(context.TODO(), "A1", loadFunc, 2) + if err != nil || value.(string) != "a" { + t.Fatalf("unexpected GetWithLoad value: %v, want:a, err:%v", value, err) + } + + time.Sleep(wait) + + got2, found := Get("A1") + if !found || got2.(string) != "a" { + t.Fatalf("unexpected Get value: %v, want:a, found:%v", got2, found) + } +} + +func TestDelAndClear(t *testing.T) { + Set("Foo", "bar", 10) + Set("Foo1", "bar", 10) + time.Sleep(time.Millisecond * 10) + _, ok := Get("Foo") + assert.True(t, ok) + Del("Foo") + time.Sleep(time.Millisecond * 10) + _, ok = Get("Foo") + assert.False(t, ok) + Clear() + time.Sleep(time.Millisecond * 10) + _, ok = Get("Foo1") + assert.False(t, ok) +} diff --git a/localcache/go.mod b/localcache/go.mod new file mode 100644 index 0000000..0537776 --- /dev/null +++ b/localcache/go.mod @@ -0,0 +1,18 @@ +module trpc.group/trpc-go/trpc-database/localcache + +go 1.18 + +require ( + github.com/RussellLuo/timingwheel v0.0.0-20191022104228-f534fd34a762 + github.com/cespare/xxhash v1.1.0 + github.com/golang/mock v1.4.4 + github.com/stretchr/testify v1.7.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/kr/pretty v0.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/localcache/go.sum b/localcache/go.sum new file mode 100644 index 0000000..01ffc95 --- /dev/null +++ b/localcache/go.sum @@ -0,0 +1,35 @@ +github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= +github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/RussellLuo/timingwheel v0.0.0-20191022104228-f534fd34a762 h1:N611cQQA4tgy8FT5MpEFPxSkGk2JwYa1fSYes0dk4Yk= +github.com/RussellLuo/timingwheel v0.0.0-20191022104228-f534fd34a762/go.mod h1:3VIJp8oOAlnDUnPy3kwyBGqsMiJJujqTP6ic9Jv6NbM= +github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= +github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72 h1:qLC7fQah7D6K1B0ujays3HV9gkFtllcxhzImRR7ArPQ= +github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/localcache/lru.go b/localcache/lru.go new file mode 100644 index 0000000..556ef02 --- /dev/null +++ b/localcache/lru.go @@ -0,0 +1,95 @@ +package localcache + +import ( + "container/list" + "sync" + "time" +) + +// entry store entity +type entry struct { + mux sync.RWMutex + key string + value interface{} + expireTime time.Time +} + +func getEntry(ele *list.Element) *entry { + return ele.Value.(*entry) +} + +func setEntry(ele *list.Element, ent *entry) { + ele.Value = ent +} + +// lru non-concurrency-safe lru queue +type lru struct { + ll *list.List + store store + capacity int +} + +func newLRU(capacity int, store store) *lru { + return &lru{ + ll: list.New(), + store: store, + capacity: capacity, + } +} + +func (l *lru) add(ent *entry) *entry { + val, ok := l.store.get(ent.key) + ele, _ := val.(*list.Element) + if ok { + setEntry(ele, ent) + l.ll.MoveToFront(ele) + return nil + } + if l.capacity <= 0 || l.ll.Len() < l.capacity { + ele := l.ll.PushFront(ent) + l.store.set(ent.key, ele) + return nil + } + // When lru is full, the last element is deleted and the new element is added to the head of the list. + ele = l.ll.Back() + if ele == nil { + return ent + } + l.ll.Remove(ele) + victimEnt := getEntry(ele) + l.store.del(victimEnt.key) + + ele = l.ll.PushFront(ent) + l.store.set(ent.key, ele) + return victimEnt +} + +func (l *lru) hit(ele *list.Element) { + l.ll.MoveToFront(ele) +} + +func (l *lru) push(elements []*list.Element) { + for _, ele := range elements { + l.ll.MoveToFront(ele) + } +} + +func (l *lru) del(key string) *entry { + value, ok := l.store.get(key) + if !ok { + return nil + } + ele, _ := value.(*list.Element) + delEnt := getEntry(ele) + l.ll.Remove(ele) + l.store.del(key) + return delEnt +} + +func (l *lru) len() int { + return l.ll.Len() +} + +func (l *lru) clear() { + l.ll = list.New() +} diff --git a/localcache/lru_test.go b/localcache/lru_test.go new file mode 100644 index 0000000..dd3fae2 --- /dev/null +++ b/localcache/lru_test.go @@ -0,0 +1,102 @@ +package localcache + +import ( + "container/list" + "fmt" + "runtime" + "testing" +) + +func assertLRULen(t *testing.T, l *lru, n int) { + if l.store.len() != n || l.len() != n { + _, file, line, _ := runtime.Caller(1) + t.Fatalf("%s:%d unexpected store length (s-%d l-%d), want: %d", + file, line, l.store.len(), l.len(), n) + } +} + +func assertLRUEntry(t *testing.T, ent *entry, k string, v string) { + if ent.key != k || ent.value.(string) != v { + _, file, line, _ := runtime.Caller(1) + t.Fatalf("%s:%d unexpected entry:%+v, want: {key: %s, value:%s}", + file, line, ent, k, v) + } +} + +func TestLRU(t *testing.T) { + store := newStore() + lru := newLRU(3, store) + ents := make([]*entry, 4) + for i := 0; i < len(ents); i++ { + k := fmt.Sprintf("%d", i) + v := k + ents[i] = &entry{key: k, value: v} + } + + // set 0, lru order: 0 + victim := lru.add(ents[0]) + if victim != nil { + t.Fatalf("unexpected entry removed: %v", victim) + } + assertLRULen(t, lru, 1) + + val, _ := lru.store.get(ents[0].key) + ele0 := val.(*list.Element) + ent0 := getEntry(ele0) + assertLRUEntry(t, ent0, "0", "0") + + // set 1, lru order: 1-0 + victim = lru.add(ents[1]) + if victim != nil { + t.Fatalf("unexpected entry removed: %v", victim) + } + assertLRULen(t, lru, 2) + + val, _ = lru.store.get(ents[1].key) + ele1 := val.(*list.Element) + ent1 := getEntry(ele1) + assertLRUEntry(t, ent1, "1", "1") + + // lru order: 0-1 + lru.hit(ele0) + + // set 2, lru order: 2-0-1 + victim = lru.add(ents[2]) + if victim != nil { + t.Fatalf("unexpected entry removed: %v", victim) + } + assertLRULen(t, lru, 3) + + // set 3, lru order: 3-2-0, evict 1 + victim = lru.add(ents[3]) + if victim == nil { + t.Fatal("1 entry should be removed") + } else { + assertLRUEntry(t, victim, "1", "1") + assertLRULen(t, lru, 3) + } + + val, _ = lru.store.get(ents[3].key) + ele3 := val.(*list.Element) + ent3 := getEntry(ele3) + assertLRUEntry(t, ent3, "3", "3") + + // remove 2, lru order 3-0 + lru.del(ents[2].key) + assertLRULen(t, lru, 2) + + // again add 3, lru order 3-0 + victim = lru.add(ents[3]) + if victim != nil { + t.Fatalf("unexpected entry removed: %v", victim) + } + assertLRULen(t, lru, 2) + + // remove not exist key + lru.del("None") + assertLRULen(t, lru, 2) + + lru.clear() + lru.store.clear() + assertLRULen(t, lru, 0) +} diff --git a/localcache/mocklocalcache/localcache_mock.go b/localcache/mocklocalcache/localcache_mock.go new file mode 100644 index 0000000..7255ef5 --- /dev/null +++ b/localcache/mocklocalcache/localcache_mock.go @@ -0,0 +1,204 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: cache.go + +// Package mocklocalcache is a generated GoMock package. +package mocklocalcache + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + localcache "trpc.group/trpc-go/trpc-database/localcache" +) + +// MockCache is a mock of Cache interface. +type MockCache struct { + ctrl *gomock.Controller + recorder *MockCacheMockRecorder +} + +// MockCacheMockRecorder is the mock recorder for MockCache. +type MockCacheMockRecorder struct { + mock *MockCache +} + +// NewMockCache creates a new mock instance. +func NewMockCache(ctrl *gomock.Controller) *MockCache { + mock := &MockCache{ctrl: ctrl} + mock.recorder = &MockCacheMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockCache) EXPECT() *MockCacheMockRecorder { + return m.recorder +} + +// Clear mocks base method. +func (m *MockCache) Clear() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Clear") +} + +// Clear indicates an expected call of Clear. +func (mr *MockCacheMockRecorder) Clear() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Clear", reflect.TypeOf((*MockCache)(nil).Clear)) +} + +// Close mocks base method. +func (m *MockCache) Close() { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Close") +} + +// Close indicates an expected call of Close. +func (mr *MockCacheMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockCache)(nil).Close)) +} + +// Del mocks base method. +func (m *MockCache) Del(key string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Del", key) +} + +// Del indicates an expected call of Del. +func (mr *MockCacheMockRecorder) Del(key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Del", reflect.TypeOf((*MockCache)(nil).Del), key) +} + +// Get mocks base method. +func (m *MockCache) Get(key string) (interface{}, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", key) + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockCacheMockRecorder) Get(key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockCache)(nil).Get), key) +} + +// GetWithCustomLoad mocks base method. +func (m *MockCache) GetWithCustomLoad(ctx context.Context, key string, customLoad localcache.LoadFunc, ttl int64) (interface{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWithCustomLoad", ctx, key, customLoad, ttl) + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetWithCustomLoad indicates an expected call of GetWithCustomLoad. +func (mr *MockCacheMockRecorder) GetWithCustomLoad(ctx, key, customLoad, ttl interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWithCustomLoad", reflect.TypeOf((*MockCache)(nil).GetWithCustomLoad), ctx, key, customLoad, ttl) +} + +// GetWithLoad mocks base method. +func (m *MockCache) GetWithLoad(ctx context.Context, key string) (interface{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWithLoad", ctx, key) + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetWithLoad indicates an expected call of GetWithLoad. +func (mr *MockCacheMockRecorder) GetWithLoad(ctx, key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWithLoad", reflect.TypeOf((*MockCache)(nil).GetWithLoad), ctx, key) +} + +// GetWithStatus mocks base method. +func (m *MockCache) GetWithStatus(key string) (interface{}, localcache.CachedStatus) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWithStatus", key) + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(localcache.CachedStatus) + return ret0, ret1 +} + +// GetWithStatus indicates an expected call of GetWithStatus. +func (mr *MockCacheMockRecorder) GetWithStatus(key interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWithStatus", reflect.TypeOf((*MockCache)(nil).GetWithStatus), key) +} + +// Len mocks base method. +func (m *MockCache) Len() int { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Len") + ret0, _ := ret[0].(int) + return ret0 +} + +// Len indicates an expected call of Len. +func (mr *MockCacheMockRecorder) Len() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockCache)(nil).Len)) +} + +// MGetWithCustomLoad mocks base method. +func (m *MockCache) MGetWithCustomLoad(ctx context.Context, keys []string, customLoad localcache.MLoadFunc, ttl int64) (map[string]interface{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MGetWithCustomLoad", ctx, keys, customLoad, ttl) + ret0, _ := ret[0].(map[string]interface{}) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MGetWithCustomLoad indicates an expected call of MGetWithCustomLoad. +func (mr *MockCacheMockRecorder) MGetWithCustomLoad(ctx, keys, customLoad, ttl interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetWithCustomLoad", reflect.TypeOf((*MockCache)(nil).MGetWithCustomLoad), ctx, keys, customLoad, ttl) +} + +// MGetWithLoad mocks base method. +func (m *MockCache) MGetWithLoad(ctx context.Context, keys []string) (map[string]interface{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MGetWithLoad", ctx, keys) + ret0, _ := ret[0].(map[string]interface{}) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MGetWithLoad indicates an expected call of MGetWithLoad. +func (mr *MockCacheMockRecorder) MGetWithLoad(ctx, keys interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MGetWithLoad", reflect.TypeOf((*MockCache)(nil).MGetWithLoad), ctx, keys) +} + +// Set mocks base method. +func (m *MockCache) Set(key string, value interface{}) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Set", key, value) + ret0, _ := ret[0].(bool) + return ret0 +} + +// Set indicates an expected call of Set. +func (mr *MockCacheMockRecorder) Set(key, value interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Set", reflect.TypeOf((*MockCache)(nil).Set), key, value) +} + +// SetWithExpire mocks base method. +func (m *MockCache) SetWithExpire(key string, value interface{}, ttl int64) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetWithExpire", key, value, ttl) + ret0, _ := ret[0].(bool) + return ret0 +} + +// SetWithExpire indicates an expected call of SetWithExpire. +func (mr *MockCacheMockRecorder) SetWithExpire(key, value, ttl interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetWithExpire", reflect.TypeOf((*MockCache)(nil).SetWithExpire), key, value, ttl) +} diff --git a/localcache/mocklocalcache/localcache_test.go b/localcache/mocklocalcache/localcache_test.go new file mode 100644 index 0000000..91a9273 --- /dev/null +++ b/localcache/mocklocalcache/localcache_test.go @@ -0,0 +1,112 @@ +package mocklocalcache + +import ( + "context" + "testing" + + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "trpc.group/trpc-go/trpc-database/localcache" +) + +func TestLocalCache(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + // 1. Generate mock cache + cache := NewMockCache(ctrl) + // 2. Assign a value to the cache variable at the code call or pass it as a parameter + // gLocalcache = cache + // 3. Mock corresponding function + cache.EXPECT().Get(gomock.Any()).DoAndReturn(func(key interface{}) (interface{}, bool) { + t.Logf("key:%v", key) + return nil, true + }) + cache.EXPECT().SetWithExpire(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(key, value interface{}, ttl int64) bool { + t.Logf("key:%v", key) + return true + }) + cache.EXPECT().GetWithLoad(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, key string) (interface{}, error) { + t.Logf("GetWithLoad()") + return nil, nil + }) + cache.EXPECT().MGetWithLoad(gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, keys []string) (map[string]interface{}, error) { + t.Logf("MGetWithLoad()") + return nil, nil + }) + cache.EXPECT().Set(gomock.Any(), gomock.Any()). + DoAndReturn(func(key string, value interface{}) bool { + t.Logf("Set()") + return true + }) + cache.EXPECT().Del(gomock.Any()). + DoAndReturn(func(key string) { + t.Logf("Del()") + }) + cache.EXPECT().GetWithCustomLoad(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, key string, customLoad localcache.LoadFunc, ttl int64) ( + interface{}, error) { + t.Logf("GetWithCustomLoad()") + return nil, nil + }) + cache.EXPECT().MGetWithCustomLoad(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, keys []string, customLoad localcache.MLoadFunc, ttl int64) ( + map[string]interface{}, error) { + t.Logf("MGetWithCustomLoad()") + return nil, nil + }) + cache.EXPECT().Len().Return(4) + cache.EXPECT().Clear().DoAndReturn(func() {}) + cache.EXPECT().Close().DoAndReturn(func() {}) + + // 4. Call + ctx := context.Background() + ok := cache.SetWithExpire("Foo", "bar", 1) + assert.Equal(t, true, ok) + v, ok := cache.Get("time") + assert.Nil(t, v) + assert.Equal(t, true, ok) + _, err := cache.GetWithLoad(ctx, "Foo") + assert.Nil(t, err) + _, err = cache.MGetWithLoad(ctx, []string{"Foo"}) + assert.Nil(t, err) + ok = cache.Set("abs", "abs") + assert.Equal(t, true, ok) + cache.Del("abs") + _, err = cache.GetWithCustomLoad(ctx, "Foo", + func(ctx context.Context, key string) (interface{}, error) { + return nil, nil + }, 12, + ) + assert.Nil(t, err) + _, err = cache.MGetWithCustomLoad(ctx, []string{"Foo"}, + func(ctx context.Context, key []string) (map[string]interface{}, error) { + return nil, nil + }, 12, + ) + cacheLen := cache.Len() + assert.Equal(t, 4, cacheLen) + assert.Nil(t, err) + cache.Clear() + cache.Close() +} + +func TestLocalCache2(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + // 1. Generate mock cache + cache := NewMockCache(ctrl) + // 2. Assign a value to the cache variable at the code call or pass it as a parameter + // gLocalcache = cache + // 3. Mock corresponding function + cache.EXPECT().GetWithStatus(gomock.Any()).DoAndReturn(func(key interface{}) (interface{}, localcache.CachedStatus) { + t.Logf("key:%v", key) + return "bar", localcache.CacheExpire + }) + // 4. Call + val, status := cache.GetWithStatus("foo") + assert.Equal(t, val, "bar") + assert.Equal(t, status, localcache.CacheExpire) +} diff --git a/localcache/policy.go b/localcache/policy.go new file mode 100644 index 0000000..1a006b7 --- /dev/null +++ b/localcache/policy.go @@ -0,0 +1,23 @@ +package localcache + +import ( + "container/list" +) + +// policy store policy +type policy interface { + // add adds an element + add(ent *entry) *entry + // hit handles accessing an element hit + hit(elements *list.Element) + // push processes accessed elements in batches + push(elements []*list.Element) + // del deletes an element based on key + del(key string) *entry + // clear space + clear() +} + +func newPolicy(capacity int, store store) policy { + return newLRU(capacity, store) +} diff --git a/localcache/ring.go b/localcache/ring.go new file mode 100644 index 0000000..39b4eee --- /dev/null +++ b/localcache/ring.go @@ -0,0 +1,63 @@ +package localcache + +import ( + "container/list" + "sync" +) + +// ringConsumer accept and consume data +type ringConsumer interface { + push([]*list.Element) bool +} + +// ringStrip ring buffer caches the metadata of Get requests to batch update the location of elements in the LRU +type ringStripe struct { + consumer ringConsumer + data []*list.Element + capacity int +} + +func newRingStripe(consumer ringConsumer, capacity int) *ringStripe { + return &ringStripe{ + consumer: consumer, + data: make([]*list.Element, 0, capacity), + capacity: capacity, + } +} + +// push records an accessed element to the ring buffer +func (r *ringStripe) push(ele *list.Element) { + r.data = append(r.data, ele) + if len(r.data) >= r.capacity { + if r.consumer.push(r.data) { + r.data = make([]*list.Element, 0, r.capacity) + } else { + r.data = r.data[:0] + } + } +} + +// RingBuffer pools ringStripe, allowing multiple goroutines to write elements to Buff without locking, +// which is more efficient than writing to the same channel concurrently. At the same time, objects in +// the pool will be automatically removed without any notification. This random loss of access metadata +// can reduce the cache's operation on LRU and reduce concurrency competition with Set/Expire/Delete +// operations. The cache does not need to be completely strict LRU. It is necessary to actively discard +// some access metadata to reduce concurrency competition and improve write efficiency. +type ringBuffer struct { + pool *sync.Pool +} + +func newRingBuffer(consumer ringConsumer, capacity int) *ringBuffer { + return &ringBuffer{ + pool: &sync.Pool{ + New: func() interface{} { return newRingStripe(consumer, capacity) }, + }, + } +} + +// push records an accessed element +func (b *ringBuffer) push(ele *list.Element) { + ringStripe := b.pool.Get().(*ringStripe) + ringStripe.push(ele) + b.pool.Put(ringStripe) +} diff --git a/localcache/ring_test.go b/localcache/ring_test.go new file mode 100644 index 0000000..3aae8e6 --- /dev/null +++ b/localcache/ring_test.go @@ -0,0 +1,81 @@ +package localcache + +import ( + "container/list" + "sync" + "testing" +) + +// testConsumer test Consumer structure +type testConsumer struct { + pf func([]*list.Element) + save bool +} + +func (c *testConsumer) push(elements []*list.Element) bool { + if c.save { + c.pf(elements) + return true + } + return false +} + +// TestRingBufferPush tests RingBuffer's push method +func TestRingBufferPush(t *testing.T) { + drains := 0 + r := newRingBuffer(&testConsumer{ + pf: func(elements []*list.Element) { + drains++ + }, + save: true, + }, 1) + + for i := 0; i < 100; i++ { + r.push(&list.Element{}) + } + + if drains != 100 { + t.Fatal("elements shouldn't be dropped with capacity == 1") + } +} + +// TestRingReset tests RingBuffer's Reset method +func TestRingReset(t *testing.T) { + drains := 0 + r := newRingBuffer(&testConsumer{ + pf: func(elements []*list.Element) { + drains++ + }, + save: false, + }, 4) + for i := 0; i < 100; i++ { + r.push(&list.Element{}) + } + if drains != 0 { + t.Fatal("elements shouldn't be drained") + } +} + +// TestRingConsumer tests RingBuffer's Consumer +func TestRingConsumer(t *testing.T) { + mu := &sync.Mutex{} + drainElements := make(map[*list.Element]struct{}) + + r := newRingBuffer(&testConsumer{ + pf: func(elements []*list.Element) { + mu.Lock() + defer mu.Unlock() + for i := range elements { + drainElements[elements[i]] = struct{}{} + } + }, + save: true, + }, 4) + for i := 0; i < 100; i++ { + r.push(&list.Element{}) + } + l := len(drainElements) + if l == 0 || l > 100 { + t.Fatal("drains not being process correctly") + } +} diff --git a/localcache/store.go b/localcache/store.go new file mode 100644 index 0000000..c0fc71e --- /dev/null +++ b/localcache/store.go @@ -0,0 +1,113 @@ +package localcache + +import ( + "sync" + + "github.com/cespare/xxhash" +) + +// store is a storage for storing key-value data concurrently and safely. +// This file temporarily uses the fragmented map implementation. +type store interface { + // get returns the value corresponding to key + get(string) (interface{}, bool) + // set adds a new key-value to the storage + set(string, interface{}) + // del delete key-value + del(string) + // clear Clear all contents in storage + clear() + // len returns the size of the storage + len() int +} + +// newStore returns the default implementation of storage +func newStore() store { + return newShardedMap() +} + +const numShards uint64 = 256 + +// shardedMap storage sharding +type shardedMap struct { + shards []*lockedMap +} + +func newShardedMap() *shardedMap { + sm := &shardedMap{ + shards: make([]*lockedMap, int(numShards)), + } + for i := range sm.shards { + sm.shards[i] = newLockedMap() + } + return sm +} + +func (sm *shardedMap) get(key string) (interface{}, bool) { + return sm.shards[xxhash.Sum64String(key)&(numShards-1)].get(key) +} + +func (sm *shardedMap) set(key string, value interface{}) { + sm.shards[xxhash.Sum64String(key)&(numShards-1)].set(key, value) +} + +func (sm *shardedMap) del(key string) { + sm.shards[xxhash.Sum64String(key)&(numShards-1)].del(key) +} + +func (sm *shardedMap) clear() { + for i := uint64(0); i < numShards; i++ { + sm.shards[i].clear() + } +} + +func (sm *shardedMap) len() int { + length := 0 + for i := uint64(0); i < numShards; i++ { + length += sm.shards[i].len() + } + return length +} + +// lockedMap concurrently safe map +type lockedMap struct { + sync.RWMutex + data map[string]interface{} +} + +func newLockedMap() *lockedMap { + return &lockedMap{ + data: make(map[string]interface{}), + } +} + +func (m *lockedMap) get(key string) (interface{}, bool) { + m.RLock() + val, ok := m.data[key] + m.RUnlock() + return val, ok +} + +func (m *lockedMap) set(key string, value interface{}) { + m.Lock() + m.data[key] = value + m.Unlock() +} + +func (m *lockedMap) del(key string) { + m.Lock() + delete(m.data, key) + m.Unlock() +} + +func (m *lockedMap) clear() { + m.Lock() + m.data = make(map[string]interface{}) + m.Unlock() +} + +func (m *lockedMap) len() int { + m.RLock() + defer m.RUnlock() + return len(m.data) +} diff --git a/localcache/store_test.go b/localcache/store_test.go new file mode 100644 index 0000000..2df42d7 --- /dev/null +++ b/localcache/store_test.go @@ -0,0 +1,118 @@ +package localcache + +import ( + "testing" +) + +// TestStoreSetGet tests the Set and Get methods of Store +func TestStoreSetGet(t *testing.T) { + store := newStore() + mocks := []struct { + key string + val string + }{ + {"A", "a"}, + {"B", "b"}, + {"", "null"}, + } + + for _, mock := range mocks { + store.set(mock.key, mock.val) + } + + for _, mock := range mocks { + val, ok := store.get(mock.key) + if !ok || val.(string) != mock.val { + t.Fatalf("unexpected value: %v (%v) to key: %v", val, ok, mock.key) + } + } +} + +// TestStoreSetNil tests the situation when Store Set nil +func TestStoreSetNil(t *testing.T) { + store := newStore() + store.set("no", nil) + val, ok := store.get("no") + if !ok || val != nil { + t.Fatalf("unexpected value: %v (%v)", val, ok) + } +} + +// TestStoreDel tests Store's Del method +func TestStoreDel(t *testing.T) { + store := newStore() + mocks := []struct { + key string + val interface{} + }{ + {"A", "a"}, + {"B", "b"}, + {"C", nil}, + {"", "null"}, + } + + for _, mock := range mocks { + store.set(mock.key, mock.val) + } + + for _, mock := range mocks { + store.del(mock.key) + val, ok := store.get(mock.key) + if ok || val != nil { + t.Fatalf("del error, key: %v value: %v (%v)", mock.key, val, ok) + } + } +} + +// TestStoreClear tests Store's Clear method +func TestStoreClear(t *testing.T) { + store := newStore() + mocks := []struct { + key string + val interface{} + }{ + {"A", "a"}, + {"B", "b"}, + {"C", nil}, + {"", "null"}, + } + for _, mock := range mocks { + store.set(mock.key, mock.val) + } + store.clear() + for _, mock := range mocks { + val, ok := store.get(mock.key) + if ok || val != nil { + t.Fatalf("clear error, key: %v value: %v (%v)", mock.key, val, ok) + } + } +} + +// BenchmarkStoreGet Benchmark Store's Get method +func BenchmarkStoreGet(b *testing.B) { + k := "A" + v := "a" + + s := newStore() + s.set(k, v) + b.SetBytes(1) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + s.get(k) + } + }) +} + +// BenchmarkStoreGet Benchmark Store's Set method +func BenchmarkStoreSet(b *testing.B) { + k := "A" + v := "a" + + s := newStore() + b.SetBytes(1) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + s.set(k, v) + } + }) +} diff --git a/localcache/timer.go b/localcache/timer.go new file mode 100644 index 0000000..6a8bc6f --- /dev/null +++ b/localcache/timer.go @@ -0,0 +1,98 @@ +package localcache + +import ( + "sync" + "time" + + "github.com/RussellLuo/timingwheel" +) + +// expireQueue stores tasks that are automatically deleted after the key expires +type expireQueue struct { + tick time.Duration + wheelSize int64 + // The time wheel stores tasks that are scheduled to expire and be deleted. + tw *timingwheel.TimingWheel + + mu sync.Mutex + timers map[string]*timingwheel.Timer +} + +// newExpireQueue generates an expireQueue object. +// The queue is implemented through a time wheel. +// The elements in the queue are deleted regularly according to the expiration time. +func newExpireQueue(tick time.Duration, wheelSize int64) *expireQueue { + queue := &expireQueue{ + tick: tick, + wheelSize: wheelSize, + + tw: timingwheel.NewTimingWheel(tick, wheelSize), + timers: make(map[string]*timingwheel.Timer), + } + + // Start a goroutine to handle expired entries + queue.tw.Start() + return queue +} + +// add scheduled expired tasks. +// When each scheduled task expires, it will be executed as an independent goroutine. +func (q *expireQueue) add(key string, expireTime time.Time, f func()) { + q.mu.Lock() + defer q.mu.Unlock() + + d := expireTime.Sub(currentTime()) + timer := q.tw.AfterFunc(d, q.task(key, f)) + q.timers[key] = timer + + return +} + +// update the expiration time of the key element +func (q *expireQueue) update(key string, expireTime time.Time, f func()) { + q.mu.Lock() + defer q.mu.Unlock() + + if timer, ok := q.timers[key]; ok { + timer.Stop() + } + + d := expireTime.Sub(currentTime()) + timer := q.tw.AfterFunc(d, q.task(key, f)) + q.timers[key] = timer +} + +// remove element key +func (q *expireQueue) remove(key string) { + q.mu.Lock() + defer q.mu.Unlock() + + if timer, ok := q.timers[key]; ok { + timer.Stop() + delete(q.timers, key) + } +} + +// clear the queue +func (q *expireQueue) clear() { + q.tw.Stop() + q.tw = timingwheel.NewTimingWheel(q.tick, q.wheelSize) + q.timers = make(map[string]*timingwheel.Timer) + + // Restart a goroutine to process expired entries + q.tw.Start() +} + +// stop the running of the time wheel queue +func (q *expireQueue) stop() { + q.tw.Stop() +} + +func (q *expireQueue) task(key string, f func()) func() { + return func() { + f() + q.mu.Lock() + delete(q.timers, key) + q.mu.Unlock() + } +} diff --git a/localcache/timer_test.go b/localcache/timer_test.go new file mode 100644 index 0000000..5fcf774 --- /dev/null +++ b/localcache/timer_test.go @@ -0,0 +1,163 @@ +package localcache + +import ( + "testing" + "time" +) + +// TestExpireQueue_Add tests the Add method of the expireQueue +func TestExpireQueue_Add(t *testing.T) { + q := newExpireQueue(time.Second, 60) + mocks := []struct { + k string + ttl time.Duration + }{ + {"A", time.Second}, + {"B", 1 * time.Second}, + {"", 2 * time.Second}, + } + for _, mock := range mocks { + t.Run("", func(t *testing.T) { + exitC := make(chan time.Time) + + start := currentTime() + q.add(mock.k, start.Add(mock.ttl), func() { + exitC <- currentTime() + }) + + got := (<-exitC).Truncate(time.Second) + min := start.Add(mock.ttl).Truncate(time.Second) + + err := time.Second + if got.Before(min) || got.After(min.Add(err)) { + t.Fatalf("Timer(%s) expiration: want [%s, %s], got %s", mock.ttl, min, min.Add(err), got) + } + }) + } + if len(q.timers) != 0 { + t.Fatalf("Length(%d) of timers is not equal to 0", len(q.timers)) + } +} + +// TestExpireQueue_Remove tests the Remove method of the expireQueue +func TestExpireQueue_Remove(t *testing.T) { + q := newExpireQueue(time.Second, 60) + mocks := []struct { + k string + ttl time.Duration + }{ + {"B", 1 * time.Second}, + {"", 1 * time.Second}, + } + for _, mock := range mocks { + t.Run("", func(t *testing.T) { + exitC := make(chan time.Time) + + start := currentTime() + q.add(mock.k, start.Add(mock.ttl), func() { + exitC <- currentTime() + }) + q.remove(mock.k) + + timerC := time.NewTimer(time.Second * 2).C + select { + case <-exitC: + t.Fatalf("Failed to remove timer(%s, %v)", mock.k, mock.ttl) + case <-timerC: + return + } + }) + } + if len(q.timers) != 0 { + t.Fatalf("Length(%d) of timers is not equal to 0", len(q.timers)) + } +} + +// TestExpireQueue_Update tests the Update method of the expireQueue +func TestExpireQueue_Update(t *testing.T) { + q := newExpireQueue(time.Second, 60) + + exitC := make(chan time.Time) + + start := currentTime() + q.add("A", start.Add(2*time.Second), func() { + exitC <- currentTime() + }) + + q.update("A", start.Add(2*time.Second), func() { + exitC <- currentTime() + }) + + got := (<-exitC).Truncate(time.Second) + min := start.Add(2 * time.Second).Truncate(time.Second) + + err := time.Second + if got.Before(min) || got.After(min.Add(err)) { + t.Fatalf("Timer(%s) expiration: want [%s, %s], got %s", "5", min, min.Add(err), got) + } + + if len(q.timers) != 0 { + t.Fatalf("Length(%d) of timers is not equal to 0", len(q.timers)) + } + +} + +// TestExpireQueue_Clear tests the Clear method of the expireQueue +func TestExpireQueue_Clear(t *testing.T) { + q := newExpireQueue(time.Second, 60) + + exitC := make(chan time.Time) + + start := currentTime() + q.add("A", start.Add(time.Second), func() { + exitC <- currentTime() + }) + q.clear() + + // After clearing, determine whether it will be cleared + timerC := time.NewTimer(2 * time.Second).C + select { + case <-exitC: + t.Fatalf("Failed to remove timer(%s, %v)", "A", "2s") + case <-timerC: + } + + // re-add + start = currentTime() + q.add("B", start.Add(2*time.Second), func() { + exitC <- currentTime() + }) + + got := (<-exitC).Truncate(time.Second) + min := start.Add(2 * time.Second).Truncate(time.Second) + + err := time.Second + if got.Before(min) || got.After(min.Add(err)) { + t.Fatalf("Timer(%s) expiration: want [%s, %s], got %s", "5", min, min.Add(err), got) + } + + if len(q.timers) != 0 { + t.Fatalf("Length(%d) of timers is not equal to 0", len(q.timers)) + } +} + +// TestExpireQueue_Stop tests the stop method of the expireQueue +func TestExpireQueue_Stop(t *testing.T) { + q := newExpireQueue(time.Second, 60) + + exitC := make(chan time.Time) + + start := currentTime() + q.add("A", start.Add(2*time.Second), func() { + exitC <- currentTime() + }) + q.stop() + + // After clearing, determine whether it will be cleared + timerC := time.NewTimer(3 * time.Second).C + select { + case <-exitC: + t.Fatalf("Failed to remove timer(%s, %v)", "A", "2s") + case <-timerC: + } +} From 1898adcb7717ab747b12a938643c365bd5d1f49b Mon Sep 17 00:00:00 2001 From: goodliu Date: Thu, 23 May 2024 10:45:34 +0800 Subject: [PATCH 3/3] add mongodb (#28) (#29) * add mongodb * revert cover.out * update LICENSE and fix variable name * update comments (cherry picked from commit d476e18030df7b03ff19bbd1aea5bef69ec2e6f0) Co-authored-by: MengYinlei <43751910+YoungFr@users.noreply.github.com> --- .github/workflows/mongodb.yml | 33 + LICENSE | 6 +- README.md | 3 +- README.zh_CN.md | 2 + mongodb/CHANGELOG.md | 1 + mongodb/README.md | 97 +++ mongodb/README.zh_CN.md | 94 +++ mongodb/client.go | 841 ++++++++++++++++++++ mongodb/client_test.go | 1185 ++++++++++++++++++++++++++++ mongodb/codec.go | 38 + mongodb/codec_test.go | 65 ++ mongodb/curd.go | 718 +++++++++++++++++ mongodb/curd_test.go | 1073 +++++++++++++++++++++++++ mongodb/error.go | 25 + mongodb/error_test.go | 48 ++ mongodb/go.mod | 56 ++ mongodb/go.sum | 164 ++++ mongodb/mockmongodb/client_mock.go | 665 ++++++++++++++++ mongodb/mongodb_test.go | 156 ++++ mongodb/options.go | 15 + mongodb/options_test.go | 21 + mongodb/plugin.go | 66 ++ mongodb/plugin_test.go | 134 ++++ mongodb/test.sh | 4 + mongodb/transport.go | 429 ++++++++++ mongodb/transport_test.go | 397 ++++++++++ 26 files changed, 6333 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/mongodb.yml create mode 100644 mongodb/CHANGELOG.md create mode 100644 mongodb/README.md create mode 100644 mongodb/README.zh_CN.md create mode 100644 mongodb/client.go create mode 100644 mongodb/client_test.go create mode 100644 mongodb/codec.go create mode 100644 mongodb/codec_test.go create mode 100644 mongodb/curd.go create mode 100644 mongodb/curd_test.go create mode 100644 mongodb/error.go create mode 100644 mongodb/error_test.go create mode 100644 mongodb/go.mod create mode 100644 mongodb/go.sum create mode 100644 mongodb/mockmongodb/client_mock.go create mode 100644 mongodb/mongodb_test.go create mode 100644 mongodb/options.go create mode 100644 mongodb/options_test.go create mode 100644 mongodb/plugin.go create mode 100644 mongodb/plugin_test.go create mode 100644 mongodb/test.sh create mode 100644 mongodb/transport.go create mode 100644 mongodb/transport_test.go diff --git a/.github/workflows/mongodb.yml b/.github/workflows/mongodb.yml new file mode 100644 index 0000000..fb8573c --- /dev/null +++ b/.github/workflows/mongodb.yml @@ -0,0 +1,33 @@ +name: Mongodb Pull Request Check +on: + pull_request: + paths: + - 'mongodb/**' + - '.github/workflows/mongodb.yml' + push: + paths: + - 'mongodb/**' + - '.github/workflows/mongodb.yml' + workflow_dispatch: +permissions: + contents: read +jobs: + build: + name: build + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-go@v4 + with: + go-version: 1.19 + - name: Build + run: cd mongodb && go build -v ./... + - name: Test + run: cd mongodb && go test -v -coverprofile=coverage.out -gcflags=all=-l -run Unit ./... + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v3 + with: + files: ./mongodb/coverage.out + flags: mongodb + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/LICENSE b/LICENSE index a6f6d0e..23b8c58 100644 --- a/LICENSE +++ b/LICENSE @@ -74,6 +74,9 @@ Source code of this software can be obtained from: github.com/golang/mock Copyright 2018 by David A. Golden. All rights reserved. Source code of this software can be obtained from: github.com/xdg-go/scram +3. mongo-driver +Copyright (C) MongoDB, Inc. 2017-present. +Source code of this software can be obtained from: github.com/mongodb/mongo-go-driver Terms of the Apache License Version 2.0: -------------------------------------------------------------------- @@ -157,7 +160,6 @@ Copyright (c) supermonkey original author and authors 6. cos-go-sdk-v5 Copyright (c) 2017 mozillazg - 7. miniredis Copyright (c) 2014 Harmen @@ -187,7 +189,7 @@ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLI -Open Source Software Licensed under the MIT and Apache 2.0: +Open Source Software Licensed under the MIT and Apache 2.0: -------------------------------------------------------------------- 1. yaml.v3 Copyright 2011-2016 Canonical Ltd. diff --git a/README.md b/README.md index d009cfa..70c70fd 100644 --- a/README.md +++ b/README.md @@ -13,10 +13,11 @@ Considering that tRPC-Go provides a variety of plugins for naming routing, monit | bigcache | Wraps the open-source local caching database [Bigcache](https://github.com/allegro/bigcache) | | clickhouse | Wraps the open-source database [Clickhouse SDK](https://github.com/ClickHouse/clickhouse-go) | | cos | Wraps Tencent Cloud Object Storage [COS SDK](https://github.com/tencentyun/cos-go-sdk-v5) | +| goes | Wraps the open-source official Go [ElasticSearch client](https://github.com/elastic/go-elasticsearch) | | goredis | Wraps the in-memory database [Redis SDK](https://github.com/redis/go-redis) | | gorm | Wraps the Golang ORM library [GORM](https://github.com/go-gorm/gorm) | | hbase | Wraps the open-source database [HBase SDK](https://github.com/tsuna/gohbase) | | kafka | Wraps the open-source Kafka message queue SDK [Sarama](https://github.com/IBM/sarama) | +| mongodb | Wraps the open-source database [MongoDB Driver](https://go.mongodb.org/mongo-driver/mongo) | | mysql | Wraps the open-source database [MySQL Driver](https://github.com/go-sql-driver/mysql) | -| goes | Wraps the open-source official go [ElasticSearch client](https://github.com/elastic/go-elasticsearch) | | timer | Local/distributed timer functionality | \ No newline at end of file diff --git a/README.zh_CN.md b/README.zh_CN.md index 4ded09c..b85cc2e 100644 --- a/README.zh_CN.md +++ b/README.zh_CN.md @@ -13,9 +13,11 @@ | bigcache | 封装开源本地缓存数据库 [Bigcache](https://github.com/allegro/bigcache) | | clickhouse | 封装开源数据库 [Clickhouse SDK](https://github.com/ClickHouse/clickhouse-go) | | cos | 封装腾讯云对象存储 [COS SDK](https://github.com/tencentyun/cos-go-sdk-v5) | +| goes | 封装开源官方 Go [ElasticSearch client](https://github.com/elastic/go-elasticsearch) | | goredis | 封装内存数据库 [Redis SDK](https://github.com/redis/go-redis) | | gorm | 封装 Golang ORM 库 [GORM](https://github.com/go-gorm/gorm) | | hbase | 封装开源数据库 [HBase SDK](https://github.com/tsuna/gohbase) | | kafka | 封装开源消息队列 Kafka SDK [Sarama](https://github.com/IBM/sarama) | +| mongodb | 封装开源数据库 [MongoDB Driver](https://go.mongodb.org/mongo-driver/mongo) | | mysql | 封装开源数据库 [Mysql Driver](https://github.com/go-sql-driver/mysql) | | timer | 本地/分布式定时器 | \ No newline at end of file diff --git a/mongodb/CHANGELOG.md b/mongodb/CHANGELOG.md new file mode 100644 index 0000000..fa4d35e --- /dev/null +++ b/mongodb/CHANGELOG.md @@ -0,0 +1 @@ +# Change Log \ No newline at end of file diff --git a/mongodb/README.md b/mongodb/README.md new file mode 100644 index 0000000..4166a6f --- /dev/null +++ b/mongodb/README.md @@ -0,0 +1,97 @@ +English | [中文](README.zh_CN.md) + +# tRPC-Go mongodb plugin +[![BK Pipelines Status](https://api.bkdevops.qq.com/process/api/external/pipelines/projects/pcgtrpcproject/p-d7b163d3830a429e976bf77e2409c6d3/badge?X-DEVOPS-PROJECT-ID=pcgtrpcproject)](http://devops.oa.com/ms/process/api-html/user/builds/projects/pcgtrpcproject/pipelines/p-d7b163d3830a429e976bf77e2409c6d3/latestFinished?X-DEVOPS-PROJECT-ID=pcgtrpcproject) + +Base on community [mongo](https://go.mongodb.org/mongo-driver/mongo), used with trpc. + +## mongodb client +```yaml +client: # Backend configuration for client calls. + service: # Configuration for the backend. + - name: trpc.mongodb.xxx.xxx + target: mongodb://user:passwd@vip:port # mongodb standard uri:mongodb://[username:password@]host1[:port1][,host2[:port2],...[,hostN[:portN]]][/[database][?options]] + timeout: 800 # The maximum processing time of the current request. + - name: trpc.mongodb.xxx.xxx1 + target: mongodb+polaris://user:passwd@polaris_name # mongodb+polaris means that the host in the mongodb uri will perform Polaris analysis. + timeout: 800 # The maximum processing time of the current request. +``` +```go +package main + +import ( + "context" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "trpc.group/trpc-go/trpc-database/mongodb" + "trpc.group/trpc-go/trpc-go/log" +) + +// BattleFlow is battle information. +type BattleInfo struct { + Id string `bson:"_id,omitempty"` + Ctime uint32 `bson:"ctime,omitempty" json:"ctime,omitempty"` +} + +func (s *server) SayHello(ctx context.Context, req *pb.ReqBody, rsp *pb.RspBody) (err error) { + proxy := mongodb.NewClientProxy("trpc.mongodb.xxx.xxx") // Your custom service name,used for monitoring, reporting and mapping configuration. + + // mongodb insert + _, err = proxy.InsertOne(sc, "database", "table", bson.M{"_id": "key2", "value": "v2"}) + + // mongodb ReplaceOne + opts := options.Replace().SetUpsert(true) + filter := bson.D{{"_id", "key1"}} + _, err := proxy.ReplaceOne(ctx, "database", "table", filter, &BattleInfo{}, opts) + if err != nil { + log.Errorf("err=%v, data=%v", err, *battleInfo) + return err + } + + // mongodb FindOne + rst := proxy.FindOne(ctx, "database", "table", bson.D{{"_id", "key1"}}) + battleInfo = &BattleInfo{} + err = rst.Decode(battleInfo) + if err != nil { + return nil, err + } + + // mongodb transaction + err = proxy.Transaction(ctx, func(sc mongo.SessionContext) error { + // The same proxy instance needs to be used during transaction execution. + _, tErr := proxy.InsertOne(sc, "database", "table", bson.M{"_id": "key1", "value": "v1"}) + if tErr != nil { + return tErr + } + _, tErr = proxy.InsertOne(sc, "database", "table", bson.M{"_id": "key2", "value": "v2"}) + if tErr != nil { + return tErr + } + return nil + }, nil) + + // mongodb RunCommand + cmdDB := bson.D{} + cmdDB = append(cmdDB, bson.E{Key: "enableSharding", Value: "dbName"}) + err = proxy.RunCommand(ctx, "admin", cmdDB).Err() + if err != nil { + return nil, err + } + + cmdColl := bson.D{} + cmdColl = append(cmdColl, bson.E{Key: "shardCollection", Value: "dbName.collectionName"}) + cmdColl = append(cmdColl, bson.E{Key: "key", Value: bson.D{{"openId", "hashed"}}}) + cmdColl = append(cmdColl, bson.E{Key: "unique", Value: false}) + cmdColl = append(cmdColl, bson.E{Key: "numInitialChunks", Value: 10}) + err = proxy.RunCommand(ctx, "admin", cmdColl).Err() + if err != nil { + return nil, err + } + // Business logic. +} +``` +## Frequently Asked Questions (FAQs) +- Q1: How to configure ClientOptions: +- A1: When creating a Transport, you can use WithOptionInterceptor to configure ClientOptions. You can refer to options_test.go for more information. diff --git a/mongodb/README.zh_CN.md b/mongodb/README.zh_CN.md new file mode 100644 index 0000000..db37a8a --- /dev/null +++ b/mongodb/README.zh_CN.md @@ -0,0 +1,94 @@ +# tRPC-Go mongodb 插件 +[![BK Pipelines Status](https://api.bkdevops.qq.com/process/api/external/pipelines/projects/pcgtrpcproject/p-d7b163d3830a429e976bf77e2409c6d3/badge?X-DEVOPS-PROJECT-ID=pcgtrpcproject)](http://devops.oa.com/ms/process/api-html/user/builds/projects/pcgtrpcproject/pipelines/p-d7b163d3830a429e976bf77e2409c6d3/latestFinished?X-DEVOPS-PROJECT-ID=pcgtrpcproject) + +封装社区的 [mongo](https://go.mongodb.org/mongo-driver/mongo) ,配合 trpc 使用。 + +## mongodb client +```yaml +client: #客户端调用的后端配置 + service: #针对后端的配置 + - name: trpc.mongodb.xxx.xxx + target: mongodb://user:passwd@vip:port #mongodb 标准uri:mongodb://[username:password@]host1[:port1][,host2[:port2],...[,hostN[:portN]]][/[database][?options]] + timeout: 800 #当前这个请求最长处理时间 + - name: trpc.mongodb.xxx.xxx1 + target: mongodb+polaris://user:passwd@polaris_name # mongodb+polaris表示mongodb uri中的host会进行北极星解析 + timeout: 800 # 当前这个请求最长处理时间 +``` +```go +package main + +import ( + "time" + "context" + + "trpc.group/trpc-go/trpc-database/mongodb" + "trpc.group/trpc-go/trpc-go/client" +) + +// BattleFlow 对局信息 +type BattleInfo struct { + Id string `bson:"_id,omitempty" ` + Ctime uint32 `bson:"ctime,omitempty" json:"ctime,omitempty"` +} + +func (s *server) SayHello(ctx context.Context, req *pb.ReqBody, rsp *pb.RspBody) (err error) { + proxy := mongodb.NewClientProxy("trpc.mongodb.xxx.xxx") // service name自己随便填,主要用于监控上报和寻址配置项 + + // mongodb insert + _, err = proxy.InsertOne(sc, "database", "table", bson.M{"_id": "key2", "value": "v2"}) + + // mongodb ReplaceOne + opts := options.Replace().SetUpsert(true) + filter := bson.D{{"_id", "key1"}} + _, err := proxy.ReplaceOne(ctx, "database", "table", filter, &BattleInfo{}, opts) + if err != nil { + log.Errorf("err=%v, data=%v", err, *battleInfo) + return err + } + + // mongodb FindOne + rst := proxy.FindOne(ctx, "database", "table", bson.D{{"_id", "key1"}}) + battleInfo = &BattleInfo{} + err = rst.Decode(battleInfo) + if err != nil { + return nil, err + } + + // mongodb transaction + err = proxy.Transaction(ctx, func(sc mongo.SessionContext) error { + //事务执行过程中需要使用同一proxy实例执行 + _, tErr := proxy.InsertOne(sc, "database", "table", bson.M{"_id": "key1", "value": "v1"}) + if tErr != nil { + return tErr + } + _, tErr = proxy.InsertOne(sc, "database", "table", bson.M{"_id": "key2", "value": "v2"}) + if tErr != nil { + return tErr + } + return nil + }, nil) + + // mongodb RunCommand + cmdDB := bson.D{} + cmdDB = append(cmdDB, bson.E{Key: "enableSharding", Value: "dbName"}) + err = proxy.RunCommand(ctx, "admin", cmdDB).Err() + if err != nil { + return nil, err + } + + cmdColl := bson.D{} + cmdColl = append(cmdColl, bson.E{Key: "shardCollection", Value: "dbName.collectionName"}) + cmdColl = append(cmdColl, bson.E{Key: "key", Value: bson.D{{"openId", "hashed"}}}) + cmdColl = append(cmdColl, bson.E{Key: "unique", Value: false}) + cmdColl = append(cmdColl, bson.E{Key: "numInitialChunks", Value: 10}) + err = proxy.RunCommand(ctx, "admin", cmdColl).Err() + if err != nil { + return nil, err + } + // 业务逻辑 +} +``` +## 常见问题 + +- Q1: 如何配置 ClientOptions +- A1: 创建Transport时可以使用WithOptionInterceptor对ClientOptions进行配置,可以参考options_test.go diff --git a/mongodb/client.go b/mongodb/client.go new file mode 100644 index 0000000..158485e --- /dev/null +++ b/mongodb/client.go @@ -0,0 +1,841 @@ +// Package mongodb encapsulates standard library mongodb. +package mongodb + +import ( + "context" + "fmt" + "reflect" + "strings" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "trpc.group/trpc-go/trpc-go" + "trpc.group/trpc-go/trpc-go/client" + "trpc.group/trpc-go/trpc-go/codec" + "trpc.group/trpc-go/trpc-go/log" +) + +// TxFunc mongo is a transaction logic function, +// if an error is returned, it will be rolled back. +type TxFunc func(sc mongo.SessionContext) error + +//go:generate mockgen -source=client.go -destination=./mockmongodb/client_mock.go -package=mockmongodb + +// Client is mongodb request interface. +type Client interface { + BulkWrite(ctx context.Context, database string, coll string, models []mongo.WriteModel, + opts ...*options.BulkWriteOptions) (*mongo.BulkWriteResult, error) + InsertOne(ctx context.Context, database string, coll string, document interface{}, + opts ...*options.InsertOneOptions) (*mongo.InsertOneResult, error) + InsertMany(ctx context.Context, database string, coll string, documents []interface{}, + opts ...*options.InsertManyOptions) (*mongo.InsertManyResult, error) + DeleteOne(ctx context.Context, database string, coll string, filter interface{}, + opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) + DeleteMany(ctx context.Context, database string, coll string, filter interface{}, + opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) + UpdateOne(ctx context.Context, database string, coll string, filter interface{}, update interface{}, + opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) + UpdateMany(ctx context.Context, database string, coll string, filter interface{}, update interface{}, + opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) + ReplaceOne(ctx context.Context, database string, coll string, filter interface{}, + replacement interface{}, opts ...*options.ReplaceOptions) (*mongo.UpdateResult, error) + Aggregate(ctx context.Context, database string, coll string, pipeline interface{}, + opts ...*options.AggregateOptions) (*mongo.Cursor, error) + CountDocuments(ctx context.Context, database string, coll string, filter interface{}, + opts ...*options.CountOptions) (int64, error) + EstimatedDocumentCount(ctx context.Context, database string, coll string, + opts ...*options.EstimatedDocumentCountOptions) (int64, error) + Distinct(ctx context.Context, database string, coll string, fieldName string, filter interface{}, + opts ...*options.DistinctOptions) ([]interface{}, error) + Find(ctx context.Context, database string, coll string, filter interface{}, + opts ...*options.FindOptions) (*mongo.Cursor, error) + FindOne(ctx context.Context, database string, coll string, filter interface{}, + opts ...*options.FindOneOptions) *mongo.SingleResult + FindOneAndDelete(ctx context.Context, database string, coll string, filter interface{}, + opts ...*options.FindOneAndDeleteOptions) *mongo.SingleResult + FindOneAndReplace(ctx context.Context, database string, coll string, filter interface{}, + replacement interface{}, opts ...*options.FindOneAndReplaceOptions) *mongo.SingleResult + FindOneAndUpdate(ctx context.Context, database string, coll string, filter interface{}, + update interface{}, opts ...*options.FindOneAndUpdateOptions) *mongo.SingleResult + Watch(ctx context.Context, pipeline interface{}, + opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) + WatchDatabase(ctx context.Context, database string, pipeline interface{}, + opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) + WatchCollection(ctx context.Context, database string, collection string, pipeline interface{}, + opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) + Transaction(ctx context.Context, sf TxFunc, tOpts []*options.TransactionOptions, + opts ...*options.SessionOptions) error + Disconnect(ctx context.Context) error + RunCommand(ctx context.Context, database string, runCommand interface{}, + opts ...*options.RunCmdOptions) *mongo.SingleResult + Indexes(ctx context.Context, database string, collection string) (mongo.IndexView, error) + Database(ctx context.Context, database string) (*mongo.Database, error) + Collection(ctx context.Context, database string, collection string) (*mongo.Collection, error) + StartSession(ctx context.Context) (mongo.Session, error) + //Deprecated + Do(ctx context.Context, cmd string, db string, coll string, args map[string]interface{}) (interface{}, error) +} + +// IndexViewer is the interface definition of the index. +// Refer to the naming of the community open source library, +// define the index interface separately, and divide the interface according to the function. +type IndexViewer interface { + // CreateMany creates the interface definition of the index. + CreateMany(ctx context.Context, database string, coll string, + models []mongo.IndexModel, opts ...*options.CreateIndexesOptions) ([]string, error) + CreateOne(ctx context.Context, database string, coll string, + model mongo.IndexModel, opts ...*options.CreateIndexesOptions) (string, error) + DropOne(ctx context.Context, database string, coll string, + name string, opts ...*options.DropIndexesOptions) (bson.Raw, error) + DropAll(ctx context.Context, database string, coll string, + opts ...*options.DropIndexesOptions) (bson.Raw, error) +} + +// mongodbCli is a backend request structure. +type mongodbCli struct { + ServiceName string + Client client.Client + opts []client.Option +} + +// NewClientProxy creates a new mongo backend request proxy. +// The required parameter mongo service name: trpc.mongo.xxx.xxx. +func NewClientProxy(name string, opts ...client.Option) Client { + c := &mongodbCli{ + ServiceName: name, + Client: client.DefaultClient, + } + + c.opts = make([]client.Option, 0, len(opts)+2) + c.opts = append(c.opts, opts...) + c.opts = append(c.opts, client.WithProtocol("mongodb"), client.WithDisableServiceRouter()) + return c +} + +// NewInsertOneModel creates a new InsertOneModel. +// InsertOneModel is used to insert a single document in a BulkWrite operation. +func NewInsertOneModel() *mongo.InsertOneModel { + return mongo.NewInsertOneModel() +} + +// NewUpdateManyModel creates a new UpdateManyModel. +// UpdateManyModel is used to update multiple documents in a BulkWrite operation. +func NewUpdateManyModel() *mongo.UpdateManyModel { + return mongo.NewUpdateManyModel() +} + +// NewUpdateOneModel creates a new UpdateOneModel. +// UpdateOneModel is used to update at most one document in a BulkWrite operation. +func NewUpdateOneModel() *mongo.UpdateOneModel { + return mongo.NewUpdateOneModel() +} + +// NewReplaceOneModel creates a new ReplaceOneModel. +// ReplaceOneModel is used to replace at most one document in a BulkWrite operation. +func NewReplaceOneModel() *mongo.ReplaceOneModel { + return mongo.NewReplaceOneModel() +} + +// NewSessionContext creates a new SessionContext associated with the given Context and Session parameters. +func NewSessionContext(ctx context.Context, sess mongo.Session) mongo.SessionContext { + return mongo.NewSessionContext(ctx, sess) +} + +// mongodb cmd definition +var ( + Find = "find" + FindOne = "findone" + FindOneAndReplace = "findoneandreplace" + FindC = "findc" // Return mongo.Cursor type interface, use cursor.All/Decode to parse to structure. + DeleteOne = "deleteone" + DeleteMany = "deletemany" + FindOneAndDelete = "findoneanddelete" + FindOneAndUpdate = "findoneandupdate" + FindOneAndUpdateS = "findoneandupdates" // Return mongo.SingleResult type interface, + // use Decode to parse to structure. + InsertOne = "insertone" + InsertMany = "insertmany" + UpdateOne = "updateone" + UpdateMany = "updatemany" + ReplaceOne = "replaceone" + Count = "count" + Aggregate = "aggregate" // Polymerization + AggregateC = "aggregatec" // Return mongo.Cursor type interface, + // use cursor.All/Decode to parse to structure. + Distinct = "distinct" + BulkWrite = "bulkwrite" + CountDocuments = "countdocuments" + EstimatedDocumentCount = "estimateddocumentcount" + Watch = "watch" + WatchDatabase = "watchdatabase" + WatchCollection = "watchcollection" + Transaction = "transaction" + Disconnect = "disconnect" + RunCommand = "runcommand" // Execute commands sequentially + IndexCreateOne = "indexcreateone" // Create index + IndexCreateMany = "indexcreatemany" // Create indexes in batches + IndexDropOne = "indexdropone" // Delete index + IndexDropAll = "indexdropall" // Delete all indexes + Indexes = "indexes" // Get the original index object + DatabaseCmd = "database" // Get the original database + CollectionCmd = "collection" // Get the original collection + StartSession = "startsession" // Create a new Session and SessionContext +) + +// Request mongodb request body +type Request struct { + Command string + Database string + Collection string + Arguments map[string]interface{} + + DriverProxy bool //driver transparent transmission + Filter interface{} //driver filter + CommArg interface{} //general parameters + Opts interface{} //option parameter +} + +// Response mongodb response body +type Response struct { + Result interface{} + txClient *mongo.Client //Use transparent mongo client in transaction execution. +} + +// Do is a general execution interface, +// which executes different curd operations according to cmd. +func (c *mongodbCli) Do(ctx context.Context, cmd string, database string, collection string, + args map[string]interface{}) (interface{}, error) { + cmd = strings.ToLower(cmd) + req := &Request{ + Command: cmd, + Database: database, + Collection: collection, + Arguments: args, + } + rsp := &Response{} + + ctx, msg := codec.WithCloneMessage(ctx) + defer codec.PutBackMessage(msg) + msg.WithClientRPCName(fmt.Sprintf("/%s/%s", c.ServiceName, cmd)) + msg.WithCalleeServiceName(c.ServiceName) + msg.WithSerializationType(-1) // Not serialized. + msg.WithCompressType(0) // Not compressed. + msg.WithClientReqHead(req) + msg.WithClientRspHead(rsp) + + err := c.Client.Invoke(ctx, req, rsp, c.opts...) + return rsp.Result, err +} + +// invoke is a universal execution interface, execute different curd operations according to cmd. +func (c *mongodbCli) invoke(ctx context.Context, req *Request) ( + interface{}, error) { + req.DriverProxy = true + + // If there is a transparently transmitted response when executing a transaction, + // use the transparently transmitted instance directly. + var rsp *Response + rspHead := trpc.Message(ctx).ClientRspHead() + if rspHead != nil { + if rspIns, ok := rspHead.(*Response); ok { + rsp = rspIns + } + } + + // If there is no specified mongo instance, create a new response. + if rsp == nil { + rsp = &Response{} + } + + ctx, msg := codec.WithCloneMessage(ctx) + defer codec.PutBackMessage(msg) + msg.WithClientRPCName(fmt.Sprintf("/%s/driver.%s", c.ServiceName, req.Command)) + msg.WithCalleeServiceName(c.ServiceName) + msg.WithSerializationType(-1) //Not serialized. + msg.WithCompressType(0) //Not compressed. + msg.WithClientReqHead(req) + msg.WithClientRspHead(rsp) + + err := c.Client.Invoke(ctx, req, rsp, c.opts...) + return rsp.Result, err +} + +// Transaction executes a Transaction. +func (c *mongodbCli) Transaction(ctx context.Context, sf TxFunc, tOpts []*options.TransactionOptions, + opts ...*options.SessionOptions) error { + request := &Request{ + Command: Transaction, + CommArg: sf, + Filter: tOpts, + Opts: opts, + } + + _, err := c.invoke(ctx, request) + if err != nil { + return err + } + return nil +} + +// InsertOne executes an insert command to insert a single document into the collection. +func (c *mongodbCli) InsertOne(ctx context.Context, database string, coll string, document interface{}, + opts ...*options.InsertOneOptions) (*mongo.InsertOneResult, error) { + request := &Request{ + Command: InsertOne, + Database: database, + Collection: coll, + CommArg: document, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + return rsp.(*mongo.InsertOneResult), nil +} + +// InsertMany executes an insert command to insert multiple documents into the collection. If write errors occur +// during the operation (e.g. duplicate key error), this method returns a BulkWriteException error. +func (c *mongodbCli) InsertMany(ctx context.Context, database string, coll string, documents []interface{}, + opts ...*options.InsertManyOptions) (*mongo.InsertManyResult, error) { + request := &Request{ + Command: InsertMany, + Database: database, + Collection: coll, + CommArg: documents, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if rsp != nil { + return rsp.(*mongo.InsertManyResult), err + } + return nil, err +} + +// DeleteOne executes a delete command to delete at most one document from the collection. +func (c *mongodbCli) DeleteOne(ctx context.Context, database string, coll string, filter interface{}, + opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) { + request := &Request{ + Command: DeleteOne, + Database: database, + Collection: coll, + Filter: filter, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + return rsp.(*mongo.DeleteResult), nil +} + +// DeleteMany executes a delete command to delete documents from the collection. +func (c *mongodbCli) DeleteMany(ctx context.Context, database string, coll string, filter interface{}, + opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) { + request := &Request{ + Command: DeleteMany, + Database: database, + Collection: coll, + Filter: filter, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + return rsp.(*mongo.DeleteResult), nil +} + +// UpdateOne executes an update command to update at most one document in the collection. +func (c *mongodbCli) UpdateOne(ctx context.Context, database string, coll string, filter interface{}, + update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { + request := &Request{ + Command: UpdateOne, + Database: database, + Collection: coll, + Filter: filter, + CommArg: update, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + return rsp.(*mongo.UpdateResult), nil +} + +// UpdateMany executes an update command to update documents in the collection. +func (c *mongodbCli) UpdateMany(ctx context.Context, database string, coll string, filter interface{}, + update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { + request := &Request{ + Command: UpdateMany, + Database: database, + Collection: coll, + Filter: filter, + CommArg: update, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + return rsp.(*mongo.UpdateResult), nil +} + +// ReplaceOne executes an update command to replace at most one document in the collection. +func (c *mongodbCli) ReplaceOne(ctx context.Context, database string, coll string, filter interface{}, + replacement interface{}, opts ...*options.ReplaceOptions) (*mongo.UpdateResult, error) { + request := &Request{ + Command: ReplaceOne, + Database: database, + Collection: coll, + Filter: filter, + CommArg: replacement, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + return rsp.(*mongo.UpdateResult), nil +} + +// Aggregate executes an aggregate command against the collection and returns a cursor over the resulting documents. +func (c *mongodbCli) Aggregate(ctx context.Context, database string, coll string, pipeline interface{}, + opts ...*options.AggregateOptions) (*mongo.Cursor, error) { + request := &Request{ + Command: Aggregate, + Database: database, + Collection: coll, + CommArg: pipeline, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + return rsp.(*mongo.Cursor), nil +} + +// CountDocuments returns the number of documents in the collection. For a fast count of the documents in the +// collection, see the EstimatedDocumentCount method. +func (c *mongodbCli) CountDocuments(ctx context.Context, database string, coll string, filter interface{}, + opts ...*options.CountOptions) (int64, error) { + request := &Request{ + Command: CountDocuments, + Database: database, + Collection: coll, + Filter: filter, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return 0, err + } + return rsp.(int64), nil +} + +// EstimatedDocumentCount executes a count command and returns an estimate of the number of documents in the collection +// using collection metadata. +func (c *mongodbCli) EstimatedDocumentCount(ctx context.Context, database string, coll string, + opts ...*options.EstimatedDocumentCountOptions) (int64, error) { + request := &Request{ + Command: EstimatedDocumentCount, + Database: database, + Collection: coll, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return 0, err + } + return rsp.(int64), nil +} + +// Distinct executes a distinct command to find the unique values for a specified field in the collection. +func (c *mongodbCli) Distinct(ctx context.Context, database string, coll string, fieldName string, filter interface{}, + opts ...*options.DistinctOptions) ([]interface{}, error) { + request := &Request{ + Command: Distinct, + Database: database, + Collection: coll, + Filter: filter, + CommArg: fieldName, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + return rsp.([]interface{}), nil +} + +// InsertOne executes an insert command to insert a single document into the collection. +func (c *mongodbCli) Find(ctx context.Context, database string, coll string, filter interface{}, + opts ...*options.FindOptions) (*mongo.Cursor, error) { + request := &Request{ + Command: Find, + Database: database, + Collection: coll, + Filter: filter, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + return rsp.(*mongo.Cursor), nil +} + +// FindOne executes a find command and returns a SingleResult for one document in the collection. +func (c *mongodbCli) FindOne(ctx context.Context, database string, coll string, filter interface{}, + opts ...*options.FindOneOptions) *mongo.SingleResult { + request := &Request{ + Command: FindOne, + Database: database, + Collection: coll, + Filter: filter, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + log.Errorf("client invoke error: %v", err) + return mongo.NewSingleResultFromDocument(bson.D{}, err, nil) + } + return rsp.(*mongo.SingleResult) +} + +// FindOneAndDelete executes a findAndModify command to delete at most one document in the collection. and returns the +// document as it appeared before deletion. +func (c *mongodbCli) FindOneAndDelete(ctx context.Context, database string, coll string, filter interface{}, + opts ...*options.FindOneAndDeleteOptions) *mongo.SingleResult { + request := &Request{ + Command: FindOneAndDelete, + Database: database, + Collection: coll, + Filter: filter, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + log.Errorf("client invoke error: %v", err) + return mongo.NewSingleResultFromDocument(bson.D{}, err, nil) + } + return rsp.(*mongo.SingleResult) +} + +// FindOneAndReplace executes a findAndModify command to replace at most one document in the collection +// and returns the document as it appeared before replacement. +func (c *mongodbCli) FindOneAndReplace(ctx context.Context, database string, coll string, filter interface{}, + replacement interface{}, opts ...*options.FindOneAndReplaceOptions) *mongo.SingleResult { + request := &Request{ + Command: FindOneAndReplace, + Database: database, + Collection: coll, + Filter: filter, + CommArg: replacement, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + log.Errorf("client invoke error: %v", err) + return mongo.NewSingleResultFromDocument(bson.D{}, err, nil) + } + return rsp.(*mongo.SingleResult) +} + +// FindOneAndUpdate executes a findAndModify command to update at most one document in the collection and returns the +// document as it appeared before updating. +func (c *mongodbCli) FindOneAndUpdate(ctx context.Context, database string, coll string, filter interface{}, + update interface{}, opts ...*options.FindOneAndUpdateOptions) *mongo.SingleResult { + request := &Request{ + Command: FindOneAndUpdate, + Database: database, + Collection: coll, + Filter: filter, + CommArg: update, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + log.Errorf("client invoke error: %v", err) + return mongo.NewSingleResultFromDocument(bson.D{}, err, nil) + } + return rsp.(*mongo.SingleResult) +} + +// BulkWrite performs a bulk write operation (https://docs.mongodb.com/manual/core/bulk-write-operations/) +func (c *mongodbCli) BulkWrite(ctx context.Context, database string, coll string, models []mongo.WriteModel, + opts ...*options.BulkWriteOptions) (*mongo.BulkWriteResult, error) { + request := &Request{ + Command: BulkWrite, + Database: database, + Collection: coll, + CommArg: models, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if rsp != nil { + return rsp.(*mongo.BulkWriteResult), err + } + return nil, err +} + +// Watch returns a change stream for all changes on the deployment. +func (c *mongodbCli) Watch(ctx context.Context, pipeline interface{}, + opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) { + + request := &Request{ + Command: Watch, + CommArg: pipeline, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + return rsp.(*mongo.ChangeStream), nil +} + +// WatchDatabase returns a change stream for all changes to the corresponding database. +func (c *mongodbCli) WatchDatabase(ctx context.Context, database string, pipeline interface{}, + opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) { + request := &Request{ + Command: WatchDatabase, + Database: database, + CommArg: pipeline, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + return rsp.(*mongo.ChangeStream), nil +} + +// WatchCollection returns a change stream for all changes on the corresponding collection. +func (c *mongodbCli) WatchCollection(ctx context.Context, database string, collection string, pipeline interface{}, + opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) { + request := &Request{ + Command: WatchCollection, + Database: database, + Collection: collection, + CommArg: pipeline, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + return rsp.(*mongo.ChangeStream), nil +} + +// Disconnect closes the mongo client under service name. +func (c *mongodbCli) Disconnect(ctx context.Context) error { + request := &Request{ + Command: Disconnect, + } + _, err := c.invoke(ctx, request) + return err +} + +// RunCommand executes the given command against the database. This function does not obey the Database's read +// preference. To specify a read preference, the RunCmdOptions.ReadPreference option must be used. +// The runCommand parameter must be a document for the command to be executed. It cannot be nil. +// This must be an order-preserving type such as bson.D. Map types such as bson.M are not valid. +// The shardCollection command must be run against the admin database. +func (c *mongodbCli) RunCommand(ctx context.Context, database string, runCommand interface{}, + opts ...*options.RunCmdOptions) *mongo.SingleResult { + request := &Request{ + Command: RunCommand, + Database: database, + CommArg: runCommand, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + log.Errorf("client invoke error: %v", err) + return mongo.NewSingleResultFromDocument(bson.D{}, err, nil) + } + return rsp.(*mongo.SingleResult) +} + +// CreateMany executes a createIndexes command to create multiple indexes on the collection and returns +// the names of the new indexes. +func (c *mongodbCli) CreateMany(ctx context.Context, database string, collection string, + models []mongo.IndexModel, opts ...*options.CreateIndexesOptions) ([]string, error) { + request := &Request{ + Command: IndexCreateMany, + Database: database, + Collection: collection, + CommArg: models, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + if sliceStr, ok := rsp.([]string); ok { + return sliceStr, nil + } + return nil, buildUnMatchKindError(reflect.Slice, rsp) +} + +// CreateOne executes a createIndexes command to create an index on the collection and returns the name of the new +// index. See the IndexView.CreateMany documentation for more information and an example. +func (c *mongodbCli) CreateOne(ctx context.Context, database string, collection string, + model mongo.IndexModel, opts ...*options.CreateIndexesOptions) (string, error) { + request := &Request{ + Command: IndexCreateOne, + Database: database, + Collection: collection, + CommArg: model, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return "", err + } + if s, ok := rsp.(string); ok { + return s, nil + } + return "", buildUnMatchKindError(reflect.String, rsp) +} + +// DropOne executes a dropIndexes operation to drop an index on the collection. If the operation succeeds, this returns +// a BSON document in the form {nIndexesWas: }. The "nIndexesWas" field in the response contains the number of +// indexes that existed prior to the drop. +// +// The name parameter should be the name of the index to drop. If the name is "*", ErrMultipleIndexDrop will be returned +// without running the command because doing so would drop all indexes. +// +// The opts parameter can be used to specify options for this operation (see the options.DropIndexesOptions +// documentation). +// +// For more information about the command, see https://docs.mongodb.com/manual/reference/command/dropIndexes/. +func (c *mongodbCli) DropOne(ctx context.Context, database string, collection string, + name string, opts ...*options.DropIndexesOptions) (bson.Raw, error) { + request := &Request{ + Command: IndexDropOne, + Database: database, + Collection: collection, + CommArg: name, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + if raw, ok := rsp.(bson.Raw); ok { + return raw, nil + } + return nil, buildUnMatchKindError(reflect.Slice, rsp) +} + +// DropAll executes a dropIndexes operation to drop all indexes on the collection. If the operation succeeds, this +// returns a BSON document in the form {nIndexesWas: }. The "nIndexesWas" field in the response contains the +// number of indexes that existed prior to the drop. +// +// The opts parameter can be used to specify options for this operation (see the options.DropIndexesOptions +// documentation). +// +// For more information about the command, see https://docs.mongodb.com/manual/reference/command/dropIndexes/. +func (c *mongodbCli) DropAll(ctx context.Context, database string, collection string, + opts ...*options.DropIndexesOptions) (bson.Raw, error) { + request := &Request{ + Command: IndexDropAll, + Database: database, + Collection: collection, + CommArg: nil, + Opts: opts, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + if raw, ok := rsp.(bson.Raw); ok { + return raw, nil + } + return nil, buildUnMatchKindError(reflect.Slice, rsp) +} + +// buildUnMatchKindError builds a type mismatch error, +// the result type returned when invoke is not what we expected. +func buildUnMatchKindError(want reflect.Kind, actual interface{}) error { + val := reflect.ValueOf(actual) + return fmt.Errorf("the result kind of got is not expect, expect is %s but actural is %s", + want.String(), val.Kind().String()) +} + +// Indexes gets the original index operation object. +func (c *mongodbCli) Indexes(ctx context.Context, database string, collection string) (mongo.IndexView, error) { + request := &Request{ + Command: Indexes, + Database: database, + Collection: collection, + CommArg: nil, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return mongo.IndexView{}, err + } + if raw, ok := rsp.(mongo.IndexView); ok { + return raw, nil + } + return mongo.IndexView{}, buildUnMatchKindError(reflect.Struct, rsp) +} + +// Database gets the original database used to call the original method. +func (c *mongodbCli) Database(ctx context.Context, database string) (*mongo.Database, error) { + request := &Request{ + Command: DatabaseCmd, + Database: database, + CommArg: nil, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + if raw, ok := rsp.(*mongo.Database); ok { + return raw, nil + } + return nil, buildUnMatchKindError(reflect.Ptr, rsp) +} + +// Collection gets the original collection used to call the original method. +func (c *mongodbCli) Collection(ctx context.Context, database string, collection string) (*mongo.Collection, error) { + request := &Request{ + Command: CollectionCmd, + Database: database, + Collection: collection, + CommArg: nil, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + if raw, ok := rsp.(*mongo.Collection); ok { + return raw, nil + } + return nil, buildUnMatchKindError(reflect.Ptr, rsp) +} + +// StartSession starts a new session configured with the given options. +// StartSession does not actually communicate with the server and will not error if the client is disconnected. +// If the DefaultReadConcern, DefaultWriteConcern, or DefaultReadPreference options are not set, +// the client's read concern, write concern, or read preference will be used, respectively. +func (c *mongodbCli) StartSession(ctx context.Context) (mongo.Session, error) { + request := &Request{ + Command: StartSession, + CommArg: nil, + } + rsp, err := c.invoke(ctx, request) + if err != nil { + return nil, err + } + if raw, ok := rsp.(mongo.Session); ok { + return raw, nil + } + return nil, buildUnMatchKindError(reflect.Ptr, rsp) +} diff --git a/mongodb/client_test.go b/mongodb/client_test.go new file mode 100644 index 0000000..17f065d --- /dev/null +++ b/mongodb/client_test.go @@ -0,0 +1,1185 @@ +package mongodb + +import ( + "context" + "errors" + "fmt" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/golang/mock/gomock" + . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "trpc.group/trpc-go/trpc-go" + "trpc.group/trpc-go/trpc-go/client" + "trpc.group/trpc-go/trpc-go/client/mockclient" + "trpc.group/trpc-go/trpc-go/errs" + "trpc.group/trpc-go/trpc-go/transport" +) + +const fakeError = "fake error" + +func TestUnitNewClientProxy(t *testing.T) { + Convey("TestUnit_NewClientProxy_P0", t, func() { + mysqlClient := NewClientProxy("trpc.mongo.xxx.xxx") + rawClient, ok := mysqlClient.(*mongodbCli) + So(ok, ShouldBeTrue) + So(rawClient.Client, ShouldResemble, client.DefaultClient) + So(rawClient.ServiceName, ShouldEqual, "trpc.mongo.xxx.xxx") + So(len(rawClient.opts), ShouldEqual, 2) + }) +} + +func TestInsert(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(ClientTransport)), "GetMgoClient", + func(transport *ClientTransport, ctx context.Context, dsn string) (*mongo.Client, error) { + return &mongo.Client{}, nil + }, + ).Reset() + t.Run("InsertOne success", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "InsertOne", + func(coll *mongo.Collection, ctx context.Context, document interface{}, + opts ...*options.InsertOneOptions) (*mongo.InsertOneResult, error) { + return nil, nil + }, + ).Reset() + _, err := NewClientProxy("trpc.mongodb.app.demo").InsertOne(context.Background(), "demo", "test", nil) + assert.Nil(t, err) + }) + t.Run("InsertOne fail", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "InsertOne", + func(coll *mongo.Collection, ctx context.Context, document interface{}, + opts ...*options.InsertOneOptions) (*mongo.InsertOneResult, error) { + return nil, errors.New(fakeError) + }, + ).Reset() + _, err := NewClientProxy("trpc.mongodb.app.demo").InsertOne(context.Background(), "demo", "test", nil) + assert.NotNil(t, err) + }) + t.Run("InsertMany succ", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "InsertMany", + func(coll *mongo.Collection, ctx context.Context, documents []interface{}, + opts ...*options.InsertManyOptions) (*mongo.InsertManyResult, error) { + return nil, nil + }, + ).Reset() + _, err := NewClientProxy("trpc.mongodb.app.demo").InsertMany(context.Background(), "demo", "test", nil) + assert.Nil(t, err) + }) + t.Run("InsertMany fail", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "InsertMany", + func(coll *mongo.Collection, ctx context.Context, documents []interface{}, + opts ...*options.InsertManyOptions) (*mongo.InsertManyResult, error) { + return nil, errors.New(fakeError) + }, + ).Reset() + _, err := NewClientProxy("trpc.mongodb.app.demo").InsertMany(context.Background(), "demo", "test", nil) + assert.NotNil(t, err) + }) + t.Run("InsertMany duplicate succ", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "InsertMany", + func(coll *mongo.Collection, ctx context.Context, documents []interface{}, + opts ...*options.InsertManyOptions) (*mongo.InsertManyResult, error) { + return &mongo.InsertManyResult{ + InsertedIDs: []interface{}{"test"}, + }, mongo.BulkWriteException{ + WriteConcernError: &mongo.WriteConcernError{Name: "name", Code: 100, Message: "bar"}, + WriteErrors: []mongo.BulkWriteError{ + { + WriteError: mongo.WriteError{Code: 11000, Message: "blah E11000 blah"}, + Request: &mongo.InsertOneModel{}}, + }, + Labels: []string{"otherError"}, + } + }, + ).Reset() + result, err := NewClientProxy("trpc.mongodb.app.demo").InsertMany(context.Background(), "demo", "test", nil) + assert.NotNil(t, err) + assert.NotNil(t, result) + }) + + t.Run("BulkWrite duplicate succ", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "BulkWrite", + func(coll *mongo.Collection, ctx context.Context, models []mongo.WriteModel, + opts ...*options.BulkWriteOptions) (*mongo.BulkWriteResult, error) { + return &mongo.BulkWriteResult{ + InsertedCount: 1, + }, mongo.BulkWriteException{ + WriteConcernError: &mongo.WriteConcernError{Name: "name", Code: 100, Message: "bar"}, + WriteErrors: []mongo.BulkWriteError{ + { + WriteError: mongo.WriteError{Code: 11000, Message: "blah E11000 blah"}, + Request: &mongo.InsertOneModel{}}, + }, + Labels: []string{"otherError"}, + } + }, + ).Reset() + models := make([]mongo.WriteModel, 0, 2) + model := mongo.NewUpdateOneModel() + model.SetUpsert(true) + model.SetFilter(bson.M{"test": 1}) + model.SetUpdate(bson.M{ + "$set": bson.M{"test": 1}, + "$setOnInsert": bson.M{"test": 1}, + }) + models = append(models, model) + model = mongo.NewUpdateOneModel() + model.SetUpsert(true) + model.SetFilter(bson.M{"test": 2}) + model.SetUpdate(bson.M{ + "$set": bson.M{"test": 2}, + "$setOnInsert": bson.M{"test": 2}, + }) + models = append(models, model) + result, err := NewClientProxy("trpc.mongodb.app.demo").BulkWrite(context.Background(), "demo", "test", models) + assert.NotNil(t, err) + assert.NotNil(t, result) + }) + t.Run("Transaction succ", func(t *testing.T) { + client, _ := mongo.Connect(trpc.BackgroundContext(), + options.Client().ApplyURI("mongodb://127.0.0.1:27017")) + assert.NotNil(t, client) + + sess, _ := client.StartSession() + assert.NotNil(t, sess) + + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Client)), "StartSession", + func(coll *mongo.Client, opts ...*options.SessionOptions) (mongo.Session, error) { + return sess, nil + }, + ).Reset() + err := NewClientProxy("trpc.mongodb.app.demo").Transaction(context.Background(), + func(sc mongo.SessionContext) error { + return nil + }, nil) + assert.Nil(t, err) + + err = NewClientProxy("trpc.mongodb.app.demo").Transaction(context.Background(), + func(sc mongo.SessionContext) error { + return errs.New(-1, "failed test") + }, nil) + assert.NotNil(t, err) + + }) + + t.Run("Transaction fail", func(t *testing.T) { + client, _ := mongo.Connect(trpc.BackgroundContext(), + options.Client().ApplyURI("mongodb://127.0.0.1:27017")) + assert.NotNil(t, client) + + sess, _ := client.StartSession() + + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Client)), "StartSession", + func(coll *mongo.Client, opts ...*options.SessionOptions) (mongo.Session, error) { + return sess, nil + }, + ).Reset() + err := NewClientProxy("trpc.mongodb.app.demo").Transaction(context.Background(), + func(sc mongo.SessionContext) error { + return nil + }, nil) + assert.Nil(t, err) + err = NewClientProxy("trpc.mongodb.app.demo").Transaction(context.Background(), + func(sc mongo.SessionContext) error { + return errs.New(-1, "failed test") + }, nil) + assert.NotNil(t, err) + }) +} + +func TestDelete(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(ClientTransport)), "GetMgoClient", + func(transport *ClientTransport, ctx context.Context, dsn string) (*mongo.Client, error) { + return &mongo.Client{}, nil + }, + ).Reset() + t.Run("DeleteOne succ", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "DeleteOne", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) { + return nil, nil + }, + ).Reset() + _, err := NewClientProxy("trpc.mongodb.app.demo").DeleteOne(context.Background(), "demo", "test", nil) + assert.Nil(t, err) + }) + t.Run("DeleteOne fail", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "DeleteOne", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) { + return nil, errors.New(fakeError) + }, + ).Reset() + _, err := NewClientProxy("trpc.mongodb.app.demo").DeleteOne(context.Background(), "demo", "test", nil) + assert.NotNil(t, err) + }) + t.Run("DeleteMany succ", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "DeleteMany", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) { + return nil, nil + }, + ).Reset() + _, err := NewClientProxy("trpc.mongodb.app.demo").DeleteMany(context.Background(), "demo", "test", nil) + assert.Nil(t, err) + }) + t.Run("DeleteMany fail", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "DeleteMany", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) { + return nil, errors.New(fakeError) + }, + ).Reset() + _, err := NewClientProxy("trpc.mongodb.app.demo").DeleteMany(context.Background(), "demo", "test", nil) + assert.NotNil(t, err) + }) +} + +func TestUpdate(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(ClientTransport)), "GetMgoClient", + func(transport *ClientTransport, ctx context.Context, dsn string) (*mongo.Client, error) { + return &mongo.Client{}, nil + }, + ).Reset() + t.Run("UpdateOne succ", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "UpdateOne", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, update interface{}, + opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { + return nil, nil + }, + ).Reset() + _, err := NewClientProxy("trpc.mongodb.app.demo").UpdateOne(context.Background(), "demo", "test", nil, nil) + assert.Nil(t, err) + }) + t.Run("UpdateOne fail", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "UpdateOne", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, update interface{}, + opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { + return nil, errors.New(fakeError) + }, + ).Reset() + _, err := NewClientProxy("trpc.mongodb.app.demo").UpdateOne(context.Background(), "demo", "test", nil, nil) + assert.NotNil(t, err) + }) + t.Run("UpdateMany succ", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "UpdateMany", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, update interface{}, + opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { + return nil, nil + }, + ).Reset() + _, err := NewClientProxy("trpc.mongodb.app.demo").UpdateMany(context.Background(), "demo", "test", nil, nil) + assert.Nil(t, err) + }) + t.Run("UpdateMany fail", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "UpdateMany", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, update interface{}, + opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { + return nil, errors.New(fakeError) + }, + ).Reset() + _, err := NewClientProxy("trpc.mongodb.app.demo").UpdateMany(context.Background(), "demo", "test", nil, nil) + assert.NotNil(t, err) + }) +} + +func TestFind(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(ClientTransport)), "GetMgoClient", + func(transport *ClientTransport, ctx context.Context, dsn string) (*mongo.Client, error) { + return &mongo.Client{}, nil + }, + ).Reset() + t.Run("Find succ", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "Find", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.FindOptions) (*mongo.Cursor, error) { + return nil, nil + }, + ).Reset() + _, err := NewClientProxy("trpc.mongodb.app.demo").Find(context.Background(), "demo", "test", nil) + assert.Nil(t, err) + }) + t.Run("Find fail", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "Find", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.FindOptions) (*mongo.Cursor, error) { + return nil, errors.New(fakeError) + }, + ).Reset() + _, err := NewClientProxy("trpc.mongodb.app.demo").Find(context.Background(), "demo", "test", nil) + assert.NotNil(t, err) + }) + t.Run("FindOne succ", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "FindOne", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.FindOneOptions) *mongo.SingleResult { + return nil + }, + ).Reset() + rst := NewClientProxy("trpc.mongodb.app.demo").FindOne(context.Background(), "demo", "test", nil) + assert.Nil(t, rst) + }) + t.Run("FindOne fail", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(ClientTransport)), "RoundTrip", + func(ct *ClientTransport, ctx context.Context, _ []byte, callOpts ...transport.RoundTripOption) ([]byte, + error) { + return nil, errors.New(fakeError) + }, + ).Reset() + rst := NewClientProxy("trpc.mongodb.app.demo").FindOne(context.Background(), "demo", "test", nil) + assert.Equal(t, rst.Err().Error(), fakeError) + }) + t.Run("FindOneAndDelete succ", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "FindOneAndDelete", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.FindOneAndDeleteOptions) *mongo.SingleResult { + return nil + }, + ).Reset() + rst := NewClientProxy("trpc.mongodb.app.demo").FindOneAndDelete(context.Background(), "demo", "test", nil) + assert.Nil(t, rst) + }) + t.Run("FindOneAndDelete fail", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(ClientTransport)), "RoundTrip", + func(ct *ClientTransport, ctx context.Context, _ []byte, callOpts ...transport.RoundTripOption) ([]byte, + error) { + return nil, errors.New(fakeError) + }, + ).Reset() + rst := NewClientProxy("trpc.mongodb.app.demo").FindOneAndDelete(context.Background(), "demo", "test", nil) + assert.Equal(t, rst.Err().Error(), fakeError) + }) + t.Run("FindOneAndReplace succ", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "FindOneAndReplace", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + replacement interface{}, opts ...*options.FindOneAndReplaceOptions) *mongo.SingleResult { + return nil + }, + ).Reset() + rst := NewClientProxy("trpc.mongodb.app.demo").FindOneAndReplace(context.Background(), "demo", "test", nil, nil) + assert.Nil(t, rst) + }) + t.Run("FindOneAndReplace fail", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(ClientTransport)), "RoundTrip", + func(ct *ClientTransport, ctx context.Context, _ []byte, callOpts ...transport.RoundTripOption) ([]byte, + error) { + return nil, errors.New(fakeError) + }, + ).Reset() + rst := NewClientProxy("trpc.mongodb.app.demo").FindOneAndReplace(context.Background(), "demo", "test", nil, nil) + assert.Equal(t, rst.Err().Error(), fakeError) + }) + t.Run("FindOneAndUpdate succ", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "FindOneAndUpdate", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + replacement interface{}, opts ...*options.FindOneAndUpdateOptions) *mongo.SingleResult { + return nil + }, + ).Reset() + rst := NewClientProxy("trpc.mongodb.app.demo").FindOneAndUpdate(context.Background(), "demo", "test", nil, nil) + assert.Nil(t, rst) + }) + t.Run("FindOneAndUpdate fail", func(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(ClientTransport)), "RoundTrip", + func(ct *ClientTransport, ctx context.Context, _ []byte, callOpts ...transport.RoundTripOption) ([]byte, + error) { + return nil, errors.New(fakeError) + }, + ).Reset() + rst := NewClientProxy("trpc.mongodb.app.demo").FindOneAndUpdate(context.Background(), "demo", "test", nil, nil) + assert.Equal(t, rst.Err().Error(), fakeError) + }) + +} +func TestUnitMongodbCli_DriverProxy(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(ClientTransport)), "GetMgoClient", + func(transport *ClientTransport, ctx context.Context, dsn string) (*mongo.Client, error) { + return &mongo.Client{}, nil + }, + ).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "BulkWrite", + func(coll *mongo.Collection, ctx context.Context, models []mongo.WriteModel, + opts ...*options.BulkWriteOptions) (*mongo.BulkWriteResult, error) { + return nil, nil + }, + ).Reset() + _, err := NewClientProxy("trpc.mongodb.app.demo").BulkWrite(context.Background(), "demo", "test", nil) + assert.Nil(t, err) + + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "ReplaceOne", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + replacement interface{}, opts ...*options.ReplaceOptions) (*mongo.UpdateResult, error) { + return nil, nil + }, + ).Reset() + _, err = NewClientProxy("trpc.mongodb.app.demo").ReplaceOne(context.Background(), "demo", "test", nil, nil) + assert.Nil(t, err) + + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "Aggregate", + func(coll *mongo.Collection, ctx context.Context, pipeline interface{}, + opts ...*options.AggregateOptions) (*mongo.Cursor, error) { + return nil, nil + }, + ).Reset() + _, err = NewClientProxy("trpc.mongodb.app.demo").Aggregate(context.Background(), "demo", "test", nil) + assert.Nil(t, err) + + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "CountDocuments", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.CountOptions) (int64, error) { + return 0, nil + }, + ).Reset() + _, err = NewClientProxy("trpc.mongodb.app.demo").CountDocuments(context.Background(), "demo", "test", nil) + assert.Nil(t, err) + + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "EstimatedDocumentCount", + func(coll *mongo.Collection, ctx context.Context, + opts ...*options.EstimatedDocumentCountOptions) (int64, error) { + return 0, nil + }, + ).Reset() + _, err = NewClientProxy("trpc.mongodb.app.demo").EstimatedDocumentCount(context.Background(), "demo", "test", nil) + assert.Nil(t, err) + + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "Distinct", + func(coll *mongo.Collection, ctx context.Context, fieldName string, filter interface{}, + opts ...*options.DistinctOptions) ([]interface{}, error) { + return nil, nil + }, + ).Reset() + _, err = NewClientProxy("trpc.mongodb.app.demo").Distinct(context.Background(), "demo", "test", "", nil) + assert.Nil(t, err) + + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Client)), "Watch", + func(coll *mongo.Client, ctx context.Context, pipeline interface{}, + opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) { + return nil, nil + }, + ).Reset() + _, err = NewClientProxy("trpc.mongodb.app.demo").Watch(context.Background(), nil, nil) + assert.Nil(t, err) + + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Database)), "Watch", + func(coll *mongo.Database, ctx context.Context, pipeline interface{}, + opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) { + return nil, nil + }, + ).Reset() + _, err = NewClientProxy("trpc.mongodb.app.demo").WatchDatabase(context.Background(), "demo", nil, nil) + assert.Nil(t, err) + + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "Watch", + func(coll *mongo.Collection, ctx context.Context, pipeline interface{}, + opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) { + return nil, nil + }, + ).Reset() + _, err = NewClientProxy("trpc.mongodb.app.demo").WatchCollection(context.Background(), "demo", "test", nil, nil) + assert.Nil(t, err) + + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Client)), "Disconnect", + func(coll *mongo.Client, ctx context.Context) error { + return nil + }, + ).Reset() + err = NewClientProxy("trpc.mongodb.app.demo").Disconnect(context.Background()) + assert.Nil(t, err) +} + +// Test the watch interface. +func TestUnitMongodbCli_Watch(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockCli := mockclient.NewMockClient(ctrl) + mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).Return(errors.New("cli err")).AnyTimes() + + cli := &mongodbCli{ + Client: mockCli, + } + _, err := cli.Watch(context.Background(), nil, nil) + assert.NotNil(t, err) + + _, err = cli.WatchDatabase(context.Background(), "demo", nil, nil) + assert.NotNil(t, err) + + _, err = cli.WatchCollection(context.Background(), "demo", "test", nil, nil) + assert.NotNil(t, err) + +} + +func TestUnitMongodbCli_Do(t *testing.T) { + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockCli := mockclient.NewMockClient(ctrl) + m1 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("err")) + m2 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = "result" + return nil + }) + gomock.InOrder(m1, m2) + + type fields struct { + ServiceName string + Client client.Client + opts []client.Option + } + type args struct { + ctx context.Context + cmd string + database string + collection string + args map[string]interface{} + } + tests := []struct { + name string + fields fields + args args + want interface{} + wantErr bool + }{ + {"err", fields{Client: mockCli}, + args{ctx: context.Background()}, nil, true}, + {"succ", fields{Client: mockCli}, + args{ctx: context.Background()}, "result", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &mongodbCli{ + ServiceName: tt.fields.ServiceName, + Client: tt.fields.Client, + opts: tt.fields.opts, + } + got, err := c.Do(tt.args.ctx, tt.args.cmd, tt.args.database, tt.args.collection, tt.args.args) + if (err != nil) != tt.wantErr { + t.Errorf("Do() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Do() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_mongodbCli_CreateMany(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockCli := mockclient.NewMockClient(ctrl) + m1 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("err")) + m2 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = []string{"index_1", "index_2"} + return nil + }) + m3 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = "not match kind" + return nil + }) + gomock.InOrder(m1, m2, m3) + + type fields struct { + ServiceName string + Client client.Client + opts []client.Option + } + type args struct { + ctx context.Context + database string + collection string + models []mongo.IndexModel + opts []*options.CreateIndexesOptions + } + tests := []struct { + name string + fields fields + args args + want []string + wantErr bool + }{ + {"err", fields{Client: mockCli}, + args{ctx: context.Background()}, nil, true}, + {"succ", fields{Client: mockCli}, + args{ctx: context.Background()}, []string{"index_1", "index_2"}, false}, + {"not match kind", fields{Client: mockCli}, + args{ctx: context.Background()}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &mongodbCli{ + ServiceName: tt.fields.ServiceName, + Client: tt.fields.Client, + opts: tt.fields.opts, + } + got, err := c.CreateMany(tt.args.ctx, tt.args.database, tt.args.collection, tt.args.models, tt.args.opts...) + if (err != nil) != tt.wantErr { + t.Errorf("CreateMany() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("CreateMany() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_mongodbCli_RunCommand(t *testing.T) { + patch := gomonkey.ApplyMethod(reflect.TypeOf(new(ClientTransport)), "GetMgoClient", + func(transport *ClientTransport, ctx context.Context, dsn string) (*mongo.Client, error) { + return &mongo.Client{}, nil + }, + ) + defer patch.Reset() + + t.Run("RunCommand succ", func(t *testing.T) { + patch2 := gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Database)), "RunCommand", + // There are more than 126 columns here, + // which does not meet the specification and blocks the upload of the code, + // and the line break is processed. + func(db *mongo.Database, ctx context.Context, + runCommand interface{}, opts ...*options.RunCmdOptions) *mongo.SingleResult { + return &mongo.SingleResult{} + }, + ) + defer patch2.Reset() + rst := NewClientProxy("trpc.mongodb.app.demo").RunCommand(context.Background(), "admin", nil) + assert.NotNil(t, rst) + }) + t.Run("RunCommand fail", func(t *testing.T) { + patch3 := gomonkey.ApplyMethod(reflect.TypeOf(new(ClientTransport)), "RoundTrip", + func(ct *ClientTransport, ctx context.Context, _ []byte, callOpts ...transport.RoundTripOption) ([]byte, + error) { + return nil, errors.New(fakeError) + }, + ) + defer patch3.Reset() + rst := NewClientProxy("trpc.mongodb.app.demo").RunCommand(context.Background(), "admin", nil) + assert.Equal(t, rst.Err().Error(), fakeError) + }) + +} + +func Test_mongodbCli_CreateOne(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockCli := mockclient.NewMockClient(ctrl) + m1 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("err")) + m2 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = "create_one" + return nil + }) + m3 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = 123 + return nil + }) + gomock.InOrder(m1, m2, m3) + type fields struct { + ServiceName string + Client client.Client + opts []client.Option + } + type args struct { + ctx context.Context + database string + collection string + model mongo.IndexModel + opts []*options.CreateIndexesOptions + } + tests := []struct { + name string + fields fields + args args + want string + wantErr bool + }{ + {"err", fields{Client: mockCli}, + args{ctx: context.Background()}, "", true}, + {"succ", fields{Client: mockCli}, + args{ctx: context.Background()}, "create_one", false}, + {"not match kind", fields{Client: mockCli}, + args{ctx: context.Background()}, "", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &mongodbCli{ + ServiceName: tt.fields.ServiceName, + Client: tt.fields.Client, + opts: tt.fields.opts, + } + got, err := c.CreateOne(tt.args.ctx, tt.args.database, tt.args.collection, tt.args.model, tt.args.opts...) + if (err != nil) != tt.wantErr { + t.Errorf("CreateOne() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("CreateOne() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_mongodbCli_DropOne(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockCli := mockclient.NewMockClient(ctrl) + m1 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("err")) + m2 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = bson.Raw("drop_one") + return nil + }) + m3 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = 123 + return nil + }) + gomock.InOrder(m1, m2, m3) + type fields struct { + ServiceName string + Client client.Client + opts []client.Option + } + type args struct { + ctx context.Context + database string + collection string + name string + opts []*options.DropIndexesOptions + } + tests := []struct { + name string + fields fields + args args + want bson.Raw + wantErr bool + }{ + {"err", fields{Client: mockCli}, + args{ctx: context.Background()}, nil, true}, + {"succ", fields{Client: mockCli}, + args{ctx: context.Background()}, bson.Raw("drop_one"), false}, + {"not match kind", fields{Client: mockCli}, + args{ctx: context.Background()}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &mongodbCli{ + ServiceName: tt.fields.ServiceName, + Client: tt.fields.Client, + opts: tt.fields.opts, + } + got, err := c.DropOne(tt.args.ctx, tt.args.database, tt.args.collection, tt.args.name, tt.args.opts...) + if (err != nil) != tt.wantErr { + t.Errorf("DropOne() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("DropOne() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_mongodbCli_DropAll(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockCli := mockclient.NewMockClient(ctrl) + m1 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("err")) + m2 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = bson.Raw("drop_all") + return nil + }) + m3 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = 123 + return nil + }) + type fields struct { + ServiceName string + Client client.Client + opts []client.Option + } + type args struct { + ctx context.Context + database string + collection string + opts []*options.DropIndexesOptions + } + gomock.InOrder(m1, m2, m3) + tests := []struct { + name string + fields fields + args args + want bson.Raw + wantErr bool + }{ + {"err", fields{Client: mockCli}, + args{ctx: context.Background()}, nil, true}, + {"succ", fields{Client: mockCli}, + args{ctx: context.Background()}, bson.Raw("drop_all"), false}, + {"not match kind", fields{Client: mockCli}, + args{ctx: context.Background()}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &mongodbCli{ + ServiceName: tt.fields.ServiceName, + Client: tt.fields.Client, + opts: tt.fields.opts, + } + got, err := c.DropAll(tt.args.ctx, tt.args.database, tt.args.collection, tt.args.opts...) + if (err != nil) != tt.wantErr { + t.Errorf("DropAll() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("DropAll() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_mongodbCli_Collection(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockCli := mockclient.NewMockClient(ctrl) + m1 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("err")) + m2 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = &mongo.Collection{} + return nil + }) + m3 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = 123 + return nil + }) + gomock.InOrder(m1, m2, m3) + + type fields struct { + ServiceName string + Client client.Client + opts []client.Option + } + type args struct { + database string + collection string + } + tests := []struct { + name string + fields fields + args args + want *mongo.Collection + wantErr bool + }{ + {"err", fields{Client: mockCli}, args{}, nil, true}, + {"suc", fields{Client: mockCli}, args{}, &mongo.Collection{}, false}, + {"not match kind", fields{Client: mockCli}, args{}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &mongodbCli{ + ServiceName: tt.fields.ServiceName, + Client: tt.fields.Client, + opts: tt.fields.opts, + } + got, err := c.Collection(trpc.BackgroundContext(), tt.args.database, tt.args.collection) + if (err != nil) != tt.wantErr { + t.Errorf("Collection() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Collection() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_mongodbCli_Database(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockCli := mockclient.NewMockClient(ctrl) + m1 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("err")) + m2 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = &mongo.Database{} + return nil + }) + m3 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = 123 + return nil + }) + gomock.InOrder(m1, m2, m3) + + type fields struct { + ServiceName string + Client client.Client + opts []client.Option + } + type args struct { + database string + } + tests := []struct { + name string + fields fields + args args + want *mongo.Database + wantErr bool + }{ + {"err", fields{Client: mockCli}, args{}, nil, true}, + {"suc", fields{Client: mockCli}, args{}, &mongo.Database{}, false}, + {"not match kind", fields{Client: mockCli}, args{}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &mongodbCli{ + ServiceName: tt.fields.ServiceName, + Client: tt.fields.Client, + opts: tt.fields.opts, + } + got, err := c.Database(trpc.BackgroundContext(), tt.args.database) + if (err != nil) != tt.wantErr { + t.Errorf("Database() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Database() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_mongodbCli_Indexes(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockCli := mockclient.NewMockClient(ctrl) + m1 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("err")) + m2 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = mongo.IndexView{} + return nil + }) + m3 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = 123 + return nil + }) + gomock.InOrder(m1, m2, m3) + + type fields struct { + ServiceName string + Client client.Client + opts []client.Option + } + type args struct { + database string + collection string + } + tests := []struct { + name string + fields fields + args args + want mongo.IndexView + wantErr bool + }{ + {"err", fields{Client: mockCli}, args{}, mongo.IndexView{}, true}, + {"suc", fields{Client: mockCli}, args{}, mongo.IndexView{}, false}, + {"not match kind", fields{Client: mockCli}, args{}, mongo.IndexView{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &mongodbCli{ + ServiceName: tt.fields.ServiceName, + Client: tt.fields.Client, + opts: tt.fields.opts, + } + got, err := c.Indexes(trpc.BackgroundContext(), tt.args.database, tt.args.collection) + if (err != nil) != tt.wantErr { + t.Errorf("Indexes() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Indexes() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_mongodbCli_StartSession(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + mockCli := mockclient.NewMockClient(ctrl) + m1 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("err")) + m3 := mockCli.EXPECT().Invoke(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( + func(ctx context.Context, reqbody interface{}, rspbody interface{}, opt ...client.Option) error { + r, _ := rspbody.(*Response) + r.Result = 123 + return nil + }) + gomock.InOrder(m1, m3) + + type fields struct { + ServiceName string + Client client.Client + opts []client.Option + } + tests := []struct { + name string + fields fields + want mongo.Session + wantErr bool + }{ + {"err", fields{Client: mockCli}, nil, true}, + {"not match kind", fields{Client: mockCli}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &mongodbCli{ + ServiceName: tt.fields.ServiceName, + Client: tt.fields.Client, + opts: tt.fields.opts, + } + got, err := c.StartSession(trpc.BackgroundContext()) + if (err != nil) != tt.wantErr { + t.Errorf("StartSession() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("StartSession() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewInsertOneModel(t *testing.T) { + mc := gomonkey.ApplyFunc(mongo.NewInsertOneModel, func() *mongo.InsertOneModel { return &mongo.InsertOneModel{} }) + defer mc.Reset() + + tests := []struct { + name string + want *mongo.InsertOneModel + }{ + {"suc", &mongo.InsertOneModel{}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewInsertOneModel(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewInsertOneModel() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewReplaceOneModel(t *testing.T) { + mc := gomonkey.ApplyFunc(mongo.NewReplaceOneModel, func() *mongo.ReplaceOneModel { return &mongo.ReplaceOneModel{} }) + defer mc.Reset() + + tests := []struct { + name string + want *mongo.ReplaceOneModel + }{ + {"suc", &mongo.ReplaceOneModel{}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewReplaceOneModel(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewReplaceOneModel() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewSessionContext(t *testing.T) { + type args struct { + ctx context.Context + sess mongo.Session + } + tests := []struct { + name string + args args + want mongo.SessionContext + }{ + {"suc", args{ctx: trpc.BackgroundContext(), sess: nil}, mongo.NewSessionContext(trpc.BackgroundContext(), nil)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewSessionContext(tt.args.ctx, tt.args.sess); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewSessionContext() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewUpdateManyModel(t *testing.T) { + mc := gomonkey.ApplyFunc(mongo.NewUpdateManyModel, func() *mongo.UpdateManyModel { return &mongo.UpdateManyModel{} }) + defer mc.Reset() + + tests := []struct { + name string + want *mongo.UpdateManyModel + }{ + {"suc", &mongo.UpdateManyModel{}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewUpdateManyModel(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewUpdateManyModel() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestNewUpdateOneModel(t *testing.T) { + mc := gomonkey.ApplyFunc(mongo.NewUpdateOneModel, func() *mongo.UpdateOneModel { return &mongo.UpdateOneModel{} }) + defer mc.Reset() + + tests := []struct { + name string + want *mongo.UpdateOneModel + }{ + {"suc", &mongo.UpdateOneModel{}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := NewUpdateOneModel(); !reflect.DeepEqual(got, tt.want) { + t.Errorf("NewUpdateOneModel() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/mongodb/codec.go b/mongodb/codec.go new file mode 100644 index 0000000..3113267 --- /dev/null +++ b/mongodb/codec.go @@ -0,0 +1,38 @@ +package mongodb + +import ( + "fmt" + "os" + "path" + + "trpc.group/trpc-go/trpc-go/codec" +) + +func init() { + codec.Register("mongodb", nil, defaultClientCodec) +} + +// default codec +var ( + defaultClientCodec = &ClientCodec{} +) + +// ClientCodec decodes mongodb client request. +type ClientCodec struct{} + +// Encode sets the metadata requested by the mongodb client. +func (c *ClientCodec) Encode(msg codec.Msg, _ []byte) ([]byte, error) { + + //Itself. + if msg.CallerServiceName() == "" { + msg.WithCallerServiceName(fmt.Sprintf("trpc.mongodb.%s.service", path.Base(os.Args[0]))) + } + + return nil, nil +} + +// Decode parses the metadata in the mongodb client return packet. +func (c *ClientCodec) Decode(msg codec.Msg, _ []byte) ([]byte, error) { + + return nil, nil +} diff --git a/mongodb/codec_test.go b/mongodb/codec_test.go new file mode 100644 index 0000000..18e6949 --- /dev/null +++ b/mongodb/codec_test.go @@ -0,0 +1,65 @@ +package mongodb + +import ( + "context" + "reflect" + "testing" + + "trpc.group/trpc-go/trpc-go/codec" +) + +func TestUnitClientCodec_Decode(t *testing.T) { + type args struct { + msg codec.Msg + in1 []byte + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + {"succ", args{}, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &ClientCodec{} + got, err := c.Decode(tt.args.msg, tt.args.in1) + if (err != nil) != tt.wantErr { + t.Errorf("Decode() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Decode() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnitClientCodec_Encode(t *testing.T) { + type args struct { + msg codec.Msg + in1 []byte + } + tests := []struct { + name string + args args + want []byte + wantErr bool + }{ + {"succ", args{msg: codec.Message(context.Background())}, nil, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &ClientCodec{} + got, err := c.Encode(tt.args.msg, tt.args.in1) + if (err != nil) != tt.wantErr { + t.Errorf("Encode() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Encode() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/mongodb/curd.go b/mongodb/curd.go new file mode 100644 index 0000000..17537f1 --- /dev/null +++ b/mongodb/curd.go @@ -0,0 +1,718 @@ +package mongodb + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +// mongo operation instructions +const ( + mongoArgFilter = "filter" + mongoArgSort = "sort" + mongoArgSkip = "skip" + mongoArgLimit = "limit" + mongoArgProjection = "projection" + mongoArgBatchSize = "batchSize" + mongoArgCollation = "collation" + mongoArgUpdate = "update" + mongoArgUpSert = "upsert" + mongoArgDoc = "returnDocument" + mongoArgFieldName = "fieldName" + mongoArgArrayFilters = "arrayFilters" + mongoArgPipeline = "pipeline" + mongoArgMaxTimeMS = "maxTimeMS" + mongoArgAllowDiskUse = "allowDiskUse" + mongoArgDropIndexesOptions = "DropIndexesOptions" + mongoArgCreateIndexesOptions = "CreateIndexesOptions" + mongoArgIndexModels = "IndexModels" + mongoArgIndexModel = "IndexModel" + mongoArgName = "name" +) + +func executeFind(ctx context.Context, coll *mongo.Collection, + args map[string]interface{}) ([]map[string]interface{}, error) { + + filter, opts := handleArgsForExecuteFind(args) + + cur, err := coll.Find(ctx, filter, opts) + if err != nil { + return nil, err + } + defer cur.Close(ctx) + + result := []map[string]interface{}{} + for cur.Next(ctx) { + temp := map[string]interface{}{} + if err := bson.Unmarshal(cur.Current, &temp); err != nil { + return result, err + } + result = append(result, temp) + } + return result, nil +} + +// executeFindC returns the cursor type, and uses cursor.All/Decode to parse to the structure. +func executeFindC(ctx context.Context, coll *mongo.Collection, + args map[string]interface{}) (*mongo.Cursor, error) { + + filter, opts := handleArgsForExecuteFind(args) + return coll.Find(ctx, filter, opts) + +} + +// handleArgsForExecuteFind handles args for executeFind and executeFindC. +func handleArgsForExecuteFind(args map[string]interface{}) (filter map[string]interface{}, + opts *options.FindOptions) { + opts = options.Find() + for name, opt := range args { + switch name { + case mongoArgFilter: + if v, ok := opt.(map[string]interface{}); ok { + filter = v + } + case mongoArgSort: + if v, ok := opt.(map[string]interface{}); ok { + opts = opts.SetSort(v) + } + if v, ok := opt.(bson.D); ok { + opts = opts.SetSort(v) + } + case mongoArgSkip: + if v, ok := opt.(float64); ok { + opts = opts.SetSkip(int64(v)) + } + case mongoArgLimit: + if v, ok := opt.(float64); ok { + opts = opts.SetLimit(int64(v)) + } + case mongoArgProjection: + if v, ok := opt.(map[string]interface{}); ok { + opts = opts.SetProjection(v) + } + case mongoArgBatchSize: + if v, ok := opt.(float64); ok { + opts = opts.SetBatchSize(int32(v)) + } + case mongoArgCollation: + if v, ok := opt.(map[string]interface{}); ok { + opts = opts.SetCollation(collationFromMap(v)) + } + default: + } + } + return +} + +func executeDeleteOne(ctx context.Context, coll *mongo.Collection, + args map[string]interface{}) (map[string]interface{}, error) { + opts := options.Delete() + var filter map[string]interface{} + for name, opt := range args { + switch name { + case mongoArgFilter: + if v, ok := opt.(map[string]interface{}); ok { + filter = v + } + case mongoArgCollation: + if v, ok := opt.(map[string]interface{}); ok { + opts = opts.SetCollation(collationFromMap(v)) + } + default: + } + } + res, err := coll.DeleteOne(ctx, filter, opts) + if err != nil { + return nil, err + } + m := map[string]interface{}{} + j, _ := json.Marshal(res) + if err := json.Unmarshal(j, &m); err != nil { + return nil, err + } + return m, nil +} + +func executeDeleteMany(ctx context.Context, coll *mongo.Collection, + margs map[string]interface{}) (map[string]interface{}, error) { + mopts := options.Delete() + var filter map[string]interface{} + for name, opt := range margs { + switch name { + case mongoArgFilter: + if v, ok := opt.(map[string]interface{}); ok { + filter = v + } + case mongoArgCollation: + if v, ok := opt.(map[string]interface{}); ok { + mopts = mopts.SetCollation(collationFromMap(v)) + } + default: + } + } + res, err := coll.DeleteMany(ctx, filter, mopts) + if err != nil { + return nil, err + } + m := map[string]interface{}{} + j, _ := json.Marshal(res) + if err := json.Unmarshal(j, &m); err != nil { + return nil, err + } + return m, nil +} + +func executeFindOneAndDelete(ctx context.Context, coll *mongo.Collection, + args map[string]interface{}) (map[string]interface{}, error) { + opts := options.FindOneAndDelete() + var filter map[string]interface{} + for name, opt := range args { + switch name { + case mongoArgFilter: + if v, ok := opt.(map[string]interface{}); ok { + filter = v + } + case mongoArgSort: + if v, ok := opt.(map[string]interface{}); ok { + opts = opts.SetSort(v) + } + if v, ok := opt.(bson.D); ok { + opts = opts.SetSort(v) + } + case mongoArgProjection: + if v, ok := opt.(map[string]interface{}); ok { + opts = opts.SetProjection(v) + } + case mongoArgCollation: + if v, ok := opt.(map[string]interface{}); ok { + opts = opts.SetCollation(collationFromMap(v)) + } + default: + } + } + cur := coll.FindOneAndDelete(ctx, filter, opts) + + result := map[string]interface{}{} + + if err := cur.Decode(&result); err != nil { + return nil, err + } + return result, nil +} + +func executeFindOneAndUpdate(ctx context.Context, coll *mongo.Collection, + args map[string]interface{}) (map[string]interface{}, error) { + + cur, _ := executeFindOneAndUpdateS(ctx, coll, args) + + result := map[string]interface{}{} + + if err := cur.Decode(&result); err != nil { + return nil, err + } + return result, nil +} + +// executeFindOneAndUpdateC returns mongo.SingleResult type, and uses Decode to parse to structure. +func executeFindOneAndUpdateS(ctx context.Context, coll *mongo.Collection, + args map[string]interface{}) (*mongo.SingleResult, error) { + filter, update, fupdatePipe, fopts := handleArgsForOneAndUpdateS(args) + var cur *mongo.SingleResult + if fupdatePipe != nil { + cur = coll.FindOneAndUpdate(ctx, filter, fupdatePipe, fopts) + } else { + cur = coll.FindOneAndUpdate(ctx, filter, update, fopts) + } + + return cur, nil +} + +// handleArgsForOneAndUpdateS is an auxiliary function that handles the args passed in executeFindOneAndUpdateS. +func handleArgsForOneAndUpdateS(args map[string]interface{}) (filter map[string]interface{}, + update map[string]interface{}, fupdatePipe []interface{}, fopts *options.FindOneAndUpdateOptions) { + fopts = options.FindOneAndUpdate() + for name, opt := range args { + switch name { + case mongoArgFilter: + if v, ok := opt.(map[string]interface{}); ok { + filter = v + } + case mongoArgUpdate: + var ok bool + update, ok = opt.(map[string]interface{}) + if !ok { + if v, ok := opt.([]interface{}); ok { + fupdatePipe = v + } + } + case mongoArgArrayFilters: + if v, ok := opt.([]interface{}); ok { + fopts = fopts.SetArrayFilters(options.ArrayFilters{ + Filters: v, + }) + } + case mongoArgSort: + if v, ok := opt.(map[string]interface{}); ok { + fopts = fopts.SetSort(v) + } + if v, ok := opt.(bson.D); ok { + fopts = fopts.SetSort(v) + } + case mongoArgProjection: + if v, ok := opt.(map[string]interface{}); ok { + fopts = fopts.SetProjection(v) + } + case mongoArgUpSert: + if v, ok := opt.(bool); ok { + fopts = fopts.SetUpsert(v) + } + case mongoArgDoc: + if v, ok := opt.(string); ok { + fopts = setReturnDocument(fopts, v) + } + case mongoArgCollation: + if v, ok := opt.(map[string]interface{}); ok { + fopts = fopts.SetCollation(collationFromMap(v)) + } + default: + } + } + return +} + +func setReturnDocument(fopts *options.FindOneAndUpdateOptions, rdType string) *options.FindOneAndUpdateOptions { + switch rdType { + case "After": + fopts = fopts.SetReturnDocument(options.After) + case "Before": + fopts = fopts.SetReturnDocument(options.Before) + default: + // do nothing + } + return fopts +} + +func executeInsertOne(ctx context.Context, coll *mongo.Collection, + args map[string]interface{}) (map[string]interface{}, error) { + if _, ok := args["document"]; !ok { + return nil, fmt.Errorf("InsertOne args error,need key document") + } + + document, ok := args["document"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("InsertOne document type error,need map") + } + + res, err := coll.InsertOne(ctx, document) + if err != nil { + return nil, err + } + m := map[string]interface{}{} + j, _ := json.Marshal(res) + if err := json.Unmarshal(j, &m); err != nil { + return nil, err + } + return m, nil +} + +func executeInsertMany(ctx context.Context, coll *mongo.Collection, + args map[string]interface{}) (map[string]interface{}, error) { + + if _, ok := args["documents"]; !ok { + return nil, fmt.Errorf("InsertMany args error,need key document") + } + documents, ok := args["documents"].([]interface{}) + if !ok { + return nil, fmt.Errorf("InsertMany document type error,need slice") + } + for i, doc := range documents { + docM, ok := doc.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("InsertMany document element type error,need map[string]interface{}") + } + + documents[i] = docM + } + res, err := coll.InsertMany(ctx, documents) + if err != nil { + return nil, err + } + m := map[string]interface{}{} + j, _ := json.Marshal(res) + if err := json.Unmarshal(j, &m); err != nil { + return nil, err + } + return m, nil +} + +func executeUpdateOne(ctx context.Context, coll *mongo.Collection, + args map[string]interface{}) (map[string]interface{}, error) { + uopts := options.Update() + var ufilter map[string]interface{} + var update map[string]interface{} + var updatePipe []interface{} + var ok bool + for name, opt := range args { + switch name { + case mongoArgFilter: + if v, ok1 := opt.(map[string]interface{}); ok1 { + ufilter = v + } + case mongoArgUpdate: + update, ok = opt.(map[string]interface{}) + if !ok { + if v, ok := opt.([]interface{}); ok { + updatePipe = v + } + } + case mongoArgArrayFilters: + if v, ok := opt.([]interface{}); ok { + uopts = uopts.SetArrayFilters(options.ArrayFilters{ + Filters: v, + }) + } + case mongoArgUpSert: + if v, ok := opt.(bool); ok { + uopts = uopts.SetUpsert(v) + } + case mongoArgCollation: + if v, ok := opt.(map[string]interface{}); ok { + uopts = uopts.SetCollation(collationFromMap(v)) + } + default: + } + } + + var upRes *mongo.UpdateResult + var err error + if updatePipe != nil { + upRes, err = coll.UpdateOne(ctx, ufilter, updatePipe, uopts) + } else { + upRes, err = coll.UpdateOne(ctx, ufilter, update, uopts) + } + if err != nil { + return nil, err + } + m := map[string]interface{}{} + j, _ := json.Marshal(upRes) + if err := json.Unmarshal(j, &m); err != nil { + return nil, err + } + return m, nil +} + +func executeUpdateMany(ctx context.Context, coll *mongo.Collection, + args map[string]interface{}) (map[string]interface{}, error) { + opts := options.Update() + var filter map[string]interface{} + var update map[string]interface{} + var updatePipe []interface{} + var ok bool + for name, opt := range args { + switch name { + case mongoArgFilter: + if v, ok1 := opt.(map[string]interface{}); ok1 { + filter = v + } + case mongoArgUpdate: + update, ok = opt.(map[string]interface{}) + if !ok { + if v, ok := opt.([]interface{}); ok { + updatePipe = v + } + } + case mongoArgArrayFilters: + if v, ok := opt.([]interface{}); ok { + opts = opts.SetArrayFilters(options.ArrayFilters{ + Filters: v, + }) + } + case mongoArgUpSert: + if v, ok := opt.(bool); ok { + opts = opts.SetUpsert(v) + } + case mongoArgCollation: + if v, ok := opt.(map[string]interface{}); ok { + opts = opts.SetCollation(collationFromMap(v)) + } + default: + } + } + + var upRes *mongo.UpdateResult + var err error + if updatePipe != nil { + upRes, err = coll.UpdateMany(ctx, filter, updatePipe, opts) + } else { + upRes, err = coll.UpdateMany(ctx, filter, update, opts) + } + if err != nil { + return nil, err + } + m := map[string]interface{}{} + j, _ := json.Marshal(upRes) + if err := json.Unmarshal(j, &m); err != nil { + return nil, err + } + return m, nil +} + +func executeCount(ctx context.Context, coll *mongo.Collection, args map[string]interface{}) (int64, error) { + var filter map[string]interface{} + opts := options.Count() + for name, opt := range args { + switch name { + case mongoArgFilter: + if v, ok := opt.(map[string]interface{}); ok { + filter = v + } + case mongoArgSkip: + if v, ok := opt.(float64); ok { + opts = opts.SetSkip(int64(v)) + } + case mongoArgLimit: + if v, ok := opt.(float64); ok { + opts = opts.SetLimit(int64(v)) + } + case mongoArgCollation: + if v, ok := opt.(map[string]interface{}); ok { + opts = opts.SetCollation(collationFromMap(v)) + } + default: + } + } + return coll.CountDocuments(ctx, filter, opts) +} + +func executeDistinct(ctx context.Context, coll *mongo.Collection, + args map[string]interface{}) ([]interface{}, error) { + var fieldName string + var filter map[string]interface{} + opts := options.Distinct() + for name, opt := range args { + switch name { + case mongoArgFilter: + if v, ok := opt.(map[string]interface{}); ok { + filter = v + } + case mongoArgFieldName: + if v, ok := opt.(string); ok { + fieldName = v + } + case mongoArgCollation: + if v, ok := opt.(map[string]interface{}); ok { + opts = opts.SetCollation(collationFromMap(v)) + } + default: + } + } + return coll.Distinct(ctx, fieldName, filter, opts) +} + +func executeAggregate(ctx context.Context, coll *mongo.Collection, args map[string]interface{}) ( + []map[string]interface{}, error) { + + cur, err := executeAggregateC(ctx, coll, args) + if err != nil { + return nil, err + } + + result := []map[string]interface{}{} + for cur.Next(ctx) { + temp := map[string]interface{}{} + if err := bson.Unmarshal(cur.Current, &temp); err != nil { + return result, err + } + result = append(result, temp) + } + return result, nil +} + +func executeAggregateC(ctx context.Context, coll *mongo.Collection, args map[string]interface{}) ( + *mongo.Cursor, error) { + + var pipeline []interface{} + opts := options.Aggregate() + for name, opt := range args { + switch name { + + case mongoArgPipeline: + p, ok := opt.([]interface{}) + if !ok { + return nil, fmt.Errorf("Aggregate args error,args value need slice") + } + pipeline = p + case mongoArgBatchSize: + if v, ok := opt.(float64); ok { + opts = opts.SetBatchSize(int32(v)) + } + case mongoArgCollation: + if v, ok := opt.(map[string]interface{}); ok { + opts = opts.SetCollation(collationFromMap(v)) + } + case mongoArgMaxTimeMS: + if v, ok := opt.(float64); ok { + opts = opts.SetMaxTime(time.Duration(v) * time.Millisecond) + } + default: + } + } + return coll.Aggregate(ctx, pipeline, opts) +} +func collationFromMap(m map[string]interface{}) *options.Collation { + var collation options.Collation + + if locale, found := m["locale"]; found { + if v, ok := locale.(string); ok { + collation.Locale = v + } + } + + if caseLevel, found := m["caseLevel"]; found { + if v, ok := caseLevel.(bool); ok { + collation.CaseLevel = v + } + } + + if caseFirst, found := m["caseFirst"]; found { + if v, ok := caseFirst.(string); ok { + collation.CaseFirst = v + } + } + + if strength, found := m["strength"]; found { + if v, ok := strength.(float64); ok { + collation.Strength = int(v) + } + } + + if numericOrdering, found := m["numericOrdering"]; found { + if v, ok := numericOrdering.(bool); ok { + collation.NumericOrdering = v + } + } + + if alternate, found := m["alternate"]; found { + if v, ok := alternate.(string); ok { + collation.Alternate = v + } + } + + if maxVariable, found := m["maxVariable"]; found { + if v, ok := maxVariable.(string); ok { + collation.MaxVariable = v + } + } + + if normalization, found := m["normalization"]; found { + if v, ok := normalization.(bool); ok { + collation.Normalization = v + } + } + + if backwards, found := m["backwards"]; found { + if v, ok := backwards.(bool); ok { + collation.Backwards = v + } + } + + return &collation +} + +func executeBulkWrite(ctx context.Context, coll *mongo.Collection, args map[string]interface{}, +) (map[string]interface{}, error) { + if _, ok := args["documents"]; !ok { + return nil, fmt.Errorf("InsertMany args error,need key document") + } + operations, ok := args["documents"].([]mongo.WriteModel) + if !ok { + return nil, fmt.Errorf("InsertMany document type error,need slice") + } + bulkOption := &options.BulkWriteOptions{} + bulkOption.SetOrdered(false) + if optRaw, ok := args["BulkWriteOptions"]; ok { + if opt, ok := optRaw.(*options.BulkWriteOptions); ok { + bulkOption = opt + } + } + res, err := coll.BulkWrite(ctx, operations, bulkOption) + if err != nil { + return nil, err + } + m := map[string]interface{}{} + j, _ := json.Marshal(res) + if err := json.Unmarshal(j, &m); err != nil { + return nil, err + } + return m, nil +} + +func executeIndexCreateOne(ctx context.Context, iv mongo.IndexView, args map[string]interface{}) (string, error) { + indexModel, ok := args[mongoArgIndexModel] + if !ok { + return "", errors.New("CreateOne args error,need key IndexModel") + } + model, ok := indexModel.(mongo.IndexModel) + if !ok { + return "", errors.New("IndexModel document type error,need mongo.IndexModel") + } + createIndexesOption := &options.CreateIndexesOptions{} + if optRaw, ok := args[mongoArgCreateIndexesOptions]; ok { + if opt, ok := optRaw.(*options.CreateIndexesOptions); ok { + createIndexesOption = opt + } + } + return iv.CreateOne(ctx, model, createIndexesOption) +} + +func executeIndexCreateMany(ctx context.Context, iv mongo.IndexView, args map[string]interface{}) ([]string, error) { + indexModels, ok := args[mongoArgIndexModels] + if !ok { + return nil, errors.New("CreateMany args error,need key IndexModels") + } + models, ok := indexModels.([]mongo.IndexModel) + if !ok { + return nil, errors.New("IndexModel document type error,need []mongo.IndexModel") + } + createIndexesOption := &options.CreateIndexesOptions{} + if optRaw, ok := args[mongoArgCreateIndexesOptions]; ok { + if opt, ok := optRaw.(*options.CreateIndexesOptions); ok { + createIndexesOption = opt + } + } + return iv.CreateMany(ctx, models, createIndexesOption) +} + +func executeIndexDropOne(ctx context.Context, iv mongo.IndexView, args map[string]interface{}) (bson.Raw, error) { + nameObj, ok := args[mongoArgName] + if !ok { + return nil, errors.New("DropOn args error,need key name") + } + name, ok := nameObj.(string) + if !ok { + return nil, errors.New("IndexModel document type error,need string") + } + dropIndexesOption := &options.DropIndexesOptions{} + if optRaw, ok := args[mongoArgDropIndexesOptions]; ok { + if opt, ok := optRaw.(*options.DropIndexesOptions); ok { + dropIndexesOption = opt + } + } + return iv.DropOne(ctx, name, dropIndexesOption) +} + +func executeIndexDropAll(ctx context.Context, iv mongo.IndexView, args map[string]interface{}) (bson.Raw, error) { + dropIndexesOption := &options.DropIndexesOptions{} + if optRaw, ok := args[mongoArgDropIndexesOptions]; ok { + if opt, ok := optRaw.(*options.DropIndexesOptions); ok { + dropIndexesOption = opt + } + } + return iv.DropAll(ctx, dropIndexesOption) +} diff --git a/mongodb/curd_test.go b/mongodb/curd_test.go new file mode 100644 index 0000000..b6e7bb9 --- /dev/null +++ b/mongodb/curd_test.go @@ -0,0 +1,1073 @@ +package mongodb + +import ( + "context" + "fmt" + "reflect" + "testing" + + "github.com/agiledragon/gomonkey/v2" + "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/x/bsonx/bsoncore" + "trpc.group/trpc-go/trpc-go" +) + +var succArgs = map[string]interface{}{ + "filter": map[string]interface{}{ + "key": "123", + }, + "sort": map[string]interface{}{ + "order": "asc", + }, + "skip": 2.1, + "limit": 3.2, + "projection": map[string]interface{}{ + "key": "123", + }, + "batchSize": 4.2, + "collation": map[string]interface{}{ + "key": "123", + }, + "update": []interface{}{1, 2}, + "arrayFilters": []interface{}{1, 2}, + "upsert": true, + "returnDocument": "After", + "fieldName": "123", +} + +var sortArgs = map[string]interface{}{ + "sort": bson.D{{Key: "value", Value: 1}, {Key: "name", Value: -1}}, +} + +var errArgs = map[string]interface{}{ + "filter": 2, + "sort": 2, + "skip": 2, + "limit": 2, + "projection": 2, + "batchSize": 2, + "collation": 2, + "update": 2, + "arrayFilters": 2, + "upsert": 2, + "returnDocument": 2, + "fieldName": 123, +} + +func TestUnitExecuteFind(t *testing.T) { + ops := options.Find() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "Find", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.FindOptions) (*mongo.Cursor, error) { + if ctx == nil { + return nil, fmt.Errorf("err") + } + *ops = *opts[0] + return new(mongo.Cursor), nil + }, + ).Reset() + seqDoc := bsoncore.BuildDocument( + bsoncore.BuildDocument( + nil, + bsoncore.AppendDoubleElement(nil, "pi", 3.14159), + ), + bsoncore.AppendStringElement(nil, "hello", "world"), + ) + bs := &bsoncore.DocumentSequence{ + Style: bsoncore.SequenceStyle, + Data: seqDoc, + Pos: 0, + } + + var index int + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Cursor)), "Next", + func(cursor *mongo.Cursor, ctx context.Context) bool { + if index == 2 { + return false + } + d, _ := bs.Next() + cursor.Current = bson.Raw(d) + index++ + return true + }, + ).Reset() + + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Cursor)), "Close", + func(cursor *mongo.Cursor, ctx context.Context) error { + return nil + }, + ).Reset() + + type args struct { + ctx context.Context + coll *mongo.Collection + args map[string]interface{} + } + tests := []struct { + name string + args args + want []map[string]interface{} + opts *options.FindOptions + wantErr bool + }{ + {"err", args{ + ctx: trpc.BackgroundContext(), + coll: &mongo.Collection{}, + }, nil, options.Find(), true}, + {"args err", args{ + ctx: context.Background(), + coll: &mongo.Collection{}, + args: errArgs, + }, []map[string]interface{}{{ + "pi": 3.14159, + }, { + "hello": "world", + }}, &options.FindOptions{}, false}, + {"succ", args{ + ctx: context.Background(), + coll: &mongo.Collection{}, + args: succArgs, + }, []map[string]interface{}{}, &options.FindOptions{ + Sort: map[string]interface{}{ + "order": "asc", + }, + Skip: getInt64Point(int64(2)), + Limit: getInt64Point(int64(3)), + Projection: map[string]interface{}{ + "key": "123", + }, + BatchSize: getInt32Point(4), + Collation: new(options.Collation), + }, false}, + {"succ", args{ + ctx: context.Background(), + coll: &mongo.Collection{}, + args: sortArgs, + }, []map[string]interface{}{}, &options.FindOptions{ + Sort: bson.D{{Key: "value", Value: 1}, {Key: "name", Value: -1}}, + Skip: getInt64Point(int64(2)), + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _ = executeFindC(tt.args.ctx, tt.args.coll, tt.args.args) + _, err := executeFind(tt.args.ctx, tt.args.coll, tt.args.args) + assert.Nil(t, err) + }) + } +} + +func getInt64Point(i int64) *int64 { + return &i +} +func getInt32Point(i int32) *int32 { + return &i +} + +func TestUnitExecuteDeleteMany(t *testing.T) { + ops := options.Delete() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "DeleteMany", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) { + if ctx == nil { + return nil, fmt.Errorf("err") + } + *ops = *opts[0] + return &mongo.DeleteResult{ + DeletedCount: 1, + }, nil + }, + ).Reset() + + type args struct { + ctx context.Context + coll *mongo.Collection + args map[string]interface{} + } + tests := []struct { + name string + args args + want map[string]interface{} + opts *options.DeleteOptions + wantErr bool + }{ + {"err", args{ + ctx: nil, + coll: new(mongo.Collection), + }, nil, options.Delete(), true}, + {"type err", args{ + ctx: context.Background(), + coll: new(mongo.Collection), + args: errArgs, + }, map[string]interface{}{ + "DeletedCount": float64(1), + }, options.Delete(), false}, + {"succ", args{ + ctx: context.Background(), + coll: new(mongo.Collection), + args: succArgs, + }, map[string]interface{}{ + "DeletedCount": float64(1), + }, &options.DeleteOptions{ + Collation: new(options.Collation), + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := executeDeleteMany(tt.args.ctx, tt.args.coll, tt.args.args) + if (err != nil) != tt.wantErr { + t.Errorf("executeDeleteMany() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("executeDeleteMany() got = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(ops, tt.opts) { + t.Errorf("executeFind() got = %v, want %v", ops, tt.opts) + } + }) + } +} + +func TestUnitExecute(t *testing.T) { + ops := options.Delete() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "DeleteOne", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) { + if ctx == nil { + return nil, fmt.Errorf("err") + } + *ops = *opts[0] + return &mongo.DeleteResult{ + DeletedCount: 1, + }, nil + }, + ).Reset() + + type args struct { + ctx context.Context + coll *mongo.Collection + args map[string]interface{} + } + tests := []struct { + name string + args args + want map[string]interface{} + opts *options.DeleteOptions + wantErr bool + }{ + {"err", args{ + ctx: nil, + coll: new(mongo.Collection), + }, nil, options.Delete(), true}, + {"type err", args{ + ctx: context.Background(), + coll: new(mongo.Collection), + args: errArgs, + }, map[string]interface{}{ + "DeletedCount": float64(1), + }, options.Delete(), false}, + {"succ", args{ + ctx: context.Background(), + coll: new(mongo.Collection), + args: succArgs, + }, map[string]interface{}{ + "DeletedCount": float64(1), + }, &options.DeleteOptions{ + Collation: new(options.Collation), + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := executeDeleteOne(tt.args.ctx, tt.args.coll, tt.args.args) + if (err != nil) != tt.wantErr { + t.Errorf("executeDeleteOne() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("executeDeleteOne() got = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(ops, tt.opts) { + t.Errorf("executeFind() got = %v, want %v", ops, tt.opts) + } + }) + } +} + +func TestUnitExecuteDeleteOne(t *testing.T) { + ops := options.Delete() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "DeleteOne", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) { + if ctx == nil { + return nil, fmt.Errorf("err") + } + *ops = *opts[0] + return &mongo.DeleteResult{ + DeletedCount: 1, + }, nil + }, + ).Reset() + + type args struct { + ctx context.Context + coll *mongo.Collection + args map[string]interface{} + } + tests := []struct { + name string + args args + want map[string]interface{} + opts *options.DeleteOptions + wantErr bool + }{ + {"err", args{ + ctx: nil, + coll: new(mongo.Collection), + }, nil, options.Delete(), true}, + {"type err", args{ + ctx: context.Background(), + coll: new(mongo.Collection), + args: errArgs, + }, map[string]interface{}{ + "DeletedCount": float64(1), + }, options.Delete(), false}, + {"succ", args{ + ctx: context.Background(), + coll: new(mongo.Collection), + args: succArgs, + }, map[string]interface{}{ + "DeletedCount": float64(1), + }, &options.DeleteOptions{ + Collation: new(options.Collation), + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := executeDeleteOne(tt.args.ctx, tt.args.coll, tt.args.args) + if (err != nil) != tt.wantErr { + t.Errorf("executeDeleteOne() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("executeDeleteOne() got = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(ops, tt.opts) { + t.Errorf("executeFind() got = %v, want %v", ops, tt.opts) + } + }) + } +} + +func TestUnitExecuteFindOneAndDelete(t *testing.T) { + ops := options.FindOneAndDelete() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "FindOneAndDelete", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.FindOneAndDeleteOptions) *mongo.SingleResult { + *ops = *opts[0] + return new(mongo.SingleResult) + }, + ).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.SingleResult)), "Decode", + func(coll *mongo.SingleResult, v interface{}) error { + value, _ := v.(*map[string]interface{}) + *value = map[string]interface{}{ + "pi": 3.1415, + "hello": "world", + } + return nil + }, + ).Reset() + type args struct { + ctx context.Context + coll *mongo.Collection + args map[string]interface{} + } + tests := []struct { + name string + args args + want map[string]interface{} + opts *options.FindOneAndDeleteOptions + wantErr bool + }{ + {"type err", args{ + args: errArgs, + }, map[string]interface{}{ + "pi": 3.1415, + "hello": "world", + }, options.FindOneAndDelete(), false}, + {"succ", args{ + args: succArgs, + }, map[string]interface{}{ + "pi": 3.1415, + "hello": "world", + }, &options.FindOneAndDeleteOptions{ + Sort: map[string]interface{}{ + "order": "asc", + }, + Projection: map[string]interface{}{ + "key": "123", + }, + Collation: new(options.Collation), + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := executeFindOneAndDelete(tt.args.ctx, tt.args.coll, tt.args.args) + if (err != nil) != tt.wantErr { + t.Errorf("executeFindOneAndDelete() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("executeFindOneAndDelete() got = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(ops, tt.opts) { + t.Errorf("executeFind() got = %v, want %v", ops, tt.opts) + } + }) + } +} + +func TestUnitExecuteFindOneAndUpdate(t *testing.T) { + convey.Convey("TestUnitExecuteFindOneAndUpdate", t, func() { + ops := options.FindOneAndUpdate() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "FindOneAndUpdate", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + update interface{}, opts ...*options.FindOneAndUpdateOptions) *mongo.SingleResult { + *ops = *opts[0] + return new(mongo.SingleResult) + }, + ).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.SingleResult)), "Decode", + func(coll *mongo.SingleResult, v interface{}) error { + value, _ := v.(*map[string]interface{}) + *value = map[string]interface{}{ + "pi": 3.1415, + "hello": "world", + } + return nil + }, + ).Reset() + type args struct { + ctx context.Context + coll *mongo.Collection + args map[string]interface{} + } + tests := []struct { + name string + args args + want map[string]interface{} + opts *options.FindOneAndUpdateOptions + wantErr bool + }{ + {"type err", args{ + args: errArgs, + }, map[string]interface{}{ + "pi": 3.1415, + "hello": "world", + }, options.FindOneAndUpdate(), false}, + {"succ", args{ + args: succArgs, + }, map[string]interface{}{ + "pi": 3.1415, + "hello": "world", + }, &options.FindOneAndUpdateOptions{ + ArrayFilters: &options.ArrayFilters{ + Filters: []interface{}{1, 2}, + }, + Upsert: getBoolPoint(true), + ReturnDocument: getReturnDocumentPoint(options.After), + Collation: new(options.Collation), + Projection: map[string]interface{}{ + "key": "123", + }, + Sort: map[string]interface{}{ + "order": "asc", + }, + }, false}, + } + for _, tt := range tests { + convey.Convey(tt.name, func() { + _, _ = executeFindOneAndUpdateS(tt.args.ctx, tt.args.coll, tt.args.args) + got, err := executeFindOneAndUpdate(tt.args.ctx, tt.args.coll, tt.args.args) + if (err != nil) != tt.wantErr { + t.Errorf("executeFindOneAndUpdate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("executeFindOneAndUpdate() got = %v, want %v", got, tt.want) + } + convey.So(ops, convey.ShouldResemble, tt.opts) + + }) + } + }) +} + +func getReturnDocumentPoint(document options.ReturnDocument) *options.ReturnDocument { + return &document +} + +func getBoolPoint(b bool) *bool { + return &b +} + +func TestUnitExecuteInsertOne(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "InsertOne", + func(coll *mongo.Collection, ctx context.Context, document interface{}, + opts ...*options.InsertOneOptions) (*mongo.InsertOneResult, error) { + if ctx == nil { + return nil, fmt.Errorf("err") + } + return &mongo.InsertOneResult{ + InsertedID: "123", + }, nil + }, + ).Reset() + type args struct { + ctx context.Context + coll *mongo.Collection + args map[string]interface{} + } + tests := []struct { + name string + args args + want map[string]interface{} + wantErr bool + }{ + {"not document", args{ + ctx: nil, + }, nil, true}, + {"document type error", args{ + ctx: context.Background(), + args: map[string]interface{}{ + "document": "123", + }, + }, nil, true}, + {"insert error", args{ + ctx: nil, + args: map[string]interface{}{ + "document": map[string]interface{}{ + "pi": 3.1415, + "hello": "world", + }, + }, + }, nil, true}, + {"succ", args{ + ctx: context.Background(), + args: map[string]interface{}{ + "document": map[string]interface{}{ + "pi": 3.1415, + "hello": "world", + }, + }, + }, map[string]interface{}{ + "InsertedID": "123", + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := executeInsertOne(tt.args.ctx, tt.args.coll, tt.args.args) + if (err != nil) != tt.wantErr { + t.Errorf("executeInsertOne() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("executeInsertOne() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnitExecuteInsertMany(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "InsertMany", + func(coll *mongo.Collection, ctx context.Context, documents []interface{}, + opts ...*options.InsertManyOptions) (*mongo.InsertManyResult, error) { + if ctx == nil { + return nil, fmt.Errorf("err") + } + return &mongo.InsertManyResult{ + InsertedIDs: []interface{}{ + "123", + "456", + }, + }, nil + }, + ).Reset() + + type args struct { + ctx context.Context + coll *mongo.Collection + args map[string]interface{} + } + tests := []struct { + name string + args args + want map[string]interface{} + wantErr bool + }{ + {"not documents", args{ + ctx: nil, + }, nil, true}, + {"documents type error", args{ + ctx: context.Background(), + args: map[string]interface{}{ + "documents": "123", + }, + }, nil, true}, + {"document type error", args{ + ctx: context.Background(), + args: map[string]interface{}{ + "documents": []interface{}{ + 123, + }, + }, + }, nil, true}, + {"insert err", args{ + ctx: nil, + args: map[string]interface{}{ + "documents": []interface{}{ + map[string]interface{}{ + "pi": 3.1415, + "hello": "world", + }, + }, + }, + }, nil, true}, + {"succ", args{ + ctx: context.Background(), + args: map[string]interface{}{ + "documents": []interface{}{ + map[string]interface{}{ + "pi": 3.1415, + "hello": "world", + }, + }, + }, + }, map[string]interface{}{ + "InsertedIDs": []interface{}{"123", "456"}, + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := executeInsertMany(tt.args.ctx, tt.args.coll, tt.args.args) + if (err != nil) != tt.wantErr { + t.Errorf("executeInsertMany() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("executeInsertMany() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUnitExecuteUpdateOne(t *testing.T) { + ops := options.Update() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "UpdateOne", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, update interface{}, + opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { + if ctx == nil { + return nil, fmt.Errorf("err") + } + *ops = *opts[0] + return &mongo.UpdateResult{ + MatchedCount: 1, + ModifiedCount: 2, + UpsertedCount: 3, + }, nil + }, + ).Reset() + + type args struct { + ctx context.Context + coll *mongo.Collection + args map[string]interface{} + } + tests := []struct { + name string + args args + want map[string]interface{} + opts *options.UpdateOptions + wantErr bool + }{ + {"err", args{ + ctx: nil, + }, nil, options.Update(), true}, + {"type err", args{ + ctx: context.Background(), + args: errArgs, + }, map[string]interface{}{ + "MatchedCount": float64(1), + "ModifiedCount": float64(2), + "UpsertedCount": float64(3), + "UpsertedID": nil, + }, options.Update(), false}, + {"succ", args{ + ctx: context.Background(), + args: succArgs, + }, map[string]interface{}{ + "MatchedCount": float64(1), + "ModifiedCount": float64(2), + "UpsertedCount": float64(3), + "UpsertedID": nil, + }, &options.UpdateOptions{ + ArrayFilters: &options.ArrayFilters{ + Filters: []interface{}{1, 2}, + }, + Upsert: getBoolPoint(true), + Collation: new(options.Collation), + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := executeUpdateOne(tt.args.ctx, tt.args.coll, tt.args.args) + if (err != nil) != tt.wantErr { + t.Errorf("executeUpdateOne() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("executeUpdateOne() got = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(ops, tt.opts) { + t.Errorf("executeFind() got = %v, want %v", ops, tt.opts) + } + }) + } +} + +func TestUnitExecuteUpdateMany(t *testing.T) { + ops := options.Update() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "UpdateMany", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, update interface{}, + opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { + if ctx == nil { + return nil, fmt.Errorf("err") + } + *ops = *opts[0] + return &mongo.UpdateResult{ + MatchedCount: 1, + ModifiedCount: 2, + UpsertedCount: 3, + }, nil + }, + ).Reset() + + type args struct { + ctx context.Context + coll *mongo.Collection + args map[string]interface{} + } + tests := []struct { + name string + args args + want map[string]interface{} + opts *options.UpdateOptions + wantErr bool + }{ + {"err", args{ + ctx: nil, + }, nil, options.Update(), true}, + {"type err", args{ + ctx: context.Background(), + args: errArgs, + }, map[string]interface{}{ + "MatchedCount": float64(1), + "ModifiedCount": float64(2), + "UpsertedCount": float64(3), + "UpsertedID": nil, + }, options.Update(), false}, + {"succ", args{ + ctx: context.Background(), + args: succArgs, + }, map[string]interface{}{ + "MatchedCount": float64(1), + "ModifiedCount": float64(2), + "UpsertedCount": float64(3), + "UpsertedID": nil, + }, &options.UpdateOptions{ + ArrayFilters: &options.ArrayFilters{ + Filters: []interface{}{1, 2}, + }, + Upsert: getBoolPoint(true), + Collation: new(options.Collation), + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := executeUpdateMany(tt.args.ctx, tt.args.coll, tt.args.args) + if (err != nil) != tt.wantErr { + t.Errorf("executeUpdateMany() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("executeUpdateMany() got = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(ops, tt.opts) { + t.Errorf("executeFind() got = %v, want %v", ops, tt.opts) + } + }) + } +} + +func TestUnitExecuteCount(t *testing.T) { + ops := options.Count() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "CountDocuments", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.CountOptions) (int64, error) { + if ctx == nil { + return 0, fmt.Errorf("err") + } + *ops = *opts[0] + return 1, nil + }, + ).Reset() + + type args struct { + ctx context.Context + coll *mongo.Collection + args map[string]interface{} + } + tests := []struct { + name string + args args + want int64 + opts *options.CountOptions + wantErr bool + }{ + {"err", args{ + ctx: nil, + }, 0, options.Count(), true}, + {"type err", args{ + ctx: context.Background(), + args: errArgs, + }, 1, options.Count(), false}, + {"succ", args{ + ctx: context.Background(), + args: succArgs, + }, 1, &options.CountOptions{ + Skip: getInt64Point(2), + Limit: getInt64Point(3), + Collation: new(options.Collation), + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := executeCount(tt.args.ctx, tt.args.coll, tt.args.args) + if (err != nil) != tt.wantErr { + t.Errorf("executeCount() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("executeCount() got = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(ops, tt.opts) { + t.Errorf("executeFind() got = %v, want %v", ops, tt.opts) + } + }) + } +} + +func TestUnitExecuteDistinct(t *testing.T) { + ops := options.Distinct() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "Distinct", + func(coll *mongo.Collection, ctx context.Context, fieldName string, filter interface{}, + opts ...*options.DistinctOptions) ([]interface{}, error) { + if ctx == nil { + return nil, fmt.Errorf("err") + } + *ops = *opts[0] + return []interface{}{"1", 23}, nil + }, + ).Reset() + + type args struct { + ctx context.Context + coll *mongo.Collection + args map[string]interface{} + } + tests := []struct { + name string + args args + want []interface{} + opts *options.DistinctOptions + wantErr bool + }{ + {"err", args{ + ctx: nil, + }, nil, options.Distinct(), true}, + {"type err", args{ + ctx: context.Background(), + args: errArgs, + }, []interface{}{"1", 23}, options.Distinct(), false}, + {"succ", args{ + ctx: context.Background(), + args: succArgs, + }, []interface{}{"1", 23}, &options.DistinctOptions{ + Collation: new(options.Collation), + }, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := executeDistinct(tt.args.ctx, tt.args.coll, tt.args.args) + if (err != nil) != tt.wantErr { + t.Errorf("executeDistinct() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("executeDistinct() got = %v, want %v", got, tt.want) + } + if !reflect.DeepEqual(ops, tt.opts) { + t.Errorf("executeFind() got = %v, want %v", ops, tt.opts) + } + }) + } +} + +func TestUnitExecuteAggregate(t *testing.T) { + ops := options.Aggregate() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "Aggregate", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.AggregateOptions) (*mongo.Cursor, error) { + if ctx == nil { + return nil, fmt.Errorf("err") + } + *ops = *opts[0] + return &mongo.Cursor{}, nil + }, + ).Reset() + + type args struct { + ctx context.Context + coll *mongo.Collection + args map[string]interface{} + } + tArgs := make(map[string]interface{}) + tArgs["pipeline"] = make([]int, 0) + tArgs["batchSize"] = float64(1) + tArgs["collation"] = make(map[string]interface{}) + tArgs["maxTimeMS"] = float64(1) + tests := []struct { + name string + args args + want []interface{} + opts *options.AggregateOptions + wantErr bool + }{ + {"err", args{ + ctx: nil, + args: tArgs, + }, nil, options.Aggregate(), true}, + {"err_not_nil", args{ + ctx: trpc.BackgroundContext(), + args: tArgs, + }, nil, options.Aggregate(), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, _ = executeAggregateC(tt.args.ctx, tt.args.coll, tt.args.args) + _, err := executeAggregate(tt.args.ctx, tt.args.coll, tt.args.args) + if (err != nil) != tt.wantErr { + t.Errorf("executeAggregate() error = %v, wantErr %v", err, tt.wantErr) + return + } + + }) + } +} +func TestUnitCollationFromMap(t *testing.T) { + type args struct { + m map[string]interface{} + } + tests := []struct { + name string + args args + want *options.Collation + }{ + {"nil map", args{m: nil}, new(options.Collation)}, + {"empty map", args{m: map[string]interface{}{}}, new(options.Collation)}, + {"type err", args{m: map[string]interface{}{ + "locale": 1, + "caseLevel": 2, + "caseFirst": 3, + "strength": "str", + "numericOrdering": 1, + "alternate": 2, + "maxVariable": 3, + "normalization": 4, + "backwards": 1, + }}, new(options.Collation)}, + {"succ", args{m: map[string]interface{}{ + "locale": "locale", + "caseLevel": true, + "caseFirst": "caseFirst", + "strength": 3.15, + "numericOrdering": false, + "alternate": "alternate", + "maxVariable": "maxVariable", + "normalization": true, + "backwards": false, + }}, &options.Collation{ + Locale: "locale", + CaseLevel: true, + CaseFirst: "caseFirst", + Strength: 3, + NumericOrdering: false, + Alternate: "alternate", + MaxVariable: "maxVariable", + Normalization: true, + Backwards: false, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := collationFromMap(tt.args.m); !reflect.DeepEqual(got, tt.want) { + t.Errorf("collationFromMap() = %v, want %v", got, tt.want) + } + }) + } +} +func TestUnitBulkWrite(t *testing.T) { + ops := options.BulkWrite() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "BulkWrite", + func(coll *mongo.Collection, ctx context.Context, model []mongo.WriteModel, + opts ...*options.BulkWriteOptions) (*mongo.BulkWriteResult, error) { + *ops = *opts[0] + return &mongo.BulkWriteResult{}, nil + }, + ).Reset() + + tArgs := make(map[string]interface{}) + tArgs["documents"] = []mongo.WriteModel{} + tArgs["BulkWriteOptions"] = &options.BulkWriteOptions{} + type args struct { + ctx context.Context + coll *mongo.Collection + args map[string]interface{} + } + tests := []struct { + name string + args args + want []interface{} + opts *options.BulkWriteOptions + wantErr bool + }{ + {"err", args{ + ctx: nil, + args: tArgs, + }, nil, options.BulkWrite(), true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := executeBulkWrite(tt.args.ctx, tt.args.coll, tt.args.args) + if err != nil { + t.Errorf("executeBulkWrite() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} diff --git a/mongodb/error.go b/mongodb/error.go new file mode 100644 index 0000000..6deac70 --- /dev/null +++ b/mongodb/error.go @@ -0,0 +1,25 @@ +package mongodb + +import ( + "go.mongodb.org/mongo-driver/mongo" + "trpc.group/trpc-go/trpc-go/errs" +) + +// error code, refer: https://github.com/mongodb/mongo/blob/master/src/mongo/base/error_codes.yml +const ( + RetDuplicateKeyErr = 11000 // key conflict error +) + +// IsDuplicateKeyError handles whether it is a key conflict error. +func IsDuplicateKeyError(err error) bool { + if e, ok := err.(*errs.Error); ok && e.Code == RetDuplicateKeyErr { + return true + } + if e, ok := err.(mongo.BulkWriteError); ok && e.Code == RetDuplicateKeyErr { + return true + } + if mongo.IsDuplicateKeyError(err) { + return true + } + return false +} diff --git a/mongodb/error_test.go b/mongodb/error_test.go new file mode 100644 index 0000000..0d72b2e --- /dev/null +++ b/mongodb/error_test.go @@ -0,0 +1,48 @@ +package mongodb + +import ( + "fmt" + "testing" + + "go.mongodb.org/mongo-driver/mongo" + "trpc.group/trpc-go/trpc-go/errs" +) + +func TestIsDuplicateKeyError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "repeated key err", + err: errs.NewFrameError(RetDuplicateKeyErr, "key duplicated"), + want: true, + }, + { + name: "bulk write repeated key err", + err: mongo.BulkWriteException{ + WriteConcernError: &mongo.WriteConcernError{Name: "name", Code: 100, Message: "bar"}, + WriteErrors: []mongo.BulkWriteError{ + { + WriteError: mongo.WriteError{Code: 11000, Message: "blah E11000 blah"}, + Request: &mongo.InsertOneModel{}}, + }, + Labels: []string{"otherError"}, + }, + want: true, + }, + { + name: "other err", + err: fmt.Errorf("E 11001, timeout"), + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsDuplicateKeyError(tt.err); got != tt.want { + t.Errorf("IsDuplicateKeyError() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/mongodb/go.mod b/mongodb/go.mod new file mode 100644 index 0000000..394ae2c --- /dev/null +++ b/mongodb/go.mod @@ -0,0 +1,56 @@ +module trpc.group/trpc-go/trpc-database/mongodb + +go 1.18 + +require ( + github.com/agiledragon/gomonkey/v2 v2.6.0 + github.com/golang/mock v1.5.0 + github.com/smartystreets/goconvey v1.6.4 + github.com/stretchr/testify v1.8.4 + go.mongodb.org/mongo-driver v1.11.6 + gopkg.in/yaml.v3 v3.0.1 + trpc.group/trpc-go/trpc-go v1.0.3 + trpc.group/trpc-go/trpc-selector-dsn v1.1.0 +) + +require ( + github.com/BurntSushi/toml v0.3.1 // indirect + github.com/andybalholm/brotli v1.0.5 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/golang/snappy v0.0.4 // indirect + github.com/google/flatbuffers v2.0.0+incompatible // indirect + github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect + github.com/hashicorp/errwrap v1.0.0 // indirect + github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/jtolds/gls v4.20.0+incompatible // indirect + github.com/klauspost/compress v1.17.0 // indirect + github.com/lestrrat-go/strftime v1.0.6 // indirect + github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect + github.com/panjf2000/ants/v2 v2.4.6 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d // indirect + github.com/spf13/cast v1.3.1 // indirect + github.com/tidwall/pretty v1.0.1 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect + github.com/valyala/fasthttp v1.51.0 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.1.1 // indirect + github.com/xdg-go/stringprep v1.0.3 // indirect + github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/automaxprocs v1.3.0 // indirect + go.uber.org/multierr v1.6.0 // indirect + go.uber.org/zap v1.24.0 // indirect + golang.org/x/crypto v0.16.0 // indirect + golang.org/x/sync v0.1.0 // indirect + golang.org/x/sys v0.18.0 // indirect + golang.org/x/text v0.14.0 // indirect + google.golang.org/protobuf v1.33.0 // indirect + trpc.group/trpc-go/tnet v1.0.1 // indirect + trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 // indirect +) diff --git a/mongodb/go.sum b/mongodb/go.sum new file mode 100644 index 0000000..e69fd11 --- /dev/null +++ b/mongodb/go.sum @@ -0,0 +1,164 @@ +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/agiledragon/gomonkey/v2 v2.6.0 h1:RzdlW1ibfVipfXKy9U4zYumdHTIY7RoZwyXY3tXLYd8= +github.com/agiledragon/gomonkey/v2 v2.6.0/go.mod h1:ap1AmDzcVOAz1YpeJ3TCzIgstoaWLA6jbbgxfB4w2iY= +github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= +github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/go-playground/form/v4 v4.2.0 h1:N1wh+Goz61e6w66vo8vJkQt+uwZSoLz50kZPJWR8eic= +github.com/golang/mock v1.5.0 h1:jlYHihg//f7RRwuPfptm04yp4s7O6Kw8EZiVYIGcH0g= +github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= +github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/flatbuffers v2.0.0+incompatible h1:dicJ2oXwypfwUGnB2/TYWYEKiuk9eYQlQO/AnOHl5mI= +github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA= +github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= +github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= +github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM= +github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lestrrat-go/envload v0.0.0-20180220234015-a3eb8ddeffcc h1:RKf14vYWi2ttpEmkA4aQ3j4u9dStX2t4M8UM6qqNsG8= +github.com/lestrrat-go/envload v0.0.0-20180220234015-a3eb8ddeffcc/go.mod h1:kopuH9ugFRkIXf3YoqHKyrJ9YfUFsckUU9S7B+XP+is= +github.com/lestrrat-go/strftime v1.0.6 h1:CFGsDEt1pOpFNU+TJB0nhz9jl+K0hZSLE205AhTIGQQ= +github.com/lestrrat-go/strftime v1.0.6/go.mod h1:f7jQKgV5nnJpYgdEasS+/y7EsTb8ykN2z68n3TtcTaw= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6fbMZ9s0scYfZQ84/6SPL6zC8ACM2oIL0= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/panjf2000/ants/v2 v2.4.6 h1:drmj9mcygn2gawZ155dRbo+NfXEfAssjZNU1qoIb4gQ= +github.com/panjf2000/ants/v2 v2.4.6/go.mod h1:f6F0NZVFsGCp5A7QW/Zj/m92atWwOkY0OIhFxRNFr4A= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= +github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/tidwall/pretty v1.0.1 h1:WE4RBSZ1x6McVVC8S/Md+Qse8YUv6HRObAx6ke00NY8= +github.com/tidwall/pretty v1.0.1/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= +github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.1 h1:VOMT+81stJgXW3CpHyqHN3AXDYIMsx56mEFrB37Mb/E= +github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g= +github.com/xdg-go/stringprep v1.0.3 h1:kdwGpVNwPFtjs98xCGkHjQtGKh86rDcRZN17QEMCOIs= +github.com/xdg-go/stringprep v1.0.3/go.mod h1:W3f5j4i+9rC0kuIEJL0ky1VpHXQU3ocBgklLGvcBnW8= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= +go.mongodb.org/mongo-driver v1.11.6 h1:XM7G6PjiGAO5betLF13BIa5TlLUUE3uJ/2Ox3Lz1K+o= +go.mongodb.org/mongo-driver v1.11.6/go.mod h1:G9TgswdsWjX4tmDA5zfs2+6AEPpYJwqblyjsfuh8oXY= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/automaxprocs v1.3.0 h1:II28aZoGdaglS5vVNnspf28lnZpXScxtIozx1lAjdb0= +go.uber.org/automaxprocs v1.3.0/go.mod h1:9CWT6lKIep8U41DDaPiH6eFscnTyjfTANNQNx6LrIcA= +go.uber.org/goleak v1.1.11 h1:wy28qYRKZgnJTxGxvye5/wgWr1EKjmUDGYox5mGlRlI= +go.uber.org/multierr v1.6.0 h1:y6IPFStTAIT5Ytl7/XYmHvzXQ7S3g/IeZW9hyZ5thw4= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/zap v1.24.0 h1:FiJd5l1UOLj0wCgbSE0rwwXHzEdAZS6hiiSnxJN/D60= +go.uber.org/zap v1.24.0/go.mod h1:2kMP+WWQ8aoFoedH3T2sq6iJ2yDWpHbP0f6MQbS9Gkg= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= +golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= +golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4= +golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191125144606-a911d9008d1f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.33.0 h1:uNO2rsAINq/JlFpSdYEKIZ0uKD/R9cpdv0T+yoGwGmI= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.7/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +trpc.group/trpc-go/tnet v1.0.1 h1:Yzqyrgyfm+W742FzGr39c4+OeQmLi7PWotJxrOBtV9o= +trpc.group/trpc-go/tnet v1.0.1/go.mod h1:s/webUFYWEFBHErKyFmj7LYC7XfC2LTLCcwfSnJ04M0= +trpc.group/trpc-go/trpc-go v1.0.3 h1:X4RhPmJOkVoK6EGKoV241dvEpB6EagBeyu3ZrqkYZQY= +trpc.group/trpc-go/trpc-go v1.0.3/go.mod h1:82O+G2rD5ST+JAPuPPSqvsr6UI59UxV27iAILSkAIlQ= +trpc.group/trpc-go/trpc-selector-dsn v1.1.0 h1:z3VqiboZq60MBu0cHVlRe5q7VydGbBdrX9xAfzsTVIQ= +trpc.group/trpc-go/trpc-selector-dsn v1.1.0/go.mod h1:78NOrldaWxLJd2M+VCm4OABphAYzx98dZWTLDFSzeQg= +trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0 h1:rMtHYzI0ElMJRxHtT5cD99SigFE6XzKK4PFtjcwokI0= +trpc.group/trpc/trpc-protocol/pb/go/trpc v1.0.0/go.mod h1:K+a1K/Gnlcg9BFHWx30vLBIEDhxODhl25gi1JjA54CQ= diff --git a/mongodb/mockmongodb/client_mock.go b/mongodb/mockmongodb/client_mock.go new file mode 100644 index 0000000..3c7361a --- /dev/null +++ b/mongodb/mockmongodb/client_mock.go @@ -0,0 +1,665 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: client.go + +// Package mockmongodb is a generated GoMock package. +package mockmongodb + +import ( + context "context" + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + bson "go.mongodb.org/mongo-driver/bson" + mongo "go.mongodb.org/mongo-driver/mongo" + options "go.mongodb.org/mongo-driver/mongo/options" + mongodb "trpc.group/trpc-go/trpc-database/mongodb" +) + +// MockClient is a mock of Client interface. +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// Aggregate mocks base method. +func (m *MockClient) Aggregate(ctx context.Context, database, coll string, pipeline interface{}, opts ...*options.AggregateOptions) (*mongo.Cursor, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, pipeline} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Aggregate", varargs...) + ret0, _ := ret[0].(*mongo.Cursor) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Aggregate indicates an expected call of Aggregate. +func (mr *MockClientMockRecorder) Aggregate(ctx, database, coll, pipeline interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, pipeline}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Aggregate", reflect.TypeOf((*MockClient)(nil).Aggregate), varargs...) +} + +// BulkWrite mocks base method. +func (m *MockClient) BulkWrite(ctx context.Context, database, coll string, models []mongo.WriteModel, opts ...*options.BulkWriteOptions) (*mongo.BulkWriteResult, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, models} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "BulkWrite", varargs...) + ret0, _ := ret[0].(*mongo.BulkWriteResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BulkWrite indicates an expected call of BulkWrite. +func (mr *MockClientMockRecorder) BulkWrite(ctx, database, coll, models interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, models}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BulkWrite", reflect.TypeOf((*MockClient)(nil).BulkWrite), varargs...) +} + +// Collection mocks base method. +func (m *MockClient) Collection(ctx context.Context, database, collection string) (*mongo.Collection, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Collection", ctx, database, collection) + ret0, _ := ret[0].(*mongo.Collection) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Collection indicates an expected call of Collection. +func (mr *MockClientMockRecorder) Collection(ctx, database, collection interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Collection", reflect.TypeOf((*MockClient)(nil).Collection), ctx, database, collection) +} + +// CountDocuments mocks base method. +func (m *MockClient) CountDocuments(ctx context.Context, database, coll string, filter interface{}, opts ...*options.CountOptions) (int64, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, filter} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CountDocuments", varargs...) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountDocuments indicates an expected call of CountDocuments. +func (mr *MockClientMockRecorder) CountDocuments(ctx, database, coll, filter interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, filter}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountDocuments", reflect.TypeOf((*MockClient)(nil).CountDocuments), varargs...) +} + +// Database mocks base method. +func (m *MockClient) Database(ctx context.Context, database string) (*mongo.Database, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Database", ctx, database) + ret0, _ := ret[0].(*mongo.Database) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Database indicates an expected call of Database. +func (mr *MockClientMockRecorder) Database(ctx, database interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Database", reflect.TypeOf((*MockClient)(nil).Database), ctx, database) +} + +// DeleteMany mocks base method. +func (m *MockClient) DeleteMany(ctx context.Context, database, coll string, filter interface{}, opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, filter} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DeleteMany", varargs...) + ret0, _ := ret[0].(*mongo.DeleteResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteMany indicates an expected call of DeleteMany. +func (mr *MockClientMockRecorder) DeleteMany(ctx, database, coll, filter interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, filter}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMany", reflect.TypeOf((*MockClient)(nil).DeleteMany), varargs...) +} + +// DeleteOne mocks base method. +func (m *MockClient) DeleteOne(ctx context.Context, database, coll string, filter interface{}, opts ...*options.DeleteOptions) (*mongo.DeleteResult, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, filter} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DeleteOne", varargs...) + ret0, _ := ret[0].(*mongo.DeleteResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DeleteOne indicates an expected call of DeleteOne. +func (mr *MockClientMockRecorder) DeleteOne(ctx, database, coll, filter interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, filter}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteOne", reflect.TypeOf((*MockClient)(nil).DeleteOne), varargs...) +} + +// Disconnect mocks base method. +func (m *MockClient) Disconnect(ctx context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Disconnect", ctx) + ret0, _ := ret[0].(error) + return ret0 +} + +// Disconnect indicates an expected call of Disconnect. +func (mr *MockClientMockRecorder) Disconnect(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockClient)(nil).Disconnect), ctx) +} + +// Distinct mocks base method. +func (m *MockClient) Distinct(ctx context.Context, database, coll, fieldName string, filter interface{}, opts ...*options.DistinctOptions) ([]interface{}, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, fieldName, filter} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Distinct", varargs...) + ret0, _ := ret[0].([]interface{}) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Distinct indicates an expected call of Distinct. +func (mr *MockClientMockRecorder) Distinct(ctx, database, coll, fieldName, filter interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, fieldName, filter}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Distinct", reflect.TypeOf((*MockClient)(nil).Distinct), varargs...) +} + +// Do mocks base method. +func (m *MockClient) Do(ctx context.Context, cmd, db, coll string, args map[string]interface{}) (interface{}, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Do", ctx, cmd, db, coll, args) + ret0, _ := ret[0].(interface{}) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Do indicates an expected call of Do. +func (mr *MockClientMockRecorder) Do(ctx, cmd, db, coll, args interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Do", reflect.TypeOf((*MockClient)(nil).Do), ctx, cmd, db, coll, args) +} + +// EstimatedDocumentCount mocks base method. +func (m *MockClient) EstimatedDocumentCount(ctx context.Context, database, coll string, opts ...*options.EstimatedDocumentCountOptions) (int64, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "EstimatedDocumentCount", varargs...) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// EstimatedDocumentCount indicates an expected call of EstimatedDocumentCount. +func (mr *MockClientMockRecorder) EstimatedDocumentCount(ctx, database, coll interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "EstimatedDocumentCount", reflect.TypeOf((*MockClient)(nil).EstimatedDocumentCount), varargs...) +} + +// Find mocks base method. +func (m *MockClient) Find(ctx context.Context, database, coll string, filter interface{}, opts ...*options.FindOptions) (*mongo.Cursor, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, filter} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Find", varargs...) + ret0, _ := ret[0].(*mongo.Cursor) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Find indicates an expected call of Find. +func (mr *MockClientMockRecorder) Find(ctx, database, coll, filter interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, filter}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Find", reflect.TypeOf((*MockClient)(nil).Find), varargs...) +} + +// FindOne mocks base method. +func (m *MockClient) FindOne(ctx context.Context, database, coll string, filter interface{}, opts ...*options.FindOneOptions) *mongo.SingleResult { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, filter} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "FindOne", varargs...) + ret0, _ := ret[0].(*mongo.SingleResult) + return ret0 +} + +// FindOne indicates an expected call of FindOne. +func (mr *MockClientMockRecorder) FindOne(ctx, database, coll, filter interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, filter}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOne", reflect.TypeOf((*MockClient)(nil).FindOne), varargs...) +} + +// FindOneAndDelete mocks base method. +func (m *MockClient) FindOneAndDelete(ctx context.Context, database, coll string, filter interface{}, opts ...*options.FindOneAndDeleteOptions) *mongo.SingleResult { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, filter} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "FindOneAndDelete", varargs...) + ret0, _ := ret[0].(*mongo.SingleResult) + return ret0 +} + +// FindOneAndDelete indicates an expected call of FindOneAndDelete. +func (mr *MockClientMockRecorder) FindOneAndDelete(ctx, database, coll, filter interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, filter}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOneAndDelete", reflect.TypeOf((*MockClient)(nil).FindOneAndDelete), varargs...) +} + +// FindOneAndReplace mocks base method. +func (m *MockClient) FindOneAndReplace(ctx context.Context, database, coll string, filter, replacement interface{}, opts ...*options.FindOneAndReplaceOptions) *mongo.SingleResult { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, filter, replacement} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "FindOneAndReplace", varargs...) + ret0, _ := ret[0].(*mongo.SingleResult) + return ret0 +} + +// FindOneAndReplace indicates an expected call of FindOneAndReplace. +func (mr *MockClientMockRecorder) FindOneAndReplace(ctx, database, coll, filter, replacement interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, filter, replacement}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOneAndReplace", reflect.TypeOf((*MockClient)(nil).FindOneAndReplace), varargs...) +} + +// FindOneAndUpdate mocks base method. +func (m *MockClient) FindOneAndUpdate(ctx context.Context, database, coll string, filter, update interface{}, opts ...*options.FindOneAndUpdateOptions) *mongo.SingleResult { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, filter, update} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "FindOneAndUpdate", varargs...) + ret0, _ := ret[0].(*mongo.SingleResult) + return ret0 +} + +// FindOneAndUpdate indicates an expected call of FindOneAndUpdate. +func (mr *MockClientMockRecorder) FindOneAndUpdate(ctx, database, coll, filter, update interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, filter, update}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FindOneAndUpdate", reflect.TypeOf((*MockClient)(nil).FindOneAndUpdate), varargs...) +} + +// Indexes mocks base method. +func (m *MockClient) Indexes(ctx context.Context, database, collection string) (mongo.IndexView, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Indexes", ctx, database, collection) + ret0, _ := ret[0].(mongo.IndexView) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Indexes indicates an expected call of Indexes. +func (mr *MockClientMockRecorder) Indexes(ctx, database, collection interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Indexes", reflect.TypeOf((*MockClient)(nil).Indexes), ctx, database, collection) +} + +// InsertMany mocks base method. +func (m *MockClient) InsertMany(ctx context.Context, database, coll string, documents []interface{}, opts ...*options.InsertManyOptions) (*mongo.InsertManyResult, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, documents} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "InsertMany", varargs...) + ret0, _ := ret[0].(*mongo.InsertManyResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InsertMany indicates an expected call of InsertMany. +func (mr *MockClientMockRecorder) InsertMany(ctx, database, coll, documents interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, documents}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertMany", reflect.TypeOf((*MockClient)(nil).InsertMany), varargs...) +} + +// InsertOne mocks base method. +func (m *MockClient) InsertOne(ctx context.Context, database, coll string, document interface{}, opts ...*options.InsertOneOptions) (*mongo.InsertOneResult, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, document} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "InsertOne", varargs...) + ret0, _ := ret[0].(*mongo.InsertOneResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// InsertOne indicates an expected call of InsertOne. +func (mr *MockClientMockRecorder) InsertOne(ctx, database, coll, document interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, document}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertOne", reflect.TypeOf((*MockClient)(nil).InsertOne), varargs...) +} + +// ReplaceOne mocks base method. +func (m *MockClient) ReplaceOne(ctx context.Context, database, coll string, filter, replacement interface{}, opts ...*options.ReplaceOptions) (*mongo.UpdateResult, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, filter, replacement} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "ReplaceOne", varargs...) + ret0, _ := ret[0].(*mongo.UpdateResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReplaceOne indicates an expected call of ReplaceOne. +func (mr *MockClientMockRecorder) ReplaceOne(ctx, database, coll, filter, replacement interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, filter, replacement}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceOne", reflect.TypeOf((*MockClient)(nil).ReplaceOne), varargs...) +} + +// RunCommand mocks base method. +func (m *MockClient) RunCommand(ctx context.Context, database string, runCommand interface{}, opts ...*options.RunCmdOptions) *mongo.SingleResult { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, runCommand} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "RunCommand", varargs...) + ret0, _ := ret[0].(*mongo.SingleResult) + return ret0 +} + +// RunCommand indicates an expected call of RunCommand. +func (mr *MockClientMockRecorder) RunCommand(ctx, database, runCommand interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, runCommand}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RunCommand", reflect.TypeOf((*MockClient)(nil).RunCommand), varargs...) +} + +// StartSession mocks base method. +func (m *MockClient) StartSession(ctx context.Context) (mongo.Session, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "StartSession", ctx) + ret0, _ := ret[0].(mongo.Session) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// StartSession indicates an expected call of StartSession. +func (mr *MockClientMockRecorder) StartSession(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "StartSession", reflect.TypeOf((*MockClient)(nil).StartSession), ctx) +} + +// Transaction mocks base method. +func (m *MockClient) Transaction(ctx context.Context, sf mongodb.TxFunc, tOpts []*options.TransactionOptions, opts ...*options.SessionOptions) error { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, sf, tOpts} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Transaction", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Transaction indicates an expected call of Transaction. +func (mr *MockClientMockRecorder) Transaction(ctx, sf, tOpts interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, sf, tOpts}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Transaction", reflect.TypeOf((*MockClient)(nil).Transaction), varargs...) +} + +// UpdateMany mocks base method. +func (m *MockClient) UpdateMany(ctx context.Context, database, coll string, filter, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, filter, update} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "UpdateMany", varargs...) + ret0, _ := ret[0].(*mongo.UpdateResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateMany indicates an expected call of UpdateMany. +func (mr *MockClientMockRecorder) UpdateMany(ctx, database, coll, filter, update interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, filter, update}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateMany", reflect.TypeOf((*MockClient)(nil).UpdateMany), varargs...) +} + +// UpdateOne mocks base method. +func (m *MockClient) UpdateOne(ctx context.Context, database, coll string, filter, update interface{}, opts ...*options.UpdateOptions) (*mongo.UpdateResult, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, filter, update} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "UpdateOne", varargs...) + ret0, _ := ret[0].(*mongo.UpdateResult) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateOne indicates an expected call of UpdateOne. +func (mr *MockClientMockRecorder) UpdateOne(ctx, database, coll, filter, update interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, filter, update}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateOne", reflect.TypeOf((*MockClient)(nil).UpdateOne), varargs...) +} + +// Watch mocks base method. +func (m *MockClient) Watch(ctx context.Context, pipeline interface{}, opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, pipeline} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Watch", varargs...) + ret0, _ := ret[0].(*mongo.ChangeStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Watch indicates an expected call of Watch. +func (mr *MockClientMockRecorder) Watch(ctx, pipeline interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, pipeline}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Watch", reflect.TypeOf((*MockClient)(nil).Watch), varargs...) +} + +// WatchCollection mocks base method. +func (m *MockClient) WatchCollection(ctx context.Context, database, collection string, pipeline interface{}, opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, collection, pipeline} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "WatchCollection", varargs...) + ret0, _ := ret[0].(*mongo.ChangeStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// WatchCollection indicates an expected call of WatchCollection. +func (mr *MockClientMockRecorder) WatchCollection(ctx, database, collection, pipeline interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, collection, pipeline}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WatchCollection", reflect.TypeOf((*MockClient)(nil).WatchCollection), varargs...) +} + +// WatchDatabase mocks base method. +func (m *MockClient) WatchDatabase(ctx context.Context, database string, pipeline interface{}, opts ...*options.ChangeStreamOptions) (*mongo.ChangeStream, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, pipeline} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "WatchDatabase", varargs...) + ret0, _ := ret[0].(*mongo.ChangeStream) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// WatchDatabase indicates an expected call of WatchDatabase. +func (mr *MockClientMockRecorder) WatchDatabase(ctx, database, pipeline interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, pipeline}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WatchDatabase", reflect.TypeOf((*MockClient)(nil).WatchDatabase), varargs...) +} + +// MockIndexViewer is a mock of IndexViewer interface. +type MockIndexViewer struct { + ctrl *gomock.Controller + recorder *MockIndexViewerMockRecorder +} + +// MockIndexViewerMockRecorder is the mock recorder for MockIndexViewer. +type MockIndexViewerMockRecorder struct { + mock *MockIndexViewer +} + +// NewMockIndexViewer creates a new mock instance. +func NewMockIndexViewer(ctrl *gomock.Controller) *MockIndexViewer { + mock := &MockIndexViewer{ctrl: ctrl} + mock.recorder = &MockIndexViewerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockIndexViewer) EXPECT() *MockIndexViewerMockRecorder { + return m.recorder +} + +// CreateMany mocks base method. +func (m *MockIndexViewer) CreateMany(ctx context.Context, database, coll string, models []mongo.IndexModel, opts ...*options.CreateIndexesOptions) ([]string, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, models} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreateMany", varargs...) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateMany indicates an expected call of CreateMany. +func (mr *MockIndexViewerMockRecorder) CreateMany(ctx, database, coll, models interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, models}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateMany", reflect.TypeOf((*MockIndexViewer)(nil).CreateMany), varargs...) +} + +// CreateOne mocks base method. +func (m *MockIndexViewer) CreateOne(ctx context.Context, database, coll string, model mongo.IndexModel, opts ...*options.CreateIndexesOptions) (string, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, model} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreateOne", varargs...) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateOne indicates an expected call of CreateOne. +func (mr *MockIndexViewerMockRecorder) CreateOne(ctx, database, coll, model interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, model}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateOne", reflect.TypeOf((*MockIndexViewer)(nil).CreateOne), varargs...) +} + +// DropAll mocks base method. +func (m *MockIndexViewer) DropAll(ctx context.Context, database, coll string, opts ...*options.DropIndexesOptions) (bson.Raw, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DropAll", varargs...) + ret0, _ := ret[0].(bson.Raw) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DropAll indicates an expected call of DropAll. +func (mr *MockIndexViewerMockRecorder) DropAll(ctx, database, coll interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropAll", reflect.TypeOf((*MockIndexViewer)(nil).DropAll), varargs...) +} + +// DropOne mocks base method. +func (m *MockIndexViewer) DropOne(ctx context.Context, database, coll, name string, opts ...*options.DropIndexesOptions) (bson.Raw, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, database, coll, name} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "DropOne", varargs...) + ret0, _ := ret[0].(bson.Raw) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DropOne indicates an expected call of DropOne. +func (mr *MockIndexViewerMockRecorder) DropOne(ctx, database, coll, name interface{}, opts ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, database, coll, name}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOne", reflect.TypeOf((*MockIndexViewer)(nil).DropOne), varargs...) +} diff --git a/mongodb/mongodb_test.go b/mongodb/mongodb_test.go new file mode 100644 index 0000000..4004b13 --- /dev/null +++ b/mongodb/mongodb_test.go @@ -0,0 +1,156 @@ +package mongodb_test + +import ( + "context" + "flag" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "trpc.group/trpc-go/trpc-database/mongodb" + "trpc.group/trpc-go/trpc-go/client" +) + +var ctx = context.Background() + +var database = "test" +var collection = "test" + +// var target = flag.String("target", "mongodb://user:passwd@127.0.0.1:27017", "mongodb server target dsn address") +var cmd = flag.String("cmd", "find", "cmd") +var timeout = flag.Duration("timeout", 1*time.Millisecond, "timeout") +var findArgs = map[string]interface{}{ + "filter": map[string]interface{}{ + "name": "tony", + "age": map[string]interface{}{ + "$gt": 17, + }, + }, + "sort": map[string]interface{}{ + "age": -1, + }, +} +var deleteArgs = map[string]interface{}{ + "filter": map[string]interface{}{ + "age": map[string]interface{}{ + "$gt": 220, + }, + "name": "teacher", + }, +} + +var findOneAndUpdateArgs = map[string]interface{}{ + "filter": map[string]interface{}{ + "name": "lucy", + "age": map[string]interface{}{ + "$lt": 17, + }, + }, + "update": map[string]interface{}{ + "$set": map[string]interface{}{ + "name": "tommmmmmmmmmmmm", + "age": 00000000, + }, + }, + "sort": map[string]interface{}{ + "age": -1, + }, + "returnDocument": "After", +} + +var inserOneArgs = map[string]interface{}{ + "document": map[string]interface{}{ + "name": "tony", + "age": 18, + "sex": "female", + }, +} + +var document1 = map[string]interface{}{ + "name": "lucy", + "age": 0, + "sex": "m", +} +var document2 = map[string]interface{}{ + "name": "lucy2", + "age": 1, + "sex": "m", +} +var document = map[string]interface{}{ + "name": "lucy3", + "age": 2, + "sex": "m", +} +var insertManyArgs = map[string]interface{}{ + "documents": []interface{}{document, document1, document2}, +} + +var updateOneArgs = map[string]interface{}{ + "filter": map[string]interface{}{ + "name": "tom", + "age": map[string]interface{}{ + "$et": 17, + }, + }, + "update": map[string]interface{}{ + "$set": map[string]interface{}{ + "sex": "male", + }, + }, +} +var updateManyArgs = map[string]interface{}{ + "filter": map[string]interface{}{ + "name": "tom", + "age": map[string]interface{}{ + "$gt": 17, + }, + }, + "update": map[string]interface{}{ + "$set": map[string]interface{}{ + "sex": "male", + }, + }, +} + +var distinctArgs = map[string]interface{}{ + "filter": map[string]interface{}{ + "age": map[string]interface{}{ + "$gt": 17, + }, + }, + "fieldName": "name", +} + +func TestMongodbDo(t *testing.T) { + flag.Parse() + var args = map[string]interface{}{} + switch strings.ToLower(*cmd) { + case "find": + args = findArgs + case "deleteone": + args = deleteArgs + case "deletemany": + args = deleteArgs + case "findoneanddelete": + args = findArgs + case "findoneandupdate": + args = findOneAndUpdateArgs + case "insertone": + args = inserOneArgs + case "insertmany": + args = insertManyArgs + case "updateone": + args = updateOneArgs + case "updatemany": + args = updateManyArgs + case "count": + args = findArgs + case "distinct": + args = distinctArgs + } + proxy := mongodb.NewClientProxy("trpc.mongodb.server.service", client.WithTimeout(*timeout)) + + _, err := proxy.Do(ctx, *cmd, database, collection, args) + assert.NotNil(t, err) +} diff --git a/mongodb/options.go b/mongodb/options.go new file mode 100644 index 0000000..ac2c278 --- /dev/null +++ b/mongodb/options.go @@ -0,0 +1,15 @@ +package mongodb + +import ( + "go.mongodb.org/mongo-driver/mongo/options" +) + +// ClientTransportOption sets client transport parameter. +type ClientTransportOption func(opt *ClientTransport) + +// WithOptionInterceptor returns an ClientTransportOption which sets mongo client option interceptor +func WithOptionInterceptor(f func(dsn string, opts *options.ClientOptions)) ClientTransportOption { + return func(ct *ClientTransport) { + ct.optionInterceptor = f + } +} diff --git a/mongodb/options_test.go b/mongodb/options_test.go new file mode 100644 index 0000000..20f4b1a --- /dev/null +++ b/mongodb/options_test.go @@ -0,0 +1,21 @@ +package mongodb + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo/options" +) + +func TestProduceOption(t *testing.T) { + pm := &event.PoolMonitor{} + i := func(dsn string, opts *options.ClientOptions) { + opts.SetPoolMonitor(pm) + } + transport := NewMongoTransport(WithOptionInterceptor(i)) + ct := transport.(*ClientTransport) + opts := options.ClientOptions{} + ct.optionInterceptor("", &opts) + assert.Equal(t, opts.PoolMonitor, pm) +} diff --git a/mongodb/plugin.go b/mongodb/plugin.go new file mode 100644 index 0000000..9508e3e --- /dev/null +++ b/mongodb/plugin.go @@ -0,0 +1,66 @@ +package mongodb + +import ( + "time" + + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/readpref" + "trpc.group/trpc-go/trpc-go/log" + "trpc.group/trpc-go/trpc-go/plugin" + "trpc.group/trpc-go/trpc-go/transport" +) + +const ( + pluginType = "database" + pluginName = "mongodb" +) + +func init() { + plugin.Register(pluginName, &mongoPlugin{}) +} + +// Config mongo is a proxy configuration structure declaration. +type Config struct { + MinOpen uint64 `yaml:"min_open"` // Minimum number of simultaneous online connections + MaxOpen uint64 `yaml:"max_open"` // The maximum number of simultaneous online connections + MaxIdleTime time.Duration `yaml:"max_idle_time"` // Maximum idle time per link + ReadPreference string `yaml:"read_preference"` // reference on read +} + +// mongoPlugin is used for plug-in default initialization, +// used to load mongo proxy connection parameter configuration. +type mongoPlugin struct{} + +// Type is plugin type. +func (m *mongoPlugin) Type() string { + return pluginType +} + +// Setup is plugin initialization. +func (m *mongoPlugin) Setup(name string, configDesc plugin.Decoder) (err error) { + var config Config // yaml database:mongo connection configuration parameters + if err = configDesc.Decode(&config); err != nil { + return + } + readMode, err := readpref.ModeFromString(config.ReadPreference) + if err != nil { + log.Errorf("readpref.ModeFromString failed, err=%v, set mod to primary", err) + readMode = readpref.PrimaryMode + } + readF, err := readpref.New(readMode) + if err != nil { + log.Errorf("readpref.New failed, err=%v, set mod to primary", err) + readF = readpref.Primary() + } + DefaultClientTransport = &ClientTransport{ + mongoDB: make(map[string]*mongo.Client), + MinOpenConns: config.MinOpen, + MaxOpenConns: config.MaxOpen, + MaxConnIdleTime: config.MaxIdleTime, + ReadPreference: readF, + ServiceNameURIs: make(map[string][]string), + } + // You need to explicitly call register, otherwise the configuration will not take effect. + transport.RegisterClientTransport(pluginName, DefaultClientTransport) + return nil +} diff --git a/mongodb/plugin_test.go b/mongodb/plugin_test.go new file mode 100644 index 0000000..aa31907 --- /dev/null +++ b/mongodb/plugin_test.go @@ -0,0 +1,134 @@ +package mongodb + +import ( + "errors" + "testing" + + "github.com/agiledragon/gomonkey/v2" + . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "go.mongodb.org/mongo-driver/mongo/readpref" + "gopkg.in/yaml.v3" + "trpc.group/trpc-go/trpc-go" + "trpc.group/trpc-go/trpc-go/plugin" +) + +// TestUnit_MongoPlugin_Type_P0 MongoPlugin.Type is test case. +func TestUnit_MongoPlugin_Type_P0(t *testing.T) { + Convey("TestUnit_MongoPlugin_Type_P0", t, func() { + mongoPlugin := new(mongoPlugin) + So(mongoPlugin.Type(), ShouldEqual, pluginType) + }) +} + +// TestUnit_MongoPlugin_Setup_P0 MongoPlugin.Setup is test case. +func TestUnit_MongoPlugin_Setup_P0(t *testing.T) { + Convey("TestUnit_MongoPlugin_Setup_P0", t, func() { + mp := &mongoPlugin{} + Convey("Config Decode Fail", func() { + err := mp.Setup(pluginName, &plugin.YamlNodeDecoder{Node: nil}) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldEqual, "yaml node empty") + }) + Convey("Setup Success", func() { + var bts = ` +plugins: + database: + mongodb: + min_open: 20 + max_open: 100 + max_idle_time: 1s + read_preference: secondar +` + + var cfg = trpc.Config{} + err := yaml.Unmarshal([]byte(bts), &cfg) + assert.Nil(t, err) + var yamlNode *yaml.Node + if configP, ok := cfg.Plugins[pluginType]; ok { + if node, ok := configP[pluginName]; ok { + yamlNode = &node + } + } + err = mp.Setup(pluginName, &plugin.YamlNodeDecoder{Node: yamlNode}) + So(err, ShouldBeNil) + }) + }) +} + +// TestUnit_MongoPlugin_Setup_Fail MongoPlugin.Setup is typo test case. +func TestUnit_MongoPlugin_Setup_Fail(t *testing.T) { + + Convey("TestUnit_MongoPlugin_Setup_Fail", t, func() { + mp := &mongoPlugin{} + Convey("Config Decode Fail", func() { + err := mp.Setup(pluginName, &plugin.YamlNodeDecoder{Node: nil}) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldEqual, "yaml node empty") + }) + Convey("Setup Success", func() { + var bts = ` +plugins: + database: + mongodb: + min_open: 20 + max_open: 100 + max_idle_time: 1s + read_preference: secondary +` + + var cfg = trpc.Config{} + err := yaml.Unmarshal([]byte(bts), &cfg) + assert.Nil(t, err) + var yamlNode *yaml.Node + if configP, ok := cfg.Plugins[pluginType]; ok { + if node, ok := configP[pluginName]; ok { + yamlNode = &node + } + } + err = mp.Setup(pluginName, &plugin.YamlNodeDecoder{Node: yamlNode}) + So(err, ShouldBeNil) + }) + }) +} + +// TestUnit_MongoPlugin_Setup_New_READF_Failed MongoPlugin.Setup is test case. +func TestUnit_MongoPlugin_Setup_New_READF_Failed(t *testing.T) { + defer gomonkey.ApplyFunc(readpref.New, + func(mode readpref.Mode, opts ...readpref.Option) (*readpref.ReadPref, error) { + return nil, errors.New("test fail") + }, + ).Reset() + + Convey("TestUnit_MongoPlugin_Setup_P0", t, func() { + mp := &mongoPlugin{} + Convey("Config Decode Fail", func() { + err := mp.Setup(pluginName, &plugin.YamlNodeDecoder{Node: nil}) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldEqual, "yaml node empty") + }) + Convey("Setup Success", func() { + var bts = ` +plugins: + database: + mongodb: + min_open: 20 + max_open: 100 + max_idle_time: 1s + read_preference: secondary +` + + var cfg = trpc.Config{} + err := yaml.Unmarshal([]byte(bts), &cfg) + assert.Nil(t, err) + var yamlNode *yaml.Node + if configP, ok := cfg.Plugins[pluginType]; ok { + if node, ok := configP[pluginName]; ok { + yamlNode = &node + } + } + err = mp.Setup(pluginName, &plugin.YamlNodeDecoder{Node: yamlNode}) + So(err, ShouldBeNil) + }) + }) +} diff --git a/mongodb/test.sh b/mongodb/test.sh new file mode 100644 index 0000000..a855d95 --- /dev/null +++ b/mongodb/test.sh @@ -0,0 +1,4 @@ +#!/bin/sh +go clean --modcache +# Only run unit tests of type *Unit*, filtering some test cases with network calls. +go test ./... --covermode=count -coverprofile=cover.out -v -gcflags=all=-l -run Unit \ No newline at end of file diff --git a/mongodb/transport.go b/mongodb/transport.go new file mode 100644 index 0000000..02c29de --- /dev/null +++ b/mongodb/transport.go @@ -0,0 +1,429 @@ +package mongodb + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/readpref" + "trpc.group/trpc-go/trpc-go" + "trpc.group/trpc-go/trpc-go/codec" + "trpc.group/trpc-go/trpc-go/errs" + "trpc.group/trpc-go/trpc-go/naming/selector" + "trpc.group/trpc-go/trpc-go/transport" + dsn "trpc.group/trpc-go/trpc-selector-dsn" +) + +func init() { + selector.Register("mongodb", dsn.NewDsnSelector(true)) + selector.Register("mongodb+polaris", dsn.NewResolvableSelectorWithOpts("polaris", + dsn.WithEnableParseAddr(true), dsn.WithExtractor(&hostExtractor{}))) + transport.RegisterClientTransport("mongodb", DefaultClientTransport) +} + +// ClientTransport is a client-side mongodb transport. +type ClientTransport struct { + mongoDB map[string]*mongo.Client + optionInterceptor func(dsn string, opts *options.ClientOptions) + mongoDBLock sync.RWMutex + MaxOpenConns uint64 + MinOpenConns uint64 + MaxConnIdleTime time.Duration + ReadPreference *readpref.ReadPref + ServiceNameURIs map[string][]string +} + +// DefaultClientTransport is a default client mongodb transport. +var DefaultClientTransport = NewMongoTransport() + +// NewClientTransport creates a mongodb transport. +// Deprecated,use NewMongoTransport instead. +func NewClientTransport(opt ...transport.ClientTransportOption) transport.ClientTransport { + return NewMongoTransport() +} + +// NewMongoTransport creates a mongodb transport. +func NewMongoTransport(opt ...ClientTransportOption) transport.ClientTransport { + ct := &ClientTransport{ + optionInterceptor: func(dsn string, opts *options.ClientOptions) {}, + mongoDB: make(map[string]*mongo.Client), + MaxOpenConns: 100, // The maximum number of connections in the connection pool. + MinOpenConns: 5, + MaxConnIdleTime: 5 * time.Minute, // Connection pool idle connection time. + ReadPreference: readpref.Primary(), + ServiceNameURIs: make(map[string][]string), + } + for _, o := range opt { + o(ct) + } + return ct +} + +// RoundTrip sends and receives mongodb packets, +// returns the mongodb response and puts it in ctx, there is no need to return rspbuf here. +func (ct *ClientTransport) RoundTrip(ctx context.Context, _ []byte, + callOpts ...transport.RoundTripOption) (rspBytes []byte, + err error) { + + msg := codec.Message(ctx) + + req, ok := msg.ClientReqHead().(*Request) + if !ok { + return nil, errs.NewFrameError(errs.RetClientEncodeFail, + "mongodb client transport: ReqHead should be type of *mongodb.Request") + } + rsp, ok := msg.ClientRspHead().(*Response) + if !ok { + return nil, errs.NewFrameError(errs.RetClientEncodeFail, + "mongodb client transport: RspHead should be type of *mongodb.Response") + } + + opts := &transport.RoundTripOptions{} + for _, o := range callOpts { + o(opts) + } + + if req.Command == Disconnect { + return nil, ct.disconnect(ctx) + } + // Determine whether to use the mgo instance specified by the client. + var mgo *mongo.Client + if rsp.txClient != nil { + mgo = rsp.txClient + } else { + mgo, err = ct.GetMgoClient(ctx, opts.Address) + if err != nil { + return nil, errs.NewFrameError(errs.RetClientNetErr, + fmt.Sprintf("get mongo client failed: %s", err.Error())) + } + } + + var result interface{} + if req.DriverProxy { + result, err = handleDriverReq(ctx, mgo, req) + } else { + result, err = handleReq(ctx, mgo, req) + } + rsp.Result = result + if err != nil { + if mongo.IsDuplicateKeyError(err) { + return nil, errs.Wrap(err, RetDuplicateKeyErr, err.Error()) + } else if mongo.IsTimeout(err) { + return nil, errs.Wrap(err, errs.RetClientTimeout, err.Error()) + } else if mongo.IsNetworkError(err) { + return nil, errs.Wrap(err, errs.RetClientNetErr, err.Error()) + } else { + return nil, errs.Wrap(err, errs.RetUnknown, err.Error()) + } + } + return nil, nil +} + +// handleDriverReq handles transparent transmission. +func handleDriverReq(ctx context.Context, mgoCli *mongo.Client, req *Request) (result interface{}, err error) { + collection := mgoCli.Database(req.Database).Collection(req.Collection) + switch req.Command { + case InsertOne: + return collection.InsertOne(ctx, req.CommArg, req.Opts.([]*options.InsertOneOptions)...) + case InsertMany: + return collection.InsertMany(ctx, req.CommArg.([]interface{}), req.Opts.([]*options.InsertManyOptions)...) + case DeleteOne: + return collection.DeleteOne(ctx, req.Filter, req.Opts.([]*options.DeleteOptions)...) + case DeleteMany: + return collection.DeleteMany(ctx, req.Filter, req.Opts.([]*options.DeleteOptions)...) + case UpdateOne: + return collection.UpdateOne(ctx, req.Filter, req.CommArg, req.Opts.([]*options.UpdateOptions)...) + case UpdateMany: + return collection.UpdateMany(ctx, req.Filter, req.CommArg, req.Opts.([]*options.UpdateOptions)...) + case ReplaceOne: + return collection.ReplaceOne(ctx, req.Filter, req.CommArg, req.Opts.([]*options.ReplaceOptions)...) + case Aggregate: + return collection.Aggregate(ctx, req.CommArg, req.Opts.([]*options.AggregateOptions)...) + case CountDocuments: + return collection.CountDocuments(ctx, req.Filter, req.Opts.([]*options.CountOptions)...) + case EstimatedDocumentCount: + return collection.EstimatedDocumentCount(ctx, req.Opts.([]*options.EstimatedDocumentCountOptions)...) + case Distinct: + return collection.Distinct(ctx, req.CommArg.(string), req.Filter, req.Opts.([]*options.DistinctOptions)...) + case Find: + return collection.Find(ctx, req.Filter, req.Opts.([]*options.FindOptions)...) + case FindOne: + return extractSingleResultErr(collection.FindOne(ctx, req.Filter, req.Opts.([]*options.FindOneOptions)...)) + case FindOneAndDelete: + return extractSingleResultErr(collection.FindOneAndDelete(ctx, req.Filter, + req.Opts.([]*options.FindOneAndDeleteOptions)...)) + case FindOneAndReplace: + return extractSingleResultErr(collection.FindOneAndReplace(ctx, req.Filter, req.CommArg, + req.Opts.([]*options.FindOneAndReplaceOptions)...)) + case FindOneAndUpdate: + return extractSingleResultErr(collection.FindOneAndUpdate(ctx, req.Filter, req.CommArg, + req.Opts.([]*options.FindOneAndUpdateOptions)...)) + case BulkWrite: + return collection.BulkWrite(ctx, req.CommArg.([]mongo.WriteModel), req.Opts.([]*options.BulkWriteOptions)...) + case Watch: + return mgoCli.Watch(ctx, req.CommArg, req.Opts.([]*options.ChangeStreamOptions)...) + case WatchDatabase: + return mgoCli.Database(req.Database).Watch(ctx, req.CommArg, req.Opts.([]*options.ChangeStreamOptions)...) + case WatchCollection: + return mgoCli.Database(req.Database).Collection(req.Collection).Watch(ctx, + req.CommArg, req.Opts.([]*options.ChangeStreamOptions)...) + case Transaction: + return execMongoTransaction(ctx, mgoCli, req.CommArg.(TxFunc), req.Filter.([]*options.TransactionOptions), + req.Opts.([]*options.SessionOptions)...) + case RunCommand: + return extractSingleResultErr(mgoCli.Database(req.Database).RunCommand(ctx, + req.CommArg, req.Opts.([]*options.RunCmdOptions)...)) + case IndexCreateOne: + return collection.Indexes().CreateOne(ctx, + req.CommArg.(mongo.IndexModel), req.Opts.([]*options.CreateIndexesOptions)...) + case IndexCreateMany: + return collection.Indexes().CreateMany(ctx, + req.CommArg.([]mongo.IndexModel), req.Opts.([]*options.CreateIndexesOptions)...) + case IndexDropOne: + return collection.Indexes().DropOne(ctx, + req.CommArg.(string), req.Opts.([]*options.DropIndexesOptions)...) + case IndexDropAll: + return collection.Indexes().DropAll(ctx, + req.Opts.([]*options.DropIndexesOptions)...) + case Indexes: + return collection.Indexes(), nil + case DatabaseCmd: + return mgoCli.Database(req.Database), nil + case CollectionCmd: + return collection, nil + case StartSession: + return mgoCli.StartSession() + default: + return nil, errs.New(errs.RetClientDecodeFail, "error mongo command") + } +} + +// execMongoTransaction executes mongo transactions. +func execMongoTransaction(ctx context.Context, mgoCli *mongo.Client, sf TxFunc, tOpts []*options.TransactionOptions, + opts ...*options.SessionOptions) (result interface{}, err error) { + + rspHead := trpc.Message(ctx).ClientRspHead() + // Bind client before transaction execution. + if rspHead == nil { + return nil, errs.New(errs.RetClientDecodeFail, "rspHead can not be nil") + } + mCliRsp, ok := rspHead.(*Response) + if !ok { + return nil, errs.New(errs.RetClientDecodeFail, "conversion from rspHead to Respons failed") + } + mCliRsp.txClient = mgoCli + + // Obtain session. + sess, err := mgoCli.StartSession(opts...) + if err != nil { + return nil, err + } + + // Close session when finished. + defer func() { + sess.EndSession(ctx) + }() + + transactionFn := func(sessCtx mongo.SessionContext) (interface{}, error) { + return nil, sf(sessCtx) + } + + return sess.WithTransaction(ctx, transactionFn, tOpts...) +} + +// handleReq is an auxiliary function that handles the Req passed in by RoundTrip. +func handleReq(ctx context.Context, mgoCli *mongo.Client, + req *Request) (result interface{}, err error) { + collection := mgoCli.Database(req.Database).Collection(req.Collection) + + switch strings.ToLower(req.Command) { + case Find: + result, err = executeFind(ctx, collection, req.Arguments) + case FindC: + result, err = executeFindC(ctx, collection, req.Arguments) + case DeleteOne: + result, err = executeDeleteOne(ctx, collection, req.Arguments) + case DeleteMany: + result, err = executeDeleteMany(ctx, collection, req.Arguments) + case FindOneAndDelete: + result, err = executeFindOneAndDelete(ctx, collection, req.Arguments) + case FindOneAndUpdate: + result, err = executeFindOneAndUpdate(ctx, collection, req.Arguments) + case FindOneAndUpdateS: + result, err = executeFindOneAndUpdateS(ctx, collection, req.Arguments) + case InsertOne: + result, err = executeInsertOne(ctx, collection, req.Arguments) + case InsertMany: + result, err = executeInsertMany(ctx, collection, req.Arguments) + case UpdateOne: + result, err = executeUpdateOne(ctx, collection, req.Arguments) + case UpdateMany: + result, err = executeUpdateMany(ctx, collection, req.Arguments) + case Count: + result, err = executeCount(ctx, collection, req.Arguments) + case Aggregate: + result, err = executeAggregate(ctx, collection, req.Arguments) + case AggregateC: + result, err = executeAggregateC(ctx, collection, req.Arguments) + case Distinct: + result, err = executeDistinct(ctx, collection, req.Arguments) + case BulkWrite: + result, err = executeBulkWrite(ctx, collection, req.Arguments) + case IndexCreateOne: + result, err = executeIndexCreateOne(ctx, collection.Indexes(), req.Arguments) + case IndexCreateMany: + result, err = executeIndexCreateMany(ctx, collection.Indexes(), req.Arguments) + case IndexDropOne: + result, err = executeIndexDropOne(ctx, collection.Indexes(), req.Arguments) + case IndexDropAll: + result, err = executeIndexDropAll(ctx, collection.Indexes(), req.Arguments) + case Indexes: + result = collection.Indexes() + case DatabaseCmd: + result = mgoCli.Database(req.Database) + case CollectionCmd: + result = collection + case StartSession: + result, err = mgoCli.StartSession() + default: + err = fmt.Errorf("error mongo command") + } + return result, err +} + +// GetMgoClient obtains mongodb client, cache dsn=>client, +// save some initialization steps such as reparsing parameters, generating topology server, etc. +func (ct *ClientTransport) GetMgoClient(ctx context.Context, dsn string) (*mongo.Client, error) { + ct.mongoDBLock.RLock() + mgo, ok := ct.mongoDB[dsn] + ct.mongoDBLock.RUnlock() + + if ok { + return mgo, nil + } + ct.mongoDBLock.Lock() + defer ct.mongoDBLock.Unlock() + + mgo, ok = ct.mongoDB[dsn] + if ok { + return mgo, nil + } + clientOptions := ct.getClientOptions(dsn) + + // Based on the uri parameter, if it is not set, use the default value. + if clientOptions.MaxPoolSize == nil { + clientOptions.SetMaxPoolSize(ct.MaxOpenConns) + } + if clientOptions.MinPoolSize == nil { + clientOptions.SetMinPoolSize(ct.MinOpenConns) + } + if clientOptions.MaxConnIdleTime == nil { + clientOptions.SetMaxConnIdleTime(ct.MaxConnIdleTime) + } + if clientOptions.ReadPreference == nil { + clientOptions.SetReadPreference(ct.ReadPreference) + } + + if ct.optionInterceptor != nil { + ct.optionInterceptor(dsn, clientOptions) + } + + // The mongo-driver manages the connection itself, once Connect is initialized and used multiple times. + mgo, err := mongo.Connect(ctx, clientOptions) + if err != nil { + return nil, err + } + + err = mgo.Ping(ctx, ct.ReadPreference) + if err != nil { + _ = mgo.Disconnect(ctx) + return nil, fmt.Errorf("ping mongo failed: %w", err) + } + + ct.mongoDB[dsn] = mgo + serviceName := codec.Message(ctx).CalleeServiceName() + ct.ServiceNameURIs[serviceName] = append(ct.ServiceNameURIs[serviceName], dsn) + return mgo, nil +} + +func (ct *ClientTransport) getClientOptions(dsn string) *options.ClientOptions { + uri := "mongodb://" + dsn + clientOptions := options.Client().ApplyURI(uri) + if clientOptions.MaxPoolSize == nil { + clientOptions.SetMaxPoolSize(ct.MaxOpenConns) + } + if clientOptions.MinPoolSize == nil { + clientOptions.SetMinPoolSize(ct.MinOpenConns) + } + if clientOptions.MaxConnIdleTime == nil { + clientOptions.SetMaxConnIdleTime(ct.MaxConnIdleTime) + } + return clientOptions +} + +func (ct *ClientTransport) disconnect(ctx context.Context) error { + serviceName := codec.Message(ctx).CalleeServiceName() + ct.mongoDBLock.RLock() + uris, ok := ct.ServiceNameURIs[serviceName] + ct.mongoDBLock.RUnlock() + if !ok { + return nil + } + + ct.mongoDBLock.Lock() + defer ct.mongoDBLock.Unlock() + var funcs []func() error + for _, uri := range uris { + mgoCli, ok := ct.mongoDB[uri] + if !ok { + continue + } + delete(ct.mongoDB, uri) + funcs = append(funcs, func() error { + return mgoCli.Disconnect(ctx) + }) + } + + delete(ct.ServiceNameURIs, serviceName) + return trpc.GoAndWait(funcs...) +} + +type hostExtractor struct { +} + +// Extract extractHost is used to remove the "://" of uri and the part before it. +func (e *hostExtractor) Extract(uri string) (begin int, length int, err error) { + // mongodb+polaris://user:pswd@xxx.mongodb.com + offset := 0 + + if idx := strings.Index(uri, "@"); idx != -1 { + uri = uri[idx+1:] + offset += idx + 1 + } + + begin = offset + length = len(uri) + if idx := strings.IndexAny(uri, "/?@"); idx != -1 { + if uri[idx] == '@' { + return 0, 0, errs.NewFrameError(errs.RetClientRouteErr, "unescaped @ sign in user info") + } + if uri[idx] == '?' { + return 0, 0, errs.NewFrameError(errs.RetClientRouteErr, "must have a / before the query ?") + } + length = idx + } + + return +} + +func extractSingleResultErr(ms *mongo.SingleResult) (*mongo.SingleResult, error) { + if ms != nil && ms.Err() != nil { + return nil, ms.Err() + } + + return ms, nil +} diff --git a/mongodb/transport_test.go b/mongodb/transport_test.go new file mode 100644 index 0000000..7fc5921 --- /dev/null +++ b/mongodb/transport_test.go @@ -0,0 +1,397 @@ +package mongodb + +import ( + "context" + "fmt" + "reflect" + "testing" + "time" + + "github.com/agiledragon/gomonkey/v2" + . "github.com/smartystreets/goconvey/convey" + "github.com/stretchr/testify/assert" + "go.mongodb.org/mongo-driver/event" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/readpref" + "trpc.group/trpc-go/trpc-go" + "trpc.group/trpc-go/trpc-go/codec" + "trpc.group/trpc-go/trpc-go/errs" +) + +func TestUnitExtractHost(t *testing.T) { + assert := assert.New(t) + + cases := []struct { + uri string // The uri here has removed "://" and the part before it. + host string + err string + }{ + { + uri: "localhost", + host: "localhost", + }, + { + uri: "admin:123456@localhost/", + host: "localhost", + }, + { + uri: "admin:123456@localhost", + host: "localhost", + }, + { + uri: "example1.com:27017,example2.com:27017", + host: "example1.com:27017,example2.com:27017", + }, + { + uri: "host1,host2,host3/?slaveOk=true", + host: "host1,host2,host3", + }, + { + uri: "admin:@123456@localhost/", + err: errs.NewFrameError(errs.RetClientRouteErr, "unescaped @ sign in user info").Error(), + }, + { + uri: "admin:123456@localhost?", + err: errs.NewFrameError(errs.RetClientRouteErr, "must have a / before the query ?").Error(), + }, + } + + for _, c := range cases { + pos, length, err := new(hostExtractor).Extract(c.uri) + if len(c.err) != 0 { + assert.EqualErrorf(err, c.err, "case: %+v ", c) + } else { + assert.Equalf(c.host, c.uri[pos:pos+length], "case: %+v", c) + } + } +} + +func TestUnitClientTransport_GetMgoClient(t *testing.T) { + Convey("TestUnitClientTransport_GetMgoClient", t, func() { + pm := &event.PoolMonitor{} + i := func(dsn string, opts *options.ClientOptions) { + opts.SetPoolMonitor(pm) + } + cli := &ClientTransport{ + optionInterceptor: i, + mongoDB: make(map[string]*mongo.Client), + MaxOpenConns: 100, // The maximum number of connections in the connection pool. + MinOpenConns: 5, + MaxConnIdleTime: 5 * time.Minute, // Connection pool idle connection time. + ServiceNameURIs: make(map[string][]string), + } + Convey("err", func() { + defer gomonkey.ApplyFunc(mongo.NewClient, func(opts ...*options.ClientOptions) (*mongo.Client, error) { + return nil, fmt.Errorf("err") + }).Reset() + ctx := context.TODO() + mCli, err := cli.GetMgoClient(ctx, "addr1") + So(mCli, ShouldBeNil) + So(err, ShouldNotBeNil) + }) + Convey("succ", func() { + defer gomonkey.ApplyFunc(mongo.NewClient, func(opts ...*options.ClientOptions) (*mongo.Client, error) { + return new(mongo.Client), nil + }).Reset() + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Client)), "Ping", + func(client *mongo.Client, ctx context.Context, rp *readpref.ReadPref) error { + return nil + }, + ).Reset() + ctx := context.TODO() + mCli1, err := cli.GetMgoClient(ctx, "addr1") + So(mCli1, ShouldNotBeNil) + So(err, ShouldBeNil) + So(len(cli.mongoDB), ShouldEqual, 1) + mCli2, err := cli.GetMgoClient(ctx, "addr2") + So(mCli2, ShouldNotBeNil) + So(err, ShouldBeNil) + So(len(cli.mongoDB), ShouldEqual, 2) + mCli3, err := cli.GetMgoClient(ctx, "addr1") + So(mCli3, ShouldNotBeNil) + So(err, ShouldBeNil) + So(len(cli.mongoDB), ShouldEqual, 2) + So(mCli1, ShouldEqual, mCli3) + So(mCli1, ShouldNotEqual, mCli2) + }) + }) +} + +// TestUnitRoundTrip tests the RoundTrip function. +func TestUnitRoundTrip(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(ClientTransport)), "GetMgoClient", + func(ct *ClientTransport, ctx context.Context, dsn string) (*mongo.Client, error) { + return &mongo.Client{}, nil + }, + ).Reset() + + clientTransport := DefaultClientTransport + ctx, msg := codec.WithNewMessage(trpc.BackgroundContext()) + _, err := clientTransport.RoundTrip(ctx, nil) + assert.NotNil(t, err) + + msg.WithClientReqHead(&Request{}) + _, err = clientTransport.RoundTrip(ctx, nil) + assert.NotNil(t, err) + + msg.WithClientRspHead(&Response{}) + _, err = clientTransport.RoundTrip(ctx, nil) + assert.NotNil(t, err) + + msg.WithClientReqHead(&Request{ + Command: Find, + }) + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "Find", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.FindOptions) (*mongo.Cursor, error) { + return nil, errs.New(-1, "find cover") + }, + ).Reset() + _, err = clientTransport.RoundTrip(ctx, nil) + assert.NotNil(t, err) + + msg.WithClientReqHead(&Request{ + Command: Find, + }) + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "Find", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.FindOptions) (*mongo.Cursor, error) { + return nil, mongo.WriteException{WriteErrors: mongo.WriteErrors{mongo.WriteError{Code: 11000}}} + }, + ).Reset() + _, err = clientTransport.RoundTrip(ctx, nil) + assert.NotNil(t, err) + + msg.WithClientReqHead(&Request{ + Command: Find, + }) + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "Find", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.FindOptions) (*mongo.Cursor, error) { + return nil, context.DeadlineExceeded + }, + ).Reset() + _, err = clientTransport.RoundTrip(ctx, nil) + assert.NotNil(t, err) + + msg.WithClientReqHead(&Request{ + Command: Find, + }) + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "Find", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.FindOptions) (*mongo.Cursor, error) { + return nil, mongo.CommandError{Labels: []string{"NetworkError"}} + }, + ).Reset() + _, err = clientTransport.RoundTrip(ctx, nil) + assert.NotNil(t, err) + + msg.WithClientReqHead(&Request{ + Command: FindC, + }) + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "Find", + func(coll *mongo.Collection, ctx context.Context, filter interface{}, + opts ...*options.FindOptions) (*mongo.Cursor, error) { + return &mongo.Cursor{}, nil + }, + ).Reset() + _, err = clientTransport.RoundTrip(ctx, nil) + assert.Nil(t, err) + + msg.WithClientReqHead(&Request{ + Command: DeleteOne, + }) + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Collection)), "DeleteOne", + func(*mongo.Collection, context.Context, interface{}, ...*options.DeleteOptions) (*mongo.DeleteResult, error) { + return &mongo.DeleteResult{}, nil + }, + ).Reset() + _, err = clientTransport.RoundTrip(ctx, nil) + assert.Nil(t, err) +} + +func TestClientTransport_disconnect(t *testing.T) { + defer gomonkey.ApplyMethod(reflect.TypeOf(new(mongo.Client)), "Disconnect", + func(coll *mongo.Client, ctx context.Context) error { + return nil + }, + ).Reset() + type fields struct { + mongoDB map[string]*mongo.Client + MaxOpenConns uint64 + MinOpenConns uint64 + MaxConnIdleTime time.Duration + ReadPreference *readpref.ReadPref + ServiceNameURIs map[string][]string + } + + ctx := trpc.BackgroundContext() + msg := trpc.Message(ctx) + msg.WithCalleeServiceName("uri") + + type args struct { + ctx context.Context + } + tests := []struct { + name string + fields fields + args args + wantErr bool + }{ + { + name: "succ", + fields: fields{ + ServiceNameURIs: map[string][]string{ + "uri": {"uri1"}, + }, + mongoDB: map[string]*mongo.Client{ + "uri1": nil, + }, + }, + args: args{ + ctx: ctx, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ct := &ClientTransport{ + mongoDB: tt.fields.mongoDB, + MaxOpenConns: tt.fields.MaxOpenConns, + MinOpenConns: tt.fields.MinOpenConns, + MaxConnIdleTime: tt.fields.MaxConnIdleTime, + ReadPreference: tt.fields.ReadPreference, + ServiceNameURIs: tt.fields.ServiceNameURIs, + } + if err := ct.disconnect(tt.args.ctx); (err != nil) != tt.wantErr { + t.Errorf("ClientTransport.disconnect() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_handleReq(t *testing.T) { + c := &mongo.Client{} + db := &mongo.Database{} + col := &mongo.Collection{} + m1 := gomonkey.ApplyMethod(c, "Database", + func(c *mongo.Client, name string, opts ...*options.DatabaseOptions) *mongo.Database { + return db + }) + defer m1.Reset() + m2 := gomonkey.ApplyMethod(db, "Collection", + func(c *mongo.Database, name string, opts ...*options.CollectionOptions) *mongo.Collection { + return col + }) + defer m2.Reset() + ctx := trpc.BackgroundContext() + t.Run("indexes", func(t *testing.T) { + mm := gomonkey.ApplyMethod(col, "Indexes", func(c *mongo.Collection) mongo.IndexView { return mongo.IndexView{} }) + defer mm.Reset() + gotResult, err := handleReq(ctx, c, &Request{Command: Indexes}) + if err != nil { + t.Errorf("handleReq() error = %v, wantErr %v", err, nil) + return + } + if !reflect.DeepEqual(gotResult, mongo.IndexView{}) { + t.Errorf("handleReq() gotResult = %v, want %v", gotResult, mongo.IndexView{}) + } + }) + t.Run("DatabaseCmd", func(t *testing.T) { + gotResult, err := handleReq(ctx, c, &Request{Command: DatabaseCmd}) + if err != nil { + t.Errorf("handleReq() error = %v, wantErr %v", err, nil) + return + } + if !reflect.DeepEqual(gotResult, db) { + t.Errorf("handleReq() gotResult = %v, want %v", gotResult, db) + } + }) + t.Run("CollectionCmd", func(t *testing.T) { + gotResult, err := handleReq(ctx, c, &Request{Command: CollectionCmd}) + if err != nil { + t.Errorf("handleReq() error = %v, wantErr %v", err, nil) + return + } + if !reflect.DeepEqual(gotResult, col) { + t.Errorf("handleReq() gotResult = %v, want %v", gotResult, col) + } + }) + t.Run("StartSession", func(t *testing.T) { + mm := gomonkey.ApplyMethod(c, "StartSession", + func(c *mongo.Client, opts ...*options.SessionOptions) (mongo.Session, error) { return nil, nil }) + defer mm.Reset() + gotResult, err := handleReq(ctx, c, &Request{Command: StartSession}) + if err != nil { + t.Errorf("handleReq() error = %v, wantErr %v", err, nil) + return + } + if !reflect.DeepEqual(gotResult, nil) { + t.Errorf("handleReq() gotResult = %v, want %v", gotResult, nil) + } + }) +} + +func Test_handleDriverReq(t *testing.T) { + c := &mongo.Client{} + db := &mongo.Database{} + col := &mongo.Collection{} + m1 := gomonkey.ApplyMethod(c, "Database", + func(c *mongo.Client, name string, opts ...*options.DatabaseOptions) *mongo.Database { + return db + }) + defer m1.Reset() + m2 := gomonkey.ApplyMethod(db, "Collection", + func(c *mongo.Database, name string, opts ...*options.CollectionOptions) *mongo.Collection { + return col + }) + defer m2.Reset() + ctx := trpc.BackgroundContext() + + t.Run("indexes", func(t *testing.T) { + mm := gomonkey.ApplyMethod(col, "Indexes", func(c *mongo.Collection) mongo.IndexView { return mongo.IndexView{} }) + defer mm.Reset() + gotResult, err := handleDriverReq(ctx, c, &Request{Command: Indexes}) + if err != nil { + t.Errorf("handleReq() error = %v, wantErr %v", err, nil) + return + } + if !reflect.DeepEqual(gotResult, mongo.IndexView{}) { + t.Errorf("handleReq() gotResult = %v, want %v", gotResult, mongo.IndexView{}) + } + }) + t.Run("DatabaseCmd", func(t *testing.T) { + gotResult, err := handleDriverReq(ctx, c, &Request{Command: DatabaseCmd}) + if err != nil { + t.Errorf("handleReq() error = %v, wantErr %v", err, nil) + return + } + if !reflect.DeepEqual(gotResult, db) { + t.Errorf("handleReq() gotResult = %v, want %v", gotResult, db) + } + }) + t.Run("CollectionCmd", func(t *testing.T) { + gotResult, err := handleDriverReq(ctx, c, &Request{Command: CollectionCmd}) + if err != nil { + t.Errorf("handleReq() error = %v, wantErr %v", err, nil) + return + } + if !reflect.DeepEqual(gotResult, col) { + t.Errorf("handleReq() gotResult = %v, want %v", gotResult, col) + } + }) + t.Run("StartSession", func(t *testing.T) { + mm := gomonkey.ApplyMethod(c, "StartSession", + func(c *mongo.Client, opts ...*options.SessionOptions) (mongo.Session, error) { return nil, nil }) + defer mm.Reset() + gotResult, err := handleDriverReq(ctx, c, &Request{Command: StartSession}) + if err != nil { + t.Errorf("handleReq() error = %v, wantErr %v", err, nil) + return + } + if !reflect.DeepEqual(gotResult, nil) { + t.Errorf("handleReq() gotResult = %v, want %v", gotResult, nil) + } + }) +}