forked from 626_privacy/tensorflow_privacy
Consistent string formatting.
PiperOrigin-RevId: 326007570
This commit is contained in:
parent
680aaa4499
commit
06bb047525
3 changed files with 7 additions and 7 deletions
|
@ -56,7 +56,7 @@ class SingleSliceSpec:
|
||||||
if self.feature == SlicingFeature.PERCENTILE:
|
if self.feature == SlicingFeature.PERCENTILE:
|
||||||
return 'Loss percentiles: %d-%d' % self.value
|
return 'Loss percentiles: %d-%d' % self.value
|
||||||
|
|
||||||
return f'{self.feature.name}={self.value}'
|
return '%s=%s' % (self.feature.name, self.value)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -98,7 +98,7 @@ class AttackType(enum.Enum):
|
||||||
|
|
||||||
# Return LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION
|
# Return LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f'{self.name}'
|
return '%s' % self.name
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -125,7 +125,7 @@ class AttackInputData:
|
||||||
def num_classes(self):
|
def num_classes(self):
|
||||||
if self.labels_train is None or self.labels_test is None:
|
if self.labels_train is None or self.labels_test is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Can't identify the number of classes as no labels were provided. "
|
'Can\'t identify the number of classes as no labels were provided. '
|
||||||
'Please set labels_train and labels_test')
|
'Please set labels_train and labels_test')
|
||||||
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
|
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
|
||||||
|
|
||||||
|
@ -204,7 +204,7 @@ class RocCurve:
|
||||||
over all available classifier thresholds.
|
over all available classifier thresholds.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
a single float number with membership attaker's advantage.
|
a single float number with membership attacker's advantage.
|
||||||
"""
|
"""
|
||||||
return max(np.abs(self.tpr - self.fpr))
|
return max(np.abs(self.tpr - self.fpr))
|
||||||
|
|
||||||
|
|
|
@ -137,7 +137,7 @@ def get_slice(data: AttackInputData,
|
||||||
elif slice_spec.feature == SlicingFeature.CORRECTLY_CLASSIFIED:
|
elif slice_spec.feature == SlicingFeature.CORRECTLY_CLASSIFIED:
|
||||||
data_slice = _slice_by_classification_correctness(data, slice_spec.value)
|
data_slice = _slice_by_classification_correctness(data, slice_spec.value)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f'Unknown slice spec feature "{slice_spec.feature}"')
|
raise ValueError('Unknown slice spec feature "%s"' % slice_spec.feature)
|
||||||
|
|
||||||
data_slice.slice_spec = slice_spec
|
data_slice.slice_spec = slice_spec
|
||||||
return data_slice
|
return data_slice
|
||||||
|
|
|
@ -54,8 +54,8 @@ def run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
|
||||||
elif attack_type == AttackType.K_NEAREST_NEIGHBORS:
|
elif attack_type == AttackType.K_NEAREST_NEIGHBORS:
|
||||||
attacker = models.KNearestNeighborsAttacker()
|
attacker = models.KNearestNeighborsAttacker()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError('Attack type %s not implemented yet.' %
|
||||||
'Attack type {} not implemented yet.'.format(attack_type))
|
attack_type)
|
||||||
|
|
||||||
prepared_attacker_data = models.create_attacker_data(attack_input)
|
prepared_attacker_data = models.create_attacker_data(attack_input)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue