This is a helper function to store payload records (and JSON equivalents) in .got files. The code it replaces missed to insert a newline before the new entry and also did not check for existing records in all spots. Signed-off-by: Phil Sutter --- tests/py/nft-test.py | 117 ++++++++++++++++++++++++++----------------- 1 file changed, 71 insertions(+), 46 deletions(-) diff --git a/tests/py/nft-test.py b/tests/py/nft-test.py index 019c828f957a5..dc074d4c3872a 100755 --- a/tests/py/nft-test.py +++ b/tests/py/nft-test.py @@ -16,6 +16,7 @@ from __future__ import print_function import sys import os +import io import argparse import signal import json @@ -741,6 +742,66 @@ def set_delete_elements(set_element, set_name, table, filename=None, return i > 0 +def payload_record(path, rule, payload, desc="payload"): + ''' + Record payload for @rule in file at @path + + - @payload may be a file handle, a string or an array of strings + - Avoid duplicate entries by searching for a match first + - Separate entries by a single empty line, so check for trailing newlines + before writing + - @return False if already existing, True otherwise + ''' + try: + with open(path, 'r') as f: + lines = f.readlines() + except: + lines = [] + + plines = [] + if isinstance(payload, io.TextIOWrapper): + payload.seek(0, 0) + while True: + line = payload.readline() + if line.startswith("family "): + continue + if line == "": + break + plines.append(line) + elif isinstance(payload, str): + plines = [l + "\n" for l in payload.split("\n")] + elif isinstance(payload, list): + plines = payload + else: + raise Exception + + found = False + for i in range(len(lines)): + if lines[i] == rule + "\n": + found = True + for pline in plines: + i += 1 + if lines[i] != pline: + found = False + break + if found: + return False + + try: + with open(path, 'a') as f: + if len(lines) > 0 and lines[-1] != "\n": + f.write("\n") + f.write("# %s\n" % rule) + f.writelines(plines) + except: + warnfmt = "Failed to write %s for rule %s" + else: + warnfmt = "Wrote %s for rule %s" + + print_warning(warnfmt % (desc, rule[0]), os.path.basename(path), 1) + return True + + def json_dump_normalize(json_string, human_readable = False): json_obj = json.loads(json_string) @@ -867,28 +928,8 @@ def set_delete_elements(set_element, set_name, table, filename=None, if state == "ok" and not payload_check(table_payload_expected, payload_log, cmd): error += 1 - - try: - gotf = open("%s.got" % table_payload_path) - gotf_payload_expected = payload_find_expected(gotf, rule[0]) - gotf.close() - except: - gotf_payload_expected = None - payload_log.seek(0, 0) - if not payload_check(gotf_payload_expected, payload_log, cmd): - gotf = open("%s.got" % table_payload_path, 'a') - payload_log.seek(0, 0) - gotf.write("# %s\n" % rule[0]) - while True: - line = payload_log.readline() - if line.startswith("family "): - continue - if line == "": - break - gotf.write(line) - gotf.close() - print_warning("Wrote payload for rule %s" % rule[0], - gotf.name, 1) + payload_record("%s.got" % table_payload_path, + rule[0], payload_log) # Check for matching ruleset listing numeric_proto_old = nftables.set_numeric_proto_output(True) @@ -979,13 +1020,9 @@ def set_delete_elements(set_element, set_name, table, filename=None, json_output = item["rule"] break json_input = json.dumps(json_output["expr"], sort_keys = True) - - gotf = open("%s.json.got" % filename_path, 'a') - jdump = json_dump_normalize(json_input, True) - gotf.write("# %s\n%s\n\n" % (rule[0], jdump)) - gotf.close() - print_warning("Wrote JSON equivalent for rule %s" % rule[0], - gotf.name, 1) + payload_record("%s.json.got" % filename_path, rule[0], + json_dump_normalize(json_input, True), + "JSON equivalent") table_flush(table, filename, lineno) payload_log = tempfile.TemporaryFile(mode="w+") @@ -1013,17 +1050,8 @@ def set_delete_elements(set_element, set_name, table, filename=None, # Check for matching payload if not payload_check(table_payload_expected, payload_log, cmd): error += 1 - gotf = open("%s.json.payload.got" % filename_path, 'a') - payload_log.seek(0, 0) - gotf.write("# %s\n" % rule[0]) - while True: - line = payload_log.readline() - if line == "": - break - gotf.write(line) - gotf.close() - print_warning("Wrote JSON payload for rule %s" % rule[0], - gotf.name, 1) + payload_record("%s.json.payload.got" % filename_path, + rule[0], payload_log, "JSON payload") # Check for matching ruleset listing numeric_proto_old = nftables.set_numeric_proto_output(True) @@ -1049,12 +1077,9 @@ def set_delete_elements(set_element, set_name, table, filename=None, print_differences_warning(filename, lineno, json_input, json_output, cmd) error += 1 - gotf = open("%s.json.output.got" % filename_path, 'a') - jdump = json_dump_normalize(json_output, True) - gotf.write("# %s\n%s\n\n" % (rule[0], jdump)) - gotf.close() - print_warning("Wrote JSON output for rule %s" % rule[0], - gotf.name, 1) + payload_record("%s.json.output.got" % filename_path, rule[0], + json_dump_normalize(json_output, True), + "JSON output") # prevent further warnings and .got file updates json_expected = json_output elif json_expected and json_output != json_expected: -- 2.51.0