Skip to content

Commit

Permalink
Merge pull request #80 from TykTechnologies/TT-9464-2
Browse files Browse the repository at this point in the history
[TT-9464] New encoding URL function
  • Loading branch information
mativm02 authored Aug 7, 2023
2 parents 10200b7 + 01b6eff commit 0871150
Show file tree
Hide file tree
Showing 3 changed files with 293 additions and 10 deletions.
166 changes: 161 additions & 5 deletions persistent/internal/driver/mongo/life_cycle.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ package mongo
import (
"context"
"errors"
"fmt"
"net/url"
"strings"
"time"

"github.com/TykTechnologies/storage/persistent/internal/helper"
Expand All @@ -12,7 +15,6 @@ import (
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"go.mongodb.org/mongo-driver/mongo/readpref"
"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"

"github.com/TykTechnologies/storage/persistent/internal/types"
)
Expand All @@ -26,17 +28,23 @@ type lifeCycle struct {

var _ types.StorageLifecycle = &lifeCycle{}

const (
MongoPrefix = "mongodb://"
MongoSRVPrefix = "mongodb+srv://"
)

// Connect connects to the mongo database given the ClientOpts.
func (lc *lifeCycle) Connect(opts *types.ClientOpts) error {
var err error
var client *mongo.Client

// we check if the connection string is valid before building the connOpts.
cs, err := connstring.ParseAndValidate(opts.ConnectionString)
url, cs, err := parseURL(opts.ConnectionString)
if err != nil {
return errors.New("invalid connection string")
return err
}

opts.ConnectionString = url

connOpts, err := mongoOptsBuilder(opts)
if err != nil {
return errors.New(err.Error())
Expand All @@ -50,12 +58,160 @@ func (lc *lifeCycle) Connect(opts *types.ClientOpts) error {
}

lc.connectionString = opts.ConnectionString
lc.database = cs.Database
lc.database = cs.db
lc.client = client

return lc.client.Ping(context.Background(), nil)
}

type urlInfo struct {
addrs []string
user string
pass string
db string
options []urlOptions
}

// urlOptions is a key/value pair representing a single option in a URL.
// we need to use this struct instead of a map to avoid flaky tests due to the order of the options
type urlOptions struct {
key string
val string
}

func isOptSep(c rune) bool {
return c == ';' || c == '&'
}

func parseURL(s string) (string, *urlInfo, error) {
var info *urlInfo
prefix := ""

if strings.HasPrefix(s, MongoPrefix) {
prefix = MongoPrefix
} else if strings.HasPrefix(s, MongoSRVPrefix) {
prefix = MongoSRVPrefix
}

switch prefix {
case MongoPrefix:
s = strings.TrimPrefix(s, MongoPrefix)
case MongoSRVPrefix:
s = strings.TrimPrefix(s, MongoSRVPrefix)
default:
return "", info, errors.New("invalid connection string, no prefix found")
}

info, err := extractURL(s)
if err != nil {
return "", info, err
}

var connString string
connString += prefix

if info.user != "" {
info.user = url.QueryEscape(info.user)
connString += info.user

if info.pass != "" {
info.pass = url.QueryEscape(info.pass)
connString += ":" + info.pass
}

connString += "@"
}

connString += strings.Join(info.addrs, ",")

connString += "/" + info.db

if len(info.options) > 0 {
connString += "?"
for _, v := range info.options {
connString += v.key + "=" + v.val + "&"
}

connString = connString[:len(connString)-1]
}

return connString, info, nil
}

func extractURL(s string) (*urlInfo, error) {
info := &urlInfo{options: make([]urlOptions, 0)}
var err error

if s, err = extractOptions(s, info); err != nil {
return nil, err
}

if s, err = extractCredentials(s, info); err != nil {
return nil, err
}

if s, err = extractDatabase(s, info); err != nil {
return nil, err
}

info.addrs = strings.Split(s, ",")

return info, nil
}

func extractOptions(s string, info *urlInfo) (string, error) {
if c := strings.Index(s, "?"); c != -1 {
for _, pair := range strings.FieldsFunc(s[c+1:], isOptSep) {
l := strings.SplitN(pair, "=", 2)
if len(l) != 2 || l[0] == "" || l[1] == "" {
return s, errors.New("connection option must be key=value: " + pair)
}

info.options = append(info.options, urlOptions{key: l[0], val: l[1]})
}

s = s[:c]
}

return s, nil
}

func extractCredentials(s string, info *urlInfo) (string, error) {
if c := strings.Index(s, "@"); c != -1 {
pair := strings.SplitN(s[:c], ":", 2)
if len(pair) > 2 || pair[0] == "" {
return s, errors.New("credentials must be provided as user:pass@host")
}

var err error

info.user, err = url.QueryUnescape(pair[0])
if err != nil {
return s, fmt.Errorf("cannot unescape username in URL: %q", pair[0])
}

if len(pair) > 1 {
info.pass, err = url.QueryUnescape(pair[1])
if err != nil {
return s, fmt.Errorf("cannot unescape password in URL")
}
}

s = s[c+1:]
}

return s, nil
}

func extractDatabase(s string, info *urlInfo) (string, error) {
if c := strings.Index(s, "/"); c != -1 {
info.db = s[c+1:]
s = s[:c]
}

return s, nil
}

// Close finish the session.
func (lc *lifeCycle) Close() error {
if lc.client != nil {
Expand Down
135 changes: 131 additions & 4 deletions persistent/internal/driver/mongo/life_cycle_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
//go:build mongo
// +build mongo

package mongo

import (
Expand Down Expand Up @@ -160,7 +157,7 @@ func TestConnect(t *testing.T) {
UseSSL: false,
Type: "mongodb",
},
want: errors.New("invalid connection string"),
want: errors.New("invalid connection string, no prefix found"),
},
{
name: "valid connection_string and invalid tls config",
Expand All @@ -185,6 +182,112 @@ func TestConnect(t *testing.T) {
}
}

func TestParseURL(t *testing.T) {
tests := []struct {
name string
url string
want string
wantErr bool
}{
{
name: "valid connection_string with special characters",
url: "mongodb://lt_tyk:6}3cZQU.9KvM/hVR4qkm-hHqZTu3yg=G@localhost:27017/tyk_analytics",
want: "mongodb://lt_tyk:6%7D3cZQU.9KvM%2FhVR4qkm-hHqZTu3yg%3DG@localhost:27017/tyk_analytics",
},
{
name: "already encoded valid url",
url: "mongodb://lt_tyk:6%7D3cZQU.9KvM%2FhVR4qkm-hHqZTu3yg%3DG@localhost:27017/tyk_analytics",
want: "mongodb://lt_tyk:6%7D3cZQU.9KvM%2FhVR4qkm-hHqZTu3yg%3DG@localhost:27017/tyk_analytics",
},
{
name: "invalid connection_string",
url: "invalid_conn_string",
want: "",
wantErr: true,
},
{
name: "valid connection string with @",
url: "mongodb://user:p@ssword@localhost:27017",
want: "mongodb://user:p@ssword@localhost:27017/",
},
{
name: "valid connection string with @ and /",
url: "mongodb://u=s@r:p@sswor/d@localhost:27017/test",
want: "mongodb://u%3Ds@r:p@sswor/d@localhost:27017/test",
},
{
name: "valid connection string with @ and / and '?' outside of the credentials part",
url: "mongodb://user:p@sswor/d@localhost:27017/test?authSource=admin",
want: "mongodb://user:p@sswor/d@localhost:27017/test?authSource=admin",
},
{
name: "special characters and multiple hosts",
url: "mongodb://user:p@sswor/d@localhost:27017,localhost:27018/test?authSource=admin",
want: "mongodb://user:p@sswor/d@localhost:27017,localhost:27018/test?authSource=admin",
},
{
name: "url without credentials",
url: "mongodb://localhost:27017/test?authSource=admin",
want: "mongodb://localhost:27017/test?authSource=admin",
},
{
name: "invalid connection string",
url: "test",
want: "",
wantErr: true,
},
{
name: "srv connection string",
url: "mongodb+srv://tyk:tyk@clur0.zlgl.mongodb.net/tyk?w=majority",
want: "mongodb+srv://tyk:tyk@clur0.zlgl.mongodb.net/tyk?w=majority",
},
{
name: "srv connection string with special characters",
url: "mongodb+srv://tyk:p@ssword@clur0.zlgl.mongodb.net/tyk?w=majority",
want: "mongodb+srv://tyk:p@ssword@clur0.zlgl.mongodb.net/tyk?w=majority",
},
{
name: "connection string without username",
url: "mongodb://:password@localhost:27017/test",
want: "",
wantErr: true,
},
{
name: "connection string without password",
url: "mongodb://user:@localhost:27017/test",
want: "mongodb://user@localhost:27017/test",
},
{
name: "connection string without host",
url: "mongodb://user:password@/test",
want: "mongodb://user:password@/test",
},
{
name: "connection string without database",
url: "mongodb://user:password@localhost:27017",
want: "mongodb://user:password@localhost:27017/",
},
{
name: "cosmosdb url",
url: "mongodb+srv://4-0-qa:zFAQ==@4-0-qa.azure:10/a1?appName=@4-testing@&maxIdleTimeMS=120000",
want: "mongodb+srv://4-0-qa:zFAQ%3D%3D@4-0-qa.azure:10/a1?appName=@4-testing@&maxIdleTimeMS=120000",
},
{
name: "cosmosdb url without database with options",
url: "mongodb+srv://tyk:6}3c.9KvM/hVR4qkm-hu3yg=G@clu0.zl.mongodb.net/?retryWrites=true&w=majority",
want: "mongodb+srv://tyk:6%7D3c.9KvM%2FhVR4qkm-hu3yg%3DG@clu0.zl.mongodb.net/?retryWrites=true&w=majority",
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
parsedURL, _, err := parseURL(test.url)
assert.Equal(t, test.want, parsedURL)
assert.Equal(t, test.wantErr, err != nil)
})
}
}

func TestClose(t *testing.T) {
lc := &lifeCycle{}
opts := &types.ClientOpts{
Expand Down Expand Up @@ -216,3 +319,27 @@ func TestDBType(t *testing.T) {
dbType := lc.DBType()
assert.Equal(t, utils.StandardMongo, dbType)
}

func TestIsOptSep(t *testing.T) {
tests := []struct {
input rune
want bool
}{
{';', true},
{'&', true},
{':', false},
{'a', false},
{'1', false},
{' ', false},
{'\t', false},
{'\n', false},
{'!', false},
}

for _, test := range tests {
got := isOptSep(test.input)
if got != test.want {
t.Errorf("isOptSep(%q) = %v, want %v", test.input, got, test.want)
}
}
}
2 changes: 1 addition & 1 deletion persistent/internal/driver/mongo/mongo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func TestNewMongoDriver(t *testing.T) {
})

assert.NotNil(t, err)
assert.Equal(t, "invalid connection string", err.Error())
assert.Equal(t, "invalid connection string, no prefix found", err.Error())
assert.Nil(t, newDriver)
})
t.Run("new driver without connection string", func(t *testing.T) {
Expand Down

0 comments on commit 0871150

Please sign in to comment.