Add word2vec attack example to attack library.

PiperOrigin-RevId: 453489150
This commit is contained in:
Matthew Jagielski 2022-06-07 11:46:15 -07:00 committed by A. Unique TensorFlower
parent fca208e514
commit 6c0cc858e0
4 changed files with 821 additions and 9 deletions

View file

@ -178,7 +178,7 @@ def main(unused_argv):
loss_train=scores[in_indices_target], loss_train=scores[in_indices_target],
loss_test=scores[~in_indices_target]) loss_test=scores[~in_indices_target])
result_lira = mia.run_attacks(attack_input).single_attack_results[0] result_lira = mia.run_attacks(attack_input).single_attack_results[0]
print('Better MIA attack with Gaussian:', print('Advanced MIA attack with Gaussian:',
f'auc = {result_lira.get_auc():.4f}', f'auc = {result_lira.get_auc():.4f}',
f'adv = {result_lira.get_attacker_advantage():.4f}') f'adv = {result_lira.get_attacker_advantage():.4f}')
@ -190,7 +190,7 @@ def main(unused_argv):
loss_train=scores[in_indices_target], loss_train=scores[in_indices_target],
loss_test=scores[~in_indices_target]) loss_test=scores[~in_indices_target])
result_offset = mia.run_attacks(attack_input).single_attack_results[0] result_offset = mia.run_attacks(attack_input).single_attack_results[0]
print('Better MIA attack with offset:', print('Advanced MIA attack with offset:',
f'auc = {result_offset.get_auc():.4f}', f'auc = {result_offset.get_auc():.4f}',
f'adv = {result_offset.get_attacker_advantage():.4f}') f'adv = {result_offset.get_attacker_advantage():.4f}')

View file

