diff --git a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation.py b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation.py index 53c8393..610c209 100644 --- a/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation.py +++ b/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/keras_evaluation.py @@ -25,6 +25,7 @@ from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_s from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import get_flattened_attack_metrics from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.data_structures import SlicingSpec from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack.utils_tensorboard import write_results_to_tensorboard_tf2 as write_results_to_tensorboard +from tensorflow_privacy.privacy.privacy_tests.utils import log_loss def calculate_losses(model, data, labels, is_logit=False, batch_size=32):