From 1b099f4dad4d19605fa0c2c3f34b5f12cfcd905e Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Fri, 29 Nov 2024 19:43:38 -0700 Subject: [PATCH] Torchlira: attempt microbatch --- lira-pytorch/train.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) diff --git a/lira-pytorch/train.py b/lira-pytorch/train.py index c28eee1..d092913 100644 --- a/lira-pytorch/train.py +++ b/lira-pytorch/train.py @@ -21,6 +21,8 @@ from tqdm import tqdm from opacus.validators import ModuleValidator from opacus import PrivacyEngine from opacus.utils.batch_memory_manager import BatchMemoryManager +import pyvacy +#from pyvacy import optim#, analysis, sampling from wide_resnet import WideResNet @@ -113,6 +115,12 @@ def run(): ModuleValidator.validate(m, strict=True) optim = torch.optim.SGD(m.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) + #optim = pyvacy.DPSGD( + # params=m.parameters(), + # lr=args.lr, + # momentum=0.9, + # weight_decay=5e-4, + #) sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=args.epochs) privacy_engine = PrivacyEngine() @@ -124,12 +132,13 @@ def run(): target_epsilon=1, target_delta=1e-4, max_grad_norm=1.0, + batch_first=True, ) print(f"Device: {DEVICE}") + accumulation_steps = 10 # Train - # max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE, with BatchMemoryManager( data_loader=train_dl, max_physical_batch_size=1000, @@ -142,14 +151,24 @@ def run(): #pbar = tqdm(train_dl, leave=False) for itr, (x, y) in enumerate(pbar): x, y = x.to(DEVICE), y.to(DEVICE) + if False: + loss = F.cross_entropy(m(x), y) / accumulation_steps + loss_norm = loss / accumulation_steps + loss_total += loss_norm + loss_norm.backward() + pbar.set_postfix_str(f"loss: {loss:.2f}") - loss = F.cross_entropy(m(x), y) - loss_total += loss + if ((itr + 1) % accumulation_steps == 0) or (itr + 1 == len(memory_safe_data_loader)): + optim.step() + optim.zero_grad() + else: + loss = F.cross_entropy(m(x), y) + loss_total += loss - pbar.set_postfix_str(f"loss: {loss:.2f}") - optim.zero_grad() - loss.backward() - optim.step() + pbar.set_postfix_str(f"loss: {loss:.2f}") + optim.zero_grad() + loss.backward() + optim.step() sched.step() wandb.log({"loss": loss_total / len(train_dl)})