diff --git a/tests/run_tests.py b/tests/run_tests.py new file mode 100644 index 0000000..a73d16e --- /dev/null +++ b/tests/run_tests.py @@ -0,0 +1,127 @@ +import argparse +import datetime +import difflib +import subprocess +import sys +import tempfile +from pathlib import Path + +parser = argparse.ArgumentParser(description="Process some integers.") +parser.add_argument("src", default="src", type=Path, help="src directory", nargs="?") +parser.add_argument( + "tests", default="tests", type=Path, help="tests directory", nargs="?" +) +args = parser.parse_args() + + +def make_log_file(): + time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + return (args.tests / time).with_suffix(".log") + + +def log_new_test(in_path): + global LOG_FILE + with open(LOG_FILE, "a") as f: + f.write("# ======================================\n") + f.write(f"# Test: {in_path}\n") + f.write("# ======================================\n") + + +def log_message(message): + global LOG_FILE + with open(LOG_FILE, "a") as f: + f.write(message) + f.write("\n") + + +def log_return(ret): + if ret[0]: + log_message("Passed") + else: + log_message(f"Error: {ret[1]}") + + log_message("# ======================================\n") + return ret + + +def attempt_decode(out): + try: + return out.decode("utf-8") + except UnicodeDecodeError: + return "Invalid utf-8 in output" + + +def diff_files(file1, file2): + with open(file1, "r") as f1, open(file2, "r") as f2: + diff = difflib.unified_diff( + f1.readlines(), + f2.readlines(), + fromfile=str(file1), + tofile=str(file2), + ) + return "".join(diff) + + +def run_test(exe_path, in_path): + expected_file = in_path.with_suffix(".out") + log_new_test(in_path) + + with tempfile.TemporaryDirectory() as tmp: + temp = Path(tmp) + out_path = temp / "output" + + test = subprocess.Popen( + [sys.executable, str(exe_path), str(in_path), "--out", str(out_path)], + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + ) + + try: + _, stderr = test.communicate(1) + except subprocess.TimeoutExpired: + return [False, "Timed out"] + + if test.returncode != 0: + log_message(attempt_decode(stderr)) + return log_return( + [ + False, + f"Expected returncode 0 got: {test.returncode}", + ] + ) + + diff = diff_files(expected_file, out_path) + + if diff != "": + log_message(diff) + return log_return([False, "Output did not match"]) + + return log_return([True]) + + +if __name__ == "__main__": + LOG_FILE = make_log_file() + PASS_MESSAGE = "✅ Passed" + ERROR_MESSAGE = "❌ Error:" + + tests = list(Path(args.tests).glob("**/*.in")) + max_len = max(len(str(t)) for t in tests) + + for test in tests: + print(f"{test}...", end="") + res = run_test(args.src / "main.py", test) + + if res[0]: + dots = 80 - len(str(test)) - 3 - len(PASS_MESSAGE) + + if dots > 0: + print("." * dots, end="") + print(PASS_MESSAGE) + else: + dots = 80 - len(str(test)) - 3 - len(ERROR_MESSAGE) - len(res[1]) + + if dots > 0: + print("." * dots, end="") + print(f"{ERROR_MESSAGE} {res[1]}") + + print(f"Log file at: {LOG_FILE}")