O1: multithread k-search
This commit is contained in:
parent
70d4e4dfdc
commit
99ba0b3f6d
1 changed files with 27 additions and 13 deletions
|
@ -18,6 +18,7 @@ import random
|
|||
from tqdm import tqdm
|
||||
from opacus.validators import ModuleValidator
|
||||
from opacus.utils.batch_memory_manager import BatchMemoryManager
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from WideResNet import WideResNet
|
||||
from equations import get_eps_audit
|
||||
import student_model
|
||||
|
@ -33,6 +34,19 @@ DTYPE = None
|
|||
DATADIR = Path("./data")
|
||||
|
||||
|
||||
def get_k_audit(k, scores, hp):
|
||||
correct = np.sum(~scores[:k]) + np.sum(scores[-k:])
|
||||
|
||||
eps_lb = get_eps_audit(
|
||||
hp['target_points'],
|
||||
2*k,
|
||||
correct,
|
||||
hp['delta'],
|
||||
hp['p_value']
|
||||
)
|
||||
return eps_lb, k, correct, len(scores)
|
||||
|
||||
|
||||
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||
seed = np.random.randint(0, 1e9)
|
||||
seed ^= int(time.time())
|
||||
|
@ -853,20 +867,20 @@ def main():
|
|||
k_schedule = np.linspace(1, hp['target_points']//2, 40)
|
||||
k_schedule = np.floor(k_schedule).astype(int)
|
||||
|
||||
for k in tqdm(k_schedule):
|
||||
correct = np.sum(~scores[:k]) + np.sum(scores[-k:])
|
||||
total = len(scores)
|
||||
with ProcessPoolExecutor() as executor:
|
||||
futures = {
|
||||
executor.submit(get_k_audit, k, scores, hp): k for k in k_schedule
|
||||
}
|
||||
|
||||
eps_lb = get_eps_audit(
|
||||
hp['target_points'],
|
||||
2*k,
|
||||
correct,
|
||||
hp['delta'],
|
||||
hp['p_value']
|
||||
)
|
||||
|
||||
if eps_lb > audits[0]:
|
||||
audits = (eps_lb, k, correct, total)
|
||||
# Iterate through completed futures with a progress bar
|
||||
for future in tqdm(as_completed(futures), total=len(futures)):
|
||||
try:
|
||||
eps_lb, k, correct, total = future.result()
|
||||
if eps_lb > audits[0]:
|
||||
audits = (eps_lb, k, correct, total)
|
||||
except Exception as exc:
|
||||
k = futures[future]
|
||||
print(f"'k={k}' generated an exception: {exc}")
|
||||
|
||||
print(f"Audit total: {audits[2]}/{2*audits[1]}/{audits[3]}")
|
||||
print(f"p[ε < {audits[0]}] < {hp['p_value']} for true epsilon {hp['epsilon']}")
|
||||
|
|
Loading…
Reference in a new issue