O1: multithread k-search

This commit is contained in:
Akemi Izuko 2024-12-07 17:46:07 -07:00
parent 70d4e4dfdc
commit 99ba0b3f6d
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -18,6 +18,7 @@ import random
from tqdm import tqdm from tqdm import tqdm
from opacus.validators import ModuleValidator from opacus.validators import ModuleValidator
from opacus.utils.batch_memory_manager import BatchMemoryManager from opacus.utils.batch_memory_manager import BatchMemoryManager
from concurrent.futures import ProcessPoolExecutor, as_completed
from WideResNet import WideResNet from WideResNet import WideResNet
from equations import get_eps_audit from equations import get_eps_audit
import student_model import student_model
@ -33,6 +34,19 @@ DTYPE = None
DATADIR = Path("./data") DATADIR = Path("./data")
def get_k_audit(k, scores, hp):
correct = np.sum(~scores[:k]) + np.sum(scores[-k:])
eps_lb = get_eps_audit(
hp['target_points'],
2*k,
correct,
hp['delta'],
hp['p_value']
)
return eps_lb, k, correct, len(scores)
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10): def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
seed = np.random.randint(0, 1e9) seed = np.random.randint(0, 1e9)
seed ^= int(time.time()) seed ^= int(time.time())
@ -853,20 +867,20 @@ def main():
k_schedule = np.linspace(1, hp['target_points']//2, 40) k_schedule = np.linspace(1, hp['target_points']//2, 40)
k_schedule = np.floor(k_schedule).astype(int) k_schedule = np.floor(k_schedule).astype(int)
for k in tqdm(k_schedule): with ProcessPoolExecutor() as executor:
correct = np.sum(~scores[:k]) + np.sum(scores[-k:]) futures = {
total = len(scores) executor.submit(get_k_audit, k, scores, hp): k for k in k_schedule
}
eps_lb = get_eps_audit( # Iterate through completed futures with a progress bar
hp['target_points'], for future in tqdm(as_completed(futures), total=len(futures)):
2*k, try:
correct, eps_lb, k, correct, total = future.result()
hp['delta'], if eps_lb > audits[0]:
hp['p_value'] audits = (eps_lb, k, correct, total)
) except Exception as exc:
k = futures[future]
if eps_lb > audits[0]: print(f"'k={k}' generated an exception: {exc}")
audits = (eps_lb, k, correct, total)
print(f"Audit total: {audits[2]}/{2*audits[1]}/{audits[3]}") print(f"Audit total: {audits[2]}/{2*audits[1]}/{audits[3]}")
print(f"p[ε < {audits[0]}] < {hp['p_value']} for true epsilon {hp['epsilon']}") print(f"p[ε < {audits[0]}] < {hp['p_value']} for true epsilon {hp['epsilon']}")