diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models.py b/tensorflow_privacy/privacy/membership_inference_attack/models.py index 851d3ba..0f4c238 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/models.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/models.py @@ -115,8 +115,13 @@ class TrainedAttacker: Args: input_features : A vector of features with the same semantics as x_train passed to train_model. + Returns: + An array of probabilities denoting whether the example belongs to test. """ - raise NotImplementedError() + if self.model is None: + raise AssertionError( + 'Model not trained yet. Please call train_model first.') + return self.model.predict_proba(input_features)[:, 1] class LogisticRegressionAttacker(TrainedAttacker): @@ -132,12 +137,6 @@ class LogisticRegressionAttacker(TrainedAttacker): model.fit(input_features, is_training_labels) self.model = model - def predict(self, input_features): - if self.model is None: - raise AssertionError( - 'Model not trained yet. Please call train_model first.') - return self.model.predict(input_features) - class MultilayerPerceptronAttacker(TrainedAttacker): """Multilayer perceptron attacker.""" @@ -155,12 +154,6 @@ class MultilayerPerceptronAttacker(TrainedAttacker): model.fit(input_features, is_training_labels) self.model = model - def predict(self, input_features): - if self.model is None: - raise AssertionError( - 'Model not trained yet. Please call train_model first.') - return self.model.predict(input_features) - class RandomForestAttacker(TrainedAttacker): """Random forest attacker.""" @@ -182,12 +175,6 @@ class RandomForestAttacker(TrainedAttacker): model.fit(input_features, is_training_labels) self.model = model - def predict(self, input_features): - if self.model is None: - raise AssertionError( - 'Model not trained yet. Please call train_model first.') - return self.model.predict(input_features) - class KNearestNeighborsAttacker(TrainedAttacker): """K nearest neighbor attacker.""" @@ -201,9 +188,3 @@ class KNearestNeighborsAttacker(TrainedAttacker): knn_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0) model.fit(input_features, is_training_labels) self.model = model - - def predict(self, input_features): - if self.model is None: - raise AssertionError( - 'Model not trained yet. Please call train_model first.') - return self.model.predict(input_features) diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models_test.py b/tensorflow_privacy/privacy/membership_inference_attack/models_test.py index e6b9fb6..b5cf1ac 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/models_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/models_test.py @@ -26,7 +26,7 @@ class TrainedAttackerTest(absltest.TestCase): def test_base_attacker_train_and_predict(self): base_attacker = models.TrainedAttacker() self.assertRaises(NotImplementedError, base_attacker.train_model, [], []) - self.assertRaises(NotImplementedError, base_attacker.predict, []) + self.assertRaises(AssertionError, base_attacker.predict, []) def test_predict_before_training(self): lr_attacker = models.LogisticRegressionAttacker()