Remove call to _get_slicing_spec in run_seq2seq_attack

This commit is contained in:
amad-person 2020-11-14 02:13:11 +08:00
parent 641c4dd98c
commit b25808cfbe

View file

@ -19,7 +19,7 @@ This file belongs to the new API for membership inference attacks. This file
will be renamed to membership_inference_attack.py after the old API is removed. will be renamed to membership_inference_attack.py after the old API is removed.
""" """
from typing import Iterable, Union from typing import Iterable
import numpy as np import numpy as np
from sklearn import metrics from sklearn import metrics
@ -38,7 +38,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing impo
from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_slice from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_slice
def _get_slice_spec(data: Union[AttackInputData, Seq2SeqAttackInputData]) -> SingleSliceSpec: def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
if hasattr(data, 'slice_spec'): if hasattr(data, 'slice_spec'):
return data.slice_spec return data.slice_spec
return SingleSliceSpec() return SingleSliceSpec()
@ -210,7 +210,7 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds) roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
attack_results = [SingleAttackResult( attack_results = [SingleAttackResult(
slice_spec=_get_slice_spec(attack_input), slice_spec=SingleSliceSpec(),
attack_type=AttackType.LOGISTIC_REGRESSION, attack_type=AttackType.LOGISTIC_REGRESSION,
roc_curve=roc_curve)] roc_curve=roc_curve)]