Bugfixes:

1. Child classes of 'TrainedAttacker' now have a 'backend' parameter, so require an __init__() method.

PiperOrigin-RevId: 451005298
This commit is contained in:
A. Unique TensorFlower 2022-05-25 13:45:59 -07:00
parent 95e527acfb
commit 5461f911a6
2 changed files with 35 additions and 17 deletions

View file

@ -127,18 +127,18 @@ def _column_stack(logits, loss):
return np.column_stack((logits, loss))
class TrainedAttacker:
"""Base class for training attack models.
class TrainedAttacker(object):
"""Base class for training attack models.
Attributes:
backend: Name of Scikit-Learn parallel backend to use for this attack
model. The default value of `None` performs single-threaded training.
model: The trained attack model.
ctx_mgr: The backend context manager within which to perform training.
Defaults to the null context manager for single-threaded training.
n_jobs: Number of jobs that can run in parallel when using a backend.
Set to `1` for single-threading, and to `-1` for all parallel
backends.
Attributes:
backend: Name of Scikit-Learn parallel backend to use for this attack
model. The default value of `None` performs single-threaded training.
model: The trained attack model.
ctx_mgr: The backend context manager within which to perform training.
Defaults to the null context manager for single-threaded training.
n_jobs: Number of jobs that can run in parallel when using a backend.
Set to `1` for single-threading, and to `-1` for all parallel
backends.
"""
def __init__(self, backend: Optional[str] = None):
@ -148,6 +148,7 @@ class TrainedAttacker:
# Default value of `None` will perform single-threaded training.
self.ctx_mgr = contextlib.nullcontext()
self.n_jobs = 1
logging.info('Using single-threaded backend for training.')
else:
self.n_jobs = -1
self.ctx_mgr = parallel_backend(
@ -189,6 +190,9 @@ class TrainedAttacker:
class LogisticRegressionAttacker(TrainedAttacker):
"""Logistic regression attacker."""
def __init__(self, backend: Optional[str] = None):
super().__init__(backend=backend)
def train_model(self, input_features, is_training_labels):
with self.ctx_mgr:
lr = linear_model.LogisticRegression(solver='lbfgs', n_jobs=self.n_jobs)
@ -204,6 +208,9 @@ class LogisticRegressionAttacker(TrainedAttacker):
class MultilayerPerceptronAttacker(TrainedAttacker):
"""Multilayer perceptron attacker."""
def __init__(self, backend: Optional[str] = None):
super().__init__(backend=backend)
def train_model(self, input_features, is_training_labels):
with self.ctx_mgr:
mlp_model = neural_network.MLPClassifier()
@ -221,6 +228,9 @@ class MultilayerPerceptronAttacker(TrainedAttacker):
class RandomForestAttacker(TrainedAttacker):
"""Random forest attacker."""
def __init__(self, backend: Optional[str] = None):
super().__init__(backend=backend)
def train_model(self, input_features, is_training_labels):
"""Setup a random forest pipeline with cross-validation."""
with self.ctx_mgr:
@ -242,6 +252,9 @@ class RandomForestAttacker(TrainedAttacker):
class KNearestNeighborsAttacker(TrainedAttacker):
"""K nearest neighbor attacker."""
def __init__(self, backend: Optional[str] = None):
super().__init__(backend=backend)
def train_model(self, input_features, is_training_labels):
with self.ctx_mgr:
knn_model = neighbors.KNeighborsClassifier(n_jobs=self.n_jobs)

View file

@ -13,6 +13,7 @@
# limitations under the License.
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import models
@ -20,7 +21,7 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_s
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
class TrainedAttackerTest(absltest.TestCase):
class TrainedAttackerTest(parameterized.TestCase):
def test_base_attacker_train_and_predict(self):
base_attacker = models.TrainedAttacker()
@ -90,15 +91,19 @@ class TrainedAttackerTest(absltest.TestCase):
self.assertLen(attacker_data.fold_indices, 6)
self.assertEmpty(attacker_data.left_out_indices)
def test_training_with_threading_backend(self):
# Parameters for testing: backend.
@parameterized.named_parameters(
('threading_backend', 'threading'),
('None_backend', None),
)
def test_training_with_backends(self, backend):
with self.assertLogs(level='INFO') as log:
attacker = models.create_attacker(AttackType.LOGISTIC_REGRESSION,
'threading')
self.assertIsInstance(attacker, models.LogisticRegressionAttacker)
attacker = models.create_attacker(
AttackType.MULTI_LAYERED_PERCEPTRON, backend=backend)
self.assertIsInstance(attacker, models.MultilayerPerceptronAttacker)
self.assertLen(log.output, 1)
self.assertLen(log.records, 1)
self.assertRegex(log.output[0], r'.+?Using .+? backend for training.')
if __name__ == '__main__':
absltest.main()