Add files via upload
This commit is contained in:
parent
2ef5c6e332
commit
239827251a
3 changed files with 21 additions and 25 deletions
|
@ -27,8 +27,6 @@ import tensorflow as tf
|
|||
import pandas as pd
|
||||
from sklearn.model_selection import KFold
|
||||
|
||||
# from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
||||
# from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from tensorflow_privacy.privacy.optimizers import dp_optimizer
|
||||
|
||||
from tensorflow_privacy.privacy.analysis.gdp_accountant import *
|
||||
|
@ -46,9 +44,10 @@ flags.DEFINE_integer('max_mu', 2, 'GDP upper limit')
|
|||
flags.DEFINE_string('model_dir', None, 'Model directory')
|
||||
|
||||
microbatches = 256
|
||||
num_examples = 29305
|
||||
|
||||
def nn_model_fn(features, labels, mode):
|
||||
''' Define CNN architecture using tf.keras.layers.'''
|
||||
'''Define CNN architecture using tf.keras.layers.'''
|
||||
input_layer = tf.reshape(features['x'], [-1, 123])
|
||||
y = tf.keras.layers.Dense(16, activation='relu').apply(input_layer)
|
||||
logits = tf.keras.layers.Dense(2).apply(y)
|
||||
|
@ -137,12 +136,12 @@ def main(unused_argv):
|
|||
shuffle=False)
|
||||
|
||||
# Training loop.
|
||||
steps_per_epoch = 29305 // 256
|
||||
steps_per_epoch = num_examples // microbatches
|
||||
test_accuracy_list = []
|
||||
for epoch in range(1, FLAGS.epochs + 1):
|
||||
for step in range(steps_per_epoch):
|
||||
whether = np.random.random_sample(29305) > (1-256/29305)
|
||||
subsampling = [i for i in np.arange(29305) if whether[i]]
|
||||
whether = np.random.random_sample(num_examples) > (1-microbatches/num_examples)
|
||||
subsampling = [i for i in np.arange(num_examples) if whether[i]]
|
||||
global microbatches
|
||||
microbatches = len(subsampling)
|
||||
|
||||
|
@ -163,8 +162,8 @@ def main(unused_argv):
|
|||
|
||||
# Compute the privacy budget expended so far.
|
||||
if FLAGS.dpsgd:
|
||||
eps = compute_eps_Poisson(epoch, FLAGS.noise_multiplier, 29305, 256, 1e-5)
|
||||
mu = compute_mu_Poisson(epoch, FLAGS.noise_multiplier, 29305, 256)
|
||||
eps = compute_eps_poisson(epoch, FLAGS.noise_multiplier, num_examples, 256, 1e-5)
|
||||
mu = compute_mu_poisson(epoch, FLAGS.noise_multiplier, num_examples, 256)
|
||||
print('For delta=1e-5, the current epsilon is: %.2f' % eps)
|
||||
print('For delta=1e-5, the current mu is: %.2f' % mu)
|
||||
|
||||
|
|
|
@ -26,8 +26,6 @@ import numpy as np
|
|||
import tensorflow as tf
|
||||
from keras.preprocessing import sequence
|
||||
|
||||
#from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
||||
#from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from tensorflow_privacy.privacy.optimizers import dp_optimizer
|
||||
|
||||
from tensorflow_privacy.privacy.analysis.gdp_accountant import *
|
||||
|
@ -51,7 +49,7 @@ microbatches = 512
|
|||
max_features = 10000
|
||||
# cut texts after this number of words (among top max_features most common words)
|
||||
maxlen = 256
|
||||
|
||||
num_examples = 25000
|
||||
|
||||
def nn_model_fn(features, labels, mode):
|
||||
'''Define NN architecture using tf.keras.layers.'''
|
||||
|
@ -139,13 +137,13 @@ def main(unused_argv):
|
|||
shuffle=False)
|
||||
|
||||
# Training loop.
|
||||
steps_per_epoch = 25000 // 512
|
||||
steps_per_epoch = num_examples // microbatches
|
||||
test_accuracy_list = []
|
||||
|
||||
for epoch in range(1, FLAGS.epochs + 1):
|
||||
for step in range(steps_per_epoch):
|
||||
whether = np.random.random_sample(25000) > (1-512/25000)
|
||||
subsampling = [i for i in np.arange(25000) if whether[i]]
|
||||
whether = np.random.random_sample(num_examples) > (1-microbatches/num_examples)
|
||||
subsampling = [i for i in np.arange(num_examples) if whether[i]]
|
||||
global microbatches
|
||||
microbatches = len(subsampling)
|
||||
|
||||
|
@ -166,8 +164,8 @@ def main(unused_argv):
|
|||
|
||||
# Compute the privacy budget expended so far.
|
||||
if FLAGS.dpsgd:
|
||||
eps = compute_eps_Poisson(epoch, FLAGS.noise_multiplier, 25000, 512, 1e-5)
|
||||
mu = compute_mu_Poisson(epoch, FLAGS.noise_multiplier, 25000, 512)
|
||||
eps = compute_eps_poisson(epoch, FLAGS.noise_multiplier, num_examples, microbatches, 1e-5)
|
||||
mu = compute_mu_poisson(epoch, FLAGS.noise_multiplier, num_examples, microbatches)
|
||||
print('For delta=1e-5, the current epsilon is: %.2f' % eps)
|
||||
print('For delta=1e-5, the current mu is: %.2f' % mu)
|
||||
|
||||
|
|
|
@ -28,8 +28,6 @@ import pandas as pd
|
|||
from scipy.stats import rankdata
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
#from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
||||
#from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from tensorflow_privacy.privacy.optimizers import dp_optimizer
|
||||
|
||||
from tensorflow_privacy.privacy.analysis.gdp_accountant import *
|
||||
|
@ -48,6 +46,7 @@ flags.DEFINE_string('model_dir', None, 'Model directory')
|
|||
|
||||
|
||||
microbatches = 10000
|
||||
num_examples = 800167
|
||||
|
||||
def nn_model_fn(features, labels, mode):
|
||||
'''NN adapted from github.com/hexiangnan/neural_collaborative_filtering'''
|
||||
|
@ -132,7 +131,7 @@ def nn_model_fn(features, labels, mode):
|
|||
return None
|
||||
|
||||
|
||||
def load_adult():
|
||||
def load_movielens():
|
||||
"""Loads MovieLens 1M as from https://grouplens.org/datasets/movielens/1m"""
|
||||
data = pd.read_csv('ratings.dat', sep='::', header=None,
|
||||
names=["userId", "movieId", "rating", "timestamp"])
|
||||
|
@ -158,7 +157,7 @@ def main(unused_argv):
|
|||
tf.compat.v1.logging.set_verbosity(3)
|
||||
|
||||
# Load training and test data.
|
||||
train_data, test_data, mean = load_adult()
|
||||
train_data, test_data, mean = load_movielens()
|
||||
|
||||
# Instantiate the tf.Estimator.
|
||||
ml_classifier = tf.estimator.Estimator(model_fn=nn_model_fn,
|
||||
|
@ -172,12 +171,12 @@ def main(unused_argv):
|
|||
shuffle=False)
|
||||
|
||||
# Training loop.
|
||||
steps_per_epoch = 800167 // 10000
|
||||
steps_per_epoch = num_examples // microbatches
|
||||
test_accuracy_list = []
|
||||
for epoch in range(1, FLAGS.epochs + 1):
|
||||
for step in range(steps_per_epoch):
|
||||
whether = np.random.random_sample(800167) > (1-10000/800167)
|
||||
subsampling = [i for i in np.arange(800167) if whether[i]]
|
||||
whether = np.random.random_sample(num_examples) > (1-microbatches/num_examples)
|
||||
subsampling = [i for i in np.arange(num_examples) if whether[i]]
|
||||
global microbatches
|
||||
microbatches = len(subsampling)
|
||||
|
||||
|
@ -198,8 +197,8 @@ def main(unused_argv):
|
|||
|
||||
# Compute the privacy budget expended so far.
|
||||
if FLAGS.dpsgd:
|
||||
eps = compute_eps_Poisson(epoch, FLAGS.noise_multiplier, 800167, 10000, 1e-6)
|
||||
mu = compute_mu_Poisson(epoch, FLAGS.noise_multiplier, 800167, 10000)
|
||||
eps = compute_eps_poisson(epoch, FLAGS.noise_multiplier, num_examples, microbatches, 1e-6)
|
||||
mu = compute_mu_poisson(epoch, FLAGS.noise_multiplier, num_examples, microbatches)
|
||||
print('For delta=1e-6, the current epsilon is: %.2f' % eps)
|
||||
print('For delta=1e-6, the current mu is: %.2f' % mu)
|
||||
|
||||
|
|
Loading…
Reference in a new issue