From 2b865a5f582418ba4389e8a6c7f55d612dcc5064 Mon Sep 17 00:00:00 2001 From: Akemi Izuko Date: Sat, 30 Nov 2024 13:34:05 -0700 Subject: [PATCH] Torchlira: add additional networks --- lira-pytorch/train.py | 143 +++++++++++++++++++++++++++++------------- 1 file changed, 101 insertions(+), 42 deletions(-) diff --git a/lira-pytorch/train.py b/lira-pytorch/train.py index d092913..e1a1c06 100644 --- a/lira-pytorch/train.py +++ b/lira-pytorch/train.py @@ -41,6 +41,60 @@ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("mp EPOCHS = args.epochs +class DewisNet(nn.Module): + def __init__(self): + super(DewisNet, self).__init__() + # I started my model from the tutorial: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html, + # then modified it. + + # 2 convolutional layers, with pooling after each + self.conv1 = nn.Conv2d(3, 12, 5) + self.conv2 = nn.Conv2d(12, 32, 5) + self.pool = nn.MaxPool2d(2, 2) + + # 3 linear layers + self.fc1 = nn.Linear(32 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.pool(F.relu(self.conv1(x))) + x = self.pool(F.relu(self.conv2(x))) + x = torch.flatten(x, 1) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class JagielskiNet(nn.Module): + def __init__(self, input_shape, num_classes, l2=0.01): + super(JagielskiNet, self).__init__() + self.flatten = nn.Flatten() + + input_dim = 1 + for dim in input_shape: + input_dim *= dim + + self.dense1 = nn.Linear(input_dim, 32) + self.relu1 = nn.ReLU() + self.dense2 = nn.Linear(32, num_classes) + + # Initialize weights with Glorot Normal (Xavier Normal) + torch.nn.init.xavier_normal_(self.dense1.weight) + torch.nn.init.xavier_normal_(self.dense2.weight) + + # L2 regularization (weight decay) + self.l2 = l2 + + def forward(self, x): + x = self.flatten(x) + x = self.dense1(x) + x = self.relu1(x) + x = self.dense2(x) + return x + + def run(): seed = np.random.randint(0, 1000000000) seed ^= int(time.time()) @@ -92,18 +146,19 @@ def run(): keep_bool[keep] = True train_ds = torch.utils.data.Subset(train_ds, keep) - train_dl = DataLoader(train_ds, batch_size=128, shuffle=True, num_workers=4) + train_dl = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=4) test_dl = DataLoader(test_ds, batch_size=128, shuffle=False, num_workers=4) # Model - if args.model == "wresnet28-2": + if args.model == "dewisnet": + m = DewisNet() + elif args.model == "jnet": + m = JagielskiNet((3,32,32), 10) + elif args.model == "wresnet28-2": m = WideResNet(28, 2, 0.0, 10) - print("one") elif args.model == "wresnet28-10": m = WideResNet(28, 10, 0.3, 10) - print("two") elif args.model == "resnet18": - print("three") m = models.resnet18(weights=None, num_classes=10) m.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) m.maxpool = nn.Identity() @@ -114,6 +169,7 @@ def run(): m = ModuleValidator.fix(m) ModuleValidator.validate(m, strict=True) + print(f"Device: {DEVICE}") optim = torch.optim.SGD(m.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4) #optim = pyvacy.DPSGD( # params=m.parameters(), @@ -123,45 +179,31 @@ def run(): #) sched = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=args.epochs) - privacy_engine = PrivacyEngine() - m, optim, train_dl = privacy_engine.make_private_with_epsilon( - module=m, - optimizer=optim, - data_loader=train_dl, - epochs=args.epochs, - target_epsilon=1, - target_delta=1e-4, - max_grad_norm=1.0, - batch_first=True, - ) - - print(f"Device: {DEVICE}") - accumulation_steps = 10 - # Train - with BatchMemoryManager( - data_loader=train_dl, - max_physical_batch_size=1000, - optimizer=optim - ) as memory_safe_data_loader: - for i in tqdm(range(args.epochs)): - m.train() - loss_total = 0 - pbar = tqdm(memory_safe_data_loader, leave=False) - #pbar = tqdm(train_dl, leave=False) - for itr, (x, y) in enumerate(pbar): - x, y = x.to(DEVICE), y.to(DEVICE) - if False: - loss = F.cross_entropy(m(x), y) / accumulation_steps - loss_norm = loss / accumulation_steps - loss_total += loss_norm - loss_norm.backward() - pbar.set_postfix_str(f"loss: {loss:.2f}") + if False: + privacy_engine = PrivacyEngine() + m, optim, train_dl = privacy_engine.make_private_with_epsilon( + module=m, + optimizer=optim, + data_loader=train_dl, + epochs=args.epochs, + target_epsilon=8, + target_delta=1e-4, + max_grad_norm=1.0, + batch_first=True, + ) - if ((itr + 1) % accumulation_steps == 0) or (itr + 1 == len(memory_safe_data_loader)): - optim.step() - optim.zero_grad() - else: + with BatchMemoryManager( + data_loader=train_dl, + max_physical_batch_size=1000, + optimizer=optim + ) as memory_safe_data_loader: + for i in tqdm(range(args.epochs)): + m.train() + loss_total = 0 + pbar = tqdm(memory_safe_data_loader, leave=False) + for itr, (x, y) in enumerate(pbar): + x, y = x.to(DEVICE), y.to(DEVICE) loss = F.cross_entropy(m(x), y) loss_total += loss @@ -169,6 +211,23 @@ def run(): optim.zero_grad() loss.backward() optim.step() + sched.step() + + wandb.log({"loss": loss_total / len(train_dl)}) + else: + for i in tqdm(range(args.epochs)): + m.train() + loss_total = 0 + pbar = tqdm(train_dl, leave=False) + for itr, (x, y) in enumerate(pbar): + x, y = x.to(DEVICE), y.to(DEVICE) + loss = F.cross_entropy(m(x), y) + loss_total += loss + + pbar.set_postfix_str(f"loss: {loss:.2f}") + optim.zero_grad() + loss.backward() + optim.step() sched.step() wandb.log({"loss": loss_total / len(train_dl)})