Remove call to _get_slicing_spec in run_seq2seq_attack
This commit is contained in:
parent
641c4dd98c
commit
b25808cfbe
1 changed files with 3 additions and 3 deletions
|
@ -19,7 +19,7 @@ This file belongs to the new API for membership inference attacks. This file
|
||||||
will be renamed to membership_inference_attack.py after the old API is removed.
|
will be renamed to membership_inference_attack.py after the old API is removed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Iterable, Union
|
from typing import Iterable
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing impo
|
||||||
from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_slice
|
from tensorflow_privacy.privacy.membership_inference_attack.dataset_slicing import get_slice
|
||||||
|
|
||||||
|
|
||||||
def _get_slice_spec(data: Union[AttackInputData, Seq2SeqAttackInputData]) -> SingleSliceSpec:
|
def _get_slice_spec(data: AttackInputData) -> SingleSliceSpec:
|
||||||
if hasattr(data, 'slice_spec'):
|
if hasattr(data, 'slice_spec'):
|
||||||
return data.slice_spec
|
return data.slice_spec
|
||||||
return SingleSliceSpec()
|
return SingleSliceSpec()
|
||||||
|
@ -210,7 +210,7 @@ def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
|
||||||
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
|
||||||
|
|
||||||
attack_results = [SingleAttackResult(
|
attack_results = [SingleAttackResult(
|
||||||
slice_spec=_get_slice_spec(attack_input),
|
slice_spec=SingleSliceSpec(),
|
||||||
attack_type=AttackType.LOGISTIC_REGRESSION,
|
attack_type=AttackType.LOGISTIC_REGRESSION,
|
||||||
roc_curve=roc_curve)]
|
roc_curve=roc_curve)]
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue