forked from 626_privacy/tensorflow_privacy
Update to use TF 2.0 API in TensorFlow Privacy:
tf.logging -> Removed for absl tf.assert_type -> tf.debugging.assert_type tf.assert_less_equal -> tf.debugging.assert_less_equal tf.global_norm -> tf.linalg.global_norm PiperOrigin-RevId: 425730344
This commit is contained in:
parent
438da5a09b
commit
8a6827b27c
5 changed files with 21 additions and 16 deletions
|
@ -17,7 +17,7 @@ from absl import app
|
|||
from absl import flags
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics
|
||||
|
@ -91,10 +91,8 @@ def load_cifar10():
|
|||
|
||||
|
||||
def main(unused_argv):
|
||||
tf.logging.set_verbosity(tf.logging.ERROR)
|
||||
logging.set_verbosity(logging.ERROR)
|
||||
logging.set_stderrthreshold(logging.ERROR)
|
||||
logging.get_absl_handler().use_absl_log_file()
|
||||
logger = tf.get_logger()
|
||||
logger.set_level(logging.ERROR)
|
||||
|
||||
# Load training and test data.
|
||||
x_train, y_train, x_test, y_test = load_cifar10()
|
||||
|
|
|
@ -34,10 +34,10 @@ import os
|
|||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from tensorflow_privacy.privacy.optimizers import dp_optimizer
|
||||
|
@ -157,7 +157,9 @@ def compute_epsilon(steps):
|
|||
|
||||
|
||||
def main(unused_argv):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
logger = tf.get_logger()
|
||||
logger.set_level(logging.INFO)
|
||||
|
||||
if FLAGS.batch_size % FLAGS.microbatches != 0:
|
||||
raise ValueError('Number of microbatches should divide evenly batch_size')
|
||||
|
||||
|
|
|
@ -15,9 +15,9 @@
|
|||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
import tensorflow as tf
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from tensorflow_privacy.privacy.optimizers import dp_optimizer_vectorized
|
||||
|
@ -140,7 +140,9 @@ def load_mnist():
|
|||
|
||||
|
||||
def main(unused_argv):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
logger = tf.get_logger()
|
||||
logger.set_level(logging.INFO)
|
||||
|
||||
if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0:
|
||||
raise ValueError('Number of microbatches should divide evenly batch_size')
|
||||
|
||||
|
|
|
@ -25,10 +25,9 @@ import math
|
|||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
|
||||
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
|
||||
from tensorflow_privacy.privacy.optimizers import dp_optimizer
|
||||
|
@ -178,7 +177,9 @@ def print_privacy_guarantees(epochs, batch_size, samples, noise_multiplier):
|
|||
|
||||
|
||||
def main(unused_argv):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
logger = tf.get_logger()
|
||||
logger.set_level(logging.INFO)
|
||||
|
||||
if FLAGS.data_l2_norm <= 0:
|
||||
raise ValueError('data_l2_norm must be positive.')
|
||||
if FLAGS.dpsgd and FLAGS.learning_rate > 8 / FLAGS.data_l2_norm**2:
|
||||
|
|
|
@ -13,8 +13,9 @@
|
|||
# limitations under the License.
|
||||
"""Scratchpad for training a CNN on MNIST with DPSGD."""
|
||||
|
||||
from absl import logging
|
||||
import numpy as np
|
||||
import tensorflow.compat.v1 as tf
|
||||
import tensorflow as tf
|
||||
|
||||
tf.flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
|
||||
tf.flags.DEFINE_integer('batch_size', 256, 'Batch size')
|
||||
|
@ -86,7 +87,8 @@ def load_mnist():
|
|||
|
||||
|
||||
def main(unused_argv):
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
logger = tf.get_logger()
|
||||
logger.set_level(logging.INFO)
|
||||
|
||||
# Load training and test data.
|
||||
train_data, train_labels, test_data, test_labels = load_mnist()
|
||||
|
|
Loading…
Reference in a new issue