forked from 626_privacy/tensorflow_privacy
Make sklearn classifiers in parallel.
It's done only for those classifiers that run a significant amount of time. PiperOrigin-RevId: 326215987
This commit is contained in:
parent
37ff5d502e
commit
0fd06493cc
1 changed files with 4 additions and 2 deletions
|
@ -149,8 +149,9 @@ class MultilayerPerceptronAttacker(TrainedAttacker):
|
|||
'solver': ['adam'],
|
||||
'alpha': [0.0001, 0.001, 0.01],
|
||||
}
|
||||
n_jobs = -1
|
||||
model = model_selection.GridSearchCV(
|
||||
mlp_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0)
|
||||
mlp_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0)
|
||||
model.fit(input_features, is_training_labels)
|
||||
self.model = model
|
||||
|
||||
|
@ -175,8 +176,9 @@ class RandomForestAttacker(TrainedAttacker):
|
|||
'min_samples_split': [2, 5, 10],
|
||||
'min_samples_leaf': [1, 2, 4]
|
||||
}
|
||||
n_jobs = -1
|
||||
model = model_selection.GridSearchCV(
|
||||
rf_model, param_grid=param_grid, cv=3, n_jobs=1, verbose=0)
|
||||
rf_model, param_grid=param_grid, cv=3, n_jobs=n_jobs, verbose=0)
|
||||
model.fit(input_features, is_training_labels)
|
||||
self.model = model
|
||||
|
||||
|
|
Loading…
Reference in a new issue