Replaces predict with predict_proba.

PiperOrigin-RevId: 326227257
This commit is contained in:
A. Unique TensorFlower 2020-08-12 07:23:09 -07:00
parent 59192e6f5c
commit f8515dfd71
2 changed files with 7 additions and 26 deletions

View file

@ -115,8 +115,13 @@ class TrainedAttacker:
Args: Args:
input_features : A vector of features with the same semantics as x_train input_features : A vector of features with the same semantics as x_train
passed to train_model. passed to train_model.
Returns:
An array of probabilities denoting whether the example belongs to test.
""" """
raise NotImplementedError() if self.model is None:
raise AssertionError(
'Model not trained yet. Please call train_model first.')
return self.model.predict_proba(input_features)[:, 1]
class LogisticRegressionAttacker(TrainedAttacker): class LogisticRegressionAttacker(TrainedAttacker):
@ -132,12 +137,6 @@ class LogisticRegressionAttacker(TrainedAttacker):
model.fit(input_features, is_training_labels) model.fit(input_features, is_training_labels)
self.model = model self.model = model
def predict(self, input_features):
if self.model is None:
raise AssertionError(
'Model not trained yet. Please call train_model first.')
return self.model.predict(input_features)
class MultilayerPerceptronAttacker(TrainedAttacker): class MultilayerPerceptronAttacker(TrainedAttacker):
"""Multilayer perceptron attacker.""" """Multilayer perceptron attacker."""
@ -155,12 +154,6 @@ class MultilayerPerceptronAttacker(TrainedAttacker):
model.fit(input_features, is_training_labels) model.fit(input_features, is_training_labels)
self.model = model self.model = model
def predict(self, input_features):
if self.model is None:
raise AssertionError(
'Model not trained yet. Please call train_model first.')
return self.model.predict(input_features)
class RandomForestAttacker(TrainedAttacker): class RandomForestAttacker(TrainedAttacker):
"""Random forest attacker.""" """Random forest attacker."""
@ -182,12 +175,6 @@ class RandomForestAttacker(TrainedAttacker):
model.fit(input_features, is_training_labels) model.fit(input_features, is_training_labels)
self.model = model self.model = model
def predict(self, input_features):
if self.model is None:
raise AssertionError(
'Model not trained yet. Please call train_model first.')
return self.model.predict(input_features)
class KNearestNeighborsAttacker(TrainedAttacker): class KNearestNeighborsAttacker(TrainedAttacker):
"""K nearest neighbor attacker.""" """K nearest neighbor attacker."""
@ -201,9 +188,3 @@ class KNearestNeighborsAttacker(TrainedAttacker):
knn_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0) knn_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0)
model.fit(input_features, is_training_labels) model.fit(input_features, is_training_labels)
self.model = model self.model = model
def predict(self, input_features):
if self.model is None:
raise AssertionError(
'Model not trained yet. Please call train_model first.')
return self.model.predict(input_features)

View file

@ -26,7 +26,7 @@ class TrainedAttackerTest(absltest.TestCase):
def test_base_attacker_train_and_predict(self): def test_base_attacker_train_and_predict(self):
base_attacker = models.TrainedAttacker() base_attacker = models.TrainedAttacker()
self.assertRaises(NotImplementedError, base_attacker.train_model, [], []) self.assertRaises(NotImplementedError, base_attacker.train_model, [], [])
self.assertRaises(NotImplementedError, base_attacker.predict, []) self.assertRaises(AssertionError, base_attacker.predict, [])
def test_predict_before_training(self): def test_predict_before_training(self):
lr_attacker = models.LogisticRegressionAttacker() lr_attacker = models.LogisticRegressionAttacker()