diff --git a/tensorflow_privacy/privacy/membership_inference_attack/codelabs/third_party/seq2seq_membership_inference/LICENSE b/tensorflow_privacy/privacy/membership_inference_attack/codelabs/third_party/seq2seq_membership_inference/LICENSE
new file mode 100644
index 0000000..172b5d8
--- /dev/null
+++ b/tensorflow_privacy/privacy/membership_inference_attack/codelabs/third_party/seq2seq_membership_inference/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2019 Congzheng Song
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/tensorflow_privacy/privacy/membership_inference_attack/codelabs/third_party/seq2seq_membership_inference/seq2seq_membership_inference_codelab.ipynb b/tensorflow_privacy/privacy/membership_inference_attack/codelabs/third_party/seq2seq_membership_inference/seq2seq_membership_inference_codelab.ipynb
new file mode 100644
index 0000000..108f89d
--- /dev/null
+++ b/tensorflow_privacy/privacy/membership_inference_attack/codelabs/third_party/seq2seq_membership_inference/seq2seq_membership_inference_codelab.ipynb
@@ -0,0 +1,1191 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Assess privacy risks on a seq2seq model with TensorFlow Privacy Membership Inference Attacks"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Overview\n",
+ "\n",
+ "In this codelab we'll train a simple translation model on the SATED dataset which consists of sentences grouped together by a \"user\" (i.e. the person who spoke the sentences). We will then use a \"membership inference attack\" against this model to assess if the attacker is able to \"guess\" whether a particular user was present in the training set."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup\n",
+ "\n",
+ "First, set this notebook's runtime to use a GPU, under Runtime > Change runtime type > Hardware accelerator.\n",
+ "\n",
+ "Then, begin importing the necessary libraries."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Import statements.\n",
+ "\n",
+ "import numpy as np\n",
+ "from collections import Counter, defaultdict\n",
+ "from itertools import chain\n",
+ "\n",
+ "import tensorflow as tf\n",
+ "import tensorflow.keras.backend as K\n",
+ "from tensorflow.keras import Model\n",
+ "from tensorflow.keras import activations, initializers, regularizers, constraints\n",
+ "from tensorflow.keras.layers import Layer, InputSpec, Input, Embedding, LSTM, Dropout, Dense, Add\n",
+ "from tensorflow.keras.optimizers import Adam, SGD\n",
+ "from tensorflow.keras.regularizers import l2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Install TensorFlow Privacy."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Collecting git+https://github.com/tensorflow/privacy\n",
+ " Cloning https://github.com/tensorflow/privacy to /private/var/folders/z8/yyl7bbz90bx_mcf8y1sghj700000gn/T/pip-req-build-858lr99q\n",
+ "Requirement already satisfied (use --upgrade to upgrade): tensorflow-privacy==0.5.1 from git+https://github.com/tensorflow/privacy in /usr/local/anaconda3/envs/tfprivacyenv/lib/python3.6/site-packages\n",
+ "Requirement already satisfied: scipy>=0.17 in /usr/local/anaconda3/envs/tfprivacyenv/lib/python3.6/site-packages (from tensorflow-privacy==0.5.1) (1.5.2)\n",
+ "Requirement already satisfied: tensorflow-estimator>=2.3.0 in /usr/local/anaconda3/envs/tfprivacyenv/lib/python3.6/site-packages (from tensorflow-privacy==0.5.1) (2.3.0)\n",
+ "Requirement already satisfied: mpmath in /usr/local/anaconda3/envs/tfprivacyenv/lib/python3.6/site-packages (from tensorflow-privacy==0.5.1) (1.1.0)\n",
+ "Requirement already satisfied: dm-tree~=0.1.1 in /usr/local/anaconda3/envs/tfprivacyenv/lib/python3.6/site-packages (from tensorflow-privacy==0.5.1) (0.1.5)\n",
+ "Requirement already satisfied: numpy>=1.14.5 in /usr/local/anaconda3/envs/tfprivacyenv/lib/python3.6/site-packages (from scipy>=0.17->tensorflow-privacy==0.5.1) (1.19.2)\n",
+ "Requirement already satisfied: six>=1.12.0 in /usr/local/anaconda3/envs/tfprivacyenv/lib/python3.6/site-packages (from dm-tree~=0.1.1->tensorflow-privacy==0.5.1) (1.15.0)\n",
+ "Building wheels for collected packages: tensorflow-privacy\n",
+ " Building wheel for tensorflow-privacy (setup.py) ... \u001B[?25ldone\n",
+ "\u001B[?25h Created wheel for tensorflow-privacy: filename=tensorflow_privacy-0.5.1-py3-none-any.whl size=144389 sha256=0dcea18c0b4b06c3f19bc765d29a807fd5782fa75c0d27a881c7f1a88de8e3da\n",
+ " Stored in directory: /private/var/folders/z8/yyl7bbz90bx_mcf8y1sghj700000gn/T/pip-ephem-wheel-cache-hteoqxwl/wheels/2f/fb/b8/7eabbe4b85682ff7e299a9446b36521ed33dd97dff1f1a86ba\n",
+ "Successfully built tensorflow-privacy\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip3 install git+https://github.com/tensorflow/privacy\n",
+ "\n",
+ "from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Load the SATED dataset.\n",
+ "\n",
+ "Download the SATED dataset from [here.](https://www.cs.cmu.edu/~pmichel1/sated/) We use **English-French** sentence pairs for this codelab.\n",
+ "\n",
+ "The code for data-loading is adapted from [csong27/auditing-text-generation/data_loader/load_sated](https://github.com/csong27/auditing-text-generation/blob/master/data_loader/load_sated.py)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# File paths for SATED dataset.\n",
+ "\n",
+ "SATED_PATH = 'sated-release-0.9.0/en-fr/'\n",
+ "SATED_TRAIN_ENG = SATED_PATH + 'train.en'\n",
+ "SATED_TRAIN_FR = SATED_PATH + 'train.fr'\n",
+ "SATED_TRAIN_USER = SATED_PATH + 'train.usr'\n",
+ "SATED_DEV_ENG = SATED_PATH + 'dev.en'\n",
+ "SATED_DEV_FR = SATED_PATH + 'dev.fr'\n",
+ "SATED_DEV_USER = SATED_PATH + 'dev.usr'\n",
+ "SATED_TEST_ENG = SATED_PATH + 'test.en'\n",
+ "SATED_TEST_FR = SATED_PATH + 'test.fr'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Helper methods for preprocessing and loading the SATED dataset.\n",
+ "\n",
+ "def load_users(p=SATED_TRAIN_USER):\n",
+ " \"\"\"Loads users from the dataset.\"\"\"\n",
+ " users = []\n",
+ " with open(p, 'r', encoding='UTF-8') as f:\n",
+ " for line in f:\n",
+ " users.append(line.replace('\\n', ''))\n",
+ " return users\n",
+ "\n",
+ "\n",
+ "def load_texts(p=SATED_TRAIN_ENG):\n",
+ " \"\"\"Loads and adds start and end tokens to sentences.\"\"\"\n",
+ " texts = []\n",
+ " with open(p, 'r', encoding='UTF-8') as f:\n",
+ " for line in f:\n",
+ " arr = [''] + line.replace('\\n', '').split(' ') + ['']\n",
+ " words = []\n",
+ " for w in arr:\n",
+ " words.append(w)\n",
+ " texts.append(words)\n",
+ "\n",
+ " return texts\n",
+ "\n",
+ "\n",
+ "def process_texts(texts, vocabs):\n",
+ " \"\"\"Processes sentences according to vocabs i.e. if word is not present\n",
+ " in the vocab it is replaced with the token.\"\"\"\n",
+ " for t in texts:\n",
+ " for i, w in enumerate(t):\n",
+ " if w not in vocabs:\n",
+ " t[i] = ''\n",
+ "\n",
+ "\n",
+ "def process_vocabs(vocabs, num_words=10000):\n",
+ " \"\"\"Returns vocabs with num_words amount of most frequent words.\"\"\"\n",
+ "\n",
+ " counter = Counter(vocabs)\n",
+ " count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))\n",
+ "\n",
+ " if num_words is not None:\n",
+ " count_pairs = count_pairs[:num_words - 1]\n",
+ "\n",
+ " words, _ = list(zip(*count_pairs))\n",
+ " word_to_id = dict(zip(words, np.arange(len(words))))\n",
+ " return word_to_id"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "### Define data-loading method."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def load_sated_data_by_user(num_users=100, num_words=5000, test_on_user=False, seed=12345):\n",
+ " \"\"\"Loads training, validation, and test sets of sentences.\n",
+ "\n",
+ " Sentences in the training set are grouped by the user who spoke the sentences.\n",
+ " One can specify if sentences need to be grouped in the test set.\n",
+ "\n",
+ " Args:\n",
+ " num_users: Number of users to include in the training and test sets.\n",
+ " num_words: Denotes how many of the most frequent words to include in\n",
+ " the source and target language vocabularies.\n",
+ " test_on_user: Determines if sentences in the test set will be grouped\n",
+ " by the user who spoke the sentences.\n",
+ " seed: Seed for shuffling users in the entire dataset before sampling.\n",
+ "\n",
+ " Returns:\n",
+ " Training, validation, and test sets of sentences.\n",
+ " \"\"\"\n",
+ " src_users = load_users(SATED_TRAIN_USER)\n",
+ "\n",
+ " # Load training set sentences\n",
+ " train_src_texts = load_texts(SATED_TRAIN_ENG)\n",
+ " train_trg_texts = load_texts(SATED_TRAIN_FR)\n",
+ "\n",
+ " # Load validation set sentences\n",
+ " dev_src_texts = load_texts(SATED_DEV_ENG)\n",
+ " dev_trg_texts = load_texts(SATED_DEV_FR)\n",
+ "\n",
+ " # Load test set sentences\n",
+ " test_src_texts = load_texts(SATED_TEST_ENG)\n",
+ " test_trg_texts = load_texts(SATED_TEST_FR)\n",
+ "\n",
+ " # Shuffle users\n",
+ " user_counter = Counter(src_users)\n",
+ " all_users = [tup[0] for tup in user_counter.most_common()]\n",
+ " np.random.seed(seed)\n",
+ " np.random.shuffle(all_users)\n",
+ " np.random.seed(None)\n",
+ "\n",
+ " # Sample users for training and test sets\n",
+ " train_users = set(all_users[:num_users])\n",
+ " test_users = set(all_users[num_users: num_users * 2])\n",
+ "\n",
+ " user_src_texts = defaultdict(list)\n",
+ " user_trg_texts = defaultdict(list)\n",
+ "\n",
+ " test_user_src_texts = defaultdict(list)\n",
+ " test_user_trg_texts = defaultdict(list)\n",
+ "\n",
+ " # Create training set (and optionally the test set), grouped by user\n",
+ " for u, s, t in zip(src_users, train_src_texts, train_trg_texts):\n",
+ " if u in train_users:\n",
+ " user_src_texts[u].append(s)\n",
+ " user_trg_texts[u].append(t)\n",
+ " if test_on_user and u in test_users:\n",
+ " test_user_src_texts[u].append(s)\n",
+ " test_user_trg_texts[u].append(t)\n",
+ "\n",
+ " # Create source and target language vocabs for tokenizing sentences\n",
+ " # Restrict number of words in vocabs to num_words\n",
+ " src_words = []\n",
+ " trg_words = []\n",
+ " for u in train_users:\n",
+ " src_words += list(chain(*user_src_texts[u]))\n",
+ " trg_words += list(chain(*user_trg_texts[u]))\n",
+ "\n",
+ " src_vocabs = process_vocabs(src_words, num_words)\n",
+ " trg_vocabs = process_vocabs(trg_words, num_words)\n",
+ "\n",
+ " # Tokenize sentences in the training set\n",
+ " for u in train_users:\n",
+ " process_texts(user_src_texts[u], src_vocabs)\n",
+ " process_texts(user_trg_texts[u], trg_vocabs)\n",
+ "\n",
+ " # Tokenize sentences in the test set, if grouped by user\n",
+ " if test_on_user:\n",
+ " for u in test_users:\n",
+ " process_texts(test_user_src_texts[u], src_vocabs)\n",
+ " process_texts(test_user_trg_texts[u], trg_vocabs)\n",
+ "\n",
+ " # Tokenize sentences in the validation set and test set\n",
+ " process_texts(dev_src_texts, src_vocabs)\n",
+ " process_texts(dev_trg_texts, trg_vocabs)\n",
+ "\n",
+ " process_texts(test_src_texts, src_vocabs)\n",
+ " process_texts(test_trg_texts, trg_vocabs)\n",
+ "\n",
+ " # Create source and target language vocabs\n",
+ " # Include all words since we won't use these for tokenizing anymore\n",
+ " src_words = []\n",
+ " trg_words = []\n",
+ " for u in train_users:\n",
+ " src_words += list(chain(*user_src_texts[u]))\n",
+ " trg_words += list(chain(*user_trg_texts[u]))\n",
+ "\n",
+ " src_vocabs = process_vocabs(src_words, None)\n",
+ " trg_vocabs = process_vocabs(trg_words, None)\n",
+ "\n",
+ " # Return the appropriate training, validation, test sets and source and target vocabs\n",
+ " if test_on_user:\n",
+ " return user_src_texts, user_trg_texts, test_user_src_texts, test_user_trg_texts, src_vocabs, trg_vocabs\n",
+ " else:\n",
+ " return user_src_texts, user_trg_texts, dev_src_texts, dev_trg_texts, test_src_texts, test_trg_texts,\\\n",
+ " src_vocabs, trg_vocabs"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Define the seq2seq model.\n",
+ "\n",
+ "We follow the model architecture specified in [Extreme Adaptation for Personalized Neural Machine Translation (P. Michel, G. Neubig)](https://arxiv.org/pdf/1805.01817.pdf).\n",
+ "\n",
+ "The code for the model architecture is adapted from [csong27/auditing-text-generation/sated_nmt](https://github.com/csong27/auditing-text-generation/blob/master/sated_nmt.py)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "### Define layers."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def flatten_data(data):\n",
+ " return np.asarray([w for t in data for w in t]).astype(np.int32)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "class Attention(Layer):\n",
+ " def __init__(self, units,\n",
+ " activation='linear',\n",
+ " use_bias=True,\n",
+ " kernel_initializer='glorot_uniform',\n",
+ " bias_initializer='zeros',\n",
+ " kernel_regularizer=None,\n",
+ " bias_regularizer=None,\n",
+ " activity_regularizer=None,\n",
+ " kernel_constraint=None,\n",
+ " bias_constraint=None,\n",
+ " **kwargs):\n",
+ " if 'input_shape' not in kwargs and 'input_dim' in kwargs:\n",
+ " kwargs['input_shape'] = (kwargs.pop('input_dim'),)\n",
+ " super(Attention, self).__init__(**kwargs)\n",
+ " self.units = units\n",
+ " self.activation = activations.get(activation)\n",
+ " self.use_bias = use_bias\n",
+ " self.kernel_initializer = initializers.get(kernel_initializer)\n",
+ " self.bias_initializer = initializers.get(bias_initializer)\n",
+ " self.kernel_regularizer = regularizers.get(kernel_regularizer)\n",
+ " self.bias_regularizer = regularizers.get(bias_regularizer)\n",
+ " self.activity_regularizer = regularizers.get(activity_regularizer)\n",
+ " self.kernel_constraint = constraints.get(kernel_constraint)\n",
+ " self.bias_constraint = constraints.get(bias_constraint)\n",
+ " self.supports_masking = True\n",
+ "\n",
+ " def build(self, input_shape):\n",
+ " if not isinstance(input_shape, list) or len(input_shape) != 2:\n",
+ " raise ValueError('An attention layer should be called '\n",
+ " 'on a list of 2 inputs.')\n",
+ " enc_dim = input_shape[0][-1]\n",
+ " dec_dim = input_shape[1][-1]\n",
+ "\n",
+ " self.W_enc = self.add_weight(shape=(enc_dim, self.units),\n",
+ " initializer=self.kernel_initializer,\n",
+ " name='W_enc',\n",
+ " regularizer=self.kernel_regularizer,\n",
+ " constraint=self.kernel_constraint)\n",
+ "\n",
+ " self.W_dec = self.add_weight(shape=(dec_dim, self.units),\n",
+ " initializer=self.kernel_initializer,\n",
+ " name='W_dec',\n",
+ " regularizer=self.kernel_regularizer,\n",
+ " constraint=self.kernel_constraint)\n",
+ "\n",
+ " self.W_score = self.add_weight(shape=(self.units, 1),\n",
+ " initializer=self.kernel_initializer,\n",
+ " name='W_score',\n",
+ " regularizer=self.kernel_regularizer,\n",
+ " constraint=self.kernel_constraint)\n",
+ "\n",
+ " if self.use_bias:\n",
+ " self.bias_enc = self.add_weight(shape=(self.units,),\n",
+ " initializer=self.bias_initializer,\n",
+ " name='bias_enc',\n",
+ " regularizer=self.bias_regularizer,\n",
+ " constraint=self.bias_constraint)\n",
+ " self.bias_dec = self.add_weight(shape=(self.units,),\n",
+ " initializer=self.bias_initializer,\n",
+ " name='bias_dec',\n",
+ " regularizer=self.bias_regularizer,\n",
+ " constraint=self.bias_constraint)\n",
+ " self.bias_score = self.add_weight(shape=(1,),\n",
+ " initializer=self.bias_initializer,\n",
+ " name='bias_score',\n",
+ " regularizer=self.bias_regularizer,\n",
+ " constraint=self.bias_constraint)\n",
+ "\n",
+ " else:\n",
+ " self.bias_enc = None\n",
+ " self.bias_dec = None\n",
+ " self.bias_score = None\n",
+ "\n",
+ " self.built = True\n",
+ "\n",
+ " def call(self, inputs, **kwargs):\n",
+ " if not isinstance(inputs, list) or len(inputs) != 2:\n",
+ " raise ValueError('An attention layer should be called '\n",
+ " 'on a list of 2 inputs.')\n",
+ " encodings, decodings = inputs\n",
+ " d_enc = K.dot(encodings, self.W_enc)\n",
+ " d_dec = K.dot(decodings, self.W_dec)\n",
+ "\n",
+ " if self.use_bias:\n",
+ " d_enc = K.bias_add(d_enc, self.bias_enc)\n",
+ " d_dec = K.bias_add(d_dec, self.bias_dec)\n",
+ "\n",
+ " if self.activation is not None:\n",
+ " d_enc = self.activation(d_enc)\n",
+ " d_dec = self.activation(d_dec)\n",
+ "\n",
+ " enc_seqlen = K.shape(d_enc)[1]\n",
+ " d_dec_shape = K.shape(d_dec)\n",
+ "\n",
+ " stacked_d_dec = K.tile(d_dec, [enc_seqlen, 1, 1]) # enc time x batch x dec time x da\n",
+ " stacked_d_dec = K.reshape(stacked_d_dec, [enc_seqlen, d_dec_shape[0], d_dec_shape[1], d_dec_shape[2]])\n",
+ " stacked_d_dec = K.permute_dimensions(stacked_d_dec, [2, 1, 0, 3]) # dec time x batch x enc time x da\n",
+ " tanh_add = K.tanh(stacked_d_dec + d_enc) # dec time x batch x enc time x da\n",
+ " scores = K.dot(tanh_add, self.W_score)\n",
+ " if self.use_bias:\n",
+ " scores = K.bias_add(scores, self.bias_score)\n",
+ " scores = K.squeeze(scores, 3) # batch x dec time x enc time\n",
+ "\n",
+ " weights = K.softmax(scores) # dec time x batch x enc time\n",
+ " weights = K.expand_dims(weights)\n",
+ "\n",
+ " weighted_encodings = weights * encodings # dec time x batch x enc time x h\n",
+ " contexts = K.sum(weighted_encodings, axis=2) # dec time x batch x h\n",
+ " contexts = K.permute_dimensions(contexts, [1, 0, 2]) # batch x dec time x h\n",
+ "\n",
+ " return contexts\n",
+ "\n",
+ " def compute_output_shape(self, input_shape):\n",
+ " assert isinstance(input_shape, list) and len(input_shape) == 2\n",
+ " assert input_shape[-1]\n",
+ " output_shape = list(input_shape[1])\n",
+ " output_shape[-1] = self.units\n",
+ " return tuple(output_shape)\n",
+ "\n",
+ " def get_config(self):\n",
+ " config = {\n",
+ " 'units': self.units,\n",
+ " 'activation': activations.serialize(self.activation),\n",
+ " 'use_bias': self.use_bias,\n",
+ " 'kernel_initializer': initializers.serialize(self.kernel_initializer),\n",
+ " 'bias_initializer': initializers.serialize(self.bias_initializer),\n",
+ " 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),\n",
+ " 'bias_regularizer': regularizers.serialize(self.bias_regularizer),\n",
+ " 'activity_regularizer': regularizers.serialize(self.activity_regularizer),\n",
+ " 'kernel_constraint': constraints.serialize(self.kernel_constraint),\n",
+ " 'bias_constraint': constraints.serialize(self.bias_constraint)\n",
+ " }\n",
+ " base_config = super(Attention, self).get_config()\n",
+ " return dict(list(base_config.items()) + list(config.items()))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "class DenseTransposeTied(Layer):\n",
+ " def __init__(self, units,\n",
+ " tied_to=None, # Enter a layer as input to enforce weight-tying\n",
+ " activation=None,\n",
+ " use_bias=True,\n",
+ " kernel_initializer='glorot_uniform',\n",
+ " bias_initializer='zeros',\n",
+ " kernel_regularizer=None,\n",
+ " bias_regularizer=None,\n",
+ " activity_regularizer=None,\n",
+ " kernel_constraint=None,\n",
+ " bias_constraint=None,\n",
+ " **kwargs):\n",
+ " if 'input_shape' not in kwargs and 'input_dim' in kwargs:\n",
+ " kwargs['input_shape'] = (kwargs.pop('input_dim'),)\n",
+ " super(DenseTransposeTied, self).__init__(**kwargs)\n",
+ " self.units = units\n",
+ " # We add these two properties to save the tied weights\n",
+ " self.tied_to = tied_to\n",
+ " self.tied_weights = self.tied_to.weights\n",
+ " self.activation = activations.get(activation)\n",
+ " self.use_bias = use_bias\n",
+ " self.kernel_initializer = initializers.get(kernel_initializer)\n",
+ " self.bias_initializer = initializers.get(bias_initializer)\n",
+ " self.kernel_regularizer = regularizers.get(kernel_regularizer)\n",
+ " self.bias_regularizer = regularizers.get(bias_regularizer)\n",
+ " self.activity_regularizer = regularizers.get(activity_regularizer)\n",
+ " self.kernel_constraint = constraints.get(kernel_constraint)\n",
+ " self.bias_constraint = constraints.get(bias_constraint)\n",
+ " self.input_spec = InputSpec(min_ndim=2)\n",
+ " self.supports_masking = True\n",
+ "\n",
+ " def build(self, input_shape):\n",
+ " assert len(input_shape) >= 2\n",
+ " input_dim = input_shape[-1]\n",
+ "\n",
+ " # We remove the weights and bias because we do not want them to be trainable\n",
+ " self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dim})\n",
+ " if self.use_bias:\n",
+ " self.bias = self.add_weight(shape=(self.units,),\n",
+ " initializer=self.bias_initializer,\n",
+ " name='bias',\n",
+ " regularizer=self.bias_regularizer,\n",
+ " constraint=self.bias_constraint)\n",
+ " else:\n",
+ " self.bias = None\n",
+ " self.built = True\n",
+ "\n",
+ " def call(self, inputs, **kwargs):\n",
+ " # Return the transpose layer mapping using the explicit weight matrices\n",
+ " output = K.dot(inputs, K.transpose(self.tied_weights[0]))\n",
+ " if self.use_bias:\n",
+ " output = K.bias_add(output, self.bias, data_format='channels_last')\n",
+ "\n",
+ " if self.activation is not None:\n",
+ " output = self.activation(output)\n",
+ "\n",
+ " return output\n",
+ "\n",
+ " def compute_output_shape(self, input_shape):\n",
+ " assert input_shape and len(input_shape) >= 2\n",
+ " assert input_shape[-1]\n",
+ " output_shape = list(input_shape)\n",
+ " output_shape[-1] = self.units\n",
+ " return tuple(output_shape)\n",
+ "\n",
+ " def get_config(self):\n",
+ " config = {\n",
+ " 'units': self.units,\n",
+ " 'activation': activations.serialize(self.activation),\n",
+ " 'use_bias': self.use_bias,\n",
+ " 'kernel_initializer': initializers.serialize(self.kernel_initializer),\n",
+ " 'bias_initializer': initializers.serialize(self.bias_initializer),\n",
+ " 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),\n",
+ " 'bias_regularizer': regularizers.serialize(self.bias_regularizer),\n",
+ " 'activity_regularizer': regularizers.serialize(self.activity_regularizer),\n",
+ " 'kernel_constraint': constraints.serialize(self.kernel_constraint),\n",
+ " 'bias_constraint': constraints.serialize(self.bias_constraint)\n",
+ " }\n",
+ " base_config = super(DenseTransposeTied, self).get_config()\n",
+ " return dict(list(base_config.items()) + list(config.items()))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "### Define batch processing and model creation methods."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def group_texts_by_len(src_texts, trg_texts, bs=20):\n",
+ " \"\"\"Groups sentences by their length, so batches can have minimal padding.\"\"\"\n",
+ " # Bucket samples by source sentence length\n",
+ " buckets = defaultdict(list)\n",
+ " batches = []\n",
+ " for src, trg in zip(src_texts, trg_texts):\n",
+ " buckets[len(src)].append((src, trg))\n",
+ "\n",
+ " # Create batches\n",
+ " for src_len, bucket in buckets.items():\n",
+ " np.random.shuffle(bucket)\n",
+ " num_batches = int(np.ceil(len(bucket) * 1.0 / bs))\n",
+ " for i in range(num_batches):\n",
+ " cur_batch_size = bs if i < num_batches - 1 else len(bucket) - bs * i\n",
+ " batches.append(([bucket[i * bs + j][0] for j in range(cur_batch_size)],\n",
+ " [bucket[i * bs + j][1] for j in range(cur_batch_size)]))\n",
+ " return batches"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "def build_nmt_model(Vs, Vt, demb=128, h=128, drop_p=0.5, tied=True, mask=True, attn=True, l2_ratio=1e-4,\n",
+ " training=None, rnn_fn='lstm'):\n",
+ " \"\"\"Builds the target machine translation model.\n",
+ "\n",
+ " Args:\n",
+ " Vs, Vt: Vocab sizes for source, target vocabs.\n",
+ " demb: Embedding dimension.\n",
+ " h: Number of hidden units.\n",
+ " drop_p: Dropout percentage.\n",
+ " attn: Flag to include attention units.\n",
+ " rnn_fn: RNN type to use in the model. Can be 'lstm' or 'gru'.\n",
+ " \"\"\"\n",
+ " if rnn_fn == 'lstm':\n",
+ " rnn = LSTM\n",
+ " elif rnn_fn == 'gru':\n",
+ " rnn = LSTM\n",
+ " else:\n",
+ " raise ValueError(rnn_fn)\n",
+ "\n",
+ " # Build encoder\n",
+ " encoder_input = Input((None,), dtype='float32', name='encoder_input')\n",
+ " if mask:\n",
+ " encoder_emb_layer = Embedding(Vs + 1, demb, mask_zero=True, embeddings_regularizer=l2(l2_ratio),\n",
+ " name='encoder_emb')\n",
+ " else:\n",
+ " encoder_emb_layer = Embedding(Vs, demb, mask_zero=False, embeddings_regularizer=l2(l2_ratio),\n",
+ " name='encoder_emb')\n",
+ "\n",
+ " encoder_emb = encoder_emb_layer(encoder_input)\n",
+ "\n",
+ " # Dropout for encoder\n",
+ " if drop_p > 0.:\n",
+ " encoder_emb = Dropout(drop_p)(encoder_emb, training=training)\n",
+ "\n",
+ " encoder_rnn = rnn(h, return_sequences=True, return_state=True, kernel_regularizer=l2(l2_ratio), name='encoder_rnn')\n",
+ " encoder_rtn = encoder_rnn(encoder_emb)\n",
+ " encoder_outputs = encoder_rtn[0]\n",
+ " encoder_states = encoder_rtn[1:]\n",
+ "\n",
+ " # Build decoder\n",
+ " decoder_input = Input((None,), dtype='float32', name='decoder_input')\n",
+ " if mask:\n",
+ " decoder_emb_layer = Embedding(Vt + 1, demb, mask_zero=True, embeddings_regularizer=l2(l2_ratio),\n",
+ " name='decoder_emb')\n",
+ " else:\n",
+ " decoder_emb_layer = Embedding(Vt, demb, mask_zero=False, embeddings_regularizer=l2(l2_ratio),\n",
+ " name='decoder_emb')\n",
+ "\n",
+ " decoder_emb = decoder_emb_layer(decoder_input)\n",
+ "\n",
+ " # Dropout for decoder\n",
+ " if drop_p > 0.:\n",
+ " decoder_emb = Dropout(drop_p)(decoder_emb, training=training)\n",
+ "\n",
+ " decoder_rnn = rnn(h, return_sequences=True, kernel_regularizer=l2(l2_ratio), name='decoder_rnn')\n",
+ " decoder_outputs = decoder_rnn(decoder_emb, initial_state=encoder_states)\n",
+ "\n",
+ " if drop_p > 0.:\n",
+ " decoder_outputs = Dropout(drop_p)(decoder_outputs, training=training)\n",
+ "\n",
+ " if tied:\n",
+ " final_outputs = DenseTransposeTied(Vt, kernel_regularizer=l2(l2_ratio), name='outputs',\n",
+ " tied_to=decoder_emb_layer, activation='linear')(decoder_outputs)\n",
+ " else:\n",
+ " final_outputs = Dense(Vt, activation='linear', kernel_regularizer=l2(l2_ratio), name='outputs')(decoder_outputs)\n",
+ "\n",
+ " # Add attention units\n",
+ " if attn:\n",
+ " contexts = Attention(units=h, kernel_regularizer=l2(l2_ratio), name='attention',\n",
+ " use_bias=False)([encoder_outputs, decoder_outputs])\n",
+ " if drop_p > 0.:\n",
+ " contexts = Dropout(drop_p)(contexts, training=training)\n",
+ "\n",
+ " contexts_outputs = Dense(Vt, activation='linear', use_bias=False, name='context_outputs',\n",
+ " kernel_regularizer=l2(l2_ratio))(contexts)\n",
+ "\n",
+ " final_outputs = Add(name='final_outputs')([final_outputs, contexts_outputs])\n",
+ "\n",
+ " model = Model(inputs=[encoder_input, decoder_input], outputs=[final_outputs])\n",
+ " return model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Helper functions for training the translation model.\n",
+ "\n",
+ "def words_to_indices(data, vocab, mask=True):\n",
+ " \"\"\"\n",
+ " Converts words to indices according to vocabulary.\n",
+ " \"\"\"\n",
+ " if mask:\n",
+ " return [[vocab[w] + 1 for w in t] for t in data]\n",
+ " else:\n",
+ " return [[vocab[w] for w in t] for t in data]\n",
+ "\n",
+ "\n",
+ "def pad_texts(texts, eos, mask=True):\n",
+ " \"\"\"\n",
+ " Adds padding to a batch of texts.\n",
+ " \"\"\"\n",
+ " maxlen = max(len(t) for t in texts)\n",
+ " for t in texts:\n",
+ " while len(t) < maxlen:\n",
+ " if mask:\n",
+ " t.insert(0, 0)\n",
+ " else:\n",
+ " t.append(eos)\n",
+ " return np.asarray(texts, dtype='float32')\n",
+ "\n",
+ "\n",
+ "def get_perp(user_src_data, user_trg_data, pred_fn, prop=1.0, shuffle=False):\n",
+ " \"\"\"\n",
+ " Returns perplexity scores.\n",
+ " \"\"\"\n",
+ " loss = 0.\n",
+ " iters = 0.\n",
+ "\n",
+ " indices = np.arange(len(user_src_data))\n",
+ " n = int(prop * len(indices))\n",
+ "\n",
+ " if shuffle:\n",
+ " np.random.shuffle(indices)\n",
+ "\n",
+ " for idx in indices[:n]:\n",
+ " src_text = np.asarray(user_src_data[idx], dtype=np.float32).reshape(1, -1)\n",
+ " trg_text = np.asarray(user_trg_data[idx], dtype=np.float32)\n",
+ " trg_input = trg_text[:-1].reshape(1, -1)\n",
+ " trg_label = trg_text[1:].reshape(1, -1)\n",
+ "\n",
+ " err = pred_fn([src_text, trg_input, trg_label, 0])[0]\n",
+ "\n",
+ " loss += err\n",
+ " iters += trg_label.shape[1]\n",
+ "\n",
+ " return loss, iters"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "### Define training method. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "MODEL_PATH = 'checkpoints/'\n",
+ "\n",
+ "tf.compat.v1.disable_eager_execution()\n",
+ "\n",
+ "def train_sated_nmt(num_users=300, num_words=5000, num_epochs=20, h=128, emb_h=128,\n",
+ " l2_ratio=1e-4, lr=0.001, batch_size=32, mask=False, drop_p=0.5,\n",
+ " tied=False, rnn_fn='lstm', optim_fn='adam'):\n",
+ " \"\"\"Trains the machine translation model.\n",
+ "\n",
+ " Args:\n",
+ " num_users: Number of users to sample from the dataset.\n",
+ " num_words: Size of vocabulary.\n",
+ " h: Number of hidden units.\n",
+ " emb_h: Embedding dimension.\n",
+ " l2_ratio: Regularization term.\n",
+ " lr: Learning rate.\n",
+ " drop_p: Dropout percentage.\n",
+ " tied: Flag to use DenseTransposeTied or Dense layer for the model's output layer.\n",
+ " rnn_fn: Can be 'lstm' or 'gru'.\n",
+ " optim_fn: Can be 'adam' or 'mom_sgd'.\n",
+ " \"\"\"\n",
+ " # Load dataset for training\n",
+ " user_src_texts, user_trg_texts, dev_src_texts, dev_trg_texts, test_src_texts, test_trg_texts, \\\n",
+ " src_vocabs, trg_vocabs = load_sated_data_by_user(num_users, num_words, test_on_user=False)\n",
+ " train_src_texts, train_trg_texts = [], []\n",
+ "\n",
+ " users = sorted(user_src_texts.keys())\n",
+ "\n",
+ " for i, user in enumerate(users):\n",
+ " train_src_texts += user_src_texts[user]\n",
+ " train_trg_texts += user_trg_texts[user]\n",
+ "\n",
+ " # Convert words to indices based on the source and target vocabs\n",
+ " train_src_texts = words_to_indices(train_src_texts, src_vocabs, mask=mask)\n",
+ " train_trg_texts = words_to_indices(train_trg_texts, trg_vocabs, mask=mask)\n",
+ " dev_src_texts = words_to_indices(dev_src_texts, src_vocabs, mask=mask)\n",
+ " dev_trg_texts = words_to_indices(dev_trg_texts, trg_vocabs, mask=mask)\n",
+ "\n",
+ " # Vocab lengths for source and target language vocabularies.\n",
+ " Vs = len(src_vocabs)\n",
+ " Vt = len(trg_vocabs)\n",
+ "\n",
+ " # Build model\n",
+ " model = build_nmt_model(Vs=Vs, Vt=Vt, mask=mask, drop_p=drop_p, h=h, demb=emb_h, tied=tied, l2_ratio=l2_ratio,\n",
+ " rnn_fn=rnn_fn)\n",
+ " src_input_var, trg_input_var = model.inputs\n",
+ " prediction = model.output\n",
+ "\n",
+ " trg_label_var = K.placeholder((None, None), dtype='float32')\n",
+ "\n",
+ " # Define loss\n",
+ " loss = K.sparse_categorical_crossentropy(trg_label_var, prediction, from_logits=True)\n",
+ " loss = K.mean(K.sum(loss, axis=-1))\n",
+ "\n",
+ " # Define optimizer\n",
+ " if optim_fn == 'adam':\n",
+ " optimizer = Adam(learning_rate=lr, clipnorm=5.)\n",
+ " elif optim_fn == 'mom_sgd':\n",
+ " optimizer = SGD(learning_rate=lr, momentum=0.9)\n",
+ " else:\n",
+ " raise ValueError(optim_fn)\n",
+ " updates = optimizer.get_updates(loss, model.trainable_weights)\n",
+ "\n",
+ " # Define train and prediction functions\n",
+ " train_fn = K.function(inputs=[src_input_var, trg_input_var, trg_label_var, K.learning_phase()], outputs=[loss],\n",
+ " updates=updates)\n",
+ " pred_fn = K.function(inputs=[src_input_var, trg_input_var, trg_label_var, K.learning_phase()], outputs=[loss])\n",
+ "\n",
+ " # Pad batches to same length\n",
+ " train_prop = 0.2\n",
+ " batches = []\n",
+ " for batch in group_texts_by_len(train_src_texts, train_trg_texts, bs=batch_size):\n",
+ " src_input, trg_input = batch\n",
+ " src_input = pad_texts(src_input, src_vocabs[''], mask=mask)\n",
+ " trg_input = pad_texts(trg_input, trg_vocabs[''], mask=mask)\n",
+ " batches.append((src_input, trg_input))\n",
+ "\n",
+ " # Train machine translation model\n",
+ " print(\"Training NMT model...\")\n",
+ " for epoch in range(num_epochs):\n",
+ " np.random.shuffle(batches)\n",
+ "\n",
+ " for batch in batches:\n",
+ " src_input, trg_input = batch\n",
+ " _ = train_fn([src_input, trg_input[:, :-1], trg_input[:, 1:], 1])[0]\n",
+ "\n",
+ " train_loss, train_it = get_perp(train_src_texts, train_trg_texts, pred_fn, shuffle=True, prop=train_prop)\n",
+ " test_loss, test_it = get_perp(dev_src_texts, dev_trg_texts, pred_fn)\n",
+ "\n",
+ " print(\"Epoch {}, train loss={:.3f}, train perp={:.3f}, test loss={:.3f}, test perp={:.3f}\".format(\n",
+ " epoch,\n",
+ " train_loss / len(train_src_texts) / train_prop,\n",
+ " np.exp(train_loss / train_it),\n",
+ " test_loss / len(dev_src_texts),\n",
+ " np.exp(test_loss / test_it)))\n",
+ "\n",
+ " fname = 'sated_nmt'\n",
+ "\n",
+ " # Save model\n",
+ " model.save(MODEL_PATH + '{}_{}.h5'.format(fname, num_users))\n",
+ " print(f\"Target model saved to {MODEL_PATH + '{}_{}.h5'.format(fname, num_users)}.\")\n",
+ " K.clear_session()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ },
+ "source": [
+ "## Train the seq2seq model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Define parameters for data-loading\n",
+ "num_users = 300\n",
+ "num_words = 5000\n",
+ "\n",
+ "# Define hyperparameters for target model\n",
+ "lr = 0.001\n",
+ "h = 128\n",
+ "emb_h = 128\n",
+ "num_epochs = 30\n",
+ "batch_size = 20\n",
+ "drop_p = 0.5\n",
+ "rnn_fn = 'lstm'\n",
+ "optim_fn = 'adam'\n",
+ "\n",
+ "train_sated_nmt(lr=lr, h=h, emb_h=emb_h, num_epochs=num_epochs,\n",
+ " num_users=num_users, batch_size=batch_size,\n",
+ " drop_p=drop_p, rnn_fn=rnn_fn, optim_fn=optim_fn)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Calculate logits and labels for the training and test sets.\n",
+ "\n",
+ "We will use these values later in the membership inference attack to separate training and test samples."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Load the train and test sets grouped by user\n",
+ "user_src_texts, user_trg_texts, test_user_src_texts, test_user_trg_texts, src_vocabs, trg_vocabs \\\n",
+ " = load_sated_data_by_user(num_users, num_words, test_on_user=True)\n",
+ "\n",
+ "train_users = sorted(user_src_texts.keys())\n",
+ "train_src_texts, train_trg_texts = [], []\n",
+ "for user in train_users:\n",
+ " user_src_text = words_to_indices(user_src_texts[user], src_vocabs, mask=False)\n",
+ " user_trg_text = words_to_indices(user_trg_texts[user], trg_vocabs, mask=False)\n",
+ " train_src_texts.append(user_src_text)\n",
+ " train_trg_texts.append(user_trg_text)\n",
+ "\n",
+ "test_users = sorted(test_user_src_texts.keys())\n",
+ "test_src_texts, test_trg_texts = [], []\n",
+ "for user in test_users:\n",
+ " user_src_text = words_to_indices(test_user_src_texts[user], src_vocabs, mask=False)\n",
+ " user_trg_text = words_to_indices(test_user_trg_texts[user], trg_vocabs, mask=False)\n",
+ " test_src_texts.append(user_src_text)\n",
+ " test_trg_texts.append(user_trg_text)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "WARNING:tensorflow:From /usr/local/anaconda3/envs/tfprivacyenv/lib/python3.6/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1630: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n",
+ "Instructions for updating:\n",
+ "If using Keras pass *_constraint arguments to layers.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Get prediction function of the translation model\n",
+ "model = build_nmt_model(Vs=num_words, Vt=num_words, mask=False, drop_p=0., h=h, demb=emb_h, tied=False)\n",
+ "model_path = 'sated_nmt'\n",
+ "model.load_weights(MODEL_PATH + '{}_{}.h5'.format(model_path, num_users))\n",
+ "src_input_var, trg_input_var = model.inputs\n",
+ "prediction = model.output\n",
+ "trg_label_var = K.placeholder((None, None), dtype='float32')\n",
+ "prediction = K.softmax(prediction)\n",
+ "pred_fn = K.function([src_input_var, trg_input_var, trg_label_var, K.learning_phase()], [prediction])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Generator for loading logits by user\n",
+ "def get_logits(src_texts, trg_texts):\n",
+ " for user_src_texts, user_trg_texts in zip(src_texts, trg_texts):\n",
+ " user_trg_logits = []\n",
+ "\n",
+ " sentence_indices = np.arange(len(user_trg_texts))\n",
+ " for idx in sentence_indices:\n",
+ " src_sentence = np.asarray(user_src_texts[idx], dtype=np.float32).reshape(1, -1)\n",
+ " trg_sentence = np.asarray(user_trg_texts[idx], dtype=np.float32)\n",
+ " trg_input = trg_sentence[:-1].reshape(1, -1)\n",
+ " trg_label = trg_sentence[1:].reshape(1, -1)\n",
+ " trg_logits = pred_fn([src_sentence, trg_input, trg_label, 0])[0][0]\n",
+ " user_trg_logits.append(trg_logits)\n",
+ "\n",
+ " yield np.array(user_trg_logits, dtype=object)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "# Generator for loading labels by user.\n",
+ "def get_labels(trg_texts):\n",
+ " for user_trg_texts in trg_texts:\n",
+ " user_trg_labels = []\n",
+ "\n",
+ " for sentence in user_trg_texts:\n",
+ " trg_sentence = np.asarray(sentence, dtype=np.float32)\n",
+ " trg_label = trg_sentence[1:]\n",
+ " user_trg_labels.append(trg_label)\n",
+ "\n",
+ " yield np.array(user_trg_labels, dtype=object)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Run membership inference attacks.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [],
+ "source": [
+ "n_train = 50\n",
+ "n_test = 50\n",
+ "\n",
+ "logits_train_gen = get_logits(train_src_texts[:n_train], train_trg_texts[n_train])\n",
+ "logits_test_gen = get_logits(test_src_texts[:n_test], test_trg_texts[:n_test])\n",
+ "labels_train_gen = get_labels(train_trg_texts[:n_train])\n",
+ "labels_test_gen = get_labels(test_trg_texts[:n_test])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Best-performing attacks over all slices\n",
+ " LOGISTIC_REGRESSION achieved an AUC of 1.00 on slice Entire dataset\n",
+ " LOGISTIC_REGRESSION achieved an advantage of 1.00 on slice Entire dataset\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAlT0lEQVR4nO3deXhU9dn/8fet6IPWpYKKSsQgi0WSsAVwK4KICC60FRXrT5/KBATBrS6IVlt+7dPSx15WRSngblVQERFKkKiA7E0ChGSAUiMiBFDZBBQQEu7fH0n5RQwhQE5OZubzuq5cV84yM59v0LnnPufM95i7IyIiieuosAOIiEi4VAhERBKcCoGISIJTIRARSXAqBCIiCa5O2AEO1amnnurJyclhxxARiSkLFy7c6O6nVbQt5gpBcnIyubm5YccQEYkpZvb5gbbp0JCISIJTIRARSXAqBCIiCU6FQEQkwakQiIgkuMAKgZm9aGZfmVn0ANvNzJ42s0IzyzeztkFlERGRAwuyI3gZuLKS7T2AZmU//YG/BZhFREQOILDvEbj7LDNLrmSXXsCrXjoP9gIz+7GZnenu64PIc9tL2cxYsSGIpxYRqTGrhl9V7c8Z5jmChsCacstFZet+wMz6m1mumeVu2HB4b+YqAiIiFQvzm8VWwboK75Lj7mOAMQDp6elHdCedIKqpiEgQ5syZQyQSIS0tjREjRnDGGWcE8jphFoIi4Oxyy0nAupCyiIjUGtu3b2fo0KFMmDCBZ555hl/84heBvl6Yh4YmAbeWXT10AbA1qPMDIiKxYtq0aaSmprJjxw6WLl0aeBGAADsCMxsLdAZONbMi4LfAMQDuPgrIBHoChcAO4LagsoiI1HabN2/m3nvv5eOPP2bMmDFcccUVNfbaQV41dNNBtjswKKjXFxGJBe7OO++8w1133cX1119PNBrlhBNOqNEMMTcNtYhIvFi/fj2DBg1i+fLljB8/nosuuiiUHJpiQkSkhrk7L730Eq1ateL8889n8eLFoRUBUEcgIlKjPvvsM/r378/mzZvJysqidevWYUdSRyAiUhNKSkp4+umnad++PZdffjn//Oc/a0URAHUEIiKBW758OZFIhDp16jBv3jyaN28edqTvUUcgIhKQPXv28Ic//IFOnTpxyy23MHPmzFpXBEAdgYhIIBYuXEjfvn1p2LAhCxcupFGjRmFHOiB1BCIi1Wjnzp0MGTKEnj178sADDzBlypRaXQRAHYGISLWZNWsWGRkZtGnThoKCAk4//fSwI1WJCoGIyBHatm0bDz30EJMmTeKZZ57hZz/7WdiRDokODYmIHIHMzExSUlLYvXs30Wg05ooAqCMQETksGzdu5N5772Xu3Lm89NJLdO3aNexIh00dgYjIIXB33nzzTVJTUznttNMoKCiI6SIA6ghERKps3bp1DBw4kMLCQt59910uuOCCsCNVC3UEIiIH4e48//zztGrVitatW7No0aK4KQKgjkBEpFIrV66kX79+bNu2jY8++oi0tLSwI1U7dQQiIhUoKSnhr3/9Kx06dKBHjx7Mnz8/LosAqCMQEfmBaDRKRkYGdevWZcGCBTRt2jTsSIFSRyAiUmb37t0MGzaMLl260LdvX6ZPnx73RQDUEYiIAJCTk0Pfvn0555xzWLx4MUlJSWFHqjHqCEQkoe3YsYP777+fa665hocffpjJkycnVBEAFQIRSWAzZ84kLS2NdevWUVBQwE033YSZhR2rxunQkIgknK1bt/Lggw+SmZnJyJEjueaaa8KOFCp1BCKSUCZPnkxKSgpmRjQaTfgiAOoIRCRBbNiwgbvvvpvs7GxeffVVunTpEnakWkMdgYjENXfnjTfeIDU1lYYNG5Kfn68isB91BCISt4qKihg4cCCrVq1i0qRJdOjQIexItZI6AhGJO3v37mX06NG0adOG9u3bs3DhQhWBSqgjEJG4UlhYSL9+/dixYwczZswgJSUl7Ei1njoCEYkLxcXF/OUvf+GCCy7g2muvZd68eSoCVaSOQERiXn5+PpFIhJNOOons7GzOPffcsCPFlEA7AjO70sxWmFmhmT1UwfaTzWyymS0xs6VmdluQeUQkvnz33Xf89re/pWvXrtx+++18+OGHKgKHIbCOwMyOBp4FugFFQI6ZTXL3ZeV2GwQsc/drzOw0YIWZve7uu4PKJSLxYcGCBUQiEZo2bUpeXh4NGzYMO1LMCvLQUAeg0N1XApjZOKAXUL4QOHCilU7ucQKwGSgOMJOIxLhvv/2WRx99lLFjx/Lkk09yww03JOT8QNUpyENDDYE15ZaLytaV9wzQAlgHFAB3u/ve/Z/IzPqbWa6Z5W7YsCGovCJSy3300UekpqayYcMGCgoKuPHGG1UEqkGQHUFF/zq+33J3IA+4DGgCfGBms9192/ce5D4GGAOQnp6+/3OISJz7+uuvuf/++8nKymLUqFH07Nkz7EhxJciOoAg4u9xyEqWf/Mu7DZjgpQqBz4CfBJhJRGLMe++9R0pKCsceeyzRaFRFIABBdgQ5QDMzawysBfoAv9xvn9VAV2C2mTUAzgNWBphJRGLEl19+yV133cXixYt544036NSpU9iR4lZgHYG7FwODgWnAcuAtd19qZgPMbEDZbr8HLjKzAuAjYIi7bwwqk4jUfu7Oa6+9RlpaGsnJySxZskRFIGCBfqHM3TOBzP3WjSr3+zrgiiAziEjsWL16NQMGDGDt2rVMmTKF9PT0sCMlBE0xISKh27t3LyNHjqRdu3ZcfPHF5ObmqgjUIE0xISKh+ve//01GRgbFxcXMmjWLFi1ahB0p4agjEJFQFBcX8+c//5mLLrqI3r17M3v2bBWBkKgjEJEat2TJEvr27Uu9evXIycmhcePGYUdKaOoIRKTG7Nq1i9/85jd069aNwYMHk5WVpSJQC6gjEJEaMW/ePCKRCC1atGDJkiWceeaZYUeSMioEIhKob775hkceeYS3336bp59+muuuu07zA9UyOjQkIoHJysoiNTWVrVu3Eo1G6d27t4pALaSOQESq3ZYtW/j1r3/NjBkzGD16NN27dw87klRCHYGIVKsJEyaQkpLCCSecQEFBgYpADFBHICLV4osvvmDw4MFEo1HefPNNLrnkkrAjSRWpIxCRI+LuvPLKK6SlpdG8eXPy8vJUBGKMOgIROWyrVq3i9ttv56uvvmLatGm0adMm7EhyGNQRiMgh27t3LyNGjCA9PZ3OnTuTnZ2tIhDD1BGIyCH517/+RUZGBgBz5szhJz/RTQVjnToCEamSPXv28Mc//pFLLrmEPn36MGvWLBWBOKGOQEQOavHixfTt25cGDRqwcOFCzjnnnLAjSTVSRyAiB7Rr1y6GDh1K9+7dueeee5g6daqKQBxSRyAiFZozZw6RSIS0tDTy8/M544wzwo4kAVEhEJHv2b59O0OHDuXdd99lxIgR/OIXvwg7kgRMh4ZEZJ/333+flJQUduzYQTQaVRFIEOoIRIRNmzbx61//mlmzZvH888/TrVu3sCNJDVJHIJLA3J3x48eTmprKj3/8YwoKClQEEpA6ApEEtX79egYNGsTy5csZP348F110UdiRJCTqCEQSjLvz4osv0qpVK1q2bMnixYtVBBKcOgKRBPLZZ5/Rv39/tmzZwgcffECrVq3CjiS1gDoCkQRQUlLCU089Rfv27enWrRsLFixQEZB91BGIxLlly5aRkZFBnTp1mDdvHs2bNw87ktQy6ghE4tSePXv4wx/+wKWXXsott9zCzJkzVQSkQuoIROJQbm4ukUiEhg0bsnDhQho1ahR2JKnFAu0IzOxKM1thZoVm9tAB9ulsZnlmttTMPg4yj0i827lzJw8++CBXXXUVDzzwAFOmTFERkIMKrCMws6OBZ4FuQBGQY2aT3H1ZuX1+DIwErnT31WZ2elB5ROLdxx9/TEZGBu3ataOgoIDTT9f/TlI1QR4a6gAUuvtKADMbB/QClpXb55fABHdfDeDuXwWYRyQubdu2jSFDhjB58mSeffZZevXqFXYkiTFBHhpqCKwpt1xUtq685sApZjbTzBaa2a0VPZGZ9TezXDPL3bBhQ0BxRWJPZmYmKSkpFBcXE41GVQTksATZEVgF67yC128HdAWOA+ab2QJ3//f3HuQ+BhgDkJ6evv9ziCScjRs3cs899zB//nxeeuklunbtGnYkiWFBdgRFwNnllpOAdRXs8767f+vuG4FZgL7lInIA7s6bb75JamoqDRo0ID8/X0VAjliQHUEO0MzMGgNrgT6UnhMo7z3gGTOrAxwLdAT+GmAmkZi1du1a7rjjDgoLC5k4cSIdO3YMO5LEicA6AncvBgYD04DlwFvuvtTMBpjZgLJ9lgPvA/lANvC8u0eDyiQSi9yd5557jtatW9OmTRsWLVqkIiDVKtAvlLl7JpC537pR+y0/DjweZA6RWPXpp5/Sr18/tm/fzvTp00lNTQ07ksQhTTEhUguVlJTwxBNP0LFjR6666irmz5+vIiCB0RQTIrVMNBolEolw3HHHsWDBApo2bRp2JIlz6ghEaondu3czbNgwunTpQiQSYfr06SoCUiPUEYjUAtnZ2UQiEZKTk1m8eDFJSUlhR5IEokIgEqIdO3bw2GOP8dprr/HXv/6VPn36YFbRdzFFgqNDQyIhmTFjBmlpaaxfv56CggJuuukmFQEJxSF3BGWzivZx99cDyCMS97Zu3cqDDz5IZmYmI0eO5Jprrgk7kiS4A3YEZnaSmQ01s2fM7AordSewErih5iKKxI/JkyeTkpKCmRGNRlUEpFaorCP4O7AFmA9kAA9QOg1EL3fPCz6aSPzYsGEDd999Nzk5Ofz973+nc+fOYUcS2aeycwTnuvuv3H00cBOQDlytIiBSde7OG2+8QWpqKg0bNmTJkiUqAlLrVNYR7PnPL+5eYmafufv2GsgkEhfWrFnDwIEDWb16NZMnT6Z9+/ZhRxKpUGUdQSsz22Zm281sO5BWbnlbTQUUiTV79+5l9OjRtG3blg4dOpCbm6siILXaATsCdz+6JoOIxINPPvmEfv36sWvXLmbOnEnLli3DjiRyUJVdNVTXzO4pu2qof9k9A0SkAsXFxfzlL3/hwgsvpFevXsydO1dFQGJGZW/ur1B6nmA20BNoCdxdE6FEYkl+fj6RSISTTjqJ7Oxszj333LAjiRySys4RnO/u/6fsqqHewE9rKJNITPjuu+947LHH6Nq1KwMGDODDDz9UEZCYVNWrhor11XeR/2/BggVEIhGaNWvGkiVLOOuss8KOJHLYKisErctdHWTAcWXLBri7nxR4OpFa5ttvv+U3v/kN48aN46mnnuL666/X/EAS8yo7NLTE3U8q+znR3euU+11FQBLORx99RGpqKhs3biQajXLDDTeoCEhcqKwj8BpLIVKLff3119x///1kZWUxatQoevbsGXYkkWpVWSE43cx+faCN7v5EAHlEapWJEycyePBgevXqRTQa5aST1AxL/KmsEBwNnEDpOQGRhPLll19y5513kpeXxxtvvEGnTp3CjiQSmMoKwXp3/781lkSkFnB3XnvtNe6//35uu+02XnnlFY477riwY4kEqrJCoE5AEsrq1asZMGAA69atIzMzk3bt2oUdSaRGVHbVUNcaSyESor179zJy5EjatWvHxRdfTE5OjoqAJJTKJp3bXJNBRMKwYsUK+vXrR3FxMbNmzaJFixZhRxKpcbp5vSSk4uJihg8fzsUXX0zv3r2ZPXu2ioAkLM0oKgknLy+PSCRC/fr1yc3NJTk5OexIIqFSRyAJY9euXTzyyCNcccUV3HnnnUybNk1FQAR1BJIg5s2bRyQSoUWLFixZsoQzzzwz7EgitYYKgcS1b775hocffpjx48czYsQIrrvuurAjidQ6gR4aMrMrzWyFmRWa2UOV7NfezErMrHeQeSSxZGVlkZqayrZt24hGoyoCIgcQWEdgZkcDzwLdgCIgx8wmufuyCvb7MzAtqCySWDZv3sx9993HjBkzGD16NN27dw87kkitFmRH0AEodPeV7r4bGAf0qmC/O4F3gK8CzCIJ4p133iElJYUTTjiBgoICFQGRKgjyHEFDYE255SKgY/kdzKwh8HPgMqD9gZ7IzPoD/QEaNWpU7UEl9n3xxRcMHjyYaDTKW2+9xSWXXBJ2JJGYEWRHUNFcRfvf4+BJYIi7l1T2RO4+xt3T3T39tNNOq658EgfcnZdffpm0tDSaN29OXl6eioDIIQqyIygCzi63nASs22+fdGBc2V2eTgV6mlmxu08MMJfEiVWrVnH77bfz1VdfMW3aNNq0aRN2JJGYFGRHkAM0M7PGZnYs0AeYVH4Hd2/s7snungyMB+5QEZCD2bt3LyNGjCA9PZ0uXbqQnZ2tIiByBALrCNy92MwGU3o10NHAi+6+1MwGlG0fFdRrS/xavnw5GRkZHHXUUcydO5fzzjsv7EgiMS/QL5S5eyaQud+6CguAu/8qyCwS2/bs2cPjjz/OE088wbBhwxg4cCBHHaUZUkSqg75ZLLXeokWLiEQiNGjQgIULF3LOOeeEHUkkrugjldRaO3fuZOjQofTo0YN7772XqVOnqgiIBEAdgdRKc+bMIRKJkJaWRn5+Pg0aNAg7kkjcUiGQWmX79u0MHTqUd999l2eeeYaf//znYUcSiXs6NCS1xtSpU0lJSWHnzp1Eo1EVAZEaoo5AQrdp0ybuvfdeZs+ezQsvvMDll18ediSRhKKOQELj7rz99tukpKRQr149CgoKVAREQqCOQEKxfv167rjjDlasWMGECRO48MILw44kkrDUEUiNcndefPFFWrVqRUpKCosXL1YREAmZOgKpMStXruT2229ny5YtfPDBB7Rq1SrsSCKCOgKpASUlJTz55JN06NCBK664ggULFqgIiNQi6ggkUMuWLSMSiXDssccyb948mjdvHnYkEdmPOgIJxO7du/n973/PpZdeyn//938zY8YMFQGRWkodgVS73NxcIpEIDRs2ZNGiRZx99tkHf5CIhEYdgVSbnTt38uCDD3LVVVfx4IMPMmXKFBUBkRigQiDV4uOPPyYtLY01a9ZQUFDAzTffTNktSEWkltOhITki27ZtY8iQIUyePJmRI0dy7bXXhh1JRA6ROgI5bFOmTCElJYWSkhKi0aiKgEiMUkcgh2zjxo3cc889zJ8/n5dffpnLLrss7EgicgTUEUiVuTvjxo0jJSWFBg0akJ+fryIgEgfUEUiVrF27ljvuuIPCwkLee+89OnbsGHYkEakm6gikUu7Oc889R+vWrWnTpg2LFi1SERCJM+oI5IA+/fRT+vXrxzfffMP06dNJTU0NO5KIBEAdgfxASUkJTzzxBB07duSqq65i/vz5KgIicUwdgXxPNBolEolw/PHHs2DBApo2bRp2JBEJmDoCAUoniRs2bBhdunQhEonw0UcfqQiIJAh1BEJ2djaRSITk5GQWL15MUlJS2JFEpAapECSwHTt28Oijj/L666/z5JNPcuONN2p+IJEEpENDCWrGjBmkpqbyxRdfEI1G6dOnj4qASIJSR5Bgtm7dygMPPMDUqVP529/+xtVXXx12JBEJWaAdgZldaWYrzKzQzB6qYPvNZpZf9jPPzHQj2wBNnjyZlJQUjjrqKKLRqIqAiAABdgRmdjTwLNANKAJyzGySuy8rt9tnwKXuvsXMegBjAH1ttZpt2LCBu+66i5ycHP7+97/TuXPnsCOJSC0SZEfQASh095XuvhsYB/Qqv4O7z3P3LWWLCwBdrlKN3J3XX3+dlJQUkpKSyM/PVxEQkR8I8hxBQ2BNueUiKv+0HwGmVrTBzPoD/QEaNWpUXfni2po1axg4cCCrV6/mH//4B+3btw87kojUUkF2BBVdguIV7mjWhdJCMKSi7e4+xt3T3T39tNNOq8aI8Wfv3r2MGjWKtm3b0rFjR3Jzc1UERKRSQXYERUD5O5cnAev238nM0oDngR7uvinAPHHvk08+oV+/fuzatYuZM2fSsmXLsCOJSAwIsiPIAZqZWWMzOxboA0wqv4OZNQImALe4+78DzBLXiouLefzxx7nwwgv52c9+xty5c1UERKTKAusI3L3YzAYD04CjgRfdfamZDSjbPgp4DKgPjCz7MlOxu6cHlSke5efnE4lEOPnkk8nOzubcc88NO5KIxJhAv1Dm7plA5n7rRpX7PQPICDJDvPruu+/4n//5H0aNGsWf/vQn+vbtq28Gi8hh0TeLY9D8+fOJRCI0b96cvLw8zjrrrLAjiUgMUyGIId9++y2PPPIIb775Jk8//TS9e/dWFyAiR0yTzsWIDz/8kNTUVDZv3kw0GuX6669XERCRaqGOoJb7+uuvue+++/jwww8ZNWoUPXr0CDuSiMQZdQS12MSJE2nZsiV169aloKBARUBEAqGOoBb68ssvufPOO1myZAljx46lU6dOYUcSkTimjqAWcXdeffVV0tLSaNKkCXl5eSoCIhI4dQS1xOrVq7n99ttZv349mZmZtGvXLuxIIpIg1BGEbO/evTz77LO0bduWn/70p+Tk5KgIiEiNUkcQohUrVpCRkUFJSQmzZ8+mRYsWYUcSkQSkjiAExcXFDB8+nIsvvpgbbrhBRUBEQqWOoIbl5eURiUSoX78+ubm5JCcnhx1JRBKcOoIasmvXLh555BGuuOIK7rzzTqZNm6YiICK1gjqCGjB37lwikQgtW7YkPz+fM844I+xIIiL7qBAE6JtvvuHhhx9m/PjxjBgxguuuuy7sSCIiP6BDQwHJysoiJSWFbdu2EY1GVQREpNZSR1DNNm/ezH333ceMGTMYPXo03bt3DzuSiEil1BFUo3feeYeUlBROPPFECgoKVAREJCaoI6gG69evZ/DgwSxdupS3336biy++OOxIIiJVpo7gCLg7L7/8Mq1ateInP/kJeXl5KgIiEnPUERymVatW0b9/fzZu3EhWVhatW7cOO5KIyGFRIThE/5kkbtiwYdx///3cd999HHPMMWHHEqnV9uzZQ1FREbt27Qo7StyrW7cuSUlJh/S+pEJwCJYvX05GRgZHHXUUc+fO5bzzzgs7kkhMKCoq4sQTTyQ5OVn32g6Qu7Np0yaKiopo3LhxlR+ncwRVsGfPHv74xz/SqVMnbr75Zj7++GMVAZFDsGvXLurXr68iEDAzo379+ofceakjOIhFixbRt29fzjzzTHJzcznnnHPCjiQSk1QEasbh/J3VERzAzp07eeihh+jRowf33XcfmZmZKgIiEpdUCCowe/ZsWrduzcqVK8nPz+eWW27RpxmROPDuu+9iZvzrX//at27mzJlcffXV39vvV7/6FePHjwdKDw0/9NBDNGvWjJSUFDp06MDUqVMrfZ1Zs2bRtm1b6tSps+95KrJw4UJSU1Np2rQpd911F+4OwHfffceNN95I06ZN6dixI6tWrdr3mFdeeYVmzZrRrFkzXnnllUP9E1RIhaCc7du3M2jQIPr06cPw4cN56623aNCgQdixRKSajB07lksuuYRx48ZV+TGPPvoo69evJxqNEo1GmTx5Mtu3b6/0MY0aNeLll1/ml7/8ZaX7DRw4kDFjxvDJJ5/wySef8P777wPwwgsvcMopp1BYWMi9997LkCFDgNIpbIYNG8Y///lPsrOzGTZsGFu2bKnyWA5E5wjKTJ06lQEDBnD55ZcTjUY55ZRTwo4kEpeSH5oSyPOuGn5Vpdu/+eYb5s6dy4wZM7j22mv53e9+d9Dn3LFjB8899xyfffYZ//Vf/wVAgwYNuOGGGyp93H/uNXLUUQf+rL1+/Xq2bdvGhRdeCMCtt97KxIkT6dGjB++9996+fL1792bw4MG4O9OmTaNbt27Uq1cPgG7duvH+++9z0003HXQslUn4jmDTpk3ceuutDBo0iBdeeGFfJRaR+DJx4kSuvPJKmjdvTr169Vi0aNFBH1NYWEijRo046aSTKtyekZFBbm7uYeVZu3YtSUlJ+5aTkpJYu3btvm1nn302AHXq1OHkk09m06ZN31u//2OORMJ2BO7O22+/zd13302fPn0oKCjgRz/6UdixROLewT65B2Xs2LHcc889APTp04exY8fStm3bA57/q8p5weeff/6w8/znfEBFr3mgbZU95kgEWgjM7ErgKeBo4Hl3H77fdivb3hPYAfzK3Q9epo/QunXrGDRoECtWrGDChAn7WjMRiU+bNm1i+vTpRKNRzIySkhLMjP/93/+lfv36PzjOvnnzZk499VSaNm3K6tWr2b59OyeeeGK1ZkpKSqKoqGjfclFREWeddda+bWvWrCEpKYni4mK2bt1KvXr1SEpKYubMmd97TOfOnY84S2CHhszsaOBZoAdwPnCTmZ2/3249gGZlP/2BvwWV5z9eeOEFWrduTWpqKosXL1YREEkA48eP59Zbb+Xzzz9n1apVrFmzhsaNGzNnzhyaNWvGunXrWL58OQCff/45S5YsoXXr1hx//PFEIhHuuusudu/eDZQe23/ttdeOONOZZ57JiSeeyIIFC3B3Xn31VXr16gXAtddeu++KoPHjx3PZZZdhZnTv3p2srCy2bNnCli1byMrKqp7p7t09kB/gQmBaueWhwND99hkN3FRueQVwZmXP265dOz8c5wz5h58z5B/erl07z8vLO6znEJHDs2zZslBf/9JLL/WpU6d+b91TTz3lAwYMcHf3OXPmeMeOHb1Vq1aenp7uWVlZ+/b77rvv/IEHHvAmTZp4y5YtvUOHDv7++++7u3skEvGcnJwfvF52drY3bNjQjz/+eK9Xr56ff/75+7a1atVq3+85OTnesmVLP/fcc33QoEG+d+9ed3ffuXOn9+7d25s0aeLt27f3Tz/9dN9jXnjhBW/SpIk3adLEX3zxxQrHW9HfG8j1A7yvmldwzKk6mFlv4Ep3zyhbvgXo6O6Dy+3zD2C4u88pW/4IGOLuufs9V39KOwYaNWrU7vPPPz/kPP+5UqHwD92pUydhT42IhGL58uW0aNEi7BgJo6K/t5ktdPf0ivYP8h2xojMY+1edquyDu48BxgCkp6cfVuUK6wSViEhtF+Tlo0XA2eWWk4B1h7GPiIgEKMhCkAM0M7PGZnYs0AeYtN8+k4BbrdQFwFZ3Xx9gJhEJSVCHoeX7DufvHNihIXcvNrPBwDRKLx990d2XmtmAsu2jgExKLx0tpPTy0duCyiMi4albty6bNm3SVNQB87L7EdStW/eQHhfYyeKgpKen++F+k09EwqE7lNWcA92hLKyTxSIiABxzzDGHdMcsqVkJP9eQiEiiUyEQEUlwKgQiIgku5k4Wm9kG4NC/WlzqVGBjNcaJBRpzYtCYE8ORjPkcdz+tog0xVwiOhJnlHuisebzSmBODxpwYghqzDg2JiCQ4FQIRkQSXaIVgTNgBQqAxJwaNOTEEMuaEOkcgIiI/lGgdgYiI7EeFQEQkwcVlITCzK81shZkVmtlDFWw3M3u6bHu+mbUNI2d1qsKYby4ba76ZzTOzVmHkrE4HG3O5/dqbWUnZXfNiWlXGbGadzSzPzJaa2cc1nbG6VeG/7ZPNbLKZLSkbc0zPYmxmL5rZV2YWPcD26n//OtA9LGP1h9Iprz8FzgWOBZYA5++3T09gKqV3SLsA+GfYuWtgzBcBp5T93iMRxlxuv+mUTnneO+zcNfDv/GNgGdCobPn0sHPXwJgfBv5c9vtpwGbg2LCzH8GYOwFtgegBtlf7+1c8dgQdgEJ3X+nuu4FxQK/99ukFvOqlFgA/NrMzazpoNTromN19nrtvKVtcQOnd4GJZVf6dAe4E3gG+qslwAanKmH8JTHD31QDuHuvjrsqYHTjRSm90cAKlhaC4ZmNWH3efRekYDqTa37/isRA0BNaUWy4qW3eo+8SSQx1PhNJPFLHsoGM2s4bAz4FRNZgrSFX5d24OnGJmM81soZndWmPpglGVMT8DtKD0NrcFwN3uvrdm4oWi2t+/4vF+BBXd/mj/a2Srsk8sqfJ4zKwLpYXgkkATBa8qY34SGOLuJXFyV6yqjLkO0A7oChwHzDezBe7+76DDBaQqY+4O5AGXAU2AD8xstrtvCzhbWKr9/SseC0ERcHa55SRKPykc6j6xpErjMbM04Hmgh7tvqqFsQanKmNOBcWVF4FSgp5kVu/vEGklY/ar63/ZGd/8W+NbMZgGtgFgtBFUZ823AcC89gF5oZp8BPwGyayZijav29694PDSUAzQzs8ZmdizQB5i03z6TgFvLzr5fAGx19/U1HbQaHXTMZtYImADcEsOfDss76JjdvbG7J7t7MjAeuCOGiwBU7b/t94CfmlkdMzse6Agsr+Gc1akqY15NaQeEmTUAzgNW1mjKmlXt719x1xG4e7GZDQamUXrFwYvuvtTMBpRtH0XpFSQ9gUJgB6WfKGJWFcf8GFAfGFn2CbnYY3jmxiqOOa5UZczuvtzM3gfygb3A8+5e4WWIsaCK/86/B142swJKD5sMcfeYnZ7azMYCnYFTzawI+C1wDAT3/qUpJkREElw8HhoSEZFDoEIgIpLgVAhERBKcCoGISIJTIRARSXAqBCJVVDaDaV65n+SymT63mtliM1tuZr8t27f8+n+Z2V/Czi9yIHH3PQKRAO1099blV5hZMjDb3a82sx8BeWb2j7LN/1l/HLDYzN5197k1G1nk4NQRiFSTsmkdFlI630359TspnQsnlic2lDimQiBSdceVOyz07v4bzaw+pfPDL91v/SlAM2BWzcQUOTQ6NCRSdT84NFTmp2a2mNIpHYaXTYHQuWx9PqVz3wx39y9qLKnIIVAhEDlys9396gOtN7PmwJyycwR5NZxN5KB0aEgkYGWzvf4JGBJ2FpGKqBCI1IxRQCczaxx2EJH9afZREZEEp45ARCTBqRCIiCQ4FQIRkQSnQiAikuBUCEREEpwKgYhIglMhEBFJcP8Pxofgk1x2uvoAAAAASUVORK5CYII=\n",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia\n",
+ "from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData\n",
+ "import tensorflow_privacy.privacy.membership_inference_attack.plotting as plotting\n",
+ "\n",
+ "attack_input = Seq2SeqAttackInputData(\n",
+ " logits_train = logits_train_gen,\n",
+ " logits_test = logits_test_gen,\n",
+ " labels_train = labels_train_gen,\n",
+ " labels_test = labels_test_gen,\n",
+ " vocab_size = num_words,\n",
+ " train_size = n_train,\n",
+ " test_size = n_test\n",
+ ")\n",
+ "\n",
+ "# Run several attacks for different data slices\n",
+ "attack_result = mia.run_seq2seq_attack(attack_input)\n",
+ "\n",
+ "# Plot the ROC curve of the best classifier\n",
+ "fig = plotting.plot_roc_curve(attack_result.get_result_with_max_auc().roc_curve)\n",
+ "\n",
+ "# Print a user-friendly summary of the attacks\n",
+ "print(attack_result.summary())"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
\ No newline at end of file
diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py
index 1140611..ba58990 100644
--- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py
+++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures.py
@@ -18,7 +18,7 @@ import enum
import glob
import os
import pickle
-from typing import Any, Iterable, Union
+from typing import Any, Iterable, Union, Iterator
from dataclasses import dataclass
import numpy as np
@@ -378,6 +378,91 @@ def _append_array_shape(arr: np.array, arr_name: str, result):
result.append(' %s with shape: %s,' % (arr_name, arr.shape))
+def _is_iterator(obj, obj_name):
+ """Checks whether obj is a generator."""
+ if obj is not None and not isinstance(obj, Iterator):
+ raise ValueError('%s should be a generator.' % obj_name)
+
+
+@dataclass
+class Seq2SeqAttackInputData:
+ """Input data for running an attack on seq2seq models.
+
+ This includes only the data, and not configuration.
+ """
+ logits_train: Iterator[np.ndarray] = None
+ logits_test: Iterator[np.ndarray] = None
+
+ # Contains ground-truth token indices for the target sequences.
+ labels_train: Iterator[np.ndarray] = None
+ labels_test: Iterator[np.ndarray] = None
+
+ # Size of the target sequence vocabulary.
+ vocab_size: int = None
+
+ # Train, test size = number of batches in training, test set.
+ # These values need to be supplied by the user as logits, labels
+ # are lazy loaded for seq2seq models.
+ train_size: int = 0
+ test_size: int = 0
+
+ def validate(self):
+ """Validates the inputs."""
+
+ if (self.logits_train is None) != (self.logits_test is None):
+ raise ValueError(
+ 'logits_train and logits_test should both be either set or unset')
+
+ if (self.labels_train is None) != (self.labels_test is None):
+ raise ValueError(
+ 'labels_train and labels_test should both be either set or unset')
+
+ if self.logits_train is None or self.labels_train is None:
+ raise ValueError(
+ 'Labels, logits of training, test sets should all be set')
+
+ if (self.vocab_size is None or self.train_size is None or
+ self.test_size is None):
+ raise ValueError('vocab_size, train_size, test_size should all be set')
+
+ if self.vocab_size is not None and not int:
+ raise ValueError('vocab_size should be of integer type')
+
+ if self.train_size is not None and not int:
+ raise ValueError('train_size should be of integer type')
+
+ if self.test_size is not None and not int:
+ raise ValueError('test_size should be of integer type')
+
+ _is_iterator(self.logits_train, 'logits_train')
+ _is_iterator(self.logits_test, 'logits_test')
+ _is_iterator(self.labels_train, 'labels_train')
+ _is_iterator(self.labels_test, 'labels_test')
+
+ def __str__(self):
+ """Return the shapes of variables that are not None."""
+ result = ['AttackInputData(']
+
+ if self.vocab_size is not None and self.train_size is not None:
+ result.append(
+ 'logits_train with shape (%d, num_sequences, num_tokens, %d)' %
+ (self.train_size, self.vocab_size))
+ result.append(
+ 'labels_train with shape (%d, num_sequences, num_tokens, 1)' %
+ self.train_size)
+
+ if self.vocab_size is not None and self.test_size is not None:
+ result.append(
+ 'logits_test with shape (%d, num_sequences, num_tokens, %d)' %
+ (self.test_size, self.vocab_size))
+ result.append(
+ 'labels_test with shape (%d, num_sequences, num_tokens, 1)' %
+ self.test_size)
+
+ result.append(')')
+ return '\n'.join(result)
+
+
@dataclass
class RocCurve:
"""Represents ROC curve of a membership inference classifier."""
diff --git a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py
index caa8ab6..eb1d8db 100644
--- a/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py
+++ b/tensorflow_privacy/privacy/membership_inference_attack/data_structures_test.py
@@ -27,6 +27,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import PrivacyReportMetadata
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
+from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
@@ -152,6 +153,75 @@ class AttackInputDataTest(absltest.TestCase):
probs_test=np.array([])).validate)
+class Seq2SeqAttackInputDataTest(absltest.TestCase):
+
+ def test_validator(self):
+ valid_logits_train = iter([np.array([]), np.array([])])
+ valid_logits_test = iter([np.array([]), np.array([])])
+ valid_labels_train = iter([np.array([]), np.array([])])
+ valid_labels_test = iter([np.array([]), np.array([])])
+
+ invalid_logits_train = []
+ invalid_logits_test = []
+ invalid_labels_train = []
+ invalid_labels_test = []
+
+ self.assertRaises(
+ ValueError,
+ Seq2SeqAttackInputData(logits_train=valid_logits_train).validate)
+ self.assertRaises(
+ ValueError,
+ Seq2SeqAttackInputData(labels_train=valid_labels_train).validate)
+ self.assertRaises(
+ ValueError,
+ Seq2SeqAttackInputData(logits_test=valid_logits_test).validate)
+ self.assertRaises(
+ ValueError,
+ Seq2SeqAttackInputData(labels_test=valid_labels_test).validate)
+ self.assertRaises(ValueError, Seq2SeqAttackInputData(vocab_size=0).validate)
+ self.assertRaises(ValueError, Seq2SeqAttackInputData(train_size=0).validate)
+ self.assertRaises(ValueError, Seq2SeqAttackInputData(test_size=0).validate)
+ self.assertRaises(ValueError, Seq2SeqAttackInputData().validate)
+
+ # Tests that both logits and labels must be set.
+ self.assertRaises(
+ ValueError,
+ Seq2SeqAttackInputData(
+ logits_train=valid_logits_train,
+ logits_test=valid_logits_test,
+ vocab_size=0,
+ train_size=0,
+ test_size=0).validate)
+ self.assertRaises(
+ ValueError,
+ Seq2SeqAttackInputData(
+ labels_train=valid_labels_train,
+ labels_test=valid_labels_test,
+ vocab_size=0,
+ train_size=0,
+ test_size=0).validate)
+
+ # Tests that vocab, train, test sizes must all be set.
+ self.assertRaises(
+ ValueError,
+ Seq2SeqAttackInputData(
+ logits_train=valid_logits_train,
+ logits_test=valid_logits_test,
+ labels_train=valid_labels_train,
+ labels_test=valid_labels_test).validate)
+
+ self.assertRaises(
+ ValueError,
+ Seq2SeqAttackInputData(
+ logits_train=invalid_logits_train,
+ logits_test=invalid_logits_test,
+ labels_train=invalid_labels_train,
+ labels_test=invalid_labels_test,
+ vocab_size=0,
+ train_size=0,
+ test_size=0).validate)
+
+
class RocCurveTest(absltest.TestCase):
def test_auc_random_classifier(self):
@@ -275,7 +345,6 @@ class AttackResultsCollectionTest(absltest.TestCase):
class AttackResultsTest(absltest.TestCase):
-
perfect_classifier_result: SingleAttackResult
random_classifier_result: SingleAttackResult
diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py
index f731958..3d18648 100644
--- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py
+++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack.py
@@ -30,6 +30,7 @@ from tensorflow_privacy.privacy.membership_inference_attack.data_structures impo
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import \
PrivacyReportMetadata
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import RocCurve
+from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleAttackResult
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
@@ -170,6 +171,54 @@ def run_attacks(attack_input: AttackInputData,
privacy_report_metadata=privacy_report_metadata)
+def run_seq2seq_attack(attack_input: Seq2SeqAttackInputData,
+ unused_report_metadata: PrivacyReportMetadata = None,
+ balance_attacker_training: bool = True) -> AttackResults:
+ """Runs membership inference attacks on a seq2seq model.
+
+ Args:
+ attack_input: input data for running an attack
+ unused_report_metadata: the metadata of the model under attack.
+ balance_attacker_training: Whether the training and test sets for the
+ membership inference attacker should have a balanced (roughly equal)
+ number of samples from the training and test sets used to develop the
+ model under attack.
+
+ Returns:
+ the attack result.
+ """
+ attack_input.validate()
+
+ # The attacker uses the average rank (a single number) of a seq2seq dataset
+ # record to determine membership. So only Logistic Regression is supported,
+ # as it makes the most sense for single-number features.
+ attacker = models.LogisticRegressionAttacker()
+
+ prepared_attacker_data = models.create_seq2seq_attacker_data(
+ attack_input, balance=balance_attacker_training)
+
+ attacker.train_model(prepared_attacker_data.features_train,
+ prepared_attacker_data.is_training_labels_train)
+
+ # Run the attacker on (permuted) test examples.
+ predictions_test = attacker.predict(prepared_attacker_data.features_test)
+
+ # Generate ROC curves with predictions.
+ fpr, tpr, thresholds = metrics.roc_curve(
+ prepared_attacker_data.is_training_labels_test, predictions_test)
+
+ roc_curve = RocCurve(tpr=tpr, fpr=fpr, thresholds=thresholds)
+
+ attack_results = [
+ SingleAttackResult(
+ slice_spec=SingleSliceSpec(),
+ attack_type=AttackType.LOGISTIC_REGRESSION,
+ roc_curve=roc_curve)
+ ]
+
+ return AttackResults(single_attack_results=attack_results)
+
+
def _compute_missing_privacy_report_metadata(
metadata: PrivacyReportMetadata,
attack_input: AttackInputData) -> PrivacyReportMetadata:
diff --git a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py
index d6b9867..4c80f49 100644
--- a/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py
+++ b/tensorflow_privacy/privacy/membership_inference_attack/membership_inference_attack_test.py
@@ -19,6 +19,7 @@ import numpy as np
from tensorflow_privacy.privacy.membership_inference_attack import membership_inference_attack as mia
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackType
+from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SingleSliceSpec
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingFeature
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import SlicingSpec
@@ -34,6 +35,68 @@ def get_test_input(n_train, n_test):
labels_test=np.array([i % 5 for i in range(n_test)]))
+def get_seq2seq_test_input(n_train,
+ n_test,
+ max_seq_in_batch,
+ max_tokens_in_sequence,
+ vocab_size,
+ seed=None):
+ """Returns example inputs for attacks on seq2seq models."""
+ if seed is not None:
+ np.random.seed(seed=seed)
+
+ logits_train, labels_train = [], []
+ for _ in range(n_train):
+ num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
+ batch_logits, batch_labels = _get_batch_logits_and_labels(
+ num_sequences, max_tokens_in_sequence, vocab_size)
+ logits_train.append(batch_logits)
+ labels_train.append(batch_labels)
+
+ logits_test, labels_test = [], []
+ for _ in range(n_test):
+ num_sequences = np.random.choice(max_seq_in_batch, 1)[0] + 1
+ batch_logits, batch_labels = _get_batch_logits_and_labels(
+ num_sequences, max_tokens_in_sequence, vocab_size)
+ logits_test.append(batch_logits)
+ labels_test.append(batch_labels)
+
+ return Seq2SeqAttackInputData(
+ logits_train=iter(logits_train),
+ logits_test=iter(logits_test),
+ labels_train=iter(labels_train),
+ labels_test=iter(labels_test),
+ vocab_size=vocab_size,
+ train_size=n_train,
+ test_size=n_test)
+
+
+def _get_batch_logits_and_labels(num_sequences, max_tokens_in_sequence,
+ vocab_size):
+ num_tokens_in_sequence = np.random.choice(max_tokens_in_sequence,
+ num_sequences) + 1
+ batch_logits, batch_labels = [], []
+ for num_tokens in num_tokens_in_sequence:
+ logits, labels = _get_sequence_logits_and_labels(num_tokens, vocab_size)
+ batch_logits.append(logits)
+ batch_labels.append(labels)
+ return np.array(
+ batch_logits, dtype=object), np.array(
+ batch_labels, dtype=object)
+
+
+def _get_sequence_logits_and_labels(num_tokens, vocab_size):
+ sequence_logits = []
+ for _ in range(num_tokens):
+ token_logits = np.random.random(vocab_size)
+ token_logits /= token_logits.sum()
+ sequence_logits.append(token_logits)
+ sequence_labels = np.random.choice(vocab_size, num_tokens)
+ return np.array(
+ sequence_logits, dtype=np.float32), np.array(
+ sequence_labels, dtype=np.float32)
+
+
class RunAttacksTest(absltest.TestCase):
def test_run_attacks_size(self):
@@ -97,6 +160,42 @@ class RunAttacksTest(absltest.TestCase):
# If accuracy is already present, simply return it.
self.assertIsNone(mia._get_accuracy(None, labels))
+ def test_run_seq2seq_attack_size(self):
+ result = mia.run_seq2seq_attack(
+ get_seq2seq_test_input(
+ n_train=10,
+ n_test=5,
+ max_seq_in_batch=3,
+ max_tokens_in_sequence=5,
+ vocab_size=2))
+
+ self.assertLen(result.single_attack_results, 1)
+
+ def test_run_seq2seq_attack_trained_sets_attack_type(self):
+ result = mia.run_seq2seq_attack(
+ get_seq2seq_test_input(
+ n_train=10,
+ n_test=5,
+ max_seq_in_batch=3,
+ max_tokens_in_sequence=5,
+ vocab_size=2))
+ seq2seq_result = list(result.single_attack_results)[0]
+ self.assertEqual(seq2seq_result.attack_type, AttackType.LOGISTIC_REGRESSION)
+
+ def test_run_seq2seq_attack_calculates_correct_auc(self):
+ result = mia.run_seq2seq_attack(
+ get_seq2seq_test_input(
+ n_train=20,
+ n_test=10,
+ max_seq_in_batch=3,
+ max_tokens_in_sequence=5,
+ vocab_size=3,
+ seed=12345),
+ balance_attacker_training=False)
+ seq2seq_result = list(result.single_attack_results)[0]
+ np.testing.assert_almost_equal(
+ seq2seq_result.roc_curve.get_auc(), 0.63, decimal=2)
+
if __name__ == '__main__':
absltest.main()
diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models.py b/tensorflow_privacy/privacy/membership_inference_attack/models.py
index 54674e0..dd33804 100644
--- a/tensorflow_privacy/privacy/membership_inference_attack/models.py
+++ b/tensorflow_privacy/privacy/membership_inference_attack/models.py
@@ -15,8 +15,11 @@
# Lint as: python3
"""Trained models for membership inference attacks."""
+from typing import Iterator, List
+
from dataclasses import dataclass
import numpy as np
+from scipy.stats import rankdata
from sklearn import ensemble
from sklearn import linear_model
from sklearn import model_selection
@@ -24,6 +27,7 @@ from sklearn import neighbors
from sklearn import neural_network
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
+from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
@dataclass
@@ -110,6 +114,98 @@ def _column_stack(logits, loss):
return np.column_stack((logits, loss))
+def create_seq2seq_attacker_data(attack_input_data: Seq2SeqAttackInputData,
+ test_fraction: float = 0.25,
+ balance: bool = True) -> AttackerData:
+ """Prepare Seq2SeqAttackInputData to train ML attackers.
+
+ Uses logits and losses to generate ranks and performs a random train-test
+ split.
+
+ Args:
+ attack_input_data: Original Seq2SeqAttackInputData
+ test_fraction: Fraction of the dataset to include in the test split.
+ balance: Whether the training and test sets for the membership inference
+ attacker should have a balanced (roughly equal) number of samples from the
+ training and test sets used to develop the model under attack.
+
+ Returns:
+ AttackerData.
+ """
+ attack_input_train = _get_average_ranks(attack_input_data.logits_train,
+ attack_input_data.labels_train)
+ attack_input_test = _get_average_ranks(attack_input_data.logits_test,
+ attack_input_data.labels_test)
+
+ if balance:
+ min_size = min(len(attack_input_train), len(attack_input_test))
+ attack_input_train = _sample_multidimensional_array(attack_input_train,
+ min_size)
+ attack_input_test = _sample_multidimensional_array(attack_input_test,
+ min_size)
+
+ features_all = np.concatenate((attack_input_train, attack_input_test))
+
+ # Reshape for classifying one-dimensional features
+ features_all = features_all.reshape(-1, 1)
+
+ labels_all = np.concatenate(
+ ((np.zeros(len(attack_input_train))), (np.ones(len(attack_input_test)))))
+
+ # Perform a train-test split
+ features_train, features_test, \
+ is_training_labels_train, is_training_labels_test = \
+ model_selection.train_test_split(
+ features_all, labels_all, test_size=test_fraction, stratify=labels_all)
+
+ return AttackerData(features_train, is_training_labels_train, features_test,
+ is_training_labels_test)
+
+
+def _get_average_ranks(logits: Iterator[np.ndarray],
+ labels: Iterator[np.ndarray]) -> np.ndarray:
+ """Returns the average rank of tokens in a batch of sequences.
+
+ Args:
+ logits: Logits returned by a seq2seq model, dim = (num_batches,
+ num_sequences, num_tokens, vocab_size).
+ labels: Target labels for the seq2seq model, dim = (num_batches,
+ num_sequences, num_tokens, 1).
+
+ Returns:
+ An array of average ranks, dim = (num_batches, 1).
+ Each average rank is calculated over ranks of tokens in sequences of a
+ particular batch.
+ """
+ ranks = []
+ for batch_logits, batch_labels in zip(logits, labels):
+ batch_ranks = []
+ for sequence_logits, sequence_labels in zip(batch_logits, batch_labels):
+ batch_ranks += _get_ranks_for_sequence(sequence_logits, sequence_labels)
+ ranks.append(np.mean(batch_ranks))
+
+ return np.array(ranks)
+
+
+def _get_ranks_for_sequence(logits: np.ndarray,
+ labels: np.ndarray) -> List[float]:
+ """Returns ranks for a sequence.
+
+ Args:
+ logits: Logits of a single sequence, dim = (num_tokens, vocab_size).
+ labels: Target labels of a single sequence, dim = (num_tokens, 1).
+
+ Returns:
+ An array of ranks for tokens in the sequence, dim = (num_tokens, 1).
+ """
+ sequence_ranks = []
+ for logit, label in zip(logits, labels.astype(int)):
+ rank = rankdata(-logit, method='min')[label] - 1.0
+ sequence_ranks.append(rank)
+
+ return sequence_ranks
+
+
class TrainedAttacker:
"""Base class for training attack models."""
model = None
diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models_test.py b/tensorflow_privacy/privacy/membership_inference_attack/models_test.py
index c55ab98..eb672c3 100644
--- a/tensorflow_privacy/privacy/membership_inference_attack/models_test.py
+++ b/tensorflow_privacy/privacy/membership_inference_attack/models_test.py
@@ -19,6 +19,7 @@ import numpy as np
from tensorflow_privacy.privacy.membership_inference_attack import models
from tensorflow_privacy.privacy.membership_inference_attack.data_structures import AttackInputData
+from tensorflow_privacy.privacy.membership_inference_attack.data_structures import Seq2SeqAttackInputData
class TrainedAttackerTest(absltest.TestCase):
@@ -55,6 +56,66 @@ class TrainedAttackerTest(absltest.TestCase):
expected = feature[:2] not in attack_input.logits_train
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
+ def test_create_seq2seq_attacker_data_logits_and_labels(self):
+ attack_input = Seq2SeqAttackInputData(
+ logits_train=iter([
+ np.array([
+ np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
+ np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
+ ],
+ dtype=object),
+ np.array(
+ [np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
+ dtype=object),
+ np.array([
+ np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
+ np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
+ ],
+ dtype=object)
+ ]),
+ logits_test=iter([
+ np.array([
+ np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
+ ],
+ dtype=object),
+ np.array([
+ np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
+ np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
+ ],
+ dtype=object)
+ ]),
+ labels_train=iter([
+ np.array([
+ np.array([2, 0], dtype=np.float32),
+ np.array([1], dtype=np.float32)
+ ],
+ dtype=object),
+ np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
+ np.array([
+ np.array([0, 1], dtype=np.float32),
+ np.array([1, 2], dtype=np.float32)
+ ],
+ dtype=object)
+ ]),
+ labels_test=iter([
+ np.array([np.array([2, 1], dtype=np.float32)]),
+ np.array([
+ np.array([2, 0], dtype=np.float32),
+ np.array([1], dtype=np.float32)
+ ],
+ dtype=object)
+ ]),
+ vocab_size=3,
+ train_size=3,
+ test_size=2)
+ attacker_data = models.create_seq2seq_attacker_data(
+ attack_input, 0.25, balance=False)
+ self.assertLen(attacker_data.features_train, 3)
+ self.assertLen(attacker_data.features_test, 2)
+
+ for _, feature in enumerate(attacker_data.features_train):
+ self.assertLen(feature, 1) # each feature has one average rank
+
def test_balanced_create_attacker_data_loss_and_logits(self):
attack_input = AttackInputData(
logits_train=np.array([[1, 2], [5, 6], [8, 9]]),
@@ -70,6 +131,71 @@ class TrainedAttackerTest(absltest.TestCase):
expected = feature[:2] not in attack_input.logits_train
self.assertEqual(attacker_data.is_training_labels_train[i], expected)
+ def test_balanced_create_seq2seq_attacker_data_logits_and_labels(self):
+ attack_input = Seq2SeqAttackInputData(
+ logits_train=iter([
+ np.array([
+ np.array([[0.1, 0.1, 0.8], [0.7, 0.3, 0]], dtype=np.float32),
+ np.array([[0.4, 0.5, 0.1]], dtype=np.float32)
+ ],
+ dtype=object),
+ np.array(
+ [np.array([[0.25, 0.6, 0.15], [1, 0, 0]], dtype=np.float32)],
+ dtype=object),
+ np.array([
+ np.array([[0.9, 0, 0.1], [0.25, 0.5, 0.25]], dtype=np.float32),
+ np.array([[0, 1, 0], [0.2, 0.1, 0.7]], dtype=np.float32)
+ ],
+ dtype=object)
+ ]),
+ logits_test=iter([
+ np.array([
+ np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
+ ],
+ dtype=object),
+ np.array([
+ np.array([[0.3, 0.3, 0.4], [0.4, 0.4, 0.2]], dtype=np.float32),
+ np.array([[0.3, 0.35, 0.35]], dtype=np.float32)
+ ],
+ dtype=object),
+ np.array([
+ np.array([[0.25, 0.4, 0.35], [0.2, 0.4, 0.4]], dtype=np.float32)
+ ],
+ dtype=object)
+ ]),
+ labels_train=iter([
+ np.array([
+ np.array([2, 0], dtype=np.float32),
+ np.array([1], dtype=np.float32)
+ ],
+ dtype=object),
+ np.array([np.array([1, 0], dtype=np.float32)], dtype=object),
+ np.array([
+ np.array([0, 1], dtype=np.float32),
+ np.array([1, 2], dtype=np.float32)
+ ],
+ dtype=object)
+ ]),
+ labels_test=iter([
+ np.array([np.array([2, 1], dtype=np.float32)]),
+ np.array([
+ np.array([2, 0], dtype=np.float32),
+ np.array([1], dtype=np.float32)
+ ],
+ dtype=object),
+ np.array([np.array([2, 1], dtype=np.float32)])
+ ]),
+ vocab_size=3,
+ train_size=3,
+ test_size=3)
+ attacker_data = models.create_seq2seq_attacker_data(
+ attack_input, 0.33, balance=True)
+ self.assertLen(attacker_data.features_train, 4)
+ self.assertLen(attacker_data.features_test, 2)
+
+ for _, feature in enumerate(attacker_data.features_train):
+ self.assertLen(feature, 1) # each feature has one average rank
+
if __name__ == '__main__':
absltest.main()