Update to use TF 2.0 API in TensorFlow Privacy:

tf.logging -> Removed for absl
tf.assert_type -> tf.debugging.assert_type
tf.assert_less_equal -> tf.debugging.assert_less_equal
tf.global_norm -> tf.linalg.global_norm

PiperOrigin-RevId: 425730344
This commit is contained in:
Michael Reneer 2022-02-01 15:28:37 -08:00 committed by A. Unique TensorFlower
parent 438da5a09b
commit 8a6827b27c
5 changed files with 21 additions and 16 deletions

View file

@ -17,7 +17,7 @@ from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow as tf
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics
@ -91,10 +91,8 @@ def load_cifar10():
def main(unused_argv): def main(unused_argv):
tf.logging.set_verbosity(tf.logging.ERROR) logger = tf.get_logger()
logging.set_verbosity(logging.ERROR) logger.set_level(logging.ERROR)
logging.set_stderrthreshold(logging.ERROR)
logging.get_absl_handler().use_absl_log_file()
# Load training and test data. # Load training and test data.
x_train, y_train, x_test, y_test = load_cifar10() x_train, y_train, x_test, y_test = load_cifar10()

View file

@ -34,10 +34,10 @@ import os
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow as tf
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from tensorflow_privacy.privacy.optimizers import dp_optimizer from tensorflow_privacy.privacy.optimizers import dp_optimizer
@ -157,7 +157,9 @@ def compute_epsilon(steps):
def main(unused_argv): def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO) logger = tf.get_logger()
logger.set_level(logging.INFO)
if FLAGS.batch_size % FLAGS.microbatches != 0: if FLAGS.batch_size % FLAGS.microbatches != 0:
raise ValueError('Number of microbatches should divide evenly batch_size') raise ValueError('Number of microbatches should divide evenly batch_size')

View file

@ -15,9 +15,9 @@
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow as tf
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from tensorflow_privacy.privacy.optimizers import dp_optimizer_vectorized from tensorflow_privacy.privacy.optimizers import dp_optimizer_vectorized
@ -140,7 +140,9 @@ def load_mnist():
def main(unused_argv): def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO) logger = tf.get_logger()
logger.set_level(logging.INFO)
if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0: if FLAGS.dpsgd and FLAGS.batch_size % FLAGS.microbatches != 0:
raise ValueError('Number of microbatches should divide evenly batch_size') raise ValueError('Number of microbatches should divide evenly batch_size')

View file

@ -25,10 +25,9 @@ import math
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow.compat.v1 as tf
from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp
from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent
from tensorflow_privacy.privacy.optimizers import dp_optimizer from tensorflow_privacy.privacy.optimizers import dp_optimizer
@ -178,7 +177,9 @@ def print_privacy_guarantees(epochs, batch_size, samples, noise_multiplier):
def main(unused_argv): def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO) logger = tf.get_logger()
logger.set_level(logging.INFO)
if FLAGS.data_l2_norm <= 0: if FLAGS.data_l2_norm <= 0:
raise ValueError('data_l2_norm must be positive.') raise ValueError('data_l2_norm must be positive.')
if FLAGS.dpsgd and FLAGS.learning_rate > 8 / FLAGS.data_l2_norm**2: if FLAGS.dpsgd and FLAGS.learning_rate > 8 / FLAGS.data_l2_norm**2:

View file

@ -13,8 +13,9 @@
# limitations under the License. # limitations under the License.
"""Scratchpad for training a CNN on MNIST with DPSGD.""" """Scratchpad for training a CNN on MNIST with DPSGD."""
from absl import logging
import numpy as np import numpy as np
import tensorflow.compat.v1 as tf import tensorflow as tf
tf.flags.DEFINE_float('learning_rate', .15, 'Learning rate for training') tf.flags.DEFINE_float('learning_rate', .15, 'Learning rate for training')
tf.flags.DEFINE_integer('batch_size', 256, 'Batch size') tf.flags.DEFINE_integer('batch_size', 256, 'Batch size')
@ -86,7 +87,8 @@ def load_mnist():
def main(unused_argv): def main(unused_argv):
tf.logging.set_verbosity(tf.logging.INFO) logger = tf.get_logger()
logger.set_level(logging.INFO)
# Load training and test data. # Load training and test data.
train_data, train_labels, test_data, test_labels = load_mnist() train_data, train_labels, test_data, test_labels = load_mnist()