O1: multithread k-search
This commit is contained in:
parent
70d4e4dfdc
commit
99ba0b3f6d
1 changed files with 27 additions and 13 deletions
|
@ -18,6 +18,7 @@ import random
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from opacus.validators import ModuleValidator
|
from opacus.validators import ModuleValidator
|
||||||
from opacus.utils.batch_memory_manager import BatchMemoryManager
|
from opacus.utils.batch_memory_manager import BatchMemoryManager
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
from WideResNet import WideResNet
|
from WideResNet import WideResNet
|
||||||
from equations import get_eps_audit
|
from equations import get_eps_audit
|
||||||
import student_model
|
import student_model
|
||||||
|
@ -33,6 +34,19 @@ DTYPE = None
|
||||||
DATADIR = Path("./data")
|
DATADIR = Path("./data")
|
||||||
|
|
||||||
|
|
||||||
|
def get_k_audit(k, scores, hp):
|
||||||
|
correct = np.sum(~scores[:k]) + np.sum(scores[-k:])
|
||||||
|
|
||||||
|
eps_lb = get_eps_audit(
|
||||||
|
hp['target_points'],
|
||||||
|
2*k,
|
||||||
|
correct,
|
||||||
|
hp['delta'],
|
||||||
|
hp['p_value']
|
||||||
|
)
|
||||||
|
return eps_lb, k, correct, len(scores)
|
||||||
|
|
||||||
|
|
||||||
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
def get_dataloaders(m=1000, train_batch_size=128, test_batch_size=10):
|
||||||
seed = np.random.randint(0, 1e9)
|
seed = np.random.randint(0, 1e9)
|
||||||
seed ^= int(time.time())
|
seed ^= int(time.time())
|
||||||
|
@ -853,20 +867,20 @@ def main():
|
||||||
k_schedule = np.linspace(1, hp['target_points']//2, 40)
|
k_schedule = np.linspace(1, hp['target_points']//2, 40)
|
||||||
k_schedule = np.floor(k_schedule).astype(int)
|
k_schedule = np.floor(k_schedule).astype(int)
|
||||||
|
|
||||||
for k in tqdm(k_schedule):
|
with ProcessPoolExecutor() as executor:
|
||||||
correct = np.sum(~scores[:k]) + np.sum(scores[-k:])
|
futures = {
|
||||||
total = len(scores)
|
executor.submit(get_k_audit, k, scores, hp): k for k in k_schedule
|
||||||
|
}
|
||||||
|
|
||||||
eps_lb = get_eps_audit(
|
# Iterate through completed futures with a progress bar
|
||||||
hp['target_points'],
|
for future in tqdm(as_completed(futures), total=len(futures)):
|
||||||
2*k,
|
try:
|
||||||
correct,
|
eps_lb, k, correct, total = future.result()
|
||||||
hp['delta'],
|
if eps_lb > audits[0]:
|
||||||
hp['p_value']
|
audits = (eps_lb, k, correct, total)
|
||||||
)
|
except Exception as exc:
|
||||||
|
k = futures[future]
|
||||||
if eps_lb > audits[0]:
|
print(f"'k={k}' generated an exception: {exc}")
|
||||||
audits = (eps_lb, k, correct, total)
|
|
||||||
|
|
||||||
print(f"Audit total: {audits[2]}/{2*audits[1]}/{audits[3]}")
|
print(f"Audit total: {audits[2]}/{2*audits[1]}/{audits[3]}")
|
||||||
print(f"p[ε < {audits[0]}] < {hp['p_value']} for true epsilon {hp['epsilon']}")
|
print(f"p[ε < {audits[0]}] < {hp['p_value']} for true epsilon {hp['epsilon']}")
|
||||||
|
|
Loading…
Reference in a new issue