93e9585f18
PiperOrigin-RevId: 229241117
412 lines
14 KiB
Python
412 lines
14 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Produces two plots. One compares aggregators and their analyses. The other
|
|
illustrates sources of privacy loss for Confident-GNMax.
|
|
|
|
A script in support of the paper "Scalable Private Learning with PATE" by
|
|
Nicolas Papernot, Shuang Song, Ilya Mironov, Ananth Raghunathan, Kunal Talwar,
|
|
Ulfar Erlingsson (https://arxiv.org/abs/1802.08908).
|
|
|
|
The input is a file containing a numpy array of votes, one query per row, one
|
|
class per column. Ex:
|
|
43, 1821, ..., 3
|
|
31, 16, ..., 0
|
|
...
|
|
0, 86, ..., 438
|
|
The output is written to a specified directory and consists of two files.
|
|
"""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import math
|
|
import os
|
|
import pickle
|
|
import sys
|
|
|
|
sys.path.append('..') # Main modules reside in the parent directory.
|
|
|
|
from absl import app
|
|
from absl import flags
|
|
from collections import namedtuple
|
|
import matplotlib
|
|
|
|
matplotlib.use('TkAgg')
|
|
import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top
|
|
import numpy as np
|
|
import core as pate
|
|
import smooth_sensitivity as pate_ss
|
|
|
|
plt.style.use('ggplot')
|
|
|
|
FLAGS = flags.FLAGS
|
|
flags.DEFINE_boolean('cache', False,
|
|
'Read results of privacy analysis from cache.')
|
|
flags.DEFINE_string('counts_file', None, 'Counts file.')
|
|
flags.DEFINE_string('figures_dir', '', 'Path where figures are written to.')
|
|
flags.DEFINE_float('threshold', None, 'Threshold for step 1 (selection).')
|
|
flags.DEFINE_float('sigma1', None, 'Sigma for step 1 (selection).')
|
|
flags.DEFINE_float('sigma2', None, 'Sigma for step 2 (argmax).')
|
|
flags.DEFINE_integer('queries', None, 'Number of queries made by the student.')
|
|
flags.DEFINE_float('delta', 1e-8, 'Target delta.')
|
|
|
|
flags.mark_flag_as_required('counts_file')
|
|
flags.mark_flag_as_required('threshold')
|
|
flags.mark_flag_as_required('sigma1')
|
|
flags.mark_flag_as_required('sigma2')
|
|
|
|
Partition = namedtuple('Partition', ['step1', 'step2', 'ss', 'delta'])
|
|
|
|
|
|
def analyze_gnmax_conf_data_ind(votes, threshold, sigma1, sigma2, delta):
|
|
orders = np.logspace(np.log10(1.5), np.log10(500), num=100)
|
|
n = votes.shape[0]
|
|
|
|
rdp_total = np.zeros(len(orders))
|
|
answered_total = 0
|
|
answered = np.zeros(n)
|
|
eps_cum = np.full(n, None, dtype=float)
|
|
|
|
for i in range(n):
|
|
v = votes[i,]
|
|
if threshold is not None and sigma1 is not None:
|
|
q_step1 = np.exp(pate.compute_logpr_answered(threshold, sigma1, v))
|
|
rdp_total += pate.rdp_data_independent_gaussian(sigma1, orders)
|
|
else:
|
|
q_step1 = 1. # always answer
|
|
|
|
answered_total += q_step1
|
|
answered[i] = answered_total
|
|
|
|
rdp_total += q_step1 * pate.rdp_data_independent_gaussian(sigma2, orders)
|
|
|
|
eps_cum[i], order_opt = pate.compute_eps_from_delta(orders, rdp_total,
|
|
delta)
|
|
|
|
if i > 0 and (i + 1) % 1000 == 0:
|
|
print('queries = {}, E[answered] = {:.2f}, E[eps] = {:.3f} '
|
|
'at order = {:.2f}.'.format(
|
|
i + 1,
|
|
answered[i],
|
|
eps_cum[i],
|
|
order_opt))
|
|
sys.stdout.flush()
|
|
|
|
return eps_cum, answered
|
|
|
|
|
|
def analyze_gnmax_conf_data_dep(votes, threshold, sigma1, sigma2, delta):
|
|
# Short list of orders.
|
|
# orders = np.round(np.logspace(np.log10(20), np.log10(200), num=20))
|
|
|
|
# Long list of orders.
|
|
orders = np.concatenate((np.arange(20, 40, .2),
|
|
np.arange(40, 75, .5),
|
|
np.logspace(np.log10(75), np.log10(200), num=20)))
|
|
|
|
n = votes.shape[0]
|
|
num_classes = votes.shape[1]
|
|
num_teachers = int(sum(votes[0,]))
|
|
|
|
if threshold is not None and sigma1 is not None:
|
|
is_data_ind_step1 = pate.is_data_independent_always_opt_gaussian(
|
|
num_teachers, num_classes, sigma1, orders)
|
|
else:
|
|
is_data_ind_step1 = [True] * len(orders)
|
|
|
|
is_data_ind_step2 = pate.is_data_independent_always_opt_gaussian(
|
|
num_teachers, num_classes, sigma2, orders)
|
|
|
|
eps_partitioned = np.full(n, None, dtype=Partition)
|
|
order_opt = np.full(n, None, dtype=float)
|
|
ss_std_opt = np.full(n, None, dtype=float)
|
|
answered = np.zeros(n)
|
|
|
|
rdp_step1_total = np.zeros(len(orders))
|
|
rdp_step2_total = np.zeros(len(orders))
|
|
|
|
ls_total = np.zeros((len(orders), num_teachers))
|
|
answered_total = 0
|
|
|
|
for i in range(n):
|
|
v = votes[i,]
|
|
|
|
if threshold is not None and sigma1 is not None:
|
|
logq_step1 = pate.compute_logpr_answered(threshold, sigma1, v)
|
|
rdp_step1_total += pate.compute_rdp_threshold(logq_step1, sigma1, orders)
|
|
else:
|
|
logq_step1 = 0. # always answer
|
|
|
|
pr_answered = np.exp(logq_step1)
|
|
logq_step2 = pate.compute_logq_gaussian(v, sigma2)
|
|
rdp_step2_total += pr_answered * pate.rdp_gaussian(logq_step2, sigma2,
|
|
orders)
|
|
|
|
answered_total += pr_answered
|
|
|
|
rdp_ss = np.zeros(len(orders))
|
|
ss_std = np.zeros(len(orders))
|
|
|
|
for j, order in enumerate(orders):
|
|
if not is_data_ind_step1[j]:
|
|
ls_step1 = pate_ss.compute_local_sensitivity_bounds_threshold(v,
|
|
num_teachers, threshold, sigma1, order)
|
|
else:
|
|
ls_step1 = np.full(num_teachers, 0, dtype=float)
|
|
|
|
if not is_data_ind_step2[j]:
|
|
ls_step2 = pate_ss.compute_local_sensitivity_bounds_gnmax(
|
|
v, num_teachers, sigma2, order)
|
|
else:
|
|
ls_step2 = np.full(num_teachers, 0, dtype=float)
|
|
|
|
ls_total[j,] += ls_step1 + pr_answered * ls_step2
|
|
|
|
beta_ss = .49 / order
|
|
|
|
ss = pate_ss.compute_discounted_max(beta_ss, ls_total[j,])
|
|
sigma_ss = ((order * math.exp(2 * beta_ss)) / ss) ** (1 / 3)
|
|
rdp_ss[j] = pate_ss.compute_rdp_of_smooth_sensitivity_gaussian(
|
|
beta_ss, sigma_ss, order)
|
|
ss_std[j] = ss * sigma_ss
|
|
|
|
rdp_total = rdp_step1_total + rdp_step2_total + rdp_ss
|
|
|
|
answered[i] = answered_total
|
|
_, order_opt[i] = pate.compute_eps_from_delta(orders, rdp_total, delta)
|
|
order_idx = np.searchsorted(orders, order_opt[i])
|
|
|
|
# Since optimal orders are always non-increasing, shrink orders array
|
|
# and all cumulative arrays to speed up computation.
|
|
if order_idx < len(orders):
|
|
orders = orders[:order_idx + 1]
|
|
rdp_step1_total = rdp_step1_total[:order_idx + 1]
|
|
rdp_step2_total = rdp_step2_total[:order_idx + 1]
|
|
|
|
eps_partitioned[i] = Partition(step1=rdp_step1_total[order_idx],
|
|
step2=rdp_step2_total[order_idx],
|
|
ss=rdp_ss[order_idx],
|
|
delta=-math.log(delta) / (order_opt[i] - 1))
|
|
ss_std_opt[i] = ss_std[order_idx]
|
|
if i > 0 and (i + 1) % 1 == 0:
|
|
print('queries = {}, E[answered] = {:.2f}, E[eps] = {:.3f} +/- {:.3f} '
|
|
'at order = {:.2f}. Contributions: delta = {:.3f}, step1 = {:.3f}, '
|
|
'step2 = {:.3f}, ss = {:.3f}'.format(
|
|
i + 1,
|
|
answered[i],
|
|
sum(eps_partitioned[i]),
|
|
ss_std_opt[i],
|
|
order_opt[i],
|
|
eps_partitioned[i].delta,
|
|
eps_partitioned[i].step1,
|
|
eps_partitioned[i].step2,
|
|
eps_partitioned[i].ss))
|
|
sys.stdout.flush()
|
|
|
|
return eps_partitioned, answered, ss_std_opt, order_opt
|
|
|
|
|
|
def plot_comparison(figures_dir, simple_ind, conf_ind, simple_dep, conf_dep):
|
|
"""Plots variants of GNMax algorithm and their analyses.
|
|
"""
|
|
|
|
def pivot(x_axis, eps, answered):
|
|
y = np.full(len(x_axis), None, dtype=float) # delta
|
|
for i, x in enumerate(x_axis):
|
|
idx = np.searchsorted(answered, x)
|
|
if idx < len(eps):
|
|
y[i] = eps[idx]
|
|
return y
|
|
|
|
def pivot_dep(x_axis, data_dep):
|
|
eps_partitioned, answered, _, _ = data_dep
|
|
eps = [sum(p) for p in eps_partitioned] # Flatten eps
|
|
return pivot(x_axis, eps, answered)
|
|
|
|
xlim = 10000
|
|
x_axis = range(0, xlim, 10)
|
|
|
|
y_simple_ind = pivot(x_axis, *simple_ind)
|
|
y_conf_ind = pivot(x_axis, *conf_ind)
|
|
|
|
y_simple_dep = pivot_dep(x_axis, simple_dep)
|
|
y_conf_dep = pivot_dep(x_axis, conf_dep)
|
|
|
|
# plt.close('all')
|
|
fig, ax = plt.subplots()
|
|
fig.set_figheight(4.5)
|
|
fig.set_figwidth(4.7)
|
|
|
|
ax.plot(x_axis, y_simple_ind, ls='--', color='r', lw=3, label=r'Simple GNMax, data-ind analysis')
|
|
ax.plot(x_axis, y_conf_ind, ls='--', color='b', lw=3, label=r'Confident GNMax, data-ind analysis')
|
|
ax.plot(x_axis, y_simple_dep, ls='-', color='r', lw=3, label=r'Simple GNMax, data-dep analysis')
|
|
ax.plot(x_axis, y_conf_dep, ls='-', color='b', lw=3, label=r'Confident GNMax, data-dep analysis')
|
|
|
|
plt.xticks(np.arange(0, xlim + 1000, 2000))
|
|
plt.xlim([0, xlim])
|
|
plt.ylim(bottom=0)
|
|
plt.legend(fontsize=16)
|
|
ax.set_xlabel('Number of queries answered', fontsize=16)
|
|
ax.set_ylabel(r'Privacy cost $\varepsilon$ at $\delta=10^{-8}$', fontsize=16)
|
|
|
|
ax.tick_params(labelsize=14)
|
|
plot_filename = os.path.join(figures_dir, 'comparison.pdf')
|
|
print('Saving the graph to ' + plot_filename)
|
|
fig.savefig(plot_filename, bbox_inches='tight')
|
|
plt.show()
|
|
|
|
|
|
def plot_partition(figures_dir, gnmax_conf, print_order):
|
|
"""Plots an expert version of the privacy-per-answered-query graph.
|
|
|
|
Args:
|
|
figures_dir: A name of the directory where to save the plot.
|
|
eps: The cumulative privacy cost.
|
|
partition: Allocation of the privacy cost.
|
|
answered: Cumulative number of queries answered.
|
|
order_opt: The list of optimal orders.
|
|
"""
|
|
eps_partitioned, answered, ss_std_opt, order_opt = gnmax_conf
|
|
|
|
xlim = 10000
|
|
x = range(0, int(xlim), 10)
|
|
lenx = len(x)
|
|
y0 = np.full(lenx, np.nan, dtype=float) # delta
|
|
y1 = np.full(lenx, np.nan, dtype=float) # delta + step1
|
|
y2 = np.full(lenx, np.nan, dtype=float) # delta + step1 + step2
|
|
y3 = np.full(lenx, np.nan, dtype=float) # delta + step1 + step2 + ss
|
|
noise_std = np.full(lenx, np.nan, dtype=float)
|
|
|
|
y_right = np.full(lenx, np.nan, dtype=float)
|
|
|
|
for i in range(lenx):
|
|
idx = np.searchsorted(answered, x[i])
|
|
if idx < len(eps_partitioned):
|
|
y0[i] = eps_partitioned[idx].delta
|
|
y1[i] = y0[i] + eps_partitioned[idx].step1
|
|
y2[i] = y1[i] + eps_partitioned[idx].step2
|
|
y3[i] = y2[i] + eps_partitioned[idx].ss
|
|
|
|
noise_std[i] = ss_std_opt[idx]
|
|
y_right[i] = order_opt[idx]
|
|
|
|
# plt.close('all')
|
|
fig, ax = plt.subplots()
|
|
fig.set_figheight(4.5)
|
|
fig.set_figwidth(4.7)
|
|
fig.patch.set_alpha(0)
|
|
|
|
l1 = ax.plot(
|
|
x, y3, color='b', ls='-', label=r'Total privacy cost', linewidth=1).pop()
|
|
|
|
for y in (y0, y1, y2):
|
|
ax.plot(x, y, color='b', ls='-', label=r'_nolegend_', alpha=.5, linewidth=1)
|
|
|
|
ax.fill_between(x, [0] * lenx, y0.tolist(), facecolor='b', alpha=.5)
|
|
ax.fill_between(x, y0.tolist(), y1.tolist(), facecolor='b', alpha=.4)
|
|
ax.fill_between(x, y1.tolist(), y2.tolist(), facecolor='b', alpha=.3)
|
|
ax.fill_between(x, y2.tolist(), y3.tolist(), facecolor='b', alpha=.2)
|
|
|
|
ax.fill_between(x, (y3 - noise_std).tolist(), (y3 + noise_std).tolist(),
|
|
facecolor='r', alpha=.5)
|
|
|
|
|
|
plt.xticks(np.arange(0, xlim + 1000, 2000))
|
|
plt.xlim([0, xlim])
|
|
ax.set_ylim([0, 3.])
|
|
|
|
ax.set_xlabel('Number of queries answered', fontsize=16)
|
|
ax.set_ylabel(r'Privacy cost $\varepsilon$ at $\delta=10^{-8}$', fontsize=16)
|
|
|
|
# Merging legends.
|
|
if print_order:
|
|
ax2 = ax.twinx()
|
|
l2 = ax2.plot(
|
|
x, y_right, 'r', ls='-', label=r'Optimal order', linewidth=5,
|
|
alpha=.5).pop()
|
|
ax2.grid(False)
|
|
# ax2.set_ylabel(r'Optimal Renyi order', fontsize=16)
|
|
ax2.set_ylim([0, 200.])
|
|
# ax.legend((l1, l2), (l1.get_label(), l2.get_label()), loc=0, fontsize=13)
|
|
|
|
ax.tick_params(labelsize=14)
|
|
plot_filename = os.path.join(figures_dir, 'partition.pdf')
|
|
print('Saving the graph to ' + plot_filename)
|
|
fig.savefig(plot_filename, bbox_inches='tight', dpi=800)
|
|
plt.show()
|
|
|
|
|
|
def run_all_analyses(votes, threshold, sigma1, sigma2, delta):
|
|
simple_ind = analyze_gnmax_conf_data_ind(votes, None, None, sigma2,
|
|
delta)
|
|
|
|
conf_ind = analyze_gnmax_conf_data_ind(votes, threshold, sigma1, sigma2,
|
|
delta)
|
|
|
|
simple_dep = analyze_gnmax_conf_data_dep(votes, None, None, sigma2,
|
|
delta)
|
|
|
|
conf_dep = analyze_gnmax_conf_data_dep(votes, threshold, sigma1, sigma2,
|
|
delta)
|
|
|
|
return (simple_ind, conf_ind, simple_dep, conf_dep)
|
|
|
|
|
|
def run_or_load_all_analyses():
|
|
temp_filename = os.path.expanduser('~/tmp/partition_cached.pkl')
|
|
|
|
if FLAGS.cache and os.path.isfile(temp_filename):
|
|
print('Reading from cache ' + temp_filename)
|
|
with open(temp_filename, 'rb') as f:
|
|
all_analyses = pickle.load(f)
|
|
else:
|
|
fin_name = os.path.expanduser(FLAGS.counts_file)
|
|
print('Reading raw votes from ' + fin_name)
|
|
sys.stdout.flush()
|
|
|
|
votes = np.load(fin_name)
|
|
|
|
if FLAGS.queries is not None:
|
|
if votes.shape[0] < FLAGS.queries:
|
|
raise ValueError('Expect {} rows, got {} in {}'.format(
|
|
FLAGS.queries, votes.shape[0], fin_name))
|
|
# Truncate the votes matrix to the number of queries made.
|
|
votes = votes[:FLAGS.queries, ]
|
|
|
|
all_analyses = run_all_analyses(votes, FLAGS.threshold, FLAGS.sigma1,
|
|
FLAGS.sigma2, FLAGS.delta)
|
|
|
|
print('Writing to cache ' + temp_filename)
|
|
with open(temp_filename, 'wb') as f:
|
|
pickle.dump(all_analyses, f)
|
|
|
|
return all_analyses
|
|
|
|
|
|
def main(argv):
|
|
del argv # Unused.
|
|
|
|
simple_ind, conf_ind, simple_dep, conf_dep = run_or_load_all_analyses()
|
|
|
|
figures_dir = os.path.expanduser(FLAGS.figures_dir)
|
|
|
|
plot_comparison(figures_dir, simple_ind, conf_ind, simple_dep, conf_dep)
|
|
plot_partition(figures_dir, conf_dep, True)
|
|
plt.close('all')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
app.run(main)
|