For MIA in seq2seq model, add support for graph mode, add data information and fix small typo in seq2seq_membership_inference_codelab.ipynb.
PiperOrigin-RevId: 422909904
This commit is contained in:
parent
f47200f60d
commit
3a4c4400a6
2 changed files with 1246 additions and 1195 deletions
File diff suppressed because one or more lines are too long
|
@ -19,9 +19,9 @@ Contains seq2seq specific logic for attack data structures, attack data
|
|||
generation,
|
||||
and the logistic regression membership inference attack.
|
||||
"""
|
||||
import dataclasses
|
||||
from typing import Iterator, List
|
||||
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
from scipy.stats import rankdata
|
||||
from sklearn import metrics
|
||||
|
@ -46,7 +46,7 @@ def _is_iterator(obj, obj_name):
|
|||
raise ValueError('%s should be a generator.' % obj_name)
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclasses.dataclass
|
||||
class Seq2SeqAttackInputData:
|
||||
"""Input data for running an attack on seq2seq models.
|
||||
|
||||
|
@ -229,8 +229,13 @@ def _get_batch_loss_metrics(batch_logits: np.ndarray,
|
|||
tf.keras.backend.constant(sequence_labels),
|
||||
tf.keras.backend.constant(sequence_logits),
|
||||
from_logits=True)
|
||||
if tf.executing_eagerly():
|
||||
batch_loss += sequence_loss.numpy().sum()
|
||||
else:
|
||||
batch_loss += tf.reduce_sum(sequence_loss)
|
||||
|
||||
if not tf.executing_eagerly():
|
||||
batch_loss = batch_loss.eval(session=tf.compat.v1.Session())
|
||||
return batch_loss / batch_length, batch_length
|
||||
|
||||
|
||||
|
@ -250,9 +255,15 @@ def _get_batch_accuracy_metrics(batch_logits: np.ndarray,
|
|||
preds = tf.metrics.sparse_categorical_accuracy(
|
||||
tf.keras.backend.constant(sequence_labels),
|
||||
tf.keras.backend.constant(sequence_logits))
|
||||
if tf.executing_eagerly():
|
||||
batch_correct_preds += preds.numpy().sum()
|
||||
else:
|
||||
batch_correct_preds += tf.reduce_sum(preds)
|
||||
batch_total_preds += len(sequence_labels)
|
||||
|
||||
if not tf.executing_eagerly():
|
||||
batch_correct_preds = batch_correct_preds.eval(
|
||||
session=tf.compat.v1.Session())
|
||||
return batch_correct_preds, batch_total_preds
|
||||
|
||||
|
||||
|
@ -302,9 +313,7 @@ def create_seq2seq_attacker_data(
|
|||
labels_all = np.concatenate(((np.zeros(ntrain)), (np.ones(ntest))))
|
||||
|
||||
# Perform a train-test split
|
||||
features_train, features_test, \
|
||||
is_training_labels_train, is_training_labels_test = \
|
||||
model_selection.train_test_split(
|
||||
features_train, features_test, is_training_labels_train, is_training_labels_test = model_selection.train_test_split(
|
||||
features_all, labels_all, test_size=test_fraction, stratify=labels_all)
|
||||
|
||||
# Populate accuracy, loss fields in privacy report metadata
|
||||
|
|
Loading…
Reference in a new issue