forked from 626_privacy/tensorflow_privacy
Replaces predict with predict_proba.
PiperOrigin-RevId: 326227257
This commit is contained in:
parent
59192e6f5c
commit
f8515dfd71
2 changed files with 7 additions and 26 deletions
|
@ -115,8 +115,13 @@ class TrainedAttacker:
|
|||
Args:
|
||||
input_features : A vector of features with the same semantics as x_train
|
||||
passed to train_model.
|
||||
Returns:
|
||||
An array of probabilities denoting whether the example belongs to test.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
if self.model is None:
|
||||
raise AssertionError(
|
||||
'Model not trained yet. Please call train_model first.')
|
||||
return self.model.predict_proba(input_features)[:, 1]
|
||||
|
||||
|
||||
class LogisticRegressionAttacker(TrainedAttacker):
|
||||
|
@ -132,12 +137,6 @@ class LogisticRegressionAttacker(TrainedAttacker):
|
|||
model.fit(input_features, is_training_labels)
|
||||
self.model = model
|
||||
|
||||
def predict(self, input_features):
|
||||
if self.model is None:
|
||||
raise AssertionError(
|
||||
'Model not trained yet. Please call train_model first.')
|
||||
return self.model.predict(input_features)
|
||||
|
||||
|
||||
class MultilayerPerceptronAttacker(TrainedAttacker):
|
||||
"""Multilayer perceptron attacker."""
|
||||
|
@ -155,12 +154,6 @@ class MultilayerPerceptronAttacker(TrainedAttacker):
|
|||
model.fit(input_features, is_training_labels)
|
||||
self.model = model
|
||||
|
||||
def predict(self, input_features):
|
||||
if self.model is None:
|
||||
raise AssertionError(
|
||||
'Model not trained yet. Please call train_model first.')
|
||||
return self.model.predict(input_features)
|
||||
|
||||
|
||||
class RandomForestAttacker(TrainedAttacker):
|
||||
"""Random forest attacker."""
|
||||
|
@ -182,12 +175,6 @@ class RandomForestAttacker(TrainedAttacker):
|
|||
model.fit(input_features, is_training_labels)
|
||||
self.model = model
|
||||
|
||||
def predict(self, input_features):
|
||||
if self.model is None:
|
||||
raise AssertionError(
|
||||
'Model not trained yet. Please call train_model first.')
|
||||
return self.model.predict(input_features)
|
||||
|
||||
|
||||
class KNearestNeighborsAttacker(TrainedAttacker):
|
||||
"""K nearest neighbor attacker."""
|
||||
|
@ -201,9 +188,3 @@ class KNearestNeighborsAttacker(TrainedAttacker):
|
|||
knn_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0)
|
||||
model.fit(input_features, is_training_labels)
|
||||
self.model = model
|
||||
|
||||
def predict(self, input_features):
|
||||
if self.model is None:
|
||||
raise AssertionError(
|
||||
'Model not trained yet. Please call train_model first.')
|
||||
return self.model.predict(input_features)
|
||||
|
|
|
@ -26,7 +26,7 @@ class TrainedAttackerTest(absltest.TestCase):
|
|||
def test_base_attacker_train_and_predict(self):
|
||||
base_attacker = models.TrainedAttacker()
|
||||
self.assertRaises(NotImplementedError, base_attacker.train_model, [], [])
|
||||
self.assertRaises(NotImplementedError, base_attacker.predict, [])
|
||||
self.assertRaises(AssertionError, base_attacker.predict, [])
|
||||
|
||||
def test_predict_before_training(self):
|
||||
lr_attacker = models.LogisticRegressionAttacker()
|
||||
|
|
Loading…
Reference in a new issue