Bugfixes:

1. Child classes of 'TrainedAttacker' now have a 'backend' parameter, so require an __init__() method.

PiperOrigin-RevId: 451005298
This commit is contained in:
A. Unique TensorFlower 2022-05-25 13:45:59 -07:00
parent 95e527acfb
commit 5461f911a6
2 changed files with 35 additions and 17 deletions

View file

@ -127,7 +127,7 @@ def _column_stack(logits, loss):
return np.column_stack((logits, loss)) return np.column_stack((logits, loss))
class TrainedAttacker: class TrainedAttacker(object):
"""Base class for training attack models. """Base class for training attack models.
Attributes: Attributes:
@ -148,6 +148,7 @@ class TrainedAttacker:
# Default value of `None` will perform single-threaded training. # Default value of `None` will perform single-threaded training.
self.ctx_mgr = contextlib.nullcontext() self.ctx_mgr = contextlib.nullcontext()
self.n_jobs = 1 self.n_jobs = 1
logging.info('Using single-threaded backend for training.')
else: else:
self.n_jobs = -1 self.n_jobs = -1
self.ctx_mgr = parallel_backend( self.ctx_mgr = parallel_backend(
@ -189,6 +190,9 @@ class TrainedAttacker:
class LogisticRegressionAttacker(TrainedAttacker): class LogisticRegressionAttacker(TrainedAttacker):
"""Logistic regression attacker.""" """Logistic regression attacker."""
def __init__(self, backend: Optional[str] = None):
super().__init__(backend=backend)
def train_model(self, input_features, is_training_labels): def train_model(self, input_features, is_training_labels):
with self.ctx_mgr: with self.ctx_mgr:
lr = linear_model.LogisticRegression(solver='lbfgs', n_jobs=self.n_jobs) lr = linear_model.LogisticRegression(solver='lbfgs', n_jobs=self.n_jobs)
@ -204,6 +208,9 @@ class LogisticRegressionAttacker(TrainedAttacker):
class MultilayerPerceptronAttacker(TrainedAttacker): class MultilayerPerceptronAttacker(TrainedAttacker):
"""Multilayer perceptron attacker.""" """Multilayer perceptron attacker."""
def __init__(self, backend: Optional[str] = None):
super().__init__(backend=backend)
def train_model(self, input_features, is_training_labels): def train_model(self, input_features, is_training_labels):
with self.ctx_mgr: with self.ctx_mgr:
mlp_model = neural_network.MLPClassifier() mlp_model = neural_network.MLPClassifier()
@ -221,6 +228,9 @@ class MultilayerPerceptronAttacker(TrainedAttacker):
class RandomForestAttacker(TrainedAttacker): class RandomForestAttacker(TrainedAttacker):
"""Random forest attacker.""" """Random forest attacker."""
def __init__(self, backend: Optional[str] = None):
super().__init__(backend=backend)
def train_model(self, input_features, is_training_labels): def train_model(self, input_features, is_training_labels):
"""Setup a random forest pipeline with cross-validation.""" """Setup a random forest pipeline with cross-validation."""
with self.ctx_mgr: with self.ctx_mgr:
@ -242,6 +252,9 @@ class RandomForestAttacker(TrainedAttacker):
class KNearestNeighborsAttacker(TrainedAttacker): class KNearestNeighborsAttacker(TrainedAttacker):
"""K nearest neighbor attacker.""" """K nearest neighbor attacker."""
def __init__(self, backend: Optional[str] = None):
super().__init__(backend=backend)
def train_model(self, input_features, is_training_labels): def train_model(self, input_features, is_training_labels):
with self.ctx_mgr: with self.ctx_mgr:
knn_model = neighbors.KNeighborsClassifier(n_jobs=self.n_jobs) knn_model = neighbors.KNeighborsClassifier(n_jobs=self.n_jobs)

View file

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from absl.testing import absltest from absl.testing import absltest
from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import models from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import models
@ -20,7 +21,7 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_s
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
class TrainedAttackerTest(absltest.TestCase): class TrainedAttackerTest(parameterized.TestCase):
def test_base_attacker_train_and_predict(self): def test_base_attacker_train_and_predict(self):
base_attacker = models.TrainedAttacker() base_attacker = models.TrainedAttacker()
@ -90,15 +91,19 @@ class TrainedAttackerTest(absltest.TestCase):
self.assertLen(attacker_data.fold_indices, 6) self.assertLen(attacker_data.fold_indices, 6)
self.assertEmpty(attacker_data.left_out_indices) self.assertEmpty(attacker_data.left_out_indices)
def test_training_with_threading_backend(self): # Parameters for testing: backend.
@parameterized.named_parameters(
('threading_backend', 'threading'),
('None_backend', None),
)
def test_training_with_backends(self, backend):
with self.assertLogs(level='INFO') as log: with self.assertLogs(level='INFO') as log:
attacker = models.create_attacker(AttackType.LOGISTIC_REGRESSION, attacker = models.create_attacker(
'threading') AttackType.MULTI_LAYERED_PERCEPTRON, backend=backend)
self.assertIsInstance(attacker, models.LogisticRegressionAttacker) self.assertIsInstance(attacker, models.MultilayerPerceptronAttacker)
self.assertLen(log.output, 1) self.assertLen(log.output, 1)
self.assertLen(log.records, 1) self.assertLen(log.records, 1)
self.assertRegex(log.output[0], r'.+?Using .+? backend for training.') self.assertRegex(log.output[0], r'.+?Using .+? backend for training.')
if __name__ == '__main__': if __name__ == '__main__':
absltest.main() absltest.main()