#!/usr/bin/env python3
# Compiler and tester for kattis problems
#
# kat expects a ./sample directory with files named /input_[0-9]+/ and a
# corresponding /output_[0-9]+/ file. These will be run in sequence from lowest
# to highest number
#
# Supported languages:
#  - Rust: Expects a ./src/main.rs file, as well as a ./Cargo.toml
#  - Python: Expects with a single .py file or a main.py or ./src/main.py
#  - C: Expects a ./main.c file or ./src/main.c
import argparse, sys, os, pathlib, time, subprocess, re
from abc import ABC, abstractmethod
from typing import Tuple
from pathlib import Path

class ProgrammingLang(ABC):
    @abstractmethod
    def __init__(self, root: Path, main_file: Path):
        pass

    @abstractmethod
    def compile_debug(self):
        pass

    @abstractmethod
    def compile_release(self):
        pass

    @abstractmethod
    def run_debug(self, input_file, output_file) -> int:
        pass

    @abstractmethod
    def run_release(self, input_file, output_file) -> int:
        pass

class Rust(ProgrammingLang):
    def __init__(self, root: Path, main_file: Path):
        self.main_file = main_file
        self.root = root
        self.bin_name = root.name
        self.release_bin = os.path.join(root, 'target', 'release', self.bin_name)
        self.debug_bin = os.path.join(root, 'target', 'debug', self.bin_name)

    def compile_debug(self):
        cmd = subprocess.run(['cargo', 'build'], cwd=self.root)
        if cmd.returncode != 0:
            raise Exception(f'Compile time error: code {cmd.returncode}')

    def compile_release(self):
        cmd = subprocess.run(['cargo', 'build', '--release'], cwd=self.root)
        if cmd.returncode != 0:
            raise Exception(f'Compile time error: code {cmd.returncode}')

    def __run(self, input_file, output_file, sub_dir: Path) -> int:
        start = time.time()
        cmd = subprocess.run(
            os.path.join(self.root, 'target', sub_dir, self.bin_name),
            stdin=input_file, stdout=output_file)
        end = time.time()
        if cmd.returncode != 0:
            raise Exception(f'Runtime error: code {cmd.returncode}')
        return end - start

    def run_debug(self, input_file, output_file) -> int:
        return self.__run(input_file, output_file, 'debug')

    def run_release(self, input_file, output_file) -> int:
        return self.__run(input_file, output_file, 'release')

class Clang(ProgrammingLang):
    def __init__(self, root: Path, main_file: Path):
        self.main_file = main_file
        self.root = root
        self.bin_name = 'clang_out'
        self.release_bin = self.debug_bin = os.path.join(root, self.bin_name)

    def compile_debug(self):
        cmd = subprocess.run([
            'gcc', '-Wall', '-Wextra', '-Werror',
            '-g', '-O2', '-std=gnu11', '-static',
            self.main_file,
            '-o', self.debug_bin, '-lm'
        ])
        if cmd.returncode != 0:
            raise Exception(f'Compile time error: code {cmd.returncode}')

    def compile_release(self):
        cmd = subprocess.run([
            'gcc', '-g', '-O2', '-std=gnu11', '-static',
            self.main_file,
            '-o', self.release_bin, '-lm'
        ])
        if cmd.returncode != 0:
            raise Exception(f'Compile time error: code {cmd.returncode}')

    def run_debug(self, input_file, output_file) -> int:
        return self.run_release(input_file, output_file)

    def run_release(self, input_file, output_file) -> int:
        start = time.time()
        cmd = subprocess.run(self.release_bin,
            stdin=input_file, stdout=output_file)
        end = time.time()
        return end - start

class CPlusPlus(ProgrammingLang):
    def __init__(self, root: Path, main_file: Path):
        self.main_file = main_file
        self.root = root
        self.bin_name = 'clang_out'
        self.release_bin = self.debug_bin = os.path.join(root, self.bin_name)

    def compile_debug(self):
        cmd = subprocess.run([
            'g++', '-g', '-O2', '-std=gnu++17', '-static',
            '-lrt', '-Wl,--whole-archive', '-lpthread',
            '-Wl,--no-whole-archive',
            '-Wall', '-Wextra', '-Werror',
            self.main_file,
            '-o', self.release_bin
        ])
        if cmd.returncode != 0:
            raise Exception(f'Compile time error: code {cmd.returncode}')

    def compile_release(self):
        cmd = subprocess.run([
            'g++', '-g', '-O2', '-std=gnu++17', '-static',
            '-lrt', '-Wl,--whole-archive', '-lpthread',
            '-Wl,--no-whole-archive',
            self.main_file,
            '-o', self.release_bin
        ])
        if cmd.returncode != 0:
            raise Exception(f'Compile time error: code {cmd.returncode}')

    def run_debug(self, input_file, output_file) -> int:
        return self.run_release(input_file, output_file)

    def run_release(self, input_file, output_file) -> int:
        start = time.time()
        cmd = subprocess.run(self.release_bin,
            stdin=input_file, stdout=output_file)
        end = time.time()
        return end - start

class Python(ProgrammingLang):
    def __init__(self, root: Path, main_file: Path):
        self.main_file = main_file
        self.root = root
        self.bin_name = main_file.name
        self.release_bin = self.debug_bin = main_file

    def compile_debug(self):
        pass

    def compile_release(self):
        pass

    def __run(self, input_file, output_file, interpreter) -> int:
        start = time.time()
        cmd = subprocess.run([interpreter, self.release_bin],
            stdin=input_file, stdout=output_file)
        end = time.time()

        if cmd.returncode != 0:
            raise Exception(f'Runtime error: {cmd.returncode}')
        return end - start

    def run_debug(self, input_file, output_file) -> int:
        return self.__run(input_file, output_file, '/usr/bin/python3')

    def run_release(self, input_file, output_file) -> int:
        return self.__run(input_file, output_file, '/usr/bin/pypy3')

# Returns the language object
def find_language(root: Path):
    ls = os.listdir(root)

    if 'src' in ls:
        src = os.listdir(os.path.join(root, 'src'))

        if 'main.rs' in src:
            return Rust(root, Path('src', 'main.rs'))
        elif 'main.c' in src:
            return Clang(root, Path('src', 'main.c'))
        elif 'main.cc' in src:
            return CPlusPlus(root, Path('src', 'main.cc'))
        elif 'main.cpp' in src:
            return CPlusPlus(root, Path('src', 'main.cpp'))
        elif 'main.py' in src:
            return Python(root, Path('src', 'main.py'))
        elif '.py' in '\t'.join(ls):
           return Python(root, Path('src', next(p for p in ls if '.py' in p)))

    if 'main.rs' in ls:
        return Rust(root, Path('main.rs'))
    elif 'main.c' in ls:
        return Clang(root, Path('main.c'))
    elif 'main.cc' in ls:
        return CPlusPlus(root, Path('main.cc'))
    elif 'main.cpp' in ls:
        return CPlusPlus(root, Path('main.cpp'))
    elif 'main.py' in ls:
        return Python(root, Path('main.py'))
    elif '.py' in '\t'.join(ls):
        return Python(root,
                Path(os.path.join(root, next(p for p in ls if '.py' in p))))

    raise Exception('Language not found or not supported')

def project_root() -> Path:
    cwd = Path('.').cwd()
    if cwd.name == 'sample' or cwd.name == 'src':
        return cwd.parent
    return cwd

def test_files(root: Path, nb: int) -> Tuple[Path, Path, Path]:
    t_input     = os.path.join(root, f'sample/input_{nb}')
    t_output    = os.path.join(root, f'sample/output_{nb}')
    real_output = os.path.join(root, f'sample/run_output_{nb}')

    return t_input, t_output, real_output

def is_test_exists(n: str, root: Path) -> bool:
    samples = os.listdir(os.path.join(root, 'sample'))

    if f'input_{n}' not in samples:
        print(f'sample/input_{n} not found')
        return False

    if f'output_{n}' not in samples:
        print(f'sample/output_{n} not found')
        return False

    return True

def run_test(lang, n: int, root: Path, is_time: bool, is_debug: bool) -> bool:
    t_in_name, t_out_name, r_out_name = test_files(root, n)

    r_out = open(r_out_name, 'w')
    t_out = open(t_out_name, 'r')
    t_in  = open(t_in_name,  'r')

    print(f"==== Testcase {n} ====")
    if is_debug:
        run_time = lang.run_debug(t_in, r_out)
    else:
        run_time = lang.run_release(t_in, r_out)

    t_in.close()
    r_out.close()

    if is_time:
        print("^^^^^^^^^^^^^^^^^^^^^^")
        print(f"Time:   {run_time : 4.3f}s")

    r_out = open(r_out_name, 'r')

    lines_real = r_out.read().splitlines()
    lines_out  = t_out.read().splitlines()

    t_out.close()
    r_out.close()

    if len(lines_real) != len(lines_out):
        if not is_time:
            print("^^^^^^^^^^^^^^^^^^^^^^")
        print(f"!!! Mismatched output\n"
              f"Real line count:     {len(lines_real)}\n"
              f"Expected line count: {len(lines_out)}")
        print(f'!!! Failed on test {n}. See ./sample/run_output_{n}')
        return False

    for i in range(len(lines_real)):
        if lines_real[i] != lines_out[i]:
            if not is_time:
                print("^^^^^^^^^^^^^^^^^^^^^^")
            print(f"!!! Mismatched output\n"
                  f"Real  Output: `{lines_real[i]}`\n"
                  f"Expected Out: `{lines_out[i]}`")

            print(f'!!! Failed on test {n}. See ./sample/run_output_{n}')
            return False

    return True

# Returns True if all cases passed. Otherwise returns False and the number of
# test cases
def run_tests(lang, root: Path, is_time: bool, is_debug: bool) -> Tuple[bool, int]:
    try:
        sample_dir = os.listdir(os.path.join(root, 'sample'))
    except FileNotFoundError:
        print("No ./sample directory found. Cannot test script")
        return [False, None]

    cases = list(filter(lambda x: x is not None,
                [re.match(r'^input_([0-9]+)$', f) for f in sample_dir]))

    for case in cases:
        if not is_test_exists(case.group(1), root):
            raise Exception()

    N = len(cases)
    cases.sort(key=lambda x: x.group(0))

    for case in cases:
        try:
            if not run_test(lang, case.group(1), root, is_time, is_debug):
                return False, N
        except Exception as e:
            print(e)
            return False, N

    return True, N


def main():
    parser = argparse.ArgumentParser(
        description='Compiler and tester for kattis problems')

    parser.add_argument("-t", "--time", action="store_true",
                        help="Time the runtime of each test case")
    parser.add_argument("-d", "--debug", action="store_true",
                        help="Use the debugging compiler")
    parser.add_argument("test_case", type=str, metavar='T', nargs='?',
                        help="Time one specific test case in debug")

    args = parser.parse_args()
    root = project_root()

    try:
        lang = find_language(root)
    except Exception as e:
        print(f'Failed to detect language: {e}')
        sys.exit(1)

    try:
        if args.debug or args.test_case is not None:
            lang.compile_debug()
        else:
            lang.compile_release()
    except Exception as e:
        print(e)
        sys.exit(1)

    if args.test_case is not None:
        if is_test_exists(args.test_case, root):
            if run_test(lang, args.test_case, root, True, True):
                print(f'Passed test case ({args.test_case})')
    else:
        is_passed, n = run_tests(lang, root,
                                 args.time or args.test_case,
                                 args.debug or args.test_case)
        if is_passed:
            print(f'All test ({n}) cases passed!')

if __name__ == '__main__':
    try:
        main()
    except KeyboardInterrupt:
        print("\nKeyboard interrupt. Program killed");
        sys.exit(1)
    else:
        sys.exit(0)