Skip to content

Commit

Permalink
Fixed TLS --tls-protocols parsing. (#241)
Browse files Browse the repository at this point in the history
* Fixed tls arg tests

* Fixed tls arg tests

* Verbose output on CI runs

* Verbose output on CI runs

* Fixed tls-protocols parsing

* using c++ std::strtok instead of c strktok

* Skipping rate-limit + test time on cluster setups
  • Loading branch information
filipecosta90 authored Dec 2, 2023
1 parent 9ddfcff commit a5e6f19
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 13 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ jobs:
if: matrix.platform == 'ubuntu-latest'
timeout-minutes: 10
run: |
TLS_PROTOCOLS="tlsv1.2" TLS=1 ./tests/run_tests.sh
TLS_PROTOCOLS='TLSv1.2' VERBOSE=1 TLS=1 ./tests/run_tests.sh
- name: Test OSS TCP TLS v1.3
if: matrix.platform == 'ubuntu-latest'
timeout-minutes: 10
run: |
TLS_PROTOCOLS="tlsv1.3" TLS=1 ./tests/run_tests.sh
TLS_PROTOCOLS='TLSv1.3' VERBOSE=1 TLS=1 ./tests/run_tests.sh
- name: Test OSS-CLUSTER TCP
timeout-minutes: 10
Expand Down
13 changes: 7 additions & 6 deletions memtier_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@

#endif

#include <cstring>
#include <stdexcept>

#include "client.h"
Expand Down Expand Up @@ -896,9 +897,9 @@ static int config_parse_args(int argc, char *argv[], struct benchmark_config *cf
break;
case o_tls_protocols:
{
const char tls_delimiter = ',';
char* tls_token = strtok(optarg, &tls_delimiter);
while (tls_token != 0) {
const char* tls_delimiter = ",";
char* tls_token = std::strtok(optarg, tls_delimiter);
while (tls_token != NULL) {
if (!strcasecmp(tls_token, "tlsv1"))
cfg->tls_protocols |= REDIS_TLS_PROTO_TLSv1;
else if (!strcasecmp(tls_token, "tlsv1.1"))
Expand All @@ -913,12 +914,12 @@ static int config_parse_args(int argc, char *argv[], struct benchmark_config *cf
return -1;
#endif
} else {
fprintf(stderr, "Invalid tls-protocols specified. "
"Use a combination of 'TLSv1', 'TLSv1.1', 'TLSv1.2' and 'TLSv1.3'.");
fprintf(stderr, "Invalid tls-protocols specified %s. "
"Use a combination of 'TLSv1', 'TLSv1.1', 'TLSv1.2' and 'TLSv1.3'.", tls_token);
return -1;
break;
}
tls_token = strtok(0, &tls_delimiter);
tls_token = std::strtok(NULL, tls_delimiter);
}
break;
}
Expand Down
5 changes: 5 additions & 0 deletions tests/include.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import glob
import os
import logging

MEMTIER_BINARY = os.environ.get("MEMTIER_BINARY", "memtier_benchmark")
TLS_CERT = os.environ.get("TLS_CERT", "")
TLS_KEY = os.environ.get("TLS_KEY", "")
TLS_CACERT = os.environ.get("TLS_CACERT", "")
TLS_PROTOCOLS = os.environ.get("TLS_PROTOCOLS", "")
VERBOSE = bool(int(os.environ.get("VERBOSE","0")))


def ensure_tls_protocols(master_nodes_connections):
Expand Down Expand Up @@ -35,6 +37,9 @@ def assert_minimum_memtier_outcomes(config, env, memtier_ok, overall_expected_re
debugPrintMemtierOnError(config, env)

def add_required_env_arguments(benchmark_specs, config, env, master_nodes_list):
if VERBOSE:
logging.basicConfig(level=logging.DEBUG)

# if we've specified TLS_PROTOCOLS ensure we configure it on redis
master_nodes_connections = env.getOSSMasterNodesConnectionList()
ensure_tls_protocols(master_nodes_connections)
Expand Down
21 changes: 16 additions & 5 deletions tests/tests_oss_simple_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,10 @@ def test_default_set_get_3_runs(env):


def test_default_arbitrary_command_pubsub(env):
benchmark_specs = {"name": env.testName, "args": ['--command=publish \"__key__\" \"__data__\"']}
benchmark_specs = {"name": env.testName, "args": []}
addTLSArgs(benchmark_specs, env)
# on arbitrary command args should be the last one
benchmark_specs["args"].append('--command=publish \"__key__\" \"__data__\"')
config = get_default_memtier_config()
master_nodes_list = env.getMasterNodesList()

Expand All @@ -281,8 +283,10 @@ def test_default_arbitrary_command_pubsub(env):


def test_default_arbitrary_command_keyless(env):
benchmark_specs = {"name": env.testName, "args": ['--command=PING']}
benchmark_specs = {"name": env.testName, "args": []}
addTLSArgs(benchmark_specs, env)
# on arbitrary command args should be the last one
benchmark_specs["args"].append('--command=PING')
config = get_default_memtier_config()
master_nodes_list = env.getMasterNodesList()

Expand All @@ -301,8 +305,10 @@ def test_default_arbitrary_command_keyless(env):


def test_default_arbitrary_command_set(env):
benchmark_specs = {"name": env.testName, "args": ['--command=SET __key__ __data__']}
benchmark_specs = {"name": env.testName, "args": []}
addTLSArgs(benchmark_specs, env)
# on arbitrary command args should be the last one
benchmark_specs["args"].append('--command=SET __key__ __data__')
config = get_default_memtier_config()
master_nodes_list = env.getMasterNodesList()
overall_expected_request_count = get_expected_request_count(config)
Expand All @@ -327,8 +333,10 @@ def test_default_arbitrary_command_set(env):


def test_default_arbitrary_command_hset(env):
benchmark_specs = {"name": env.testName, "args": ['--command=HSET __key__ field1 __data__']}
benchmark_specs = {"name": env.testName, "args": []}
addTLSArgs(benchmark_specs, env)
# on arbitrary command args should be the last one
benchmark_specs["args"].append('--command=HSET __key__ field1 __data__')
config = get_default_memtier_config()
master_nodes_list = env.getMasterNodesList()
overall_expected_request_count = get_expected_request_count(config)
Expand All @@ -353,8 +361,10 @@ def test_default_arbitrary_command_hset(env):


def test_default_arbitrary_command_hset_multi_data_placeholders(env):
benchmark_specs = {"name": env.testName, "args": ['--command=HSET __key__ field1 __data__ field2 __data__ field3 __data__']}
benchmark_specs = {"name": env.testName, "args": []}
addTLSArgs(benchmark_specs, env)
# on arbitrary command args should be the last one
benchmark_specs["args"].append('--command=HSET __key__ field1 __data__ field2 __data__ field3 __data__')
config = get_default_memtier_config()
master_nodes_list = env.getMasterNodesList()
overall_expected_request_count = get_expected_request_count(config)
Expand All @@ -380,6 +390,7 @@ def test_default_arbitrary_command_hset_multi_data_placeholders(env):
overall_request_count)

def test_default_set_get_rate_limited(env):
env.skipOnCluster()
master_nodes_list = env.getMasterNodesList()
for client_count in [1,2,4]:
for thread_count in [1,2]:
Expand Down

0 comments on commit a5e6f19

Please sign in to comment.