@ -25,6 +25,15 @@ risk score) developed by Song and Mittal, please see their
The accompanying paper is on [arXiv](https://arxiv.org/abs/2003.10595). The accompanying paper is on [arXiv](https://arxiv.org/abs/2003.10595).
## Word2Vec models
If you're interested in word2vec models, please see the
[word2vec codelab](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelabs/word2vec_codelab.ipynb).
This notebook implements advanced membership inference, as well as a secret
sharer attack. Based on [this paper](https://arxiv.org/abs/2004.00053) and
[this code](https://github.com/google/embedding-tests).
## Copyright ## Copyright
Copyright 2020 - Google LLC Copyright 2020 - Google LLC

View file

@ -0,0 +1,793 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "z8Cp3j_xTuTg"
},
"source": [
"Copyright 2022 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ARoHX_6CTq8A"
},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yJJ1fGaz66ZJ"
},
"source": [
"# Use Membership Inference and Secret Sharer to Test Word Embedding Models\n",
"\n",
"This notebook shows how to run privacy tests for word2vec models, trained with gensim. Models are trained using the procedure used in https://arxiv.org/abs/2004.00053, code for which is found here: https://github.com/google/embedding-tests .\n",
"\n",
"We run membership inference as well as secret sharer. Membership inference attempts to identify whether a given document was included in training. Secret sharer adds random \"canary\" documents into training, and identifies which canary was added."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iWIggmE4V8Rm"
},
"source": [
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelabs/word2vec_codelab.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
" \u003c/td\u003e\n",
" \u003ctd\u003e\n",
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelabs/word2vec_codelab.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
" \u003c/td\u003e\n",
"\u003c/table\u003e"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "e-E_nBv2tyuC"
},
"outputs": [],
"source": [
"# install dependencies\n",
"!pip install gensim --upgrade\n",
"!pip install git+https://github.com/tensorflow/privacy\n",
"\n",
"from IPython.display import clear_output\n",
"clear_output()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0hsxw6Qi0rF7"
},
"outputs": [],
"source": [
"# imports\n",
"import smart_open\n",
"import random\n",
"import gensim.utils\n",
"import os\n",
"import bz2\n",
"import multiprocessing\n",
"import logging\n",
"import tqdm\n",
"import xml\n",
"import numpy as np\n",
"\n",
"from gensim.models import Word2Vec\n",
"from six import raise_from\n",
"from gensim.corpora.wikicorpus import WikiCorpus, init_to_ignore_interrupt, \\\n",
" ARTICLE_MIN_WORDS, _process_article, IGNORED_NAMESPACES, get_namespace\n",
"from pickle import PicklingError\n",
"from xml.etree.cElementTree import iterparse, ParseError\n",
"\n",
"from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import membership_inference_attack as mia\n",
"from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import data_structures as mia_data_structures\n",
"from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import plotting as mia_plotting\n",
"\n",
"from tensorflow_privacy.privacy.privacy_tests.secret_sharer.exposures import compute_exposure_interpolation, compute_exposure_extrapolation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PQWE_lgT-A4k"
},
"outputs": [],
"source": [
"# all the functions we need to get data and canary it\n",
"# we will use google drive to store data models to be able to reuse them\n",
"# you can change this to local directories by changing DATA_DIR and MODEL_DIR\n",
"# make sure to copy the data locally, otherwise training will be very slow\n",
"\n",
"# code in this cell originates from https://github.com/google/embedding-tests\n",
"# some edits were made to allow saving to google drive, and to add canaries\n",
"\n",
"from google.colab import drive\n",
"drive.mount('/content/drive/')\n",
"\n",
"LOCAL_DATA_DIR = 'data_dir'\n",
"LOCAL_MODEL_DIR = 'model_dir'\n",
"DATA_DIR = '/content/drive/MyDrive/w2v/data_dir/'\n",
"MODEL_DIR = '/content/drive/MyDrive/w2v/model_dir/'\n",
"\n",
"# made up words will be used for canaries\n",
"MADE_UP_WORDS = []\n",
"for i in range(20):\n",
" MADE_UP_WORDS.append(\"o\"*i + \"oongaboonga\")\n",
"\n",
"# deterministic dataset partitioning\n",
"def gen_seed(idx, n=10000):\n",
" random.seed(12345)\n",
"\n",
" seeds = []\n",
" for i in range(n):\n",
" s = random.random()\n",
" seeds.append(s)\n",
"\n",
" return seeds[idx]\n",
"\n",
"\n",
"def make_wiki9_dirs(data_dir):\n",
" # makes all the directories we'll need to store data\n",
" wiki9_path = os.path.join(data_dir, 'wiki9', 'enwik9.bz2')\n",
" wiki9_dir = os.path.join(data_dir, 'wiki9', 'articles')\n",
" wiki9_split_dir = os.path.join(data_dir, 'wiki9', 'split')\n",
" for d in [wiki9_dir, wiki9_split_dir]:\n",
" if not os.path.exists(d):\n",
" os.makedirs(d)\n",
" return wiki9_path, wiki9_dir, wiki9_split_dir\n",
"\n",
"\n",
"def extract_pages(f, filter_namespaces=False, filter_articles=None):\n",
" try:\n",
" elems = (elem for _, elem in iterparse(f, events=(\"end\",)))\n",
" except ParseError:\n",
" yield None, \"\", None\n",
"\n",
" elem = next(elems)\n",
" namespace = get_namespace(elem.tag)\n",
" ns_mapping = {\"ns\": namespace}\n",
" page_tag = \"{%(ns)s}page\" % ns_mapping\n",
" text_path = \"./{%(ns)s}revision/{%(ns)s}text\" % ns_mapping\n",
" title_path = \"./{%(ns)s}title\" % ns_mapping\n",
" ns_path = \"./{%(ns)s}ns\" % ns_mapping\n",
" pageid_path = \"./{%(ns)s}id\" % ns_mapping\n",
"\n",
" try:\n",
"\n",
" for elem in elems:\n",
" if elem.tag == page_tag:\n",
" title = elem.find(title_path).text\n",
" text = elem.find(text_path).text\n",
"\n",
" if filter_namespaces:\n",
" ns = elem.find(ns_path).text\n",
" if ns not in filter_namespaces:\n",
" text = None\n",
"\n",
" if filter_articles is not None:\n",
" if not filter_articles(\n",
" elem, namespace=namespace, title=title,\n",
" text=text, page_tag=page_tag,\n",
" text_path=text_path, title_path=title_path,\n",
" ns_path=ns_path, pageid_path=pageid_path):\n",
" text = None\n",
"\n",
" pageid = elem.find(pageid_path).text\n",
" yield title, text or \"\", pageid # empty page will yield None\n",
"\n",
" elem.clear()\n",
" except ParseError:\n",
" yield None, \"\", None\n",
" return\n",
"\n",
"class MyWikiCorpus(WikiCorpus):\n",
"\n",
" def get_texts(self):\n",
" logger = logging.getLogger(__name__)\n",
"\n",
" articles, articles_all = 0, 0\n",
" positions, positions_all = 0, 0\n",
"\n",
" tokenization_params = (\n",
" self.tokenizer_func, self.token_min_len, self.token_max_len, self.lower)\n",
" texts = ((text, title, pageid, tokenization_params)\n",
" for title, text, pageid in extract_pages(bz2.BZ2File(self.fname),\n",
" self.filter_namespaces,\n",
" self.filter_articles))\n",
" print(\"got texts\")\n",
" pool = multiprocessing.Pool(self.processes, init_to_ignore_interrupt)\n",
"\n",
" try:\n",
" # process the corpus in smaller chunks of docs,\n",
" # because multiprocessing.Pool\n",
" # is dumb and would load the entire input into RAM at once...\n",
" for group in gensim.utils.chunkize(texts, chunksize=10 * self.processes,\n",
" maxsize=1):\n",
" for tokens, title, pageid in pool.imap(_process_article, group):\n",
" articles_all += 1\n",
" positions_all += len(tokens)\n",
" # article redirects and short stubs are pruned here\n",
" if len(tokens) \u003c self.article_min_tokens or \\\n",
" any(title.startswith(ignore + ':') for ignore in\n",
" IGNORED_NAMESPACES):\n",
" continue\n",
" articles += 1\n",
" positions += len(tokens)\n",
" yield (tokens, (pageid, title))\n",
"\n",
" except KeyboardInterrupt:\n",
" logger.warn(\n",
" \"user terminated iteration over Wikipedia corpus after %i\"\n",
" \" documents with %i positions \"\n",
" \"(total %i articles, %i positions before pruning articles\"\n",
" \" shorter than %i words)\",\n",
" articles, positions, articles_all, positions_all, ARTICLE_MIN_WORDS\n",
" )\n",
" except PicklingError as exc:\n",
" raise_from(\n",
" PicklingError('Can not send filtering function {} to multiprocessing, '\n",
" 'make sure the function can be pickled.'.format(\n",
" self.filter_articles)), exc)\n",
" else:\n",
" logger.info(\n",
" \"finished iterating over Wikipedia corpus of %i \"\n",
" \"documents with %i positions \"\n",
" \"(total %i articles, %i positions before pruning articles\"\n",
" \" shorter than %i words)\",\n",
" articles, positions, articles_all, positions_all, ARTICLE_MIN_WORDS\n",
" )\n",
" self.length = articles # cache corpus length\n",
" finally:\n",
" pool.terminate()\n",
"\n",
"\n",
"def write_wiki9_articles(data_dir):\n",
" wiki9_path, wiki9_dir, wiki9_split_dir = make_wiki9_dirs(data_dir)\n",
" wiki = MyWikiCorpus(wiki9_path, dictionary={},\n",
" filter_namespaces=False)\n",
" i = 0\n",
" for text, (p_id, title) in tqdm.tqdm(wiki.get_texts()):\n",
" i += 1\n",
" if title is None:\n",
" continue\n",
"\n",
" article_path = os.path.join(wiki9_dir, p_id)\n",
" if os.path.exists(article_path):\n",
" continue\n",
"\n",
" with open(article_path, 'wb') as f:\n",
" f.write(' '.join(text).encode(\"utf-8\"))\n",
" print(\"done\", i)\n",
"\n",
"def split_wiki9_articles(data_dir, exp_id=0):\n",
" wiki9_path, wiki9_dir, wiki9_split_dir = make_wiki9_dirs(data_dir)\n",
" all_docs = list(os.listdir(wiki9_dir))\n",
" print(\"wiki9 len\", len(all_docs))\n",
" print(wiki9_dir)\n",
" s = gen_seed(exp_id)\n",
" random.seed(s)\n",
" random.shuffle(all_docs)\n",
" random.seed()\n",
"\n",
" n = len(all_docs) // 2\n",
" return all_docs[:n], all_docs[n:]\n",
"\n",
"\n",
"def read_wiki9_train_split(data_dir, exp_id=0):\n",
" wiki9_path, wiki9_dir, wiki9_split_dir = make_wiki9_dirs(data_dir)\n",
"\n",
" split_path = os.path.join(wiki9_split_dir, 'split{}.train'.format(exp_id))\n",
" if not os.path.exists(split_path):\n",
" train_docs, _ = split_wiki9_articles(exp_id=exp_id)\n",
" with open(split_path, 'w') as f:\n",
" for doc in tqdm.tqdm(train_docs):\n",
" with open(os.path.join(wiki9_dir, doc), 'r') as fd:\n",
" f.write(fd.read())\n",
" f.write(' ')\n",
"\n",
" return split_path\n",
"\n",
"def build_vocab(word2vec_model):\n",
" vocab = word2vec_model.wv.index_to_key\n",
" counts = [word2vec_model.wv.get_vecattr(word, \"count\") for word in vocab]\n",
" sorted_inds = np.argsort(counts)\n",
" sorted_vocab = [vocab[ind] for ind in sorted_inds]\n",
" return sorted_vocab\n",
"\n",
"def sample_words(vocab, count, rng):\n",
" inds = rng.choice(len(vocab), count, replace=False)\n",
" return [vocab[ind] for ind in inds], rng\n",
"\n",
"\n",
"def gen_canaries(num_canaries, canary_repeat, vocab_model_path, seed=0):\n",
" # create canaries, injecting made up words into the corpus\n",
" existing_w2v = Word2Vec.load(vocab_model_path)\n",
" existing_vocab = build_vocab(existing_w2v)\n",
" rng = np.random.RandomState(seed)\n",
"\n",
" all_canaries = []\n",
" for i in range(num_canaries):\n",
" new_word = MADE_UP_WORDS[i%len(MADE_UP_WORDS)]\n",
" assert new_word not in existing_vocab\n",
" canary_words, rng = sample_words(existing_vocab, 4, rng)\n",
" canary = canary_words[:2] + [new_word] + canary_words[2:]\n",
" all_canaries.append(canary)\n",
" all_canaries = all_canaries * canary_repeat\n",
" return all_canaries\n",
"\n",
"# iterator for training documents, with an option to canary\n",
"class WIKI9Articles:\n",
" def __init__(self, docs, data_dir, verbose=0, ssharer=False, num_canaries=0,\n",
" canary_repeat=0, canary_seed=0, vocab_model_path=None):\n",
" self.docs = [(0, doc) for doc in docs]\n",
" if ssharer:\n",
" all_canaries = gen_canaries(\n",
" num_canaries, canary_repeat, vocab_model_path, canary_seed)\n",
" self.docs.extend([(1, canary) for canary in all_canaries])\n",
" np.random.RandomState(0).shuffle(self.docs)\n",
"\n",
" wiki9_path, wiki9_dir, wiki9_split_dir = make_wiki9_dirs(data_dir)\n",
" self.dirname = wiki9_dir\n",
" self.verbose = verbose\n",
"\n",
" def __iter__(self):\n",
" for is_canary, fname in tqdm.tqdm(self.docs) if self.verbose else self.docs:\n",
" if not is_canary:\n",
" for line in smart_open.open(os.path.join(self.dirname, fname),\n",
" 'r', encoding='utf-8'):\n",
" yield line.split()\n",
" else:\n",
" yield fname\n",
"\n",
"\n",
"def train_word_embedding(data_dir, model_dir, exp_id=0, use_secret_sharer=False,\n",
" num_canaries=0, canary_repeat=1, canary_seed=0,\n",
" vocab_model_path=None):\n",
" # this function trains the word2vec model, after setting up the training set\n",
" logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s',\n",
" level=logging.INFO)\n",
"\n",
" params = {\n",
" 'sg': 1,\n",
" 'negative': 25,\n",
" 'alpha': 0.05,\n",
" 'sample': 1e-4,\n",
" 'workers': 48,\n",
" 'epochs': 5,\n",
" 'window': 5,\n",
" }\n",
"\n",
" train_docs, test_docs = split_wiki9_articles(data_dir, exp_id)\n",
" print(len(train_docs), len(test_docs))\n",
" wiki9_articles = WIKI9Articles(\n",
" train_docs, data_dir, ssharer=use_secret_sharer, num_canaries=num_canaries,\n",
" canary_repeat=canary_repeat, canary_seed=canary_seed, vocab_model_path=vocab_model_path)\n",
"\n",
" if not os.path.exists(model_dir):\n",
" os.makedirs(model_dir)\n",
"\n",
" model = Word2Vec(wiki9_articles, **params)\n",
"\n",
" if not use_secret_sharer:\n",
" model_path = os.path.join(model_dir, 'wiki9_w2v_{}.model'.format(exp_id))\n",
" else:\n",
" model_path = os.path.join(model_dir, 'wiki9_w2v_{}_{}_{}_{}.model'.format(\n",
" exp_id, num_canaries, canary_repeat, canary_seed\n",
" ))\n",
" model.save(model_path)\n",
" return model_path, train_docs, test_docs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "gA5VDclIrg2R"
},
"outputs": [],
"source": [
"# setup directories\n",
"wiki9_path, wiki9_dir, wiki9_split_dir = make_wiki9_dirs(DATA_DIR)\n",
"local_wiki9_path, local_wiki9_dir, local_wiki9_splitdir = make_wiki9_dirs(LOCAL_DATA_DIR)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1Hhvsg9wrqYh"
},
"outputs": [],
"source": [
"# download and format documents\n",
"!wget http://mattmahoney.net/dc/enwik9.zip\n",
"!unzip enwik9.zip\n",
"!bzip2 enwik9\n",
"!cp enwik9.bz2 $wiki9_path\n",
"!cp $wiki9_path $local_wiki9_path\n",
"write_wiki9_articles(LOCAL_DATA_DIR) # need local data for fast training"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "r2zxFJpBryZt"
},
"source": [
"# Membership Inference Attacks\n",
"\n",
"Let's start by running membership inference on a word2vec model.\n",
"\n",
"We'll start by training a bunch of word2vec models with different train/test splits. This can take a long time, so be patient!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ZPNT7Om3sovQ"
},
"outputs": [],
"source": [
"for i in range(10):\n",
" if os.path.exists(os.path.join(MODEL_DIR, f\"wiki9_w2v_{i}.model\")):\n",
" print(\"done\", i)\n",
" continue\n",
" model_path, train_docs, test_docs = train_word_embedding(LOCAL_DATA_DIR, MODEL_DIR, exp_id=i)\n",
" print(model_path)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fIWWTiZC98zv"
},
"source": [
"We now define our loss function. We follow https://arxiv.org/abs/2004.00053, computing the loss of a document as the average loss over all 5 token \"windows\" in the document."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x6pqtPZ7K36E"
},
"outputs": [],
"source": [
"from re import split\n",
"\n",
"def loss(model, window):\n",
" # compute loss for a single window of 5 tokens\n",
" try:\n",
" sum_embedding = np.array([model.wv[word] for word in window]).sum(axis=0)\n",
" except:\n",
" return np.nan\n",
" middle_embedding = model.wv[window[2]]\n",
" context_embedding = 0.25*(sum_embedding - middle_embedding)\n",
" return np.linalg.norm(middle_embedding - context_embedding)\n",
"\n",
"def loss_per_article(model, article):\n",
" # compute loss for a full document\n",
" losses = []\n",
" article = article.split(' ')\n",
" embs = [model.wv[word] if word in model.wv else np.nan for word in article]\n",
"\n",
" for i in range(len(article) - 4):\n",
" middle_embedding = embs[i+2]\n",
" context_embedding = 0.25*(np.mean(embs[i:i+2] + embs[i+3:i+5]))\n",
" losses.append(np.linalg.norm(middle_embedding - context_embedding))\n",
" return np.nanmean(losses)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "udsmRmZLDQYo"
},
"source": [
"Let's now get the losses of all models on all documents. This also takes a while, so we'll only get a subset."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6ckAeZH4nwHP"
},
"outputs": [],
"source": [
"all_models = []\n",
"for i in range(1000, 1020):\n",
" model_path = os.path.join(MODEL_DIR, f\"wiki9_w2v_{i}.model\")\n",
" if not os.path.exists(model_path):\n",
" continue\n",
" all_models.append(Word2Vec.load(model_path))\n",
"\n",
"train_docs, test_docs = split_wiki9_articles(LOCAL_DATA_DIR, 0)\n",
"all_docs = sorted(train_docs + test_docs)\n",
"all_losses = np.zeros((len(all_docs), len(all_models)))\n",
"\n",
"for i, doc in tqdm.tqdm(enumerate(all_docs)):\n",
" if i \u003e 1000:\n",
" continue\n",
" with open(os.path.join(local_wiki9_dir, doc), 'r') as fd:\n",
" doc_text = fd.read()\n",
" for j, model in enumerate(all_models):\n",
" all_losses[i,j] = loss_per_article(model, doc_text)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3mBV8qnrGzwH"
},
"source": [
"We're going to be running the LiRA attack, so, for each document, we get the document's losses when it is in the model, and the losses when it is not in the model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6j7zj_fIvSXK"
},
"outputs": [],
"source": [
"all_losses = all_losses[:500, :]\n",
"doc_lookup = {doc: i for i, doc in enumerate(all_docs)}\n",
"\n",
"def compute_scores_in_out(losses, seeds):\n",
" in_scores = [[] for _ in range(losses.shape[0])]\n",
" out_scores = [[] for _ in range(losses.shape[0])]\n",
" for seed in seeds:\n",
" train_docs, test_docs = split_wiki9_articles(LOCAL_DATA_DIR, seed)\n",
" for train_doc in train_docs:\n",
" ind = doc_lookup[train_doc]\n",
" if ind \u003e= all_losses.shape[0]:\n",
" continue\n",
" in_scores[ind].append([all_losses[ind, seed-1000]])\n",
" for test_doc in test_docs:\n",
" ind = doc_lookup[test_doc]\n",
" if ind \u003e= all_losses.shape[0]:\n",
" continue\n",
" out_scores[ind].append([all_losses[ind, seed-1000]])\n",
" in_scores = [np.array(s) for s in in_scores]\n",
" out_scores = [np.array(s) for s in out_scores]\n",
" print(in_scores[0].shape)\n",
" return in_scores, out_scores\n",
"# we will do MI on model 0\n",
"in_scores, out_scores = compute_scores_in_out(all_losses, list(range(1001, 1020)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aKI7Kv44HTBu"
},
"source": [
"Now let's run the global threshold membership inference attack. It gets an advantage of around 0.07."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6WR4s-W5ykus"
},
"outputs": [],
"source": [
"# global threshold MIA attack\n",
"train_docs, test_docs = split_wiki9_articles(LOCAL_DATA_DIR, 1000)\n",
"train_losses, test_losses = [], []\n",
"for train_doc in train_docs:\n",
" ind = doc_lookup[train_doc]\n",
" if ind \u003e= all_losses.shape[0]:\n",
" continue\n",
" train_losses.append(all_losses[ind, 0])\n",
"for test_doc in test_docs:\n",
" ind = doc_lookup[test_doc]\n",
" if ind \u003e= all_losses.shape[0]:\n",
" continue\n",
" test_losses.append(all_losses[ind, 0])\n",
"\n",
"attacks_result_baseline = mia.run_attacks(\n",
" mia_data_structures.AttackInputData(\n",
" loss_train = -np.nan_to_num(train_losses),\n",
" loss_test = -np.nan_to_num(test_losses))).single_attack_results[0]\n",
"print('Global Threshold MIA attack:',\n",
" f'auc = {attacks_result_baseline.get_auc():.4f}',\n",
" f'adv = {attacks_result_baseline.get_attacker_advantage():.4f}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4WZX4L8lHaB1"
},
"source": [
"And now we run LiRA. First we need to compute LiRA scores."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RoQc6XNeMid2"
},
"outputs": [],
"source": [
"# run LiRA\n",
"from tensorflow_privacy.privacy.privacy_tests.membership_inference_attack import advanced_mia as amia\n",
"good_inds = []\n",
"for i, (in_s, out_s) in enumerate(zip(in_scores, out_scores)):\n",
" if len(in_s) \u003e 0 and len(out_s) \u003e 0:\n",
" good_inds.append(i)\n",
"\n",
"for i in good_inds:\n",
" assert len(in_scores[i]) \u003e 0\n",
" assert len(in_scores[i]) \u003e 0\n",
"\n",
"scores = amia.compute_score_lira(all_losses[good_inds, 0],\n",
" [in_scores[i] for i in good_inds],\n",
" [out_scores[i] for i in good_inds],\n",
" fix_variance=True)\n",
"\n",
"train_docs, test_docs = split_wiki9_articles(LOCAL_DATA_DIR, 1000)\n",
"in_mask = np.zeros(len(good_inds), dtype=bool)\n",
"for doc in train_docs:\n",
" ind = doc_lookup[doc]\n",
" if ind \u003e= all_losses.shape[0]:\n",
" continue\n",
" if ind in good_inds:\n",
" in_mask[good_inds.index(ind)] = True\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SMofmN9xHg8E"
},
"source": [
"And now we threshold on LiRA scores, as before. Advantage goes from .07 to .13, it almost doubled!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IMESCOer6HAG"
},
"outputs": [],
"source": [
"attacks_result_baseline = mia.run_attacks(\n",
" mia_data_structures.AttackInputData(\n",
" loss_train = scores[in_mask],\n",
" loss_test = scores[~in_mask])).single_attack_results[0]\n",
"print('Advanced MIA attack with Gaussian:',\n",
" f'auc = {attacks_result_baseline.get_auc():.4f}',\n",
" f'adv = {attacks_result_baseline.get_attacker_advantage():.4f}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6Onv39Aiyyg9"
},
"source": [
"# Secret Sharer\n",
"\n",
"Here, we're going to run a secret sharer attack on a word2vec model. Our canaries (generated above in gen_canaries) look like the following:\n",
"\n",
"\"word1 word2 made_up_word word3 word4\",\n",
"\n",
"where all the words except for the made up word are real words from the vocabulary. The model's decision on where to put the made up word in embedding space will depend solely on the canary, which will make this an effective attack. We insert canaries with various repetition counts, and train some models:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kuCrXcv5y0ry"
},
"outputs": [],
"source": [
"vocab_model_path = os.path.join(MODEL_DIR, 'wiki9_w2v_1.model')\n",
"interp_exposures = {}\n",
"extrap_exposures = {}\n",
"all_canaries = gen_canaries(10000, 1, vocab_model_path, 0)\n",
"\n",
"for repeat_count in [5, 10, 20]:\n",
" model_path = os.path.join(MODEL_DIR, 'wiki9_w2v_0_20_{}_0.model'.format(repeat_count))\n",
" print(os.path.exists(model_path))\n",
" model_path, _, _ = train_word_embedding(\n",
" LOCAL_DATA_DIR, MODEL_DIR, exp_id=0, use_secret_sharer=True, num_canaries=20,\n",
" canary_repeat=repeat_count, canary_seed=0, vocab_model_path=vocab_model_path)\n",
" canaried_model = Word2Vec.load(model_path)\n",
" canary_losses = [loss(canaried_model, canary) for canary in all_canaries]\n",
" loss_secrets = np.array(canary_losses[:20])\n",
" loss_ref = np.array(canary_losses[20:])\n",
" loss_secrets = {1: loss_secrets[~np.isnan(loss_secrets)]}\n",
" loss_ref = loss_ref[~np.isnan(loss_ref)]\n",
" exposure_interpolation = compute_exposure_interpolation(loss_secrets, loss_ref)\n",
" exposure_extrapolation = compute_exposure_extrapolation(loss_secrets, loss_ref)\n",
" interp_exposures[repeat_count] = exposure_interpolation[1]\n",
" extrap_exposures[repeat_count] = exposure_extrapolation[1]\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dp0ejoZTIsMc"
},
"source": [
"And now let's run secret sharer! Exposure is quite high!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Vrbi1RkUGJr9",
"outputId": "fe5cc82a-421d-41dd-95ca-ac71862bf7d7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Repeats: 5, Interpolation Exposure: 12.307770031890703, Extrapolation Exposure: 54.51861034822009\n",
"Repeats: 10, Interpolation Exposure: 12.290018846932618, Extrapolation Exposure: 56.91255812786129\n",
"Repeats: 20, Interpolation Exposure: 12.290018846932618, Extrapolation Exposure: 64.00837536957133 \n"
]
}
],
"source": [
"for key in interp_exposures:\n",
" print(f\"Repeats: {key}, Interpolation Exposure: {np.median(interp_exposures[key])}, Extrapolation Exposure: {np.median(extrap_exposures[key])}\")"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "word2vec_codelab.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View file

@ -1,14 +1,14 @@
# Secret Sharer Attack # Secret Sharer Attack
A good privacy-preserving model learns from the training data, but A good privacy-preserving model learns from the training data, but doesn't
doesn't memorize it. memorize it. This folder contains codes for conducting the Secret Sharer attack
This folder contains codes for conducting the Secret Sharer attack from [this paper](https://arxiv.org/abs/1802.08232). from [this paper](https://arxiv.org/abs/1802.08232). It is a method to test if a
It is a method to test if a machine learning model memorizes its training data. machine learning model memorizes its training data.
The high level idea is to insert some random sequences as “secrets” into the The high level idea is to insert some random sequences as “secrets” into the
training data, and then measure if the model has memorized those secrets. training data, and then measure if the model has memorized those secrets. If
If there is significant memorization, it means that there can be potential there is significant memorization, it means that there can be potential privacy
privacy risk. risk.
## How to Use ## How to Use
@ -19,6 +19,16 @@ privacy risk.
- `secret_sharer_example.ipynb` is an example (character-level LSTM) for using - `secret_sharer_example.ipynb` is an example (character-level LSTM) for using
the above code to conduct secret sharer attack. the above code to conduct secret sharer attack.
### More Usage Examples
## Word2Vec models
If you're interested in word2vec models, please see the
[word2vec codelab](https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/privacy_tests/membership_inference_attack/codelabs/word2vec_codelab.ipynb).
In addition to secret sharer, this notebook also implements membership inference
attacks. Based on [this paper](https://arxiv.org/abs/2004.00053) and
[this code](https://github.com/google/embedding-tests).
### Contact / Feedback ### Contact / Feedback