Torchlira: attempt microbatch
This commit is contained in:
parent
c7eee3cdc2
commit
1b099f4dad
1 changed files with 26 additions and 7 deletions
|
@ -21,6 +21,8 @@ from tqdm import tqdm
|
|||
from opacus.validators import ModuleValidator
|
||||
from opacus import PrivacyEngine
|
||||
from opacus.utils.batch_memory_manager import BatchMemoryManager
|
||||
import pyvacy
|
||||
#from pyvacy import optim#, analysis, sampling
|
||||
|
||||
from wide_resnet import WideResNet
|
||||
|
||||
|
@ -113,6 +115,12 @@ def run():
|
|||
ModuleValidator.validate(m, strict=True)
|
||||
|
||||
optim = torch.optim.SGD(m.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)
|
||||
#optim = pyvacy.DPSGD(
|
||||
# params=m.parameters(),
|
||||
# lr=args.lr,
|
||||
# momentum=0.9,
|
||||
# weight_decay=5e-4,
|
||||
#)
|
||||
sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=args.epochs)
|
||||
|
||||
privacy_engine = PrivacyEngine()
|
||||
|
@ -124,12 +132,13 @@ def run():
|
|||
target_epsilon=1,
|
||||
target_delta=1e-4,
|
||||
max_grad_norm=1.0,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
print(f"Device: {DEVICE}")
|
||||
accumulation_steps = 10
|
||||
|
||||
# Train
|
||||
# max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,
|
||||
with BatchMemoryManager(
|
||||
data_loader=train_dl,
|
||||
max_physical_batch_size=1000,
|
||||
|
@ -142,14 +151,24 @@ def run():
|
|||
#pbar = tqdm(train_dl, leave=False)
|
||||
for itr, (x, y) in enumerate(pbar):
|
||||
x, y = x.to(DEVICE), y.to(DEVICE)
|
||||
if False:
|
||||
loss = F.cross_entropy(m(x), y) / accumulation_steps
|
||||
loss_norm = loss / accumulation_steps
|
||||
loss_total += loss_norm
|
||||
loss_norm.backward()
|
||||
pbar.set_postfix_str(f"loss: {loss:.2f}")
|
||||
|
||||
loss = F.cross_entropy(m(x), y)
|
||||
loss_total += loss
|
||||
if ((itr + 1) % accumulation_steps == 0) or (itr + 1 == len(memory_safe_data_loader)):
|
||||
optim.step()
|
||||
optim.zero_grad()
|
||||
else:
|
||||
loss = F.cross_entropy(m(x), y)
|
||||
loss_total += loss
|
||||
|
||||
pbar.set_postfix_str(f"loss: {loss:.2f}")
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
pbar.set_postfix_str(f"loss: {loss:.2f}")
|
||||
optim.zero_grad()
|
||||
loss.backward()
|
||||
optim.step()
|
||||
sched.step()
|
||||
|
||||
wandb.log({"loss": loss_total / len(train_dl)})
|
||||
|
|
Loading…
Reference in a new issue