O1: add epsilon audit

This commit is contained in:
Akemi Izuko 2024-12-03 23:26:50 -07:00
parent 5d6f7e2916
commit 369249ce69
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -17,6 +17,7 @@ import opacus
from opacus.validators import ModuleValidator from opacus.validators import ModuleValidator
from opacus.utils.batch_memory_manager import BatchMemoryManager from opacus.utils.batch_memory_manager import BatchMemoryManager
from WideResNet import WideResNet from WideResNet import WideResNet
from equations import get_eps_audit
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -225,9 +226,10 @@ def main():
"delta": 1e-5, "delta": 1e-5,
"norm": args.norm, "norm": args.norm,
"batch_size": 4096, "batch_size": 4096,
"epochs": 2, "epochs": 100,
"k+": 300, "k+": 300,
"k-": 300, "k-": 300,
"p_value": 0.05,
} }
hp['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format( hp['logfile'] = Path('WideResNet_{}_{}_{}_{}s_x{}_{}e_{}d_{}C.txt'.format(
@ -277,7 +279,17 @@ def main():
correct = np.sum(~scores[:hp['k-']]) + np.sum(scores[-hp['k+']:]) correct = np.sum(~scores[:hp['k-']]) + np.sum(scores[-hp['k+']:])
total = len(scores) total = len(scores)
eps_lb = get_eps_audit(
hp['target_points'],
hp['k+'] + hp['k-'],
correct,
hp['delta'],
hp['p_value']
)
print(f"Audit total: {correct}/{total} = {round(correct/total*100, 2)}") print(f"Audit total: {correct}/{total} = {round(correct/total*100, 2)}")
print(f"p[ε < {eps_lb}] < {hp['p_value']}")
correct, total = evaluate_on(model_init, train_dl) correct, total = evaluate_on(model_init, train_dl)
print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}") print(f"Init model accuracy: {correct}/{total} = {round(correct/total*100, 2)}")