forked from 626_privacy/tensorflow_privacy
add entropy feature
This commit is contained in:
parent
e547a10eec
commit
9b2e6a55b6
1 changed files with 61 additions and 2 deletions
|
@ -22,6 +22,7 @@ from dataclasses import dataclass
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn import metrics
|
||||
from scipy import special
|
||||
|
||||
|
||||
ENTIRE_DATASET_SLICE_STR = 'SingleSliceSpec(Entire dataset)'
|
||||
|
@ -144,6 +145,9 @@ def _is_np_array(arr, arr_name):
|
|||
if arr is not None and not isinstance(arr, np.ndarray):
|
||||
raise ValueError('%s should be a numpy array.' % arr_name)
|
||||
|
||||
def _log_value(probs, small_value=1e-30):
|
||||
"""Compute the log value on the probability. Clip the probability in case it is close to 0"""
|
||||
return -np.log(np.maximum(probs, small_value))
|
||||
|
||||
@dataclass
|
||||
class AttackInputData:
|
||||
|
@ -165,6 +169,11 @@ class AttackInputData:
|
|||
loss_train: np.ndarray = None
|
||||
loss_test: np.ndarray = None
|
||||
|
||||
# Explicitly specified prediction entropy. If provided, this is used instead of deriving
|
||||
# entropy from logits and labels (https://arxiv.org/pdf/2003.10595.pdf by Song and Mittal)
|
||||
entropy_train: np.ndarray = None
|
||||
entropy_test: np.ndarray = None
|
||||
|
||||
@property
|
||||
def num_classes(self):
|
||||
if self.labels_train is None or self.labels_test is None:
|
||||
|
@ -177,6 +186,34 @@ class AttackInputData:
|
|||
def _get_loss(logits: np.ndarray, true_labels: np.ndarray):
|
||||
return logits[range(logits.shape[0]), true_labels]
|
||||
|
||||
@staticmethod
|
||||
def _get_entropy(logits: np.ndarray, true_labels: np.ndarray):
|
||||
if (np.absolute(np.sum(logits,axis=1)-1)<=1e-3).all():
|
||||
probs = logits
|
||||
else:
|
||||
"""Using softmax to compute probability from logits"""
|
||||
probs = special.softmax(logits, axis=1)
|
||||
if true_labels is None:
|
||||
'''
|
||||
When not given ground truth label, we compute the normal prediction entropy.
|
||||
See the Equation (7) in https://arxiv.org/pdf/2003.10595.pdf
|
||||
'''
|
||||
return np.sum(np.multiply(probs, _log_value(probs)),axis=1)
|
||||
else:
|
||||
'''
|
||||
When given the groud truth label, we compute the modified prediction entropy.
|
||||
See the Equation (8) in https://arxiv.org/pdf/2003.10595.pdf
|
||||
'''
|
||||
log_probs = _log_value(probs)
|
||||
reverse_probs = 1-probs
|
||||
log_reverse_probs = _log_value(reverse_probs)
|
||||
modified_probs = np.copy(probs)
|
||||
modified_probs[range(true_labels.size), true_labels] = reverse_probs[range(true_labels.size), true_labels]
|
||||
modified_log_probs = np.copy(log_reverse_probs)
|
||||
modified_log_probs[range(true_labels.size), true_labels] = log_probs[range(true_labels.size), true_labels]
|
||||
return np.sum(np.multiply(modified_probs, modified_log_probs),axis=1)
|
||||
|
||||
|
||||
def get_loss_train(self):
|
||||
"""Calculates cross-entropy losses for the training set."""
|
||||
if self.loss_train is not None:
|
||||
|
@ -189,6 +226,18 @@ class AttackInputData:
|
|||
return self.loss_test
|
||||
return self._get_loss(self.logits_test, self.labels_test)
|
||||
|
||||
def get_entropy_train(self):
|
||||
"""Calculates prediction entropy for the training set."""
|
||||
if self.entropy_train is not None:
|
||||
return self.entropy_train
|
||||
return self._get_entropy(self.logits_train, self.labels_train)
|
||||
|
||||
def get_entropy_test(self):
|
||||
"""Calculates prediction entropy for the test set."""
|
||||
if self.entropy_test is not None:
|
||||
return self.entropy_test
|
||||
return self._get_entropy(self.logits_test, self.labels_test)
|
||||
|
||||
def get_train_size(self):
|
||||
"""Returns size of the training set."""
|
||||
if self.loss_train is not None:
|
||||
|
@ -206,6 +255,10 @@ class AttackInputData:
|
|||
if (self.loss_train is None) != (self.loss_test is None):
|
||||
raise ValueError(
|
||||
'loss_test and loss_train should both be either set or unset')
|
||||
|
||||
if (self.entropy_train is None) != (self.entropy_test is None):
|
||||
raise ValueError(
|
||||
'entropy_test and entropy_train should both be either set or unset')
|
||||
|
||||
if (self.logits_train is None) != (self.logits_test is None):
|
||||
raise ValueError(
|
||||
|
@ -216,8 +269,8 @@ class AttackInputData:
|
|||
'labels_train and labels_test should both be either set or unset')
|
||||
|
||||
if (self.labels_train is None and self.loss_train is None and
|
||||
self.logits_train is None):
|
||||
raise ValueError('At least one of labels, logits or losses should be set')
|
||||
self.logits_train is None and self.entropy_train is None):
|
||||
raise ValueError('At least one of labels, logits, losses or entropy should be set')
|
||||
|
||||
if self.labels_train is not None and not _is_integer_type_array(
|
||||
self.labels_train):
|
||||
|
@ -233,11 +286,15 @@ class AttackInputData:
|
|||
_is_np_array(self.labels_test, 'labels_test')
|
||||
_is_np_array(self.loss_train, 'loss_train')
|
||||
_is_np_array(self.loss_test, 'loss_test')
|
||||
_is_np_array(self.entropy_train, 'entropy_train')
|
||||
_is_np_array(self.entropy_test, 'entropy_test')
|
||||
|
||||
_is_last_dim_equal(self.logits_train, 'logits_train', self.logits_test,
|
||||
'logits_test')
|
||||
_is_array_one_dimensional(self.loss_train, 'loss_train')
|
||||
_is_array_one_dimensional(self.loss_test, 'loss_test')
|
||||
_is_array_one_dimensional(self.entropy_train, 'entropy_train')
|
||||
_is_array_one_dimensional(self.entropy_test, 'entropy_test')
|
||||
_is_array_one_dimensional(self.labels_train, 'labels_train')
|
||||
_is_array_one_dimensional(self.labels_test, 'labels_test')
|
||||
|
||||
|
@ -246,6 +303,8 @@ class AttackInputData:
|
|||
result = ['AttackInputData(']
|
||||
_append_array_shape(self.loss_train, 'loss_train', result)
|
||||
_append_array_shape(self.loss_test, 'loss_test', result)
|
||||
_append_array_shape(self.entropy_train, 'entropy_train', result)
|
||||
_append_array_shape(self.entropy_test, 'entropy_test', result)
|
||||
_append_array_shape(self.logits_train, 'logits_train', result)
|
||||
_append_array_shape(self.logits_test, 'logits_test', result)
|
||||
_append_array_shape(self.labels_train, 'labels_train', result)
|
||||
|
|
Loading…
Reference in a new issue