forked from 626_privacy/tensorflow_privacy
Replaces predict with predict_proba.
PiperOrigin-RevId: 326227257
This commit is contained in:
parent
59192e6f5c
commit
f8515dfd71
2 changed files with 7 additions and 26 deletions
|
@ -115,8 +115,13 @@ class TrainedAttacker:
|
||||||
Args:
|
Args:
|
||||||
input_features : A vector of features with the same semantics as x_train
|
input_features : A vector of features with the same semantics as x_train
|
||||||
passed to train_model.
|
passed to train_model.
|
||||||
|
Returns:
|
||||||
|
An array of probabilities denoting whether the example belongs to test.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
if self.model is None:
|
||||||
|
raise AssertionError(
|
||||||
|
'Model not trained yet. Please call train_model first.')
|
||||||
|
return self.model.predict_proba(input_features)[:, 1]
|
||||||
|
|
||||||
|
|
||||||
class LogisticRegressionAttacker(TrainedAttacker):
|
class LogisticRegressionAttacker(TrainedAttacker):
|
||||||
|
@ -132,12 +137,6 @@ class LogisticRegressionAttacker(TrainedAttacker):
|
||||||
model.fit(input_features, is_training_labels)
|
model.fit(input_features, is_training_labels)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def predict(self, input_features):
|
|
||||||
if self.model is None:
|
|
||||||
raise AssertionError(
|
|
||||||
'Model not trained yet. Please call train_model first.')
|
|
||||||
return self.model.predict(input_features)
|
|
||||||
|
|
||||||
|
|
||||||
class MultilayerPerceptronAttacker(TrainedAttacker):
|
class MultilayerPerceptronAttacker(TrainedAttacker):
|
||||||
"""Multilayer perceptron attacker."""
|
"""Multilayer perceptron attacker."""
|
||||||
|
@ -155,12 +154,6 @@ class MultilayerPerceptronAttacker(TrainedAttacker):
|
||||||
model.fit(input_features, is_training_labels)
|
model.fit(input_features, is_training_labels)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def predict(self, input_features):
|
|
||||||
if self.model is None:
|
|
||||||
raise AssertionError(
|
|
||||||
'Model not trained yet. Please call train_model first.')
|
|
||||||
return self.model.predict(input_features)
|
|
||||||
|
|
||||||
|
|
||||||
class RandomForestAttacker(TrainedAttacker):
|
class RandomForestAttacker(TrainedAttacker):
|
||||||
"""Random forest attacker."""
|
"""Random forest attacker."""
|
||||||
|
@ -182,12 +175,6 @@ class RandomForestAttacker(TrainedAttacker):
|
||||||
model.fit(input_features, is_training_labels)
|
model.fit(input_features, is_training_labels)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def predict(self, input_features):
|
|
||||||
if self.model is None:
|
|
||||||
raise AssertionError(
|
|
||||||
'Model not trained yet. Please call train_model first.')
|
|
||||||
return self.model.predict(input_features)
|
|
||||||
|
|
||||||
|
|
||||||
class KNearestNeighborsAttacker(TrainedAttacker):
|
class KNearestNeighborsAttacker(TrainedAttacker):
|
||||||
"""K nearest neighbor attacker."""
|
"""K nearest neighbor attacker."""
|
||||||
|
@ -201,9 +188,3 @@ class KNearestNeighborsAttacker(TrainedAttacker):
|
||||||
knn_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0)
|
knn_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0)
|
||||||
model.fit(input_features, is_training_labels)
|
model.fit(input_features, is_training_labels)
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
def predict(self, input_features):
|
|
||||||
if self.model is None:
|
|
||||||
raise AssertionError(
|
|
||||||
'Model not trained yet. Please call train_model first.')
|
|
||||||
return self.model.predict(input_features)
|
|
||||||
|
|
|
@ -26,7 +26,7 @@ class TrainedAttackerTest(absltest.TestCase):
|
||||||
def test_base_attacker_train_and_predict(self):
|
def test_base_attacker_train_and_predict(self):
|
||||||
base_attacker = models.TrainedAttacker()
|
base_attacker = models.TrainedAttacker()
|
||||||
self.assertRaises(NotImplementedError, base_attacker.train_model, [], [])
|
self.assertRaises(NotImplementedError, base_attacker.train_model, [], [])
|
||||||
self.assertRaises(NotImplementedError, base_attacker.predict, [])
|
self.assertRaises(AssertionError, base_attacker.predict, [])
|
||||||
|
|
||||||
def test_predict_before_training(self):
|
def test_predict_before_training(self):
|
||||||
lr_attacker = models.LogisticRegressionAttacker()
|
lr_attacker = models.LogisticRegressionAttacker()
|
||||||
|
|
Loading…
Reference in a new issue