O1: slight cleanup

This commit is contained in:
Akemi Izuko 2024-12-07 17:37:19 -07:00
parent f407827ac1
commit 70d4e4dfdc
Signed by: akemi
GPG key ID: 8DE0764E1809E9FC

View file

@ -596,25 +596,12 @@ def train_convnet(hp, train_dl, test_dl):
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
#if hp['epochs'] <= 10:
# optimizer = optim.Adam(model.parameters(), lr=lr)
#elif hp['epochs'] > 10 and hp['epochs'] <= 25:
# optimizer = optim.Adam(model.parameters(), lr=(lr/10))
#else:
# optimizer = optim.Adam(model.parameters(), lr=(lr/50))
scheduler = MultiStepLR(optimizer, milestones=[10, 25], gamma=0.1)
# scheduler = MultiStepLR(
# optimizer,
# milestones=[int(i * hp['epochs']) for i in [0.3, 0.6, 0.8]],
# gamma=0.2
# )
print(f"Training with {hp['epochs']} epochs")
if hp['epsilon'] is not None:
privacy_engine = opacus.PrivacyEngine()
privacy_engine = opacus.PrivacyEngine(accountant='rdp')
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
@ -764,7 +751,7 @@ def main():
"wrn_depth": 16,
"wrn_width": 1,
"epsilon": args.epsilon,
"delta": 1e-5,
"delta": 1e-6,
"norm": args.norm,
"batch_size": 50 if args.convnet else 4096,
"epochs": args.epochs,