From f8515dfd719c081e30a56348bd47ce4c6c65f3e8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 12 Aug 2020 07:23:09 -0700 Subject: [PATCH] Replaces predict with predict_proba. PiperOrigin-RevId: 326227257 --- .../membership_inference_attack/models.py | 31 ++++--------------- .../models_test.py | 2 +- 2 files changed, 7 insertions(+), 26 deletions(-) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models.py b/tensorflow_privacy/privacy/membership_inference_attack/models.py index 851d3ba..0f4c238 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/models.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/models.py @@ -115,8 +115,13 @@ class TrainedAttacker: Args: input_features : A vector of features with the same semantics as x_train passed to train_model. + Returns: + An array of probabilities denoting whether the example belongs to test. """ - raise NotImplementedError() + if self.model is None: + raise AssertionError( + 'Model not trained yet. Please call train_model first.') + return self.model.predict_proba(input_features)[:, 1] class LogisticRegressionAttacker(TrainedAttacker): @@ -132,12 +137,6 @@ class LogisticRegressionAttacker(TrainedAttacker): model.fit(input_features, is_training_labels) self.model = model - def predict(self, input_features): - if self.model is None: - raise AssertionError( - 'Model not trained yet. Please call train_model first.') - return self.model.predict(input_features) - class MultilayerPerceptronAttacker(TrainedAttacker): """Multilayer perceptron attacker.""" @@ -155,12 +154,6 @@ class MultilayerPerceptronAttacker(TrainedAttacker): model.fit(input_features, is_training_labels) self.model = model - def predict(self, input_features): - if self.model is None: - raise AssertionError( - 'Model not trained yet. Please call train_model first.') - return self.model.predict(input_features) - class RandomForestAttacker(TrainedAttacker): """Random forest attacker.""" @@ -182,12 +175,6 @@ class RandomForestAttacker(TrainedAttacker): model.fit(input_features, is_training_labels) self.model = model - def predict(self, input_features): - if self.model is None: - raise AssertionError( - 'Model not trained yet. Please call train_model first.') - return self.model.predict(input_features) - class KNearestNeighborsAttacker(TrainedAttacker): """K nearest neighbor attacker.""" @@ -201,9 +188,3 @@ class KNearestNeighborsAttacker(TrainedAttacker): knn_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0) model.fit(input_features, is_training_labels) self.model = model - - def predict(self, input_features): - if self.model is None: - raise AssertionError( - 'Model not trained yet. Please call train_model first.') - return self.model.predict(input_features) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models_test.py b/tensorflow_privacy/privacy/membership_inference_attack/models_test.py index e6b9fb6..b5cf1ac 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/models_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/models_test.py @@ -26,7 +26,7 @@ class TrainedAttackerTest(absltest.TestCase): def test_base_attacker_train_and_predict(self): base_attacker = models.TrainedAttacker() self.assertRaises(NotImplementedError, base_attacker.train_model, [], []) - self.assertRaises(NotImplementedError, base_attacker.predict, []) + self.assertRaises(AssertionError, base_attacker.predict, []) def test_predict_before_training(self): lr_attacker = models.LogisticRegressionAttacker()