Bugfixes:
1. Child classes of 'TrainedAttacker' now have a 'backend' parameter, so require an __init__() method. PiperOrigin-RevId: 451005298
This commit is contained in:
parent
95e527acfb
commit
5461f911a6
2 changed files with 35 additions and 17 deletions
|
@ -127,18 +127,18 @@ def _column_stack(logits, loss):
|
||||||
return np.column_stack((logits, loss))
|
return np.column_stack((logits, loss))
|
||||||
|
|
||||||
|
|
||||||
class TrainedAttacker:
|
class TrainedAttacker(object):
|
||||||
"""Base class for training attack models.
|
"""Base class for training attack models.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
backend: Name of Scikit-Learn parallel backend to use for this attack
|
backend: Name of Scikit-Learn parallel backend to use for this attack
|
||||||
model. The default value of `None` performs single-threaded training.
|
model. The default value of `None` performs single-threaded training.
|
||||||
model: The trained attack model.
|
model: The trained attack model.
|
||||||
ctx_mgr: The backend context manager within which to perform training.
|
ctx_mgr: The backend context manager within which to perform training.
|
||||||
Defaults to the null context manager for single-threaded training.
|
Defaults to the null context manager for single-threaded training.
|
||||||
n_jobs: Number of jobs that can run in parallel when using a backend.
|
n_jobs: Number of jobs that can run in parallel when using a backend.
|
||||||
Set to `1` for single-threading, and to `-1` for all parallel
|
Set to `1` for single-threading, and to `-1` for all parallel
|
||||||
backends.
|
backends.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, backend: Optional[str] = None):
|
def __init__(self, backend: Optional[str] = None):
|
||||||
|
@ -148,6 +148,7 @@ class TrainedAttacker:
|
||||||
# Default value of `None` will perform single-threaded training.
|
# Default value of `None` will perform single-threaded training.
|
||||||
self.ctx_mgr = contextlib.nullcontext()
|
self.ctx_mgr = contextlib.nullcontext()
|
||||||
self.n_jobs = 1
|
self.n_jobs = 1
|
||||||
|
logging.info('Using single-threaded backend for training.')
|
||||||
else:
|
else:
|
||||||
self.n_jobs = -1
|
self.n_jobs = -1
|
||||||
self.ctx_mgr = parallel_backend(
|
self.ctx_mgr = parallel_backend(
|
||||||
|
@ -189,6 +190,9 @@ class TrainedAttacker:
|
||||||
class LogisticRegressionAttacker(TrainedAttacker):
|
class LogisticRegressionAttacker(TrainedAttacker):
|
||||||
"""Logistic regression attacker."""
|
"""Logistic regression attacker."""
|
||||||
|
|
||||||
|
def __init__(self, backend: Optional[str] = None):
|
||||||
|
super().__init__(backend=backend)
|
||||||
|
|
||||||
def train_model(self, input_features, is_training_labels):
|
def train_model(self, input_features, is_training_labels):
|
||||||
with self.ctx_mgr:
|
with self.ctx_mgr:
|
||||||
lr = linear_model.LogisticRegression(solver='lbfgs', n_jobs=self.n_jobs)
|
lr = linear_model.LogisticRegression(solver='lbfgs', n_jobs=self.n_jobs)
|
||||||
|
@ -204,6 +208,9 @@ class LogisticRegressionAttacker(TrainedAttacker):
|
||||||
class MultilayerPerceptronAttacker(TrainedAttacker):
|
class MultilayerPerceptronAttacker(TrainedAttacker):
|
||||||
"""Multilayer perceptron attacker."""
|
"""Multilayer perceptron attacker."""
|
||||||
|
|
||||||
|
def __init__(self, backend: Optional[str] = None):
|
||||||
|
super().__init__(backend=backend)
|
||||||
|
|
||||||
def train_model(self, input_features, is_training_labels):
|
def train_model(self, input_features, is_training_labels):
|
||||||
with self.ctx_mgr:
|
with self.ctx_mgr:
|
||||||
mlp_model = neural_network.MLPClassifier()
|
mlp_model = neural_network.MLPClassifier()
|
||||||
|
@ -221,6 +228,9 @@ class MultilayerPerceptronAttacker(TrainedAttacker):
|
||||||
class RandomForestAttacker(TrainedAttacker):
|
class RandomForestAttacker(TrainedAttacker):
|
||||||
"""Random forest attacker."""
|
"""Random forest attacker."""
|
||||||
|
|
||||||
|
def __init__(self, backend: Optional[str] = None):
|
||||||
|
super().__init__(backend=backend)
|
||||||
|
|
||||||
def train_model(self, input_features, is_training_labels):
|
def train_model(self, input_features, is_training_labels):
|
||||||
"""Setup a random forest pipeline with cross-validation."""
|
"""Setup a random forest pipeline with cross-validation."""
|
||||||
with self.ctx_mgr:
|
with self.ctx_mgr:
|
||||||
|
@ -242,6 +252,9 @@ class RandomForestAttacker(TrainedAttacker):
|
||||||
class KNearestNeighborsAttacker(TrainedAttacker):
|
class KNearestNeighborsAttacker(TrainedAttacker):
|
||||||
"""K nearest neighbor attacker."""
|
"""K nearest neighbor attacker."""
|
||||||
|
|
||||||
|
def __init__(self, backend: Optional[str] = None):
|
||||||
|
super().__init__(backend=backend)
|
||||||
|
|
||||||
def train_model(self, input_features, is_training_labels):
|
def train_model(self, input_features, is_training_labels):
|
||||||
with self.ctx_mgr:
|
with self.ctx_mgr:
|
||||||
knn_model = neighbors.KNeighborsClassifier(n_jobs=self.n_jobs)
|
knn_model = neighbors.KNeighborsClassifier(n_jobs=self.n_jobs)
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from absl.testing import absltest
|
from absl.testing import absltest
|
||||||
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import models
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import models
|
||||||
|
@ -20,7 +21,7 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_s
|
||||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
||||||
|
|
||||||
|
|
||||||
class TrainedAttackerTest(absltest.TestCase):
|
class TrainedAttackerTest(parameterized.TestCase):
|
||||||
|
|
||||||
def test_base_attacker_train_and_predict(self):
|
def test_base_attacker_train_and_predict(self):
|
||||||
base_attacker = models.TrainedAttacker()
|
base_attacker = models.TrainedAttacker()
|
||||||
|
@ -90,15 +91,19 @@ class TrainedAttackerTest(absltest.TestCase):
|
||||||
self.assertLen(attacker_data.fold_indices, 6)
|
self.assertLen(attacker_data.fold_indices, 6)
|
||||||
self.assertEmpty(attacker_data.left_out_indices)
|
self.assertEmpty(attacker_data.left_out_indices)
|
||||||
|
|
||||||
def test_training_with_threading_backend(self):
|
# Parameters for testing: backend.
|
||||||
|
@parameterized.named_parameters(
|
||||||
|
('threading_backend', 'threading'),
|
||||||
|
('None_backend', None),
|
||||||
|
)
|
||||||
|
def test_training_with_backends(self, backend):
|
||||||
with self.assertLogs(level='INFO') as log:
|
with self.assertLogs(level='INFO') as log:
|
||||||
attacker = models.create_attacker(AttackType.LOGISTIC_REGRESSION,
|
attacker = models.create_attacker(
|
||||||
'threading')
|
AttackType.MULTI_LAYERED_PERCEPTRON, backend=backend)
|
||||||
self.assertIsInstance(attacker, models.LogisticRegressionAttacker)
|
self.assertIsInstance(attacker, models.MultilayerPerceptronAttacker)
|
||||||
self.assertLen(log.output, 1)
|
self.assertLen(log.output, 1)
|
||||||
self.assertLen(log.records, 1)
|
self.assertLen(log.records, 1)
|
||||||
self.assertRegex(log.output[0], r'.+?Using .+? backend for training.')
|
self.assertRegex(log.output[0], r'.+?Using .+? backend for training.')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|
Loading…
Reference in a new issue