Overrides default __str__ methods.

PiperOrigin-RevId: 327423772
This commit is contained in:
A. Unique TensorFlower 2020-08-19 06:48:18 -07:00
parent 6dccd9b537
commit 52c1f8fdfe

View file

@ -82,6 +82,25 @@ class SlicingSpec:
# examples will be generated.
by_classification_correctness: bool = False
def __str__(self):
"""Only keeps the True values."""
result = ['SlicingSpec(']
if self.entire_dataset:
result.append(' Entire dataset,')
if self.by_class:
if isinstance(self.by_class, Iterable):
result.append(' Into classes %s,' % self.by_class)
elif isinstance(self.by_class, int):
result.append(' Up to class %d,' % self.by_class)
else:
result.append(' By classes,')
if self.by_percentiles:
result.append(' By percentiles,')
if self.by_classification_correctness:
result.append(' By classification correctness,')
result.append(')')
return '\n'.join(result)
class AttackType(enum.Enum):
"""An enum define attack types."""
@ -96,8 +115,8 @@ class AttackType(enum.Enum):
"""Returns whether this type of attack requires training a model."""
return self != AttackType.THRESHOLD_ATTACK
# Return LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION
def __str__(self):
"""Returns LOGISTIC_REGRESSION instead of AttackType.LOGISTIC_REGRESSION."""
return '%s' % self.name
@ -220,6 +239,23 @@ class AttackInputData:
_is_array_one_dimensional(self.labels_train, 'labels_train')
_is_array_one_dimensional(self.labels_test, 'labels_test')
def __str__(self):
"""Return the shapes of variables that are not None."""
result = ['AttackInputData(']
_append_array_shape(self.loss_train, 'loss_train', result)
_append_array_shape(self.loss_test, 'loss_test', result)
_append_array_shape(self.logits_train, 'logits_train', result)
_append_array_shape(self.logits_test, 'logits_test', result)
_append_array_shape(self.labels_train, 'labels_train', result)
_append_array_shape(self.labels_test, 'labels_test', result)
result.append(')')
return '\n'.join(result)
def _append_array_shape(arr: np.array, arr_name: str, result):
if arr is not None:
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
@dataclass
class RocCurve:
@ -251,6 +287,14 @@ class RocCurve:
"""
return max(np.abs(self.tpr - self.fpr))
def __str__(self):
"""Returns AUC and advantage metrics."""
return '\n'.join([
'RocCurve(',
' AUC: %f.02' % self.get_auc(),
' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')'
])
@dataclass
class SingleAttackResult:
@ -272,6 +316,16 @@ class SingleAttackResult:
def get_auc(self):
return self.roc_curve.get_auc()
def __str__(self):
"""Returns SliceSpec, AttackType, AUC and advantage metrics."""
return '\n'.join([
'SingleAttackResult(',
' SliceSpec: %s' % str(self.slice_spec),
' AttackType: %s' % str(self.attack_type),
' AUC: %f.02' % self.get_auc(),
' Attacker advantage: %f.02' % self.get_attacker_advantage(), ')'
])
@dataclass
class AttackResults: