From 7d89fb71e8dd44176302bbf754b8611378d9933a Mon Sep 17 00:00:00 2001 From: fruffy Date: Sat, 17 Jun 2023 14:00:14 -0400 Subject: [PATCH] Migrate STF scripts to the STF parser library. Modernize stale run-bmv2-test.py code. --- backends/bmv2/bmv2stf.py | 541 ++++++++++--------------------- backends/bmv2/run-bmv2-test.py | 21 +- backends/ebpf/targets/ebpfstf.py | 6 +- tools/stf/stf_lexer.py | 20 ++ tools/stf/stf_parser.py | 110 ++++++- tools/stf/stf_test.py | 7 - tools/testutils.py | 22 +- 7 files changed, 331 insertions(+), 396 deletions(-) diff --git a/backends/bmv2/bmv2stf.py b/backends/bmv2/bmv2stf.py index cb286f324ba..1ffa2a23475 100755 --- a/backends/bmv2/bmv2stf.py +++ b/backends/bmv2/bmv2stf.py @@ -15,33 +15,30 @@ # Runs the BMv2 behavioral model simulator with input from an stf file -import difflib -import errno import json import os import random import re -import shutil import signal import socket -import stat import subprocess import sys -import tempfile import time from collections import OrderedDict from glob import glob -from subprocess import Popen -from threading import Thread +from pathlib import Path try: from scapy.layers.all import * - from scapy.utils import * + import scapy.utils as scapy_util except ImportError: pass -SUCCESS = 0 -FAILURE = 1 +FILE_DIR = Path(__file__).resolve().parent +# Append tools to the import path. +sys.path.append(str(FILE_DIR.joinpath("../../tools"))) +import testutils +from stf.stf_parser import STFParser class TimeoutException(Exception): @@ -64,19 +61,6 @@ def __init__(self): self.usePsa = False -def nextWord(text, sep=None): - # Split a text at the indicated separator. - # Note that the separator can be a string. - # Separator is discarded. - spl = text.split(sep, 1) - if len(spl) == 0: - return "", "" - elif len(spl) == 1: - return spl[0].strip(), "" - else: - return spl[0].strip(), spl[1].strip() - - def ByteToHex(byteStr): return "".join([("%02X" % x) for x in byteStr]) @@ -89,10 +73,6 @@ def convert_packet_stf2hexstr(pkt_stf_text): return "".join(pkt_stf_text.split()).upper() -def reportError(*message): - print("***", *message) - - class Local(object): # object to hold local vars accessable to nested functions pass @@ -118,38 +98,6 @@ def FindExe(dirname, exe): return exe -def run_timeout(verbose, args, timeout, stderr): - if verbose: - print("Executing ", " ".join(args)) - local = Local() - local.process = None - - def target(): - procstderr = None - if stderr is not None: - procstderr = open(stderr, "w") - local.process = Popen(args, stderr=procstderr) - local.process.wait() - - thread = Thread(target=target) - thread.start() - thread.join(timeout) - if thread.is_alive(): - print("Timeout ", " ".join(args), file=sys.stderr) - local.process.terminate() - thread.join() - if local.process is None: - # never even started - reportError("Process failed to start") - return -1 - if verbose: - print("Exit code ", local.process.returncode) - return local.process.returncode - - -timeout = 10 * 60 - - class ConcurrentInteger(object): # Generates exclusive integers in a range 0-max # in a way which is safe across multiple processes. @@ -254,7 +202,7 @@ def set(self, key, value): key = "$valid$" found = True if not found: - print(self.key.fields) + testutils.log.info(self.key.fields) raise Exception("Unexpected key field " + key) if self.key.fields[key] == "ternary": self.values[key] = self.makeMask(value) @@ -398,7 +346,6 @@ def __init__(self, folder, options, jsonfile): self.pcapPrefix = "pcap" self.interfaces = {} self.expected = {} # for each interface number of packets expected - self.expectedAny = [] # interface on which any number of packets is fine self.packetDelay = 0 self.options = options self.json = None @@ -426,141 +373,116 @@ def interface_of_filename(self, f): return int(os.path.basename(f).rstrip(".pcap").lstrip(self.pcapPrefix).rsplit("_", 1)[0]) def do_cli_command(self, cmd): - if self.options.verbose: - print(cmd) + testutils.log.info("Sending '%s'", cmd) self.cli_stdin.write(bytes(cmd + "\n", encoding="utf8")) self.cli_stdin.flush() self.packetDelay = 1 - def do_command(self, cmd): - if self.options.verbose and cmd != "": - print("STF Command:", cmd) - first, cmd = nextWord(cmd) - if first == "": + def execute_stf_command(self, stf_entry): + if self.options.verbose and stf_entry: + testutils.log.info("STF Command: %s", stf_entry) + cmd = stf_entry[0] + if cmd == "": pass - elif first == "add": - self.do_cli_command(self.parse_table_add(cmd)) - elif first == "setdefault": - self.do_cli_command(self.parse_table_set_default(cmd)) - elif ( - first == "mirroring_add" - or first == "mirroring_add_mc" - or first == "mirroring_delete" - or first == "mirroring_get" + elif cmd == "add": + self.do_cli_command(self.parse_table_add(stf_entry)) + elif cmd == "setdefault": + self.do_cli_command(self.parse_table_set_default(stf_entry)) + elif cmd in ( + "mirroring_add", + "mirroring_add_mc", + "mirroring_delete", + "mirroring_get", + "mc_mgrp_create", + "mc_node_create", + "mc_node_associate", + "counter_read", + "counter_write", + "register_read", + "register_write", + "register_reset", + "meter_get_rates", + "meter_set_rates", + "meter_array_set_rates", ): - # Pass through mirroring commands unchanged, with same - # arguments as expected by simple_switch_CLI - self.do_cli_command(first + " " + cmd) - elif first == "mc_mgrp_create" or first == "mc_node_create" or first == "mc_node_associate": # Pass through multicast group commands unchanged, with # same arguments as expected by simple_switch_CLI - self.do_cli_command(first + " " + cmd) - elif first == "counter_read" or first == "counter_write": - # Pass through counter commands unchanged, with - # same arguments as expected by simple_switch_CLI - self.do_cli_command(first + " " + cmd) - elif first == "register_read" or first == "register_write" or first == "register_reset": - # Pass through register commands unchanged, with - # same arguments as expected by simple_switch_CLI - self.do_cli_command(first + " " + cmd) - elif ( - first == "meter_get_rates" - or first == "meter_set_rates" - or first == "meter_array_set_rates" - ): - # Pass through meter commands unchanged, with - # same arguments as expected by simple_switch_CLI - self.do_cli_command(first + " " + cmd) - elif first == "packet": - interface, data = nextWord(cmd) - interface = int(interface) - data = "".join(data.split()) + self.do_cli_command(cmd + " " + " ".join(stf_entry[1:])) + elif cmd == "packet": + interface = int(stf_entry[1]) + data = stf_entry[2] time.sleep(self.packetDelay) try: self.interfaces[interface]._write_packet(bytes.fromhex(data)) except ValueError: - reportError("Invalid packet data", data) - return FAILURE + testutils.log.error("Invalid packet data %s", data) + return testutils.FAILURE self.interfaces[interface].flush() self.packetDelay = 0 - elif first == "expect": - interface, data = nextWord(cmd) - interface = int(interface) - data = "".join(data.split()) - if data != "": - self.expected.setdefault(interface, []).append(data) + elif cmd == "expect": + interface = int(stf_entry[1]) + pkt_data = stf_entry[2] + self.expected.setdefault(interface, {}) + if pkt_data != "": + self.expected[interface]["any"] = False + self.expected[interface].setdefault("pkts", []).append(pkt_data) else: - self.expectedAny.append(interface) + self.expected[interface]["any"] = True else: - if self.options.verbose: - print("ignoring stf command:", first, cmd) + testutils.log.info("ignoring stf command: %s %s", cmd, stf_entry) def parse_table_set_default(self, cmd): - tableName, cmd = nextWord(cmd) - table = self.tableByName(tableName) - actionName, cmd = nextWord(cmd, "(") - action = self.actionByName(table, actionName) - actionArgs = action.makeArgsInstance() - cmd = cmd.strip(")") - while cmd != "": - word, cmd = nextWord(cmd, ",") - k, v = nextWord(word, ":") - actionArgs.set(k, v) + # Gather objects + tableName = cmd[1] + actionCall = cmd[2] + actionName = actionCall[0] + actionArgs = actionCall[1] + + # Instantiate. + tableInstance = self.tableByName(tableName) + actionInstance = self.actionByName(tableInstance, actionName) + actionArgsInstance = actionInstance.makeArgsInstance() + for actionArg in actionArgs: + actionArgsInstance.set(actionArg[0], actionArg[1]) command = "table_set_default " + tableName + " " + actionName - if actionArgs.size(): - command += " " + str(actionArgs) + if actionArgsInstance.size(): + command += " " + str(actionArgsInstance) return command def parse_table_add(self, cmd): - tableName, cmd = nextWord(cmd) - table = self.tableByName(tableName) - key = table.makeKeyInstance() - actionArgs = None - actionName = None - prio, cmd = nextWord(cmd) - number = re.compile("[0-9]+") - if not number.match(prio): - # not a priority; push back - cmd = prio + " " + cmd - prio = "" - while cmd != "": - if actionName != None: - # parsing action arguments - word, cmd = nextWord(cmd, ",") - k, v = nextWord(word, ":") - actionArgs.set(k, v) - else: - # parsing table key - word, cmd = nextWord(cmd) - if cmd.find("=") >= 0: - # This command retrieves a handle for the key - # This feature is currently not supported, so we just ignore the handle part - cmd = cmd.split("=")[0] - if word.find("(") >= 0: - # found action - actionName, arg = nextWord(word, "(") - action = self.actionByName(table, actionName) - actionArgs = action.makeArgsInstance() - cmd = arg + cmd - cmd = cmd.strip("()") - else: - k, v = nextWord(word, ":") - key.set(k, v) - if prio != "": + # Gather objects + tableName = cmd[1] + prio = cmd[2] + keyList = cmd[3] + actionCall = cmd[4] + actionName = actionCall[0] + actionArgs = actionCall[1] + + # Instantiate. + tableInstance = self.tableByName(tableName) + keyInstance = tableInstance.makeKeyInstance() + for keyTuple in keyList: + keyInstance.set(keyTuple[0], keyTuple[1]) + + actionInstance = self.actionByName(tableInstance, actionName) + actionArgsInstance = actionInstance.makeArgsInstance() + for actionArg in actionArgs: + actionArgsInstance.set(actionArg[0], actionArg[1]) + if prio: # Priorities in BMV2 seem to be reversed with respect to the stf file # Hopefully 10000 is large enough prio = str(10000 - int(prio)) command = ( "table_add " - + table.name + + tableInstance.name + " " - + action.name + + actionInstance.name + " " - + str(key) + + str(keyInstance) + " => " - + str(actionArgs) + + str(actionArgsInstance) ) - if table.match_type == "ternary": + if tableInstance.match_type == "ternary": command += " " + prio return command @@ -614,20 +536,16 @@ def interfaceArgs(self): result.append("-i " + str(interface) + "@" + self.pcapPrefix + str(interface)) return result - def generate_model_inputs(self, stffile): - self.stffile = stffile - with open(stffile) as i: - for line in i: - line, comment = nextWord(line, "#") - first, cmd = nextWord(line) - if first == "packet" or first == "expect": - interface, cmd = nextWord(cmd) - interface = int(interface) - if not interface in self.interfaces: - # Can't open the interfaces yet, as that would block - ifname = self.interfaces[interface] = self.filename(interface, "in") - os.mkfifo(ifname) - return SUCCESS + def generate_model_inputs(self, stf_map): + for entry in stf_map: + cmd = entry[0] + if cmd == "packet" or cmd == "expect": + interface = int(entry[1]) + if not interface in self.interfaces: + # Can't open the interfaces yet, as that would block + ifname = self.interfaces[interface] = self.filename(interface, "in") + os.mkfifo(ifname) + return testutils.SUCCESS def check_switch_server_ready(self, proc, thriftPort): """While the process is running, we check if the Thrift server has been @@ -643,9 +561,8 @@ def check_switch_server_ready(self, proc, thriftPort): if result == 0: return True - def run(self): - if self.options.verbose: - print("Running model") + def run(self, stf_map): + testutils.log.info("Running model") wait = 0 # Time to wait before model starts running if self.options.usePsa: @@ -658,10 +575,10 @@ def run(self): concurrent = ConcurrentInteger(os.getcwd(), 1000) rand = concurrent.generate() if rand is None: - reportError("Could not find a free port for Thrift") - return FAILURE + testutils.log.error("Could not find a free port for Thrift") + return testutils.FAILURE thriftPort = str(9090 + rand) - rv = SUCCESS + rv = testutils.SUCCESS try: os.remove("/tmp/bmv2-%d-notifications.ipc" % rand) except OSError: @@ -689,12 +606,11 @@ def run(self): runswitch += [ "--", ] + self.target_specific_cmd_line_args - if self.options.verbose: - print("Running", " ".join(runswitch)) + testutils.log.info("Running %s", " ".join(runswitch)) sw = subprocess.Popen(runswitch, cwd=self.folder) def openInterface(ifname): - fp = self.interfaces[interface] = RawPcapWriter(ifname, linktype=0) + fp = self.interfaces[interface] = scapy_util.RawPcapWriter(ifname, linktype=0) fp._write_header(None) # Try to open input interfaces. Each time, we set a 2 second @@ -702,7 +618,7 @@ def openInterface(ifname): # not running anymore. If it is, we check if we have exceeded the # one minute timeout (exceeding this timeout is very unlikely and # could mean the system is very slow for some reason). If one of the - # 2 conditions above is met, the test is considered a FAILURE. + # 2 conditions above is met, the test is considered a failure. start = time.time() sw_timeout = 60 # open input interfaces @@ -718,9 +634,9 @@ def openInterface(ifname): signal.alarm(0) except TimeoutException: if time.time() - start > sw_timeout: - return FAILURE + return testutils.FAILURE if sw.poll() is not None: - return FAILURE + return testutils.FAILURE else: break @@ -731,7 +647,7 @@ def openInterface(ifname): self.check_switch_server_ready(sw, int(thriftPort)) signal.alarm(0) except TimeoutException: - return FAILURE + return testutils.FAILURE time.sleep(0.1) runcli = [ @@ -739,16 +655,13 @@ def openInterface(ifname): "--thrift-port", thriftPort, ] - if self.options.verbose: - print("Running", " ".join(runcli)) + testutils.log.info("Running %s", " ".join(runcli)) try: cli = subprocess.Popen(runcli, cwd=self.folder, stdin=subprocess.PIPE) self.cli_stdin = cli.stdin - with open(self.stffile) as i: - for line in i: - line, comment = nextWord(line, "#") - self.do_command(line) + for stf_cmd in stf_map: + self.execute_stf_command(stf_cmd) cli.stdin.close() for interface, fp in self.interfaces.items(): fp.close() @@ -765,215 +678,87 @@ def openInterface(ifname): # This only works on Unix: negative returncode is # minus the signal number that killed the process. if sw.returncode != 0 and sw.returncode != -15: # 15 is SIGTERM - reportError(switch, "died with return code", sw.returncode) - rv = FAILURE - elif self.options.verbose: - print(switch, "exit code", sw.returncode) + testutils.log.error("%s died with return code %s", switch, sw.returncode) + rv = testutils.FAILURE + else: + testutils.log.info("%s: exit code %s", switch, sw.returncode) cli.wait() if cli.returncode != 0 and cli.returncode != -15: - reportError("CLI process failed with exit code", cli.returncode) - rv = FAILURE + testutils.log.error("CLI process failed with exit code %s", cli.returncode) + rv = testutils.FAILURE finally: try: os.remove("/tmp/bmv2-%d-notifications.ipc" % rand) except OSError: pass concurrent.release(rand) - if self.options.verbose: - print("Execution completed") + testutils.log.info("Execution completed") return rv - def comparePacket(self, expected, received): - received = convert_packet_bin2hexstr(received) - expected = convert_packet_stf2hexstr(expected) - strict_length_check = False - if expected[-1] == "$": - strict_length_check = True - expected = expected[:-1] - if len(received) < len(expected): - reportError( - "Received packet too short", - len(received), - "vs", - len(expected), - "(in units of hex digits)", - ) - reportError("Full expected packet is ", expected) - reportError("Full received packet is ", received) - return FAILURE - for i in range(0, len(expected)): - if expected[i] == "*": - continue - if expected[i] != received[i]: - reportError("Received packet ", received) - reportError( - "Packet different at position", - i, - ": expected", - expected[i], - ", received", - received[i], - ) - reportError("Full expected packet is ", expected) - reportError("Full received packet is ", received) - return FAILURE - if strict_length_check and len(received) > len(expected): - reportError( - "Received packet too long", - len(received), - "vs", - len(expected), - "(in units of hex digits)", - ) - reportError("Full expected packet is ", expected) - reportError("Full received packet is ", received) - return FAILURE - return SUCCESS - def showLog(self): with open(self.folder + "/" + self.switchLogFile + ".txt") as a: log = a.read() - print("Log file:") - print(log) + testutils.log.info("Log file:\n%s", log) def checkOutputs(self): - if self.options.verbose: - print("Comparing outputs") + """Checks if the output of the filter matches expectations""" + testutils.log.info("Comparing outputs") direction = "out" for file in glob(self.filename("*", direction)): + testutils.log.info("Checking file %s", file) interface = self.interface_of_filename(file) if os.stat(file).st_size == 0: packets = [] else: try: - packets = rdpcap(file) - except: - reportError("Corrupt pcap file", file) - self.showLog() - return FAILURE - - # Log packets. - if self.options.observationLog: - observationLog = open(self.options.observationLog, "w") - for pkt in packets: - observationLog.write("%d %s\n" % (interface, convert_packet_bin2hexstr(pkt))) - observationLog.close() - - # Check for expected packets. - if interface in self.expectedAny: - if interface in self.expected: - reportError(f"Interface {interface} has both expected with packets and without") - continue + packets = scapy_util.rdpcap(file) + except Exception as e: + testutils.log.error("Corrupt pcap file %s\n%s", file, e) + return testutils.FAILURE + if interface not in self.expected: expected = [] else: - expected = self.expected[interface] + # Check for expected packets. + if self.expected[interface]["any"]: + if self.expected[interface]["pkts"]: + testutils.log.error( + ( + "Interface %s has both expected with packets and without", + interface, + ) + ) + continue + expected = self.expected[interface]["pkts"] if len(expected) != len(packets): - reportError( - "Expected", + testutils.log.error( + "Expected %s packets on port %s got %s", len(expected), - "packets on port", - str(interface), - "got", + interface, len(packets), ) - reportError( - "Full list of %d expected packets on port %d:" % (len(expected), interface) - ) - for i in range(len(expected)): - reportError( - " packet #%2d: %s" % (i + 1, convert_packet_stf2hexstr(expected[i])) - ) - reportError( - "Full list of %d received packets on port %d:" % (len(packets), interface) - ) - for i in range(len(packets)): - reportError( - " packet #%2d: %s" % (i + 1, convert_packet_bin2hexstr(packets[i])) - ) - self.showLog() - return FAILURE - for i in range(0, len(expected)): - cmp = self.comparePacket(expected[i], packets[i]) - if cmp != SUCCESS: - reportError("Packet", i, "on port", str(interface), "differs") - return FAILURE - # remove successfully checked interfaces + return testutils.FAILURE + for idx, expected_pkt in enumerate(expected): + # If the expected_pkt is None, the content does not matter. + # We only care that the packet was received on this particular port. + if not expected_pkt: + continue + cmp_result = testutils.compare_pkt(expected_pkt, packets[idx]) + if cmp_result != testutils.SUCCESS: + testutils.log.error("Packet %s on port %s differs", idx, interface) + return cmp_result + # Remove successfully checked interfaces if interface in self.expected: del self.expected[interface] if len(self.expected) != 0: - # didn't find all the expects we were expecting - reportError("Expected packets on ports", list(self.expected.keys()), "not received") - return FAILURE - else: - return SUCCESS - - -def run_model(options, tmpdir, jsonfile, testfile): - bmv2 = RunBMV2(tmpdir, options, jsonfile) - result = bmv2.generate_model_inputs(testfile) - if result != SUCCESS: - return result - result = bmv2.run() - if result != SUCCESS: - return result - result = bmv2.checkOutputs() - return result - - -######################### main - - -def usage(options): - print( - "usage:", - options.binary, - "[-v] [-p] [-observation-log ] ", - ) - - -def main(argv): - options = Options() - options.binary = argv[0] - argv = argv[1:] - while len(argv) > 0 and argv[0][0] == "-": - if argv[0] == "-b": - options.preserveTmp = True - elif argv[0] == "-v": - options.verbose = True - elif argv[0] == "-p": - options.usePsa = True - elif argv[0] == "-observation-log": - if len(argv) == 1: - reportError("Missing argument", argv[0]) - usage(options) - sys.exit(1) - options.observationLog = argv[1] - argv = argv[1:] - else: - reportError("Unknown option ", argv[0]) - usage(options) - argv = argv[1:] - if len(argv) < 2: - usage(options) - return FAILURE - if not os.path.isfile(argv[0]) or not os.path.isfile(argv[1]): - usage(options) - return FAILURE - - tmpdir = tempfile.mkdtemp(dir=".") - result = run_model(options, tmpdir, argv[0], argv[1]) - if options.preserveTmp: - print("preserving", tmpdir) - else: - shutil.rmtree(tmpdir) - if options.verbose: - if result == SUCCESS: - print("SUCCESS") - else: - print("FAILURE", result) - return result - - -if __name__ == "__main__": - sys.exit(main(sys.argv)) + # Didn't find all the expects we were expecting + testutils.log.error("Expected packets on port(s) %s not received", self.expected.keys()) + return testutils.FAILURE + testutils.log.info("All went well.") + return testutils.SUCCESS + + def parse_stf_file(self, testfile): + with open(testfile) as raw_stf: + parser = STFParser() + stf_str = raw_stf.read() + return parser.parse(stf_str) diff --git a/backends/bmv2/run-bmv2-test.py b/backends/bmv2/run-bmv2-test.py index 268cd432acd..1cb3f4b517e 100755 --- a/backends/bmv2/run-bmv2-test.py +++ b/backends/bmv2/run-bmv2-test.py @@ -23,6 +23,7 @@ import tempfile from subprocess import Popen from threading import Thread +import logging from bmv2stf import RunBMV2 from scapy.layers.all import * @@ -233,10 +234,15 @@ def run_model(options, tmpdir, jsonfile): # If no empty.stf present, don't try to run the model at all return SUCCESS bmv2 = RunBMV2(tmpdir, options, jsonfile) - result = bmv2.generate_model_inputs(testFile) + + stf_map, result = bmv2.parse_stf_file(testFile) + if result != SUCCESS: + return result + + result = bmv2.generate_model_inputs(stf_map) if result != SUCCESS: return result - result = bmv2.run() + result = bmv2.run(stf_map) if result != SUCCESS: return result result = bmv2.checkOutputs() @@ -375,6 +381,17 @@ def main(argv): print("Error parsing config.h") sys.exit(FAILURE) + # Configure logging. + logging.basicConfig( + filename="test.log", + format="%(levelname)s: %(message)s", + level=getattr(logging, "INFO"), + filemode="w", + ) + stderr_log = logging.StreamHandler() + stderr_log.setFormatter(logging.Formatter("%(levelname)s: %(message)s")) + logging.getLogger().addHandler(stderr_log) + options.hasBMv2 = "HAVE_SIMPLE_SWITCH" in config.vars if not options.hasBMv2: reportError("config.h indicates that BMv2 is not installedwill skip running BMv2 tests") diff --git a/backends/ebpf/targets/ebpfstf.py b/backends/ebpf/targets/ebpfstf.py index 6be36e14ace..a165308b4ce 100644 --- a/backends/ebpf/targets/ebpfstf.py +++ b/backends/ebpf/targets/ebpfstf.py @@ -127,9 +127,9 @@ def parse_stf_file(raw_stf): expected = {} for stf_entry in stf_map: if stf_entry[0] == "packet": - input_pkts.setdefault(stf_entry[1], []).append( - bytes.fromhex("".join(stf_entry[2].split())) - ) + interface = int(stf_entry[1]) + data = stf_entry[2] + input_pkts.setdefault(interface, []).append(bytes.fromhex("".join(data.split()))) elif stf_entry[0] == "expect": interface = int(stf_entry[1]) pkt_data = stf_entry[2] diff --git a/tools/stf/stf_lexer.py b/tools/stf/stf_lexer.py index 3727607ba39..a008be7eaa3 100644 --- a/tools/stf/stf_lexer.py +++ b/tools/stf/stf_lexer.py @@ -87,6 +87,21 @@ def _error(self, s, token): "REMOVE", "SETDEFAULT", "WAIT", + "MIRRORING_ADD", + "MIRRORING_ADD_MC", + "MIRRORING_DELETE", + "MIRRORING_GET", + "MC_MGRP_CREATE", + "MC_NODE_CREATE", + "MC_NODE_ASSOCIATE", + "COUNTER_READ", + "COUNTER_WRITE", + "REGISTER_READ", + "REGISTER_WRITE", + "REGISTER_RESET", + "METER_GET_RATES", + "METER_SET_RATES", + "METER_ARRAY_SET_RATES", ) keywords_map = {} @@ -102,6 +117,7 @@ def _error(self, s, token): "DATA_DEC", "DATA_HEX", "DATA_TERN", + "DATA_EXACT", "DOT", "ID", "INT_CONST_BIN", @@ -228,6 +244,10 @@ def t_packetdata_DATA_TERN(self, t): r"\*" return t + def t_packetdata_DATA_EXACT(self, t): + r"\$" + return t + def t_packetdata_newline(self, t): r"\n+" t.lexer.lineno += len(t.value) diff --git a/tools/stf/stf_parser.py b/tools/stf/stf_parser.py index c5d9d045f81..048fb822e60 100644 --- a/tools/stf/stf_parser.py +++ b/tools/stf/stf_parser.py @@ -34,7 +34,27 @@ # | PACKET port packet_data # | SETDEFAULT qualified_name action # | WAIT +# | direct_cmd # +# direct_cmd : MIRRORING_ADD number number +# | MIRRORING_ADD_MC number number +# | MIRRORING_DELETE number +# | MIRRORING_GET number +# | MC_MGRP_CREATE number +# | MC_NODE_CREATE number number +# | MC_NODE_ASSOCIATE number number +# | COUNTER_READ qualified_name number +# | COUNTER_WRITE qualified_name number number number +# | REGISTER_READ qualified_name number +# | REGISTER_WRITE qualified_name number number +# | REGISTER_RESET qualified_name +# | METER_GET_RATES qualified_name number +# | METER_SET_RATES qualified_name number meter_rate +# | METER_ARRAY_SET_RATES qualified_name meter_rate + +# meter_rate : number COLON number meter_rate +# | number COLON number + # match_list : match # | match_list match # match : qualified_name COLON number_or_lpm @@ -73,11 +93,13 @@ # # expect_data : expect_datum # | expect_data expect_datum +# | exact_datum # packet_data : packet_datum # | packet_data packet_datum # # expect_datum : packet_datum | DATA_TERN # packet_datum : DATA_DEC | DATA_HEX +# exact_datum : DATA_EXACT # PARSER ---------------------------------------------------------------------- @@ -127,7 +149,7 @@ def p_error(self, p): self.print_error( p.lineno, self.lexer.get_colno(), - "Syntax error while parsing at token '%s' (of type %s)." % (p.value, p.type), + "Unexpected token '%s' (of type %s)." % (p.value, p.type), ) # Skip to the next statement. while True: @@ -202,6 +224,78 @@ def p_statement_wait(self, p): "statement : WAIT" p[0] = (p[1].lower(),) + def p_statement_direct_cmd(self, p): + "statement : direct_cmd" + p[0] = p[1] + + def p_direct_cmd_mirroring_add(self, p): + "direct_cmd : MIRRORING_ADD number number" + p[0] = (p[1].lower(), p[2], p[3]) + + def p_direct_cmd_mirroring_add_mc(self, p): + "direct_cmd : MIRRORING_ADD_MC number number" + p[0] = (p[1].lower(), p[2], p[3]) + + def p_direct_cmd_mirroring_delete(self, p): + "direct_cmd : MIRRORING_DELETE number" + p[0] = (p[1].lower(), p[2]) + + def p_direct_cmd_mirroring_get(self, p): + "direct_cmd : MIRRORING_GET number" + p[0] = (p[1].lower(), p[2]) + + def p_direct_cmd_mc_mgrp_create(self, p): + "direct_cmd : MC_MGRP_CREATE number" + p[0] = (p[1].lower(), p[2]) + + def p_direct_cmd_mc_node_create(self, p): + "direct_cmd : MC_NODE_CREATE number number" + p[0] = (p[1].lower(), p[2], p[3]) + + def p_direct_cmd_mc_node_associate(self, p): + "direct_cmd : MC_NODE_ASSOCIATE number number" + p[0] = (p[1].lower(), p[2], p[3]) + + def p_direct_cmd_counter_read(self, p): + "direct_cmd : COUNTER_READ qualified_name number" + p[0] = (p[1].lower(), p[2], p[3]) + + def p_direct_cmd_counter_write(self, p): + "direct_cmd : COUNTER_WRITE qualified_name number number number" + p[0] = (p[1].lower(), p[2], p[3], p[4], p[5]) + + def p_direct_cmd_register_read(self, p): + "direct_cmd : REGISTER_READ qualified_name number" + p[0] = (p[1].lower(), p[2], p[3]) + + def p_direct_cmd_register_write(self, p): + "direct_cmd : REGISTER_WRITE qualified_name number number" + p[0] = (p[1].lower(), p[2], p[3], p[4]) + + def p_direct_cmd_register_reset(self, p): + "direct_cmd : REGISTER_RESET qualified_name" + p[0] = (p[1].lower(), p[2]) + + def p_direct_cmd_meter_get_rates(self, p): + "direct_cmd : METER_GET_RATES qualified_name number" + p[0] = (p[1].lower(), p[2], p[3]) + + def p_direct_cmd_meter_set_rates(self, p): + "direct_cmd : METER_SET_RATES qualified_name number meter_rate" + p[0] = (p[1].lower(), p[2], p[3], p[4]) + + def p_direct_cmd_meter_array_set_rates(self, p): + "direct_cmd : METER_ARRAY_SET_RATES qualified_name meter_rate" + p[0] = (p[1].lower(), p[2], p[3]) + + def p_meter_rate_many(self, p): + "meter_rate : number COLON number meter_rate" + p[0] = (p[1], p[3]) + [p[4]] + + def p_meter_rate_one(self, p): + "meter_rate : number COLON number" + p[0] = (p[1], p[3]) + def p_id_or_index(self, p): """id_or_index : ID | number @@ -315,14 +409,22 @@ def p_expect_data_one(self, p): p[0] = [p[1]] def p_expect_data_many(self, p): - "expect_data : expect_data expect_datum" - p[0] = p[1] + [p[2]] + "expect_data : expect_datum expect_data" + p[0] = [p[1]] + p[2] - def p_expect_dataum(self, p): + def p_expect_data_exact_one(self, p): + "expect_data : exact_datum" + p[0] = [p[1]] + + def p_expect_datum(self, p): """expect_datum : packet_datum | DATA_TERN""" p[0] = p[1] + def p_exact_datum(self, p): + """exact_datum : DATA_EXACT""" + p[0] = p[1] + # TESTING --------------------------------------------------------------------- # diff --git a/tools/stf/stf_test.py b/tools/stf/stf_test.py index 7bec7e55dcf..6899a49a3bb 100755 --- a/tools/stf/stf_test.py +++ b/tools/stf/stf_test.py @@ -17,16 +17,9 @@ # Example of usign STFRunner that simply prints the STF statement import argparse -import logging -import math import os import os.path -import random -import re import sys -import time -import traceback -import unittest from .stf_parser import STFParser from .stf_runner import STFRunner diff --git a/tools/testutils.py b/tools/testutils.py index 0acb0cb2ef3..faf1b464e33 100644 --- a/tools/testutils.py +++ b/tools/testutils.py @@ -37,6 +37,9 @@ SKIPPED: int = 999 +import scapy.packet + + class LogPipe(threading.Thread): """A log utility class that allows subprocesses to directly write into a log. Derived from https://codereview.stackexchange.com/a/17959.""" @@ -89,11 +92,26 @@ def hex_to_byte(hex_str: str) -> str: return "".join(byte_vals) -def compare_pkt(expected: str, received: str) -> int: +def compare_pkt(expected: str, received: scapy.packet.Packet) -> int: """Compare two given byte sequences and check if they are the same. Report errors if this is not the case.""" - received = bytes(received).hex().upper() + + # If the expected packet string ends with a '$' it means that the packets are only equal, + # if they are the exact same length. + strict_length_check = False + if expected[-1] == '$': + strict_length_check = True + expected = expected[:-1] + + received = received.build().hex().upper() expected = "".join(expected.split()).upper() + if strict_length_check and len(received) > len(expected): + log.error( + "Received packet too long %s vs %s (in units of hex digits)", + len(received), + len(expected), + ) + return FAILURE if len(received) < len(expected): log.error("Received packet too short %s vs %s", len(received), len(expected)) return FAILURE