For MIA in seq2seq model, add support for graph mode, add data information and fix small typo in seq2seq_membership_inference_codelab.ipynb.
PiperOrigin-RevId: 422909904
This commit is contained in:
parent
f47200f60d
commit
3a4c4400a6
2 changed files with 1246 additions and 1195 deletions
File diff suppressed because one or more lines are too long
|
@ -19,9 +19,9 @@ Contains seq2seq specific logic for attack data structures, attack data
|
||||||
generation,
|
generation,
|
||||||
and the logistic regression membership inference attack.
|
and the logistic regression membership inference attack.
|
||||||
"""
|
"""
|
||||||
|
import dataclasses
|
||||||
from typing import Iterator, List
|
from typing import Iterator, List
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy.stats import rankdata
|
from scipy.stats import rankdata
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
@ -46,7 +46,7 @@ def _is_iterator(obj, obj_name):
|
||||||
raise ValueError('%s should be a generator.' % obj_name)
|
raise ValueError('%s should be a generator.' % obj_name)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclasses.dataclass
|
||||||
class Seq2SeqAttackInputData:
|
class Seq2SeqAttackInputData:
|
||||||
"""Input data for running an attack on seq2seq models.
|
"""Input data for running an attack on seq2seq models.
|
||||||
|
|
||||||
|
@ -229,8 +229,13 @@ def _get_batch_loss_metrics(batch_logits: np.ndarray,
|
||||||
tf.keras.backend.constant(sequence_labels),
|
tf.keras.backend.constant(sequence_labels),
|
||||||
tf.keras.backend.constant(sequence_logits),
|
tf.keras.backend.constant(sequence_logits),
|
||||||
from_logits=True)
|
from_logits=True)
|
||||||
batch_loss += sequence_loss.numpy().sum()
|
if tf.executing_eagerly():
|
||||||
|
batch_loss += sequence_loss.numpy().sum()
|
||||||
|
else:
|
||||||
|
batch_loss += tf.reduce_sum(sequence_loss)
|
||||||
|
|
||||||
|
if not tf.executing_eagerly():
|
||||||
|
batch_loss = batch_loss.eval(session=tf.compat.v1.Session())
|
||||||
return batch_loss / batch_length, batch_length
|
return batch_loss / batch_length, batch_length
|
||||||
|
|
||||||
|
|
||||||
|
@ -250,9 +255,15 @@ def _get_batch_accuracy_metrics(batch_logits: np.ndarray,
|
||||||
preds = tf.metrics.sparse_categorical_accuracy(
|
preds = tf.metrics.sparse_categorical_accuracy(
|
||||||
tf.keras.backend.constant(sequence_labels),
|
tf.keras.backend.constant(sequence_labels),
|
||||||
tf.keras.backend.constant(sequence_logits))
|
tf.keras.backend.constant(sequence_logits))
|
||||||
batch_correct_preds += preds.numpy().sum()
|
if tf.executing_eagerly():
|
||||||
|
batch_correct_preds += preds.numpy().sum()
|
||||||
|
else:
|
||||||
|
batch_correct_preds += tf.reduce_sum(preds)
|
||||||
batch_total_preds += len(sequence_labels)
|
batch_total_preds += len(sequence_labels)
|
||||||
|
|
||||||
|
if not tf.executing_eagerly():
|
||||||
|
batch_correct_preds = batch_correct_preds.eval(
|
||||||
|
session=tf.compat.v1.Session())
|
||||||
return batch_correct_preds, batch_total_preds
|
return batch_correct_preds, batch_total_preds
|
||||||
|
|
||||||
|
|
||||||
|
@ -302,10 +313,8 @@ def create_seq2seq_attacker_data(
|
||||||
labels_all = np.concatenate(((np.zeros(ntrain)), (np.ones(ntest))))
|
labels_all = np.concatenate(((np.zeros(ntrain)), (np.ones(ntest))))
|
||||||
|
|
||||||
# Perform a train-test split
|
# Perform a train-test split
|
||||||
features_train, features_test, \
|
features_train, features_test, is_training_labels_train, is_training_labels_test = model_selection.train_test_split(
|
||||||
is_training_labels_train, is_training_labels_test = \
|
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
||||||
model_selection.train_test_split(
|
|
||||||
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
|
||||||
|
|
||||||
# Populate accuracy, loss fields in privacy report metadata
|
# Populate accuracy, loss fields in privacy report metadata
|
||||||
privacy_report_metadata.loss_train = loss_train
|
privacy_report_metadata.loss_train = loss_train
|
||||||
|
|
Loading…
Reference in a new issue