forked from 626_privacy/tensorflow_privacy
419 lines
14 KiB
Python
419 lines
14 KiB
Python
# Copyright 2017 The 'Scalable Private Learning with PATE' Authors All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Functions for smooth sensitivity analysis for PATE mechanisms.
|
|
|
|
This library implements functionality for doing smooth sensitivity analysis
|
|
for Gaussian Noise Max (GNMax), Threshold with Gaussian noise, and Gaussian
|
|
Noise with Smooth Sensitivity (GNSS) mechanisms.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import math
|
|
from absl import app
|
|
import numpy as np
|
|
import scipy
|
|
import sympy as sp
|
|
|
|
import core as pate
|
|
|
|
################################
|
|
# SMOOTH SENSITIVITY FOR GNMAX #
|
|
################################
|
|
|
|
# Global dictionary for storing cached q0 values keyed by (sigma, order).
|
|
_logq0_cache = {}
|
|
|
|
|
|
def _compute_logq0(sigma, order):
|
|
key = (sigma, order)
|
|
if key in _logq0_cache:
|
|
return _logq0_cache[key]
|
|
|
|
logq0 = compute_logq0_gnmax(sigma, order)
|
|
|
|
_logq0_cache[key] = logq0 # Update the global variable.
|
|
return logq0
|
|
|
|
|
|
def _compute_logq1(sigma, order, num_classes):
|
|
logq0 = _compute_logq0(sigma, order) # Most likely already cached.
|
|
logq1 = math.log(_compute_bl_gnmax(math.exp(logq0), sigma, num_classes))
|
|
assert logq1 <= logq0
|
|
return logq1
|
|
|
|
|
|
def _compute_mu1_mu2_gnmax(sigma, logq):
|
|
# Computes mu1, mu2 according to Proposition 10.
|
|
mu2 = sigma * math.sqrt(-logq)
|
|
mu1 = mu2 + 1
|
|
return mu1, mu2
|
|
|
|
|
|
def _compute_data_dep_bound_gnmax(sigma, logq, order):
|
|
# Applies Theorem 6 in Appendix without checking that logq satisfies necessary
|
|
# constraints. The pre-conditions must be assured by comparing logq against
|
|
# logq0 by the caller.
|
|
variance = sigma**2
|
|
mu1, mu2 = _compute_mu1_mu2_gnmax(sigma, logq)
|
|
eps1 = mu1 / variance
|
|
eps2 = mu2 / variance
|
|
|
|
log1q = np.log1p(-math.exp(logq)) # log1q = log(1-q)
|
|
log_a = (order - 1) * (
|
|
log1q - (np.log1p(-math.exp((logq + eps2) * (1 - 1 / mu2)))))
|
|
log_b = (order - 1) * (eps1 - logq / (mu1 - 1))
|
|
|
|
return np.logaddexp(log1q + log_a, logq + log_b) / (order - 1)
|
|
|
|
|
|
def _compute_rdp_gnmax(sigma, logq, order):
|
|
logq0 = _compute_logq0(sigma, order)
|
|
if logq >= logq0:
|
|
return pate.rdp_data_independent_gaussian(sigma, order)
|
|
else:
|
|
return _compute_data_dep_bound_gnmax(sigma, logq, order)
|
|
|
|
|
|
def compute_logq0_gnmax(sigma, order):
|
|
"""Computes the point where we start using data-independent bounds.
|
|
|
|
Args:
|
|
sigma: std of the Gaussian noise
|
|
order: Renyi order lambda
|
|
|
|
Returns:
|
|
logq0: the point above which the data-ind bound overtakes data-dependent
|
|
bound.
|
|
"""
|
|
|
|
def _check_validity_conditions(logq):
|
|
# Function returns true iff logq is in the range where data-dependent bound
|
|
# is valid. (Theorem 6 in Appendix.)
|
|
mu1, mu2 = _compute_mu1_mu2_gnmax(sigma, logq)
|
|
if mu1 < order:
|
|
return False
|
|
eps2 = mu2 / sigma**2
|
|
# Do computation in the log space. The condition below comes from Lemma 9
|
|
# from Appendix.
|
|
return (logq <= (mu2 - 1) * eps2 - mu2 * math.log(mu1 / (mu1 - 1) * mu2 /
|
|
(mu2 - 1)))
|
|
|
|
def _compare_dep_vs_ind(logq):
|
|
return (_compute_data_dep_bound_gnmax(sigma, logq, order) -
|
|
pate.rdp_data_independent_gaussian(sigma, order))
|
|
|
|
# Natural upper bounds on q0.
|
|
logub = min(-(1 + 1. / sigma)**2, -((order - .99) / sigma)**2, -1 / sigma**2)
|
|
assert _check_validity_conditions(logub)
|
|
|
|
# If data-dependent bound is already better, we are done already.
|
|
if _compare_dep_vs_ind(logub) < 0:
|
|
return logub
|
|
|
|
# Identifying a reasonable lower bound to bracket logq0.
|
|
loglb = 2 * logub # logub is negative, and thus loglb < logub.
|
|
while _compare_dep_vs_ind(loglb) > 0:
|
|
assert loglb > -10000, "The lower bound on q0 is way too low."
|
|
loglb *= 1.5
|
|
|
|
logq0, r = scipy.optimize.brentq(
|
|
_compare_dep_vs_ind, loglb, logub, full_output=True)
|
|
assert r.converged, "The root finding procedure failed to converge."
|
|
assert _check_validity_conditions(logq0) # just in case.
|
|
|
|
return logq0
|
|
|
|
|
|
def _compute_bl_gnmax(q, sigma, num_classes):
|
|
return ((num_classes - 1) / 2 * scipy.special.erfc(
|
|
1 / sigma + scipy.special.erfcinv(2 * q / (num_classes - 1))))
|
|
|
|
|
|
def _compute_bu_gnmax(q, sigma, num_classes):
|
|
return min(1, (num_classes - 1) / 2 * scipy.special.erfc(
|
|
-1 / sigma + scipy.special.erfcinv(2 * q / (num_classes - 1))))
|
|
|
|
|
|
def _compute_local_sens_gnmax(logq, sigma, num_classes, order):
|
|
"""Implements Algorithm 3 (computes an upper bound on local sensitivity).
|
|
|
|
(See Proposition 13 for proof of correctness.)
|
|
"""
|
|
logq0 = _compute_logq0(sigma, order)
|
|
logq1 = _compute_logq1(sigma, order, num_classes)
|
|
if logq1 <= logq <= logq0:
|
|
logq = logq1
|
|
|
|
beta = _compute_rdp_gnmax(sigma, logq, order)
|
|
beta_bu_q = _compute_rdp_gnmax(
|
|
sigma, math.log(_compute_bu_gnmax(math.exp(logq), sigma, num_classes)),
|
|
order)
|
|
beta_bl_q = _compute_rdp_gnmax(
|
|
sigma, math.log(_compute_bl_gnmax(math.exp(logq), sigma, num_classes)),
|
|
order)
|
|
return max(beta_bu_q - beta, beta - beta_bl_q)
|
|
|
|
|
|
def compute_local_sensitivity_bounds_gnmax(votes, num_teachers, sigma, order):
|
|
"""Computes a list of max-LS-at-distance-d for the GNMax mechanism.
|
|
|
|
A more efficient implementation of Algorithms 4 and 5 working in time
|
|
O(teachers*classes). A naive implementation is O(teachers^2*classes) or worse.
|
|
|
|
Args:
|
|
votes: A numpy array of votes.
|
|
num_teachers: Total number of voting teachers.
|
|
sigma: Standard deviation of the Guassian noise.
|
|
order: The Renyi order.
|
|
|
|
Returns:
|
|
A numpy array of local sensitivities at distances d, 0 <= d <= num_teachers.
|
|
"""
|
|
|
|
num_classes = len(votes) # Called m in the paper.
|
|
|
|
logq0 = _compute_logq0(sigma, order)
|
|
logq1 = _compute_logq1(sigma, order, num_classes)
|
|
logq = pate.compute_logq_gaussian(votes, sigma)
|
|
plateau = _compute_local_sens_gnmax(logq1, sigma, num_classes, order)
|
|
|
|
res = np.full(num_teachers, plateau)
|
|
|
|
if logq1 <= logq <= logq0:
|
|
return res
|
|
|
|
# Invariant: votes is sorted in the non-increasing order.
|
|
votes = sorted(votes, reverse=True)
|
|
|
|
res[0] = _compute_local_sens_gnmax(logq, sigma, num_classes, order)
|
|
curr_d = 0
|
|
|
|
go_left = logq > logq0 # Otherwise logq < logq1 and we go right.
|
|
|
|
# Iterate while the following is true:
|
|
# 1. If we are going left, logq is still larger than logq0 and we may still
|
|
# increase the gap between votes[0] and votes[1].
|
|
# 2. If we are going right, logq is still smaller than logq1.
|
|
while ((go_left and logq > logq0 and votes[1] > 0) or
|
|
(not go_left and logq < logq1)):
|
|
curr_d += 1
|
|
if go_left: # Try decreasing logq.
|
|
votes[0] += 1
|
|
votes[1] -= 1
|
|
idx = 1
|
|
# Restore the invariant. (Can be implemented more efficiently by keeping
|
|
# track of the range of indices equal to votes[1]. Does not seem to matter
|
|
# for the overall running time.)
|
|
while idx < len(votes) - 1 and votes[idx] < votes[idx + 1]:
|
|
votes[idx], votes[idx + 1] = votes[idx + 1], votes[idx]
|
|
idx += 1
|
|
else: # Go right, i.e., try increasing logq.
|
|
votes[0] -= 1
|
|
votes[1] += 1 # The invariant holds since otherwise logq >= logq1.
|
|
|
|
logq = pate.compute_logq_gaussian(votes, sigma)
|
|
res[curr_d] = _compute_local_sens_gnmax(logq, sigma, num_classes, order)
|
|
|
|
return res
|
|
|
|
|
|
##################################################
|
|
# SMOOTH SENSITIVITY FOR THE THRESHOLD MECHANISM #
|
|
##################################################
|
|
|
|
# A global dictionary of RDPs for various threshold values. Indexed by a 4-tuple
|
|
# (num_teachers, threshold, sigma, order).
|
|
_rdp_thresholds = {}
|
|
|
|
|
|
def _compute_rdp_list_threshold(num_teachers, threshold, sigma, order):
|
|
key = (num_teachers, threshold, sigma, order)
|
|
if key in _rdp_thresholds:
|
|
return _rdp_thresholds[key]
|
|
|
|
res = np.zeros(num_teachers + 1)
|
|
for v in range(0, num_teachers + 1):
|
|
logp = scipy.stats.norm.logsf(threshold - v, scale=sigma)
|
|
res[v] = pate.compute_rdp_threshold(logp, sigma, order)
|
|
|
|
_rdp_thresholds[key] = res
|
|
return res
|
|
|
|
|
|
def compute_local_sensitivity_bounds_threshold(counts, num_teachers, threshold,
|
|
sigma, order):
|
|
"""Computes a list of max-LS-at-distance-d for the threshold mechanism."""
|
|
|
|
def _compute_ls(v):
|
|
ls_step_up, ls_step_down = float("-inf"), float("-inf")
|
|
if v > 0:
|
|
ls_step_down = abs(rdp_list[v - 1] - rdp_list[v])
|
|
if v < num_teachers:
|
|
ls_step_up = abs(rdp_list[v + 1] - rdp_list[v])
|
|
return max(ls_step_down, ls_step_up) # Rely on max(x, None) = x.
|
|
|
|
cur_max = int(round(max(counts)))
|
|
rdp_list = _compute_rdp_list_threshold(num_teachers, threshold, sigma, order)
|
|
|
|
ls = np.zeros(num_teachers)
|
|
for d in range(max(cur_max, num_teachers - cur_max)):
|
|
ls_up, ls_down = float("-inf"), float("-inf")
|
|
if cur_max + d <= num_teachers:
|
|
ls_up = _compute_ls(cur_max + d)
|
|
if cur_max - d >= 0:
|
|
ls_down = _compute_ls(cur_max - d)
|
|
ls[d] = max(ls_up, ls_down)
|
|
return ls
|
|
|
|
|
|
#############################################
|
|
# PROCEDURES FOR SMOOTH SENSITIVITY RELEASE #
|
|
#############################################
|
|
|
|
# A global dictionary of exponentially decaying arrays. Indexed by beta.
|
|
dict_beta_discount = {}
|
|
|
|
|
|
def compute_discounted_max(beta, a):
|
|
n = len(a)
|
|
|
|
if beta not in dict_beta_discount or (len(dict_beta_discount[beta]) < n):
|
|
dict_beta_discount[beta] = np.exp(-beta * np.arange(n))
|
|
|
|
return max(a * dict_beta_discount[beta][:n])
|
|
|
|
|
|
def compute_smooth_sensitivity_gnmax(beta, counts, num_teachers, sigma, order):
|
|
"""Computes smooth sensitivity of a single application of GNMax."""
|
|
|
|
ls = compute_local_sensitivity_bounds_gnmax(counts, sigma, order,
|
|
num_teachers)
|
|
return compute_discounted_max(beta, ls)
|
|
|
|
|
|
def compute_rdp_of_smooth_sensitivity_gaussian(beta, sigma, order):
|
|
"""Computes the RDP curve for the GNSS mechanism.
|
|
|
|
Implements Theorem 23 (https://arxiv.org/pdf/1802.08908.pdf).
|
|
"""
|
|
if beta > 0 and not 1 < order < 1 / (2 * beta):
|
|
raise ValueError("Order outside the (1, 1/(2*beta)) range.")
|
|
|
|
return order * math.exp(2 * beta) / sigma**2 + (
|
|
-.5 * math.log(1 - 2 * order * beta) + beta * order) / (
|
|
order - 1)
|
|
|
|
|
|
def compute_params_for_ss_release(eps, delta):
|
|
"""Computes sigma for additive Gaussian noise scaled by smooth sensitivity.
|
|
|
|
Presently not used. (We proceed via RDP analysis.)
|
|
|
|
Compute beta, sigma for applying Lemma 2.6 (full version of Nissim et al.) via
|
|
Lemma 2.10.
|
|
"""
|
|
# Rather than applying Lemma 2.10 directly, which would give suboptimal alpha,
|
|
# (see http://www.cse.psu.edu/~ads22/pubs/NRS07/NRS07-full-draft-v1.pdf),
|
|
# we extract a sufficient condition on alpha from its proof.
|
|
#
|
|
# Let a = rho_(delta/2)(Z_1). Then solve for alpha such that
|
|
# 2 alpha a + alpha^2 = eps/2.
|
|
a = scipy.special.ndtri(1 - delta / 2)
|
|
alpha = math.sqrt(a**2 + eps / 2) - a
|
|
|
|
beta = eps / (2 * scipy.special.chdtri(1, delta / 2))
|
|
|
|
return alpha, beta
|
|
|
|
|
|
#######################################################
|
|
# SYMBOLIC-NUMERIC VERIFICATION OF CONDITIONS C5--C6. #
|
|
#######################################################
|
|
|
|
|
|
def _construct_symbolic_beta(q, sigma, order):
|
|
mu2 = sigma * sp.sqrt(sp.log(1 / q))
|
|
mu1 = mu2 + 1
|
|
eps1 = mu1 / sigma**2
|
|
eps2 = mu2 / sigma**2
|
|
a = (1 - q) / (1 - (q * sp.exp(eps2))**(1 - 1 / mu2))
|
|
b = sp.exp(eps1) / q**(1 / (mu1 - 1))
|
|
s = (1 - q) * a**(order - 1) + q * b**(order - 1)
|
|
return (1 / (order - 1)) * sp.log(s)
|
|
|
|
|
|
def _construct_symbolic_bu(q, sigma, m):
|
|
return (m - 1) / 2 * sp.erfc(sp.erfcinv(2 * q / (m - 1)) - 1 / sigma)
|
|
|
|
|
|
def _is_non_decreasing(fn, q, bounds):
|
|
"""Verifies whether the function is non-decreasing within a range.
|
|
|
|
Args:
|
|
fn: Symbolic function of a single variable.
|
|
q: The name of f's variable.
|
|
bounds: Pair of (lower_bound, upper_bound) reals.
|
|
|
|
Returns:
|
|
True iff the function is non-decreasing in the range.
|
|
"""
|
|
diff_fn = sp.diff(fn, q) # Symbolically compute the derivative.
|
|
diff_fn_lambdified = sp.lambdify(
|
|
q,
|
|
diff_fn,
|
|
modules=[
|
|
"numpy", {
|
|
"erfc": scipy.special.erfc,
|
|
"erfcinv": scipy.special.erfcinv
|
|
}
|
|
])
|
|
r = scipy.optimize.minimize_scalar(
|
|
diff_fn_lambdified, bounds=bounds, method="bounded")
|
|
assert r.success, "Minimizer failed to converge."
|
|
return r.fun >= 0 # Check whether the derivative is non-negative.
|
|
|
|
|
|
def check_conditions(sigma, m, order):
|
|
"""Checks conditions C5 and C6 (Section B.4.2 in Appendix)."""
|
|
q = sp.symbols("q", positive=True, real=True)
|
|
|
|
beta = _construct_symbolic_beta(q, sigma, order)
|
|
q0 = math.exp(compute_logq0_gnmax(sigma, order))
|
|
|
|
cond5 = _is_non_decreasing(beta, q, (0, q0))
|
|
|
|
if cond5:
|
|
bl_q0 = _compute_bl_gnmax(q0, sigma, m)
|
|
|
|
bu = _construct_symbolic_bu(q, sigma, m)
|
|
delta_beta = beta.subs(q, bu) - beta
|
|
|
|
cond6 = _is_non_decreasing(delta_beta, q, (0, bl_q0))
|
|
else:
|
|
cond6 = False # Skip the check, since Condition 5 is false already.
|
|
|
|
return (cond5, cond6)
|
|
|
|
|
|
def main(argv):
|
|
del argv # Unused.
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run(main)
|