diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models.py b/tensorflow_privacy/privacy/membership_inference_attack/models.py index b4e7056..851d3ba 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/models.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/models.py @@ -149,8 +149,9 @@ class MultilayerPerceptronAttacker(TrainedAttacker): 'solver': ['adam'], 'alpha': [0.0001, 0.001, 0.01], } + n_jobs = -1 model = model_selection.GridSearchCV( - mlp_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0) + mlp_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0) model.fit(input_features, is_training_labels) self.model = model @@ -175,8 +176,9 @@ class RandomForestAttacker(TrainedAttacker): 'min_samples_split': [2, 5, 10], 'min_samples_leaf': [1, 2, 4] } + n_jobs = -1 model = model_selection.GridSearchCV( - rf_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0) + rf_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0) model.fit(input_features, is_training_labels) self.model = model