From 65eadd3a02f092844399388409767fad899012b6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 5 May 2022 16:20:46 -0700 Subject: [PATCH] Enable parallel processing in the Scikit-Learn models. Add support for configuring the parallel processing backend for Scikit-Learn while setting up the attack models. PiperOrigin-RevId: 446844669 --- .../membership_inference_attack.py | 30 ++-- .../membership_inference_attack_test.py | 23 ++- .../membership_inference_attack/models.py | 134 +++++++++++------- .../models_test.py | 10 ++ 4 files changed, 136 insertions(+), 61 deletions(-) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py index 1660497..acbbf94 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack.py @@ -18,7 +18,7 @@ will be renamed to membership_inference_attack.py after the old API is removed. """ import logging -from typing import Iterable, List, Union +from typing import Iterable, List, Optional, Union import numpy as np from scipy import special @@ -54,7 +54,8 @@ def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec: def _run_trained_attack(attack_input: AttackInputData, attack_type: AttackType, balance_attacker_training: bool = True, - cross_validation_folds: int = 2): + cross_validation_folds: int = 2, + backend: Optional[str] = None): """Classification attack done by ML models.""" prepared_attacker_data = models.create_attacker_data( attack_input, balance=balance_attacker_training) @@ -84,7 +85,7 @@ def _run_trained_attack(attack_input: AttackInputData, # Make sure one sample only got score predicted once assert np.all(np.isnan(scores[test_indices])) - attacker = models.create_attacker(attack_type) + attacker = models.create_attacker(attack_type, backend=backend) attacker.train_model(features[train_indices], labels[train_indices]) predictions = attacker.predict(features[test_indices]) scores[test_indices] = predictions @@ -161,7 +162,8 @@ def _run_threshold_entropy_attack(attack_input: AttackInputData): def _run_attack(attack_input: AttackInputData, attack_type: AttackType, balance_attacker_training: bool = True, - min_num_samples: int = 1): + min_num_samples: int = 1, + backend: Optional[str] = None): """Runs membership inference attacks for specified input and type. Args: @@ -172,6 +174,11 @@ def _run_attack(attack_input: AttackInputData, number of samples from the training and test sets used to develop the model under attack. min_num_samples: minimum number of examples in either training or test data. + backend: The Scikit-Learn/Joblib backend to use for model training, defaults + to `None`, which will use single-threaded training. Note that some systems + may not support multiprocessing and in those cases the `threading` backend + should be used. See https://joblib.readthedocs.io/en/latest/parallel.html + for more details. Returns: the attack result. @@ -182,8 +189,8 @@ def _run_attack(attack_input: AttackInputData, return None if attack_type.is_trained_attack: - return _run_trained_attack(attack_input, attack_type, - balance_attacker_training) + return _run_trained_attack( + attack_input, attack_type, balance_attacker_training, backend=backend) if attack_type == AttackType.THRESHOLD_ENTROPY_ATTACK: return _run_threshold_entropy_attack(attack_input) return _run_threshold_attack(attack_input) @@ -195,7 +202,8 @@ def run_attacks(attack_input: AttackInputData, AttackType.THRESHOLD_ATTACK,), privacy_report_metadata: PrivacyReportMetadata = None, balance_attacker_training: bool = True, - min_num_samples: int = 1) -> AttackResults: + min_num_samples: int = 1, + backend: Optional[str] = None) -> AttackResults: """Runs membership inference attacks on a classification model. It runs attacks specified by attack_types on each attack_input slice which is @@ -211,6 +219,11 @@ def run_attacks(attack_input: AttackInputData, number of samples from the training and test sets used to develop the model under attack. min_num_samples: minimum number of examples in either training or test data. + backend: The Scikit-Learn/Joblib backend to use for model training, defaults + to `None`, which will use single-threaded training. Note that some systems + may not support multiprocessing and in those cases the `threading` backend + should be used. See https://joblib.readthedocs.io/en/latest/parallel.html + for more details. Returns: the attack result. @@ -234,7 +247,8 @@ def run_attacks(attack_input: AttackInputData, for attack_type in attack_types: logging.info('Running attack: %s', attack_type.name) attack_result = _run_attack(attack_input_slice, attack_type, - balance_attacker_training, min_num_samples) + balance_attacker_training, min_num_samples, + backend) if attack_result is not None: logging.info('%s attack had an AUC=%s and attacker advantage=%s', attack_type.name, attack_result.get_auc(), diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py index 282478b..fc33af0 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/membership_inference_attack_test.py @@ -13,8 +13,8 @@ # limitations under the License. from absl.testing import absltest +from absl.testing import parameterized import numpy as np - from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType @@ -78,7 +78,7 @@ def get_test_input_logits_only(n_train, n_test): logits_test=rng.randn(n_test, 5) + 0.2) -class RunAttacksTest(absltest.TestCase): +class RunAttacksTest(parameterized.TestCase): def test_run_attacks_size(self): result = mia.run_attacks( @@ -87,6 +87,17 @@ class RunAttacksTest(absltest.TestCase): self.assertLen(result.single_attack_results, 2) + def test_run_attacks_parallel_backend(self): + result = mia.run_attacks( + get_multilabel_test_input(100, 100), + SlicingSpec(), ( + AttackType.THRESHOLD_ATTACK, + AttackType.LOGISTIC_REGRESSION, + ), + backend='threading') + + self.assertLen(result.single_attack_results, 2) + def test_trained_attacks_logits_only_size(self): result = mia.run_attacks( get_test_input_logits_only(100, 100), SlicingSpec(), @@ -217,6 +228,14 @@ class RunAttacksTestOnMultilabelData(absltest.TestCase): self.assertLen(result.single_attack_results, 1) + def test_run_attacks_parallel_backend(self): + result = mia.run_attacks( + get_multilabel_test_input(100, 100), + SlicingSpec(), (AttackType.LOGISTIC_REGRESSION,), + backend='threading') + + self.assertLen(result.single_attack_results, 1) + def test_run_attack_trained_sets_attack_type(self): result = mia._run_attack( get_multilabel_test_input(100, 100), AttackType.LOGISTIC_REGRESSION) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py index 28d61b2..8ca2ce6 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models.py @@ -13,7 +13,9 @@ # limitations under the License. """Trained models for membership inference attacks.""" +import contextlib import dataclasses +import logging from typing import Optional import numpy as np from sklearn import ensemble @@ -21,6 +23,7 @@ from sklearn import linear_model from sklearn import model_selection from sklearn import neighbors from sklearn import neural_network +from sklearn.utils import parallel_backend from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures @@ -101,19 +104,6 @@ def create_attacker_data(attack_input_data: data_structures.AttackInputData, data_size=data_structures.DataSize(ntrain=ntrain, ntest=ntest)) -def create_attacker(attack_type): - """Returns the corresponding attacker for the provided attack_type.""" - if attack_type == data_structures.AttackType.LOGISTIC_REGRESSION: - return LogisticRegressionAttacker() - if attack_type == data_structures.AttackType.MULTI_LAYERED_PERCEPTRON: - return MultilayerPerceptronAttacker() - if attack_type == data_structures.AttackType.RANDOM_FOREST: - return RandomForestAttacker() - if attack_type == data_structures.AttackType.K_NEAREST_NEIGHBORS: - return KNearestNeighborsAttacker() - raise NotImplementedError('Attack type %s not implemented yet.' % attack_type) - - def _sample_multidimensional_array(array, size): indices = np.random.choice(len(array), size, replace=False) return array[indices] @@ -138,8 +128,34 @@ def _column_stack(logits, loss): class TrainedAttacker: - """Base class for training attack models.""" - model = None + """Base class for training attack models. + + Attributes: + backend: Name of Scikit-Learn parallel backend to use for this attack + model. The default value of `None` performs single-threaded training. + model: The trained attack model. + ctx_mgr: The backend context manager within which to perform training. + Defaults to the null context manager for single-threaded training. + n_jobs: Number of jobs that can run in parallel when using a backend. + Set to `1` for single-threading, and to `-1` for all parallel + backends. + """ + + def __init__(self, backend: Optional[str] = None): + self.model = None + self.backend = backend + if backend is None: + # Default value of `None` will perform single-threaded training. + self.ctx_mgr = contextlib.nullcontext() + self.n_jobs = 1 + else: + self.n_jobs = -1 + self.ctx_mgr = parallel_backend( + # Values for 'backend': `loky`, `threading`, `multiprocessing`. + # Can also use `dask`, `distributed`, `ray` if they are installed. + backend=backend, + n_jobs=self.n_jobs) + logging.info('Using %s backend for training.', backend) def train_model(self, input_features, is_training_labels): """Train an attacker model. @@ -174,13 +190,14 @@ class LogisticRegressionAttacker(TrainedAttacker): """Logistic regression attacker.""" def train_model(self, input_features, is_training_labels): - lr = linear_model.LogisticRegression(solver='lbfgs') - param_grid = { - 'C': np.logspace(-4, 2, 10), - } - model = model_selection.GridSearchCV( - lr, param_grid=param_grid, cv=3, n_jobs=1, verbose=0) - model.fit(input_features, is_training_labels) + with self.ctx_mgr: + lr = linear_model.LogisticRegression(solver='lbfgs', n_jobs=self.n_jobs) + param_grid = { + 'C': np.logspace(-4, 2, 10), + } + model = model_selection.GridSearchCV( + lr, param_grid=param_grid, cv=3, n_jobs=self.n_jobs, verbose=0) + model.fit(input_features, is_training_labels) self.model = model @@ -188,16 +205,16 @@ class MultilayerPerceptronAttacker(TrainedAttacker): """Multilayer perceptron attacker.""" def train_model(self, input_features, is_training_labels): - mlp_model = neural_network.MLPClassifier() - param_grid = { - 'hidden_layer_sizes': [(64,), (32, 32)], - 'solver': ['adam'], - 'alpha': [0.0001, 0.001, 0.01], - } - n_jobs = -1 - model = model_selection.GridSearchCV( - mlp_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0) - model.fit(input_features, is_training_labels) + with self.ctx_mgr: + mlp_model = neural_network.MLPClassifier() + param_grid = { + 'hidden_layer_sizes': [(64,), (32, 32)], + 'solver': ['adam'], + 'alpha': [0.0001, 0.001, 0.01], + } + model = model_selection.GridSearchCV( + mlp_model, param_grid=param_grid, cv=3, n_jobs=self.n_jobs, verbose=0) + model.fit(input_features, is_training_labels) self.model = model @@ -206,19 +223,19 @@ class RandomForestAttacker(TrainedAttacker): def train_model(self, input_features, is_training_labels): """Setup a random forest pipeline with cross-validation.""" - rf_model = ensemble.RandomForestClassifier() + with self.ctx_mgr: + rf_model = ensemble.RandomForestClassifier(n_jobs=self.n_jobs) - param_grid = { - 'n_estimators': [100], - 'max_features': ['auto', 'sqrt'], - 'max_depth': [5, 10, 20, None], - 'min_samples_split': [2, 5, 10], - 'min_samples_leaf': [1, 2, 4] - } - n_jobs = -1 - model = model_selection.GridSearchCV( - rf_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0) - model.fit(input_features, is_training_labels) + param_grid = { + 'n_estimators': [100], + 'max_features': ['auto', 'sqrt'], + 'max_depth': [5, 10, 20, None], + 'min_samples_split': [2, 5, 10], + 'min_samples_leaf': [1, 2, 4] + } + model = model_selection.GridSearchCV( + rf_model, param_grid=param_grid, cv=3, n_jobs=self.n_jobs, verbose=0) + model.fit(input_features, is_training_labels) self.model = model @@ -226,11 +243,26 @@ class KNearestNeighborsAttacker(TrainedAttacker): """K nearest neighbor attacker.""" def train_model(self, input_features, is_training_labels): - knn_model = neighbors.KNeighborsClassifier() - param_grid = { - 'n_neighbors': [3, 5, 7], - } - model = model_selection.GridSearchCV( - knn_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0) - model.fit(input_features, is_training_labels) + with self.ctx_mgr: + knn_model = neighbors.KNeighborsClassifier(n_jobs=self.n_jobs) + param_grid = { + 'n_neighbors': [3, 5, 7], + } + model = model_selection.GridSearchCV( + knn_model, param_grid=param_grid, cv=3, n_jobs=self.n_jobs, verbose=0) + model.fit(input_features, is_training_labels) self.model = model + + +def create_attacker(attack_type, + backend: Optional[str] = None) -> TrainedAttacker: + """Returns the corresponding attacker for the provided attack_type.""" + if attack_type == data_structures.AttackType.LOGISTIC_REGRESSION: + return LogisticRegressionAttacker(backend=backend) + if attack_type == data_structures.AttackType.MULTI_LAYERED_PERCEPTRON: + return MultilayerPerceptronAttacker(backend=backend) + if attack_type == data_structures.AttackType.RANDOM_FOREST: + return RandomForestAttacker(backend=backend) + if attack_type == data_structures.AttackType.K_NEAREST_NEIGHBORS: + return KNearestNeighborsAttacker(backend=backend) + raise NotImplementedError('Attack type %s not implemented yet.' % attack_type) diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models_test.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models_test.py index 4ace899..c8a854d 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models_test.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/models_test.py @@ -17,6 +17,7 @@ import numpy as np from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import models from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackInputData +from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType class TrainedAttackerTest(absltest.TestCase): @@ -89,6 +90,15 @@ class TrainedAttackerTest(absltest.TestCase): self.assertLen(attacker_data.fold_indices, 6) self.assertEmpty(attacker_data.left_out_indices) + def test_training_with_threading_backend(self): + with self.assertLogs(level='INFO') as log: + attacker = models.create_attacker(AttackType.LOGISTIC_REGRESSION, + 'threading') + self.assertIsInstance(attacker, models.LogisticRegressionAttacker) + self.assertLen(log.output, 1) + self.assertLen(log.records, 1) + self.assertRegex(log.output[0], r'.+?Using .+? backend for training.') + if __name__ == '__main__': absltest.main()