Bugfixes:
1. Child classes of 'TrainedAttacker' now have a 'backend' parameter, so require an __init__() method. PiperOrigin-RevId: 451005298
This commit is contained in:
parent
95e527acfb
commit
5461f911a6
2 changed files with 35 additions and 17 deletions
|
@ -127,7 +127,7 @@ def _column_stack(logits, loss):
|
|||
return np.column_stack((logits, loss))
|
||||
|
||||
|
||||
class TrainedAttacker:
|
||||
class TrainedAttacker(object):
|
||||
"""Base class for training attack models.
|
||||
|
||||
Attributes:
|
||||
|
@ -148,6 +148,7 @@ class TrainedAttacker:
|
|||
# Default value of `None` will perform single-threaded training.
|
||||
self.ctx_mgr = contextlib.nullcontext()
|
||||
self.n_jobs = 1
|
||||
logging.info('Using single-threaded backend for training.')
|
||||
else:
|
||||
self.n_jobs = -1
|
||||
self.ctx_mgr = parallel_backend(
|
||||
|
@ -189,6 +190,9 @@ class TrainedAttacker:
|
|||
class LogisticRegressionAttacker(TrainedAttacker):
|
||||
"""Logistic regression attacker."""
|
||||
|
||||
def __init__(self, backend: Optional[str] = None):
|
||||
super().__init__(backend=backend)
|
||||
|
||||
def train_model(self, input_features, is_training_labels):
|
||||
with self.ctx_mgr:
|
||||
lr = linear_model.LogisticRegression(solver='lbfgs', n_jobs=self.n_jobs)
|
||||
|
@ -204,6 +208,9 @@ class LogisticRegressionAttacker(TrainedAttacker):
|
|||
class MultilayerPerceptronAttacker(TrainedAttacker):
|
||||
"""Multilayer perceptron attacker."""
|
||||
|
||||
def __init__(self, backend: Optional[str] = None):
|
||||
super().__init__(backend=backend)
|
||||
|
||||
def train_model(self, input_features, is_training_labels):
|
||||
with self.ctx_mgr:
|
||||
mlp_model = neural_network.MLPClassifier()
|
||||
|
@ -221,6 +228,9 @@ class MultilayerPerceptronAttacker(TrainedAttacker):
|
|||
class RandomForestAttacker(TrainedAttacker):
|
||||
"""Random forest attacker."""
|
||||
|
||||
def __init__(self, backend: Optional[str] = None):
|
||||
super().__init__(backend=backend)
|
||||
|
||||
def train_model(self, input_features, is_training_labels):
|
||||
"""Setup a random forest pipeline with cross-validation."""
|
||||
with self.ctx_mgr:
|
||||
|
@ -242,6 +252,9 @@ class RandomForestAttacker(TrainedAttacker):
|
|||
class KNearestNeighborsAttacker(TrainedAttacker):
|
||||
"""K nearest neighbor attacker."""
|
||||
|
||||
def __init__(self, backend: Optional[str] = None):
|
||||
super().__init__(backend=backend)
|
||||
|
||||
def train_model(self, input_features, is_training_labels):
|
||||
with self.ctx_mgr:
|
||||
knn_model = neighbors.KNeighborsClassifier(n_jobs=self.n_jobs)
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import models
|
||||
|
@ -20,7 +21,7 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_s
|
|||
from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import AttackType
|
||||
|
||||
|
||||
class TrainedAttackerTest(absltest.TestCase):
|
||||
class TrainedAttackerTest(parameterized.TestCase):
|
||||
|
||||
def test_base_attacker_train_and_predict(self):
|
||||
base_attacker = models.TrainedAttacker()
|
||||
|
@ -90,15 +91,19 @@ class TrainedAttackerTest(absltest.TestCase):
|
|||
self.assertLen(attacker_data.fold_indices, 6)
|
||||
self.assertEmpty(attacker_data.left_out_indices)
|
||||
|
||||
def test_training_with_threading_backend(self):
|
||||
# Parameters for testing: backend.
|
||||
@parameterized.named_parameters(
|
||||
('threading_backend', 'threading'),
|
||||
('None_backend', None),
|
||||
)
|
||||
def test_training_with_backends(self, backend):
|
||||
with self.assertLogs(level='INFO') as log:
|
||||
attacker = models.create_attacker(AttackType.LOGISTIC_REGRESSION,
|
||||
'threading')
|
||||
self.assertIsInstance(attacker, models.LogisticRegressionAttacker)
|
||||
attacker = models.create_attacker(
|
||||
AttackType.MULTI_LAYERED_PERCEPTRON, backend=backend)
|
||||
self.assertIsInstance(attacker, models.MultilayerPerceptronAttacker)
|
||||
self.assertLen(log.output, 1)
|
||||
self.assertLen(log.records, 1)
|
||||
self.assertRegex(log.output[0], r'.+?Using .+? backend for training.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main()
|
||||
|
|
Loading…
Reference in a new issue