# Copyright 2017 The 'Scalable Private Learning with PATE' Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Illustrates how noisy thresholding check changes distribution of queries. A script in support of the paper "Scalable Private Learning with PATE" by Nicolas Papernot, Shuang Song, Ilya Mironov, Ananth Raghunathan, Kunal Talwar, Ulfar Erlingsson (https://arxiv.org/abs/1802.08908). The input is a file containing a numpy array of votes, one query per row, one class per column. Ex: 43, 1821, ..., 3 31, 16, ..., 0 ... 0, 86, ..., 438 The output is one of two graphs depending on the setting of the plot variable. The output is written to a pdf file. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import math import os import sys sys.path.append('..') # Main modules reside in the parent directory. from absl import app from absl import flags import core as pate import matplotlib matplotlib.use('TkAgg') import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top import numpy as np from six.moves import xrange plt.style.use('ggplot') FLAGS = flags.FLAGS flags.DEFINE_enum('plot', 'small', ['small', 'large'], 'Selects which of' 'the two plots is produced.') flags.DEFINE_string('counts_file', None, 'Counts file.') flags.DEFINE_string('plot_file', '', 'Plot file to write.') flags.mark_flag_as_required('counts_file') def compute_count_per_bin(bin_num, votes): """Tabulates number of examples in each bin. Args: bin_num: Number of bins. votes: A matrix of votes, where each row contains votes in one instance. Returns: Array of counts of length bin_num. """ sums = np.sum(votes, axis=1) # Check that all rows contain the same number of votes. assert max(sums) == min(sums) s = max(sums) counts = np.zeros(bin_num) n = votes.shape[0] for i in xrange(n): v = votes[i,] bin_idx = int(math.floor(max(v) * bin_num / s)) assert 0 <= bin_idx < bin_num counts[bin_idx] += 1 return counts def compute_privacy_cost_per_bins(bin_num, votes, sigma2, order): """Outputs average privacy cost per bin. Args: bin_num: Number of bins. votes: A matrix of votes, where each row contains votes in one instance. sigma2: The scale (std) of the Gaussian noise. (Same as sigma_2 in Algorithms 1 and 2.) order: The Renyi order for which privacy cost is computed. Returns: Expected eps of RDP (ignoring delta) per example in each bin. """ n = votes.shape[0] bin_counts = np.zeros(bin_num) bin_rdp = np.zeros(bin_num) # RDP at order=order for i in xrange(n): v = votes[i,] logq = pate.compute_logq_gaussian(v, sigma2) rdp_at_order = pate.rdp_gaussian(logq, sigma2, order) bin_idx = int(math.floor(max(v) * bin_num / sum(v))) assert 0 <= bin_idx < bin_num bin_counts[bin_idx] += 1 bin_rdp[bin_idx] += rdp_at_order if (i + 1) % 1000 == 0: print('example {}'.format(i + 1)) sys.stdout.flush() return bin_rdp / bin_counts def compute_expected_answered_per_bin(bin_num, votes, threshold, sigma1): """Computes expected number of answers per bin. Args: bin_num: Number of bins. votes: A matrix of votes, where each row contains votes in one instance. threshold: The threshold against which check is performed. sigma1: The std of the Gaussian noise with which check is performed. (Same as sigma_1 in Algorithms 1 and 2.) Returns: Expected number of queries answered per bin. """ n = votes.shape[0] bin_answered = np.zeros(bin_num) for i in xrange(n): v = votes[i,] p = math.exp(pate.compute_logpr_answered(threshold, sigma1, v)) bin_idx = int(math.floor(max(v) * bin_num / sum(v))) assert 0 <= bin_idx < bin_num bin_answered[bin_idx] += p if (i + 1) % 1000 == 0: print('example {}'.format(i + 1)) sys.stdout.flush() return bin_answered def main(argv): del argv # Unused. fin_name = os.path.expanduser(FLAGS.counts_file) print('Reading raw votes from ' + fin_name) sys.stdout.flush() votes = np.load(fin_name) votes = votes[:4000,] # truncate to 4000 samples if FLAGS.plot == 'small': bin_num = 5 m_check = compute_expected_answered_per_bin(bin_num, votes, 3500, 1500) elif FLAGS.plot == 'large': bin_num = 10 m_check = compute_expected_answered_per_bin(bin_num, votes, 3500, 1500) a_check = compute_expected_answered_per_bin(bin_num, votes, 5000, 1500) eps = compute_privacy_cost_per_bins(bin_num, votes, 100, 50) else: raise ValueError('--plot flag must be one of ["small", "large"]') counts = compute_count_per_bin(bin_num, votes) bins = np.linspace(0, 100, num=bin_num, endpoint=False) plt.close('all') fig, ax = plt.subplots() if FLAGS.plot == 'small': fig.set_figheight(5) fig.set_figwidth(5) ax.bar( bins, counts, 20, color='orangered', linestyle='dotted', linewidth=5, edgecolor='red', fill=False, alpha=.5, align='edge', label='LNMax answers') ax.bar( bins, m_check, 20, color='g', alpha=.5, linewidth=0, edgecolor='g', align='edge', label='Confident-GNMax\nanswers') elif FLAGS.plot == 'large': fig.set_figheight(4.7) fig.set_figwidth(7) ax.bar( bins, counts, 10, linestyle='dashed', linewidth=5, edgecolor='red', fill=False, alpha=.5, align='edge', label='LNMax answers') ax.bar( bins, m_check, 10, color='g', alpha=.5, linewidth=0, edgecolor='g', align='edge', label='Confident-GNMax\nanswers (moderate)') ax.bar( bins, a_check, 10, color='b', alpha=.5, align='edge', label='Confident-GNMax\nanswers (aggressive)') ax2 = ax.twinx() bin_centers = [x + 5 for x in bins] ax2.plot(bin_centers, eps, 'ko', alpha=.8) ax2.set_ylim([1e-200, 1.]) ax2.set_yscale('log') ax2.grid(False) ax2.set_yticks([1e-3, 1e-50, 1e-100, 1e-150, 1e-200]) plt.tick_params(which='minor', right='off') ax2.set_ylabel(r'Per query privacy cost $\varepsilon$', fontsize=16) plt.xlim([0, 100]) ax.set_ylim([0, 2500]) # ax.set_yscale('log') ax.set_xlabel('Percentage of teachers that agree', fontsize=16) ax.set_ylabel('Number of queries answered', fontsize=16) vals = ax.get_xticks() ax.set_xticklabels([str(int(x)) + '%' for x in vals]) ax.tick_params(labelsize=14, bottom=True, top=True, left=True, right=True) ax.legend(loc=2, prop={'size': 16}) # simple: 'figures/noisy_thresholding_check_perf.pdf') # detailed: 'figures/noisy_thresholding_check_perf_details.pdf' print('Saving the graph to ' + FLAGS.plot_file) plt.savefig(os.path.expanduser(FLAGS.plot_file), bbox_inches='tight') plt.show() if __name__ == '__main__': app.run(main)