tensorflow_privacy/research/pate_2018/ICLR2018/plot_partition.py

413 lines
14 KiB
Python
Raw Normal View History

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Produces two plots. One compares aggregators and their analyses. The other
illustrates sources of privacy loss for Confident-GNMax.
A script in support of the paper "Scalable Private Learning with PATE" by
Nicolas Papernot, Shuang Song, Ilya Mironov, Ananth Raghunathan, Kunal Talwar,
Ulfar Erlingsson (https://arxiv.org/abs/1802.08908).
The input is a file containing a numpy array of votes, one query per row, one
class per column. Ex:
43, 1821, ..., 3
31, 16, ..., 0
...
0, 86, ..., 438
The output is written to a specified directory and consists of two files.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import os
import pickle
import sys
sys.path.append('..') # Main modules reside in the parent directory.
from absl import app
from absl import flags
from collections import namedtuple
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top
import numpy as np
import core as pate
import smooth_sensitivity as pate_ss
plt.style.use('ggplot')
FLAGS = flags.FLAGS
flags.DEFINE_boolean('cache', False,
'Read results of privacy analysis from cache.')
flags.DEFINE_string('counts_file', None, 'Counts file.')
flags.DEFINE_string('figures_dir', '', 'Path where figures are written to.')
flags.DEFINE_float('threshold', None, 'Threshold for step 1 (selection).')
flags.DEFINE_float('sigma1', None, 'Sigma for step 1 (selection).')
flags.DEFINE_float('sigma2', None, 'Sigma for step 2 (argmax).')
flags.DEFINE_integer('queries', None, 'Number of queries made by the student.')
flags.DEFINE_float('delta', 1e-8, 'Target delta.')
flags.mark_flag_as_required('counts_file')
flags.mark_flag_as_required('threshold')
flags.mark_flag_as_required('sigma1')
flags.mark_flag_as_required('sigma2')
Partition = namedtuple('Partition', ['step1', 'step2', 'ss', 'delta'])
def analyze_gnmax_conf_data_ind(votes, threshold, sigma1, sigma2, delta):
orders = np.logspace(np.log10(1.5), np.log10(500), num=100)
n = votes.shape[0]
rdp_total = np.zeros(len(orders))
answered_total = 0
answered = np.zeros(n)
eps_cum = np.full(n, None, dtype=float)
for i in range(n):
v = votes[i,]
if threshold is not None and sigma1 is not None:
q_step1 = np.exp(pate.compute_logpr_answered(threshold, sigma1, v))
rdp_total += pate.rdp_data_independent_gaussian(sigma1, orders)
else:
q_step1 = 1. # always answer
answered_total += q_step1
answered[i] = answered_total
rdp_total += q_step1 * pate.rdp_data_independent_gaussian(sigma2, orders)
eps_cum[i], order_opt = pate.compute_eps_from_delta(orders, rdp_total,
delta)
if i > 0 and (i + 1) % 1000 == 0:
print('queries = {}, E[answered] = {:.2f}, E[eps] = {:.3f} '
'at order = {:.2f}.'.format(
i + 1,
answered[i],
eps_cum[i],
order_opt))
sys.stdout.flush()
return eps_cum, answered
def analyze_gnmax_conf_data_dep(votes, threshold, sigma1, sigma2, delta):
# Short list of orders.
# orders = np.round(np.logspace(np.log10(20), np.log10(200), num=20))
# Long list of orders.
orders = np.concatenate((np.arange(20, 40, .2),
np.arange(40, 75, .5),
np.logspace(np.log10(75), np.log10(200), num=20)))
n = votes.shape[0]
num_classes = votes.shape[1]
num_teachers = int(sum(votes[0,]))
if threshold is not None and sigma1 is not None:
is_data_ind_step1 = pate.is_data_independent_always_opt_gaussian(
num_teachers, num_classes, sigma1, orders)
else:
is_data_ind_step1 = [True] * len(orders)
is_data_ind_step2 = pate.is_data_independent_always_opt_gaussian(
num_teachers, num_classes, sigma2, orders)
eps_partitioned = np.full(n, None, dtype=Partition)
order_opt = np.full(n, None, dtype=float)
ss_std_opt = np.full(n, None, dtype=float)
answered = np.zeros(n)
rdp_step1_total = np.zeros(len(orders))
rdp_step2_total = np.zeros(len(orders))
ls_total = np.zeros((len(orders), num_teachers))
answered_total = 0
for i in range(n):
v = votes[i,]
if threshold is not None and sigma1 is not None:
logq_step1 = pate.compute_logpr_answered(threshold, sigma1, v)
rdp_step1_total += pate.compute_rdp_threshold(logq_step1, sigma1, orders)
else:
logq_step1 = 0. # always answer
pr_answered = np.exp(logq_step1)
logq_step2 = pate.compute_logq_gaussian(v, sigma2)
rdp_step2_total += pr_answered * pate.rdp_gaussian(logq_step2, sigma2,
orders)
answered_total += pr_answered
rdp_ss = np.zeros(len(orders))
ss_std = np.zeros(len(orders))
for j, order in enumerate(orders):
if not is_data_ind_step1[j]:
ls_step1 = pate_ss.compute_local_sensitivity_bounds_threshold(v,
num_teachers, threshold, sigma1, order)
else:
ls_step1 = np.full(num_teachers, 0, dtype=float)
if not is_data_ind_step2[j]:
ls_step2 = pate_ss.compute_local_sensitivity_bounds_gnmax(
v, num_teachers, sigma2, order)
else:
ls_step2 = np.full(num_teachers, 0, dtype=float)
ls_total[j,] += ls_step1 + pr_answered * ls_step2
beta_ss = .49 / order
ss = pate_ss.compute_discounted_max(beta_ss, ls_total[j,])
sigma_ss = ((order * math.exp(2 * beta_ss)) / ss) ** (1 / 3)
rdp_ss[j] = pate_ss.compute_rdp_of_smooth_sensitivity_gaussian(
beta_ss, sigma_ss, order)
ss_std[j] = ss * sigma_ss
rdp_total = rdp_step1_total + rdp_step2_total + rdp_ss
answered[i] = answered_total
_, order_opt[i] = pate.compute_eps_from_delta(orders, rdp_total, delta)
order_idx = np.searchsorted(orders, order_opt[i])
# Since optimal orders are always non-increasing, shrink orders array
# and all cumulative arrays to speed up computation.
if order_idx < len(orders):
orders = orders[:order_idx + 1]
rdp_step1_total = rdp_step1_total[:order_idx + 1]
rdp_step2_total = rdp_step2_total[:order_idx + 1]
eps_partitioned[i] = Partition(step1=rdp_step1_total[order_idx],
step2=rdp_step2_total[order_idx],
ss=rdp_ss[order_idx],
delta=-math.log(delta) / (order_opt[i] - 1))
ss_std_opt[i] = ss_std[order_idx]
if i > 0 and (i + 1) % 1 == 0:
print('queries = {}, E[answered] = {:.2f}, E[eps] = {:.3f} +/- {:.3f} '
'at order = {:.2f}. Contributions: delta = {:.3f}, step1 = {:.3f}, '
'step2 = {:.3f}, ss = {:.3f}'.format(
i + 1,
answered[i],
sum(eps_partitioned[i]),
ss_std_opt[i],
order_opt[i],
eps_partitioned[i].delta,
eps_partitioned[i].step1,
eps_partitioned[i].step2,
eps_partitioned[i].ss))
sys.stdout.flush()
return eps_partitioned, answered, ss_std_opt, order_opt
def plot_comparison(figures_dir, simple_ind, conf_ind, simple_dep, conf_dep):
"""Plots variants of GNMax algorithm and their analyses.
"""
def pivot(x_axis, eps, answered):
y = np.full(len(x_axis), None, dtype=float) # delta
for i, x in enumerate(x_axis):
idx = np.searchsorted(answered, x)
if idx < len(eps):
y[i] = eps[idx]
return y
def pivot_dep(x_axis, data_dep):
eps_partitioned, answered, _, _ = data_dep
eps = [sum(p) for p in eps_partitioned] # Flatten eps
return pivot(x_axis, eps, answered)
xlim = 10000
x_axis = range(0, xlim, 10)
y_simple_ind = pivot(x_axis, *simple_ind)
y_conf_ind = pivot(x_axis, *conf_ind)
y_simple_dep = pivot_dep(x_axis, simple_dep)
y_conf_dep = pivot_dep(x_axis, conf_dep)
# plt.close('all')
fig, ax = plt.subplots()
fig.set_figheight(4.5)
fig.set_figwidth(4.7)
ax.plot(x_axis, y_simple_ind, ls='--', color='r', lw=3, label=r'Simple GNMax, data-ind analysis')
ax.plot(x_axis, y_conf_ind, ls='--', color='b', lw=3, label=r'Confident GNMax, data-ind analysis')
ax.plot(x_axis, y_simple_dep, ls='-', color='r', lw=3, label=r'Simple GNMax, data-dep analysis')
ax.plot(x_axis, y_conf_dep, ls='-', color='b', lw=3, label=r'Confident GNMax, data-dep analysis')
plt.xticks(np.arange(0, xlim + 1000, 2000))
plt.xlim([0, xlim])
plt.ylim(bottom=0)
plt.legend(fontsize=16)
ax.set_xlabel('Number of queries answered', fontsize=16)
ax.set_ylabel(r'Privacy cost $\varepsilon$ at $\delta=10^{-8}$', fontsize=16)
ax.tick_params(labelsize=14)
plot_filename = os.path.join(figures_dir, 'comparison.pdf')
print('Saving the graph to ' + plot_filename)
fig.savefig(plot_filename, bbox_inches='tight')
plt.show()
def plot_partition(figures_dir, gnmax_conf, print_order):
"""Plots an expert version of the privacy-per-answered-query graph.
Args:
figures_dir: A name of the directory where to save the plot.
eps: The cumulative privacy cost.
partition: Allocation of the privacy cost.
answered: Cumulative number of queries answered.
order_opt: The list of optimal orders.
"""
eps_partitioned, answered, ss_std_opt, order_opt = gnmax_conf
xlim = 10000
x = range(0, int(xlim), 10)
lenx = len(x)
y0 = np.full(lenx, np.nan, dtype=float) # delta
y1 = np.full(lenx, np.nan, dtype=float) # delta + step1
y2 = np.full(lenx, np.nan, dtype=float) # delta + step1 + step2
y3 = np.full(lenx, np.nan, dtype=float) # delta + step1 + step2 + ss
noise_std = np.full(lenx, np.nan, dtype=float)
y_right = np.full(lenx, np.nan, dtype=float)
for i in range(lenx):
idx = np.searchsorted(answered, x[i])
if idx < len(eps_partitioned):
y0[i] = eps_partitioned[idx].delta
y1[i] = y0[i] + eps_partitioned[idx].step1
y2[i] = y1[i] + eps_partitioned[idx].step2
y3[i] = y2[i] + eps_partitioned[idx].ss
noise_std[i] = ss_std_opt[idx]
y_right[i] = order_opt[idx]
# plt.close('all')
fig, ax = plt.subplots()
fig.set_figheight(4.5)
fig.set_figwidth(4.7)
fig.patch.set_alpha(0)
l1 = ax.plot(
x, y3, color='b', ls='-', label=r'Total privacy cost', linewidth=1).pop()
for y in (y0, y1, y2):
ax.plot(x, y, color='b', ls='-', label=r'_nolegend_', alpha=.5, linewidth=1)
ax.fill_between(x, [0] * lenx, y0.tolist(), facecolor='b', alpha=.5)
ax.fill_between(x, y0.tolist(), y1.tolist(), facecolor='b', alpha=.4)
ax.fill_between(x, y1.tolist(), y2.tolist(), facecolor='b', alpha=.3)
ax.fill_between(x, y2.tolist(), y3.tolist(), facecolor='b', alpha=.2)
ax.fill_between(x, (y3 - noise_std).tolist(), (y3 + noise_std).tolist(),
facecolor='r', alpha=.5)
plt.xticks(np.arange(0, xlim + 1000, 2000))
plt.xlim([0, xlim])
ax.set_ylim([0, 3.])
ax.set_xlabel('Number of queries answered', fontsize=16)
ax.set_ylabel(r'Privacy cost $\varepsilon$ at $\delta=10^{-8}$', fontsize=16)
# Merging legends.
if print_order:
ax2 = ax.twinx()
l2 = ax2.plot(
x, y_right, 'r', ls='-', label=r'Optimal order', linewidth=5,
alpha=.5).pop()
ax2.grid(False)
# ax2.set_ylabel(r'Optimal Renyi order', fontsize=16)
ax2.set_ylim([0, 200.])
# ax.legend((l1, l2), (l1.get_label(), l2.get_label()), loc=0, fontsize=13)
ax.tick_params(labelsize=14)
plot_filename = os.path.join(figures_dir, 'partition.pdf')
print('Saving the graph to ' + plot_filename)
fig.savefig(plot_filename, bbox_inches='tight', dpi=800)
plt.show()
def run_all_analyses(votes, threshold, sigma1, sigma2, delta):
simple_ind = analyze_gnmax_conf_data_ind(votes, None, None, sigma2,
delta)
conf_ind = analyze_gnmax_conf_data_ind(votes, threshold, sigma1, sigma2,
delta)
simple_dep = analyze_gnmax_conf_data_dep(votes, None, None, sigma2,
delta)
conf_dep = analyze_gnmax_conf_data_dep(votes, threshold, sigma1, sigma2,
delta)
return (simple_ind, conf_ind, simple_dep, conf_dep)
def run_or_load_all_analyses():
temp_filename = os.path.expanduser('~/tmp/partition_cached.pkl')
if FLAGS.cache and os.path.isfile(temp_filename):
print('Reading from cache ' + temp_filename)
with open(temp_filename, 'rb') as f:
all_analyses = pickle.load(f)
else:
fin_name = os.path.expanduser(FLAGS.counts_file)
print('Reading raw votes from ' + fin_name)
sys.stdout.flush()
votes = np.load(fin_name)
if FLAGS.queries is not None:
if votes.shape[0] < FLAGS.queries:
raise ValueError('Expect {} rows, got {} in {}'.format(
FLAGS.queries, votes.shape[0], fin_name))
# Truncate the votes matrix to the number of queries made.
votes = votes[:FLAGS.queries, ]
all_analyses = run_all_analyses(votes, FLAGS.threshold, FLAGS.sigma1,
FLAGS.sigma2, FLAGS.delta)
print('Writing to cache ' + temp_filename)
with open(temp_filename, 'wb') as f:
pickle.dump(all_analyses, f)
return all_analyses
def main(argv):
del argv # Unused.
simple_ind, conf_ind, simple_dep, conf_dep = run_or_load_all_analyses()
figures_dir = os.path.expanduser(FLAGS.figures_dir)
plot_comparison(figures_dir, simple_ind, conf_ind, simple_dep, conf_dep)
plot_partition(figures_dir, conf_dep, True)
plt.close('all')
if __name__ == '__main__':
app.run(main)