diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py index c60535f..5651e64 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py @@ -19,7 +19,7 @@ This file belongs to the new API for membership inference attacks. This file will be renamed to membership_inference_attack.py after the old API is removed. """ -from typing import Iterable, Union +from typing import Iterable import numpy as np from sklearn import metrics @@ -38,7 +38,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing impo from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_slice -def _get_slice_spec(data: Union[AttackInputData, Seq2SeqAttackInputData]) -> SingleSliceSpec: +def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec: if hasattr(data, 'slice_spec'): return data.slice_spec return SingleSliceSpec() @@ -210,7 +210,7 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData, roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) attack_results = [SingleAttackResult( - slice_spec=_get_slice_spec(attack_input), + slice_spec=SingleSliceSpec(), attack_type=AttackType.LOGISTIC_REGRESSION, roc_curve=roc_curve)]