diff --git a/one_run_audit/audit.py b/one_run_audit/audit.py index 19a2bcd..0565830 100644 --- a/one_run_audit/audit.py +++ b/one_run_audit/audit.py @@ -18,6 +18,7 @@ import random from tqdm import tqdm from opacus.validators import ModuleValidator from opacus.utils.batch_memory_manager import BatchMemoryManager +from concurrent.futures import ProcessPoolExecutor, as_completed from WideResNet import WideResNet from equations import get_eps_audit import student_model @@ -33,6 +34,19 @@ DTYPE = None DATADIR = Path("./data") +def get_k_audit(k, scores, hp): + correct = np.sum(~scores[:k]) + np.sum(scores[-k:]) + + eps_lb = get_eps_audit( + hp['target_points'], + 2*k, + correct, + hp['delta'], + hp['p_value'] + ) + return eps_lb, k, correct, len(scores) + + def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): seed = np.random.randint(0, 1e9) seed ^= int(time.time()) @@ -853,20 +867,20 @@ def main(): k_schedule = np.linspace(1, hp['target_points']//2, 40) k_schedule = np.floor(k_schedule).astype(int) - for k in tqdm(k_schedule): - correct = np.sum(~scores[:k]) + np.sum(scores[-k:]) - total = len(scores) + with ProcessPoolExecutor() as executor: + futures = { + executor.submit(get_k_audit, k, scores, hp): k for k in k_schedule + } - eps_lb = get_eps_audit( - hp['target_points'], - 2*k, - correct, - hp['delta'], - hp['p_value'] - ) - - if eps_lb > audits[0]: - audits = (eps_lb, k, correct, total) + # Iterate through completed futures with a progress bar + for future in tqdm(as_completed(futures), total=len(futures)): + try: + eps_lb, k, correct, total = future.result() + if eps_lb > audits[0]: + audits = (eps_lb, k, correct, total) + except Exception as exc: + k = futures[future] + print(f"'k={k}' generated an exception: {exc}") print(f"Audit total: {audits[2]}/{2*audits[1]}/{audits[3]}") print(f"p[ε < {audits[0]}] < {hp['p_value']} for true epsilon {hp['epsilon']}")