Consistent string formatting.

PiperOrigin-RevId: 326007570
This commit is contained in:
A. Unique TensorFlower 2020-08-11 06:17:15 -07:00
parent 680aaa4499
commit 06bb047525
3 changed files with 7 additions and 7 deletions

View file

@ -56,7 +56,7 @@ class SingleSliceSpec:
if self.feature == SlicingFeature.PERCENTILE: if self.feature == SlicingFeature.PERCENTILE:
return 'Loss percentiles: %d-%d' % self.value return 'Loss percentiles: %d-%d' % self.value
return f'{self.feature.name}={self.value}' return '%s=%s' % (self.feature.name, self.value)
@dataclass @dataclass
@ -98,7 +98,7 @@ class AttackType(enum.Enum):
# Return LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION # Return LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION
def __str__(self): def __str__(self):
return f'{self.name}' return '%s' % self.name
@dataclass @dataclass
@ -125,7 +125,7 @@ class AttackInputData:
def num_classes(self): def num_classes(self):
if self.labels_train is None or self.labels_test is None: if self.labels_train is None or self.labels_test is None:
raise ValueError( raise ValueError(
"Can't identify the number of classes as no labels were provided. " 'Can\'t identify the number of classes as no labels were provided. '
'Please set labels_train and labels_test') 'Please set labels_train and labels_test')
return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1 return int(max(np.max(self.labels_train), np.max(self.labels_test))) + 1
@ -204,7 +204,7 @@ class RocCurve:
over all available classifier thresholds. over all available classifier thresholds.
Returns: Returns:
a single float number with membership attaker's advantage. a single float number with membership attacker's advantage.
""" """
return max(np.abs(self.tpr - self.fpr)) return max(np.abs(self.tpr - self.fpr))

View file

@ -137,7 +137,7 @@ def get_slice(data: AttackInputData,
elif slice_spec.feature == SlicingFeature.CORRECTLY_CLASSIFIED: elif slice_spec.feature == SlicingFeature.CORRECTLY_CLASSIFIED:
data_slice = _slice_by_classification_correctness(data, slice_spec.value) data_slice = _slice_by_classification_correctness(data, slice_spec.value)
else: else:
raise ValueError(f'Unknown slice spec feature "{slice_spec.feature}"') raise ValueError('Unknown slice spec feature "%s"' % slice_spec.feature)
data_slice.slice_spec = slice_spec data_slice.slice_spec = slice_spec
return data_slice return data_slice

View file

@ -54,8 +54,8 @@ def run_trained_attack(attack_input: AttackInputData, attack_type: AttackType):
elif attack_type == AttackType.K_NEAREST_NEIGHBORS: elif attack_type == AttackType.K_NEAREST_NEIGHBORS:
attacker = models.KNearestNeighborsAttacker() attacker = models.KNearestNeighborsAttacker()
else: else:
raise NotImplementedError( raise NotImplementedError('Attack type %s not implemented yet.' %
'Attack type {} not implemented yet.'.format(attack_type)) attack_type)
prepared_attacker_data = models.create_attacker_data(attack_input) prepared_attacker_data = models.create_attacker_data(attack_input)