For MIA in seq2seq model, add support for graph mode, add data information and fix small typo in seq2seq_membership_inference_codelab.ipynb.

PiperOrigin-RevId: 422909904
This commit is contained in:
Shuang Song 2022-01-19 14:50:20 -08:00 committed by A. Unique TensorFlower
parent f47200f60d
commit 3a4c4400a6
2 changed files with 1246 additions and 1195 deletions

View file

@ -19,9 +19,9 @@ Contains seq2seq specific logic for attack data structures, attack data
generation,
and the logistic regression membership inference attack.
"""
import dataclasses
from typing import Iterator, List
from dataclasses import dataclass
import numpy as np
from scipy.stats import rankdata
from sklearn import metrics
@ -46,7 +46,7 @@ def _is_iterator(obj, obj_name):
raise ValueError('%s should be a generator.' % obj_name)
@dataclass
@dataclasses.dataclass
class Seq2SeqAttackInputData:
"""Input data for running an attack on seq2seq models.
@ -229,8 +229,13 @@ def _get_batch_loss_metrics(batch_logits: np.ndarray,
tf.keras.backend.constant(sequence_labels),
tf.keras.backend.constant(sequence_logits),
from_logits=True)
if tf.executing_eagerly():
batch_loss += sequence_loss.numpy().sum()
else:
batch_loss += tf.reduce_sum(sequence_loss)
if not tf.executing_eagerly():
batch_loss = batch_loss.eval(session=tf.compat.v1.Session())
return batch_loss / batch_length, batch_length
@ -250,9 +255,15 @@ def _get_batch_accuracy_metrics(batch_logits: np.ndarray,
preds = tf.metrics.sparse_categorical_accuracy(
tf.keras.backend.constant(sequence_labels),
tf.keras.backend.constant(sequence_logits))
if tf.executing_eagerly():
batch_correct_preds += preds.numpy().sum()
else:
batch_correct_preds += tf.reduce_sum(preds)
batch_total_preds += len(sequence_labels)
if not tf.executing_eagerly():
batch_correct_preds = batch_correct_preds.eval(
session=tf.compat.v1.Session())
return batch_correct_preds, batch_total_preds
@ -302,9 +313,7 @@ def create_seq2seq_attacker_data(
labels_all = np.concatenate(((np.zeros(ntrain)), (np.ones(ntest))))
# Perform a train-test split
features_train, features_test, \
is_training_labels_train, is_training_labels_test = \
model_selection.train_test_split(
features_train, features_test, is_training_labels_train, is_training_labels_test = model_selection.train_test_split(
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
# Populate accuracy, loss fields in privacy report metadata