diff --git a/tensorflow_privacy/__init__.py b/tensorflow_privacy/__init__.py new file mode 100644 index 00000000..7c00cb05 --- /dev/null +++ b/tensorflow_privacy/__init__.py @@ -0,0 +1,57 @@ +# Copyright 2019, The TensorFlow Privacy Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TensorFlow Privacy library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +# pylint: disable=g-import-not-at-top + +if hasattr(sys, 'skip_tf_privacy_import'): # Useful for standalone scripts. + pass +else: + from tensorflow_privacy.privacy.analysis.privacy_ledger import GaussianSumQueryEntry + from tensorflow_privacy.privacy.analysis.privacy_ledger import PrivacyLedger + from tensorflow_privacy.privacy.analysis.privacy_ledger import QueryWithLedger + from tensorflow_privacy.privacy.analysis.privacy_ledger import SampleEntry + + from tensorflow_privacy.privacy.dp_query.dp_query import DPQuery + from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianAverageQuery + from tensorflow_privacy.privacy.dp_query.gaussian_query import GaussianSumQuery + from tensorflow_privacy.privacy.dp_query.nested_query import NestedQuery + from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacyAverageQuery + from tensorflow_privacy.privacy.dp_query.no_privacy_query import NoPrivacySumQuery + from tensorflow_privacy.privacy.dp_query.normalized_query import NormalizedQuery + from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipSumQuery + from tensorflow_privacy.privacy.dp_query.quantile_adaptive_clip_sum_query import QuantileAdaptiveClipAverageQuery + + from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdagradGaussianOptimizer + from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdagradOptimizer + from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdamGaussianOptimizer + from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPAdamOptimizer + from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer + from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentOptimizer + + try: + from tensorflow_privacy.privacy.bolt_on.models import BoltOnModel + from tensorflow_privacy.privacy.bolt_on.optimizers import BoltOn + from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexMixin + from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexBinaryCrossentropy + from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexHuber + except ImportError: + # module `bolt_on` not yet available in this version of TF Privacy + pass diff --git a/tensorflow_privacy/privacy/BUILD b/tensorflow_privacy/privacy/BUILD new file mode 100644 index 00000000..001dcb9a --- /dev/null +++ b/tensorflow_privacy/privacy/BUILD @@ -0,0 +1,5 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) + +exports_files(["LICENSE"]) diff --git a/tensorflow_privacy/privacy/__init__.py b/tensorflow_privacy/privacy/__init__.py new file mode 100644 index 00000000..30d107ea --- /dev/null +++ b/tensorflow_privacy/privacy/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2019, The TensorFlow Privacy Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tensorflow_privacy/privacy/analysis/__init__.py b/tensorflow_privacy/privacy/analysis/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy.py b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy.py new file mode 100644 index 00000000..296618b9 --- /dev/null +++ b/tensorflow_privacy/privacy/analysis/compute_dp_sgd_privacy.py @@ -0,0 +1,97 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Command-line script for computing privacy of a model trained with DP-SGD. + +The script applies the RDP accountant to estimate privacy budget of an iterated +Sampled Gaussian Mechanism. The mechanism's parameters are controlled by flags. + +Example: + compute_dp_sgd_privacy + --N=60000 \ + --batch_size=256 \ + --noise_multiplier=1.12 \ + --epochs=60 \ + --delta=1e-5 + +The output states that DP-SGD with these parameters satisfies (2.92, 1e-5)-DP. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import sys + +from absl import app +from absl import flags + +# Opting out of loading all sibling packages and their dependencies. +sys.skip_tf_privacy_import = True + +from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp # pylint: disable=g-import-not-at-top +from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent + +FLAGS = flags.FLAGS + +flags.DEFINE_integer('N', None, 'Total number of examples') +flags.DEFINE_integer('batch_size', None, 'Batch size') +flags.DEFINE_float('noise_multiplier', None, 'Noise multiplier for DP-SGD') +flags.DEFINE_float('epochs', None, 'Number of epochs (may be fractional)') +flags.DEFINE_float('delta', 1e-6, 'Target delta') + +flags.mark_flag_as_required('N') +flags.mark_flag_as_required('batch_size') +flags.mark_flag_as_required('noise_multiplier') +flags.mark_flag_as_required('epochs') + + +def apply_dp_sgd_analysis(q, sigma, steps, orders, delta): + """Compute and print results of DP-SGD analysis.""" + + # compute_rdp requires that sigma be the ratio of the standard deviation of + # the Gaussian noise to the l2-sensitivity of the function to which it is + # added. Hence, sigma here corresponds to the `noise_multiplier` parameter + # in the DP-SGD implementation found in privacy.optimizers.dp_optimizer + rdp = compute_rdp(q, sigma, steps, orders) + + eps, _, opt_order = get_privacy_spent(orders, rdp, target_delta=delta) + + print('DP-SGD with sampling rate = {:.3g}% and noise_multiplier = {} iterated' + ' over {} steps satisfies'.format(100 * q, sigma, steps), end=' ') + print('differential privacy with eps = {:.3g} and delta = {}.'.format( + eps, delta)) + print('The optimal RDP order is {}.'.format(opt_order)) + + if opt_order == max(orders) or opt_order == min(orders): + print('The privacy estimate is likely to be improved by expanding ' + 'the set of orders.') + + +def main(argv): + del argv # argv is not used. + + q = FLAGS.batch_size / FLAGS.N # q - the sampling ratio. + if q > 1: + raise app.UsageError('N must be larger than the batch size.') + orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + + list(range(5, 64)) + [128, 256, 512]) + steps = int(math.ceil(FLAGS.epochs * FLAGS.N / FLAGS.batch_size)) + + apply_dp_sgd_analysis(q, FLAGS.noise_multiplier, steps, orders, FLAGS.delta) + + +if __name__ == '__main__': + app.run(main) diff --git a/tensorflow_privacy/privacy/analysis/privacy_ledger.py b/tensorflow_privacy/privacy/analysis/privacy_ledger.py new file mode 100644 index 00000000..22eb1f0b --- /dev/null +++ b/tensorflow_privacy/privacy/analysis/privacy_ledger.py @@ -0,0 +1,257 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PrivacyLedger class for keeping a record of private queries.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from distutils.version import LooseVersion +import numpy as np +import tensorflow as tf + +from tensorflow_privacy.privacy.analysis import tensor_buffer +from tensorflow_privacy.privacy.dp_query import dp_query + +if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + nest = tf.contrib.framework.nest +else: + nest = tf.nest + +SampleEntry = collections.namedtuple( # pylint: disable=invalid-name + 'SampleEntry', ['population_size', 'selection_probability', 'queries']) + +GaussianSumQueryEntry = collections.namedtuple( # pylint: disable=invalid-name + 'GaussianSumQueryEntry', ['l2_norm_bound', 'noise_stddev']) + + +def format_ledger(sample_array, query_array): + """Converts array representation into a list of SampleEntries.""" + samples = [] + query_pos = 0 + sample_pos = 0 + for sample in sample_array: + population_size, selection_probability, num_queries = sample + queries = [] + for _ in range(int(num_queries)): + query = query_array[query_pos] + assert int(query[0]) == sample_pos + queries.append(GaussianSumQueryEntry(*query[1:])) + query_pos += 1 + samples.append(SampleEntry(population_size, selection_probability, queries)) + sample_pos += 1 + return samples + + +class PrivacyLedger(object): + """Class for keeping a record of private queries. + + The PrivacyLedger keeps a record of all queries executed over a given dataset + for the purpose of computing privacy guarantees. + """ + + def __init__(self, + population_size, + selection_probability): + """Initialize the PrivacyLedger. + + Args: + population_size: An integer (may be variable) specifying the size of the + population, i.e. size of the training data used in each epoch. + selection_probability: A float (may be variable) specifying the + probability each record is included in a sample. + + Raises: + ValueError: If selection_probability is 0. + """ + self._population_size = population_size + self._selection_probability = selection_probability + + if tf.executing_eagerly(): + if tf.equal(selection_probability, 0): + raise ValueError('Selection probability cannot be 0.') + init_capacity = tf.cast(tf.ceil(1 / selection_probability), tf.int32) + else: + if selection_probability == 0: + raise ValueError('Selection probability cannot be 0.') + init_capacity = np.int(np.ceil(1 / selection_probability)) + + # The query buffer stores rows corresponding to GaussianSumQueryEntries. + self._query_buffer = tensor_buffer.TensorBuffer( + init_capacity, [3], tf.float32, 'query') + self._sample_var = tf.Variable( + initial_value=tf.zeros([3]), trainable=False, name='sample') + + # The sample buffer stores rows corresponding to SampleEntries. + self._sample_buffer = tensor_buffer.TensorBuffer( + init_capacity, [3], tf.float32, 'sample') + self._sample_count = tf.Variable( + initial_value=0.0, trainable=False, name='sample_count') + self._query_count = tf.Variable( + initial_value=0.0, trainable=False, name='query_count') + try: + # Newer versions of TF + self._cs = tf.CriticalSection() + except AttributeError: + # Older versions of TF + self._cs = tf.contrib.framework.CriticalSection() + + def record_sum_query(self, l2_norm_bound, noise_stddev): + """Records that a query was issued. + + Args: + l2_norm_bound: The maximum l2 norm of the tensor group in the query. + noise_stddev: The standard deviation of the noise applied to the sum. + + Returns: + An operation recording the sum query to the ledger. + """ + + def _do_record_query(): + with tf.control_dependencies( + [tf.assign(self._query_count, self._query_count + 1)]): + return self._query_buffer.append( + [self._sample_count, l2_norm_bound, noise_stddev]) + + return self._cs.execute(_do_record_query) + + def finalize_sample(self): + """Finalizes sample and records sample ledger entry.""" + with tf.control_dependencies([ + tf.assign(self._sample_var, [ + self._population_size, self._selection_probability, + self._query_count + ]) + ]): + with tf.control_dependencies([ + tf.assign(self._sample_count, self._sample_count + 1), + tf.assign(self._query_count, 0) + ]): + return self._sample_buffer.append(self._sample_var) + + def get_unformatted_ledger(self): + return self._sample_buffer.values, self._query_buffer.values + + def get_formatted_ledger(self, sess): + """Gets the formatted query ledger. + + Args: + sess: The tensorflow session in which the ledger was created. + + Returns: + The query ledger as a list of SampleEntries. + """ + sample_array = sess.run(self._sample_buffer.values) + query_array = sess.run(self._query_buffer.values) + + return format_ledger(sample_array, query_array) + + def get_formatted_ledger_eager(self): + """Gets the formatted query ledger. + + Returns: + The query ledger as a list of SampleEntries. + """ + sample_array = self._sample_buffer.values.numpy() + query_array = self._query_buffer.values.numpy() + + return format_ledger(sample_array, query_array) + + +class QueryWithLedger(dp_query.DPQuery): + """A class for DP queries that record events to a PrivacyLedger. + + QueryWithLedger should be the top-level query in a structure of queries that + may include sum queries, nested queries, etc. It should simply wrap another + query and contain a reference to the ledger. Any contained queries (including + those contained in the leaves of a nested query) should also contain a + reference to the same ledger object. + + For example usage, see privacy_ledger_test.py. + """ + + def __init__(self, query, + population_size=None, selection_probability=None, + ledger=None): + """Initializes the QueryWithLedger. + + Args: + query: The query whose events should be recorded to the ledger. Any + subqueries (including those in the leaves of a nested query) should also + contain a reference to the same ledger given here. + population_size: An integer (may be variable) specifying the size of the + population, i.e. size of the training data used in each epoch. May be + None if `ledger` is specified. + selection_probability: A float (may be variable) specifying the + probability each record is included in a sample. May be None if `ledger` + is specified. + ledger: A PrivacyLedger to use. Must be specified if either of + `population_size` or `selection_probability` is None. + """ + self._query = query + if population_size is not None and selection_probability is not None: + self.set_ledger(PrivacyLedger(population_size, selection_probability)) + elif ledger is not None: + self.set_ledger(ledger) + else: + raise ValueError('One of (population_size, selection_probability) or ' + 'ledger must be specified.') + + @property + def ledger(self): + return self._ledger + + def set_ledger(self, ledger): + self._ledger = ledger + self._query.set_ledger(ledger) + + def initial_global_state(self): + """See base class.""" + return self._query.initial_global_state() + + def derive_sample_params(self, global_state): + """See base class.""" + return self._query.derive_sample_params(global_state) + + def initial_sample_state(self, template): + """See base class.""" + return self._query.initial_sample_state(template) + + def preprocess_record(self, params, record): + """See base class.""" + return self._query.preprocess_record(params, record) + + def accumulate_preprocessed_record(self, sample_state, preprocessed_record): + """See base class.""" + return self._query.accumulate_preprocessed_record( + sample_state, preprocessed_record) + + def merge_sample_states(self, sample_state_1, sample_state_2): + """See base class.""" + return self._query.merge_sample_states(sample_state_1, sample_state_2) + + def get_noised_result(self, sample_state, global_state): + """Ensures sample is recorded to the ledger and returns noised result.""" + # Ensure sample_state is fully aggregated before calling get_noised_result. + with tf.control_dependencies(nest.flatten(sample_state)): + result, new_global_state = self._query.get_noised_result( + sample_state, global_state) + # Ensure inner queries have recorded before finalizing. + with tf.control_dependencies(nest.flatten(result)): + finalize = self._ledger.finalize_sample() + # Ensure finalizing happens. + with tf.control_dependencies([finalize]): + return nest.map_structure(tf.identity, result), new_global_state diff --git a/tensorflow_privacy/privacy/analysis/privacy_ledger_test.py b/tensorflow_privacy/privacy/analysis/privacy_ledger_test.py new file mode 100644 index 00000000..4407ad24 --- /dev/null +++ b/tensorflow_privacy/privacy/analysis/privacy_ledger_test.py @@ -0,0 +1,137 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for PrivacyLedger.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow_privacy.privacy.analysis import privacy_ledger +from tensorflow_privacy.privacy.dp_query import gaussian_query +from tensorflow_privacy.privacy.dp_query import nested_query +from tensorflow_privacy.privacy.dp_query import test_utils + +tf.enable_eager_execution() + + +class PrivacyLedgerTest(tf.test.TestCase): + + def test_fail_on_probability_zero(self): + with self.assertRaisesRegexp(ValueError, + 'Selection probability cannot be 0.'): + privacy_ledger.PrivacyLedger(10, 0) + + def test_basic(self): + ledger = privacy_ledger.PrivacyLedger(10, 0.1) + ledger.record_sum_query(5.0, 1.0) + ledger.record_sum_query(2.0, 0.5) + + ledger.finalize_sample() + + expected_queries = [[5.0, 1.0], [2.0, 0.5]] + formatted = ledger.get_formatted_ledger_eager() + + sample = formatted[0] + self.assertAllClose(sample.population_size, 10.0) + self.assertAllClose(sample.selection_probability, 0.1) + self.assertAllClose(sorted(sample.queries), sorted(expected_queries)) + + def test_sum_query(self): + record1 = tf.constant([2.0, 0.0]) + record2 = tf.constant([-1.0, 1.0]) + + population_size = tf.Variable(0) + selection_probability = tf.Variable(1.0) + + query = gaussian_query.GaussianSumQuery( + l2_norm_clip=10.0, stddev=0.0) + query = privacy_ledger.QueryWithLedger( + query, population_size, selection_probability) + + # First sample. + tf.assign(population_size, 10) + tf.assign(selection_probability, 0.1) + test_utils.run_query(query, [record1, record2]) + + expected_queries = [[10.0, 0.0]] + formatted = query.ledger.get_formatted_ledger_eager() + sample_1 = formatted[0] + self.assertAllClose(sample_1.population_size, 10.0) + self.assertAllClose(sample_1.selection_probability, 0.1) + self.assertAllClose(sample_1.queries, expected_queries) + + # Second sample. + tf.assign(population_size, 20) + tf.assign(selection_probability, 0.2) + test_utils.run_query(query, [record1, record2]) + + formatted = query.ledger.get_formatted_ledger_eager() + sample_1, sample_2 = formatted + self.assertAllClose(sample_1.population_size, 10.0) + self.assertAllClose(sample_1.selection_probability, 0.1) + self.assertAllClose(sample_1.queries, expected_queries) + + self.assertAllClose(sample_2.population_size, 20.0) + self.assertAllClose(sample_2.selection_probability, 0.2) + self.assertAllClose(sample_2.queries, expected_queries) + + def test_nested_query(self): + population_size = tf.Variable(0) + selection_probability = tf.Variable(1.0) + + query1 = gaussian_query.GaussianAverageQuery( + l2_norm_clip=4.0, sum_stddev=2.0, denominator=5.0) + query2 = gaussian_query.GaussianAverageQuery( + l2_norm_clip=5.0, sum_stddev=1.0, denominator=5.0) + + query = nested_query.NestedQuery([query1, query2]) + query = privacy_ledger.QueryWithLedger( + query, population_size, selection_probability) + + record1 = [1.0, [12.0, 9.0]] + record2 = [5.0, [1.0, 2.0]] + + # First sample. + tf.assign(population_size, 10) + tf.assign(selection_probability, 0.1) + test_utils.run_query(query, [record1, record2]) + + expected_queries = [[4.0, 2.0], [5.0, 1.0]] + formatted = query.ledger.get_formatted_ledger_eager() + sample_1 = formatted[0] + self.assertAllClose(sample_1.population_size, 10.0) + self.assertAllClose(sample_1.selection_probability, 0.1) + self.assertAllClose(sorted(sample_1.queries), sorted(expected_queries)) + + # Second sample. + tf.assign(population_size, 20) + tf.assign(selection_probability, 0.2) + test_utils.run_query(query, [record1, record2]) + + formatted = query.ledger.get_formatted_ledger_eager() + sample_1, sample_2 = formatted + self.assertAllClose(sample_1.population_size, 10.0) + self.assertAllClose(sample_1.selection_probability, 0.1) + self.assertAllClose(sorted(sample_1.queries), sorted(expected_queries)) + + self.assertAllClose(sample_2.population_size, 20.0) + self.assertAllClose(sample_2.selection_probability, 0.2) + self.assertAllClose(sorted(sample_2.queries), sorted(expected_queries)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/analysis/rdp_accountant.py b/tensorflow_privacy/privacy/analysis/rdp_accountant.py new file mode 100644 index 00000000..195b91e5 --- /dev/null +++ b/tensorflow_privacy/privacy/analysis/rdp_accountant.py @@ -0,0 +1,318 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RDP analysis of the Sampled Gaussian Mechanism. + +Functionality for computing Renyi differential privacy (RDP) of an additive +Sampled Gaussian Mechanism (SGM). Its public interface consists of two methods: + compute_rdp(q, noise_multiplier, T, orders) computes RDP for SGM iterated + T times. + get_privacy_spent(orders, rdp, target_eps, target_delta) computes delta + (or eps) given RDP at multiple orders and + a target value for eps (or delta). + +Example use: + +Suppose that we have run an SGM applied to a function with l2-sensitivity 1. +Its parameters are given as a list of tuples (q1, sigma1, T1), ..., +(qk, sigma_k, Tk), and we wish to compute eps for a given delta. +The example code would be: + + max_order = 32 + orders = range(2, max_order + 1) + rdp = np.zeros_like(orders, dtype=float) + for q, sigma, T in parameters: + rdp += rdp_accountant.compute_rdp(q, sigma, T, orders) + eps, _, opt_order = rdp_accountant.get_privacy_spent(rdp, target_delta=delta) +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import sys + +import numpy as np +from scipy import special +import six + +######################## +# LOG-SPACE ARITHMETIC # +######################## + + +def _log_add(logx, logy): + """Add two numbers in the log space.""" + a, b = min(logx, logy), max(logx, logy) + if a == -np.inf: # adding 0 + return b + # Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b) + return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1) + + +def _log_sub(logx, logy): + """Subtract two numbers in the log space. Answer must be non-negative.""" + if logx < logy: + raise ValueError("The result of subtraction must be non-negative.") + if logy == -np.inf: # subtracting 0 + return logx + if logx == logy: + return -np.inf # 0 is represented as -np.inf in the log space. + + try: + # Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y). + return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1 + except OverflowError: + return logx + + +def _log_print(logx): + """Pretty print.""" + if logx < math.log(sys.float_info.max): + return "{}".format(math.exp(logx)) + else: + return "exp({})".format(logx) + + +def _compute_log_a_int(q, sigma, alpha): + """Compute log(A_alpha) for integer alpha. 0 < q < 1.""" + assert isinstance(alpha, six.integer_types) + + # Initialize with 0 in the log space. + log_a = -np.inf + + for i in range(alpha + 1): + log_coef_i = ( + math.log(special.binom(alpha, i)) + i * math.log(q) + + (alpha - i) * math.log(1 - q)) + + s = log_coef_i + (i * i - i) / (2 * (sigma**2)) + log_a = _log_add(log_a, s) + + return float(log_a) + + +def _compute_log_a_frac(q, sigma, alpha): + """Compute log(A_alpha) for fractional alpha. 0 < q < 1.""" + # The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are + # initialized to 0 in the log space: + log_a0, log_a1 = -np.inf, -np.inf + i = 0 + + z0 = sigma**2 * math.log(1 / q - 1) + .5 + + while True: # do ... until loop + coef = special.binom(alpha, i) + log_coef = math.log(abs(coef)) + j = alpha - i + + log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q) + log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q) + + log_e0 = math.log(.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma)) + log_e1 = math.log(.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma)) + + log_s0 = log_t0 + (i * i - i) / (2 * (sigma**2)) + log_e0 + log_s1 = log_t1 + (j * j - j) / (2 * (sigma**2)) + log_e1 + + if coef > 0: + log_a0 = _log_add(log_a0, log_s0) + log_a1 = _log_add(log_a1, log_s1) + else: + log_a0 = _log_sub(log_a0, log_s0) + log_a1 = _log_sub(log_a1, log_s1) + + i += 1 + if max(log_s0, log_s1) < -30: + break + + return _log_add(log_a0, log_a1) + + +def _compute_log_a(q, sigma, alpha): + """Compute log(A_alpha) for any positive finite alpha.""" + if float(alpha).is_integer(): + return _compute_log_a_int(q, sigma, int(alpha)) + else: + return _compute_log_a_frac(q, sigma, alpha) + + +def _log_erfc(x): + """Compute log(erfc(x)) with high accuracy for large x.""" + try: + return math.log(2) + special.log_ndtr(-x * 2**.5) + except NameError: + # If log_ndtr is not available, approximate as follows: + r = special.erfc(x) + if r == 0.0: + # Using the Laurent series at infinity for the tail of the erfc function: + # erfc(x) ~ exp(-x^2-.5/x^2+.625/x^4)/(x*pi^.5) + # To verify in Mathematica: + # Series[Log[Erfc[x]] + Log[x] + Log[Pi]/2 + x^2, {x, Infinity, 6}] + return (-math.log(math.pi) / 2 - math.log(x) - x**2 - .5 * x**-2 + + .625 * x**-4 - 37. / 24. * x**-6 + 353. / 64. * x**-8) + else: + return math.log(r) + + +def _compute_delta(orders, rdp, eps): + """Compute delta given a list of RDP values and target epsilon. + + Args: + orders: An array (or a scalar) of orders. + rdp: A list (or a scalar) of RDP guarantees. + eps: The target epsilon. + + Returns: + Pair of (delta, optimal_order). + + Raises: + ValueError: If input is malformed. + + """ + orders_vec = np.atleast_1d(orders) + rdp_vec = np.atleast_1d(rdp) + + if len(orders_vec) != len(rdp_vec): + raise ValueError("Input lists must have the same length.") + + deltas = np.exp((rdp_vec - eps) * (orders_vec - 1)) + idx_opt = np.argmin(deltas) + return min(deltas[idx_opt], 1.), orders_vec[idx_opt] + + +def _compute_eps(orders, rdp, delta): + """Compute epsilon given a list of RDP values and target delta. + + Args: + orders: An array (or a scalar) of orders. + rdp: A list (or a scalar) of RDP guarantees. + delta: The target delta. + + Returns: + Pair of (eps, optimal_order). + + Raises: + ValueError: If input is malformed. + + """ + orders_vec = np.atleast_1d(orders) + rdp_vec = np.atleast_1d(rdp) + + if len(orders_vec) != len(rdp_vec): + raise ValueError("Input lists must have the same length.") + + eps = rdp_vec - math.log(delta) / (orders_vec - 1) + + idx_opt = np.nanargmin(eps) # Ignore NaNs + return eps[idx_opt], orders_vec[idx_opt] + + +def _compute_rdp(q, sigma, alpha): + """Compute RDP of the Sampled Gaussian mechanism at order alpha. + + Args: + q: The sampling rate. + sigma: The std of the additive Gaussian noise. + alpha: The order at which RDP is computed. + + Returns: + RDP at alpha, can be np.inf. + """ + if q == 0: + return 0 + + if q == 1.: + return alpha / (2 * sigma**2) + + if np.isinf(alpha): + return np.inf + + return _compute_log_a(q, sigma, alpha) / (alpha - 1) + + +def compute_rdp(q, noise_multiplier, steps, orders): + """Compute RDP of the Sampled Gaussian Mechanism. + + Args: + q: The sampling rate. + noise_multiplier: The ratio of the standard deviation of the Gaussian noise + to the l2-sensitivity of the function to which it is added. + steps: The number of steps. + orders: An array (or a scalar) of RDP orders. + + Returns: + The RDPs at all orders, can be np.inf. + """ + if np.isscalar(orders): + rdp = _compute_rdp(q, noise_multiplier, orders) + else: + rdp = np.array([_compute_rdp(q, noise_multiplier, order) + for order in orders]) + + return rdp * steps + + +def get_privacy_spent(orders, rdp, target_eps=None, target_delta=None): + """Compute delta (or eps) for given eps (or delta) from RDP values. + + Args: + orders: An array (or a scalar) of RDP orders. + rdp: An array of RDP values. Must be of the same length as the orders list. + target_eps: If not None, the epsilon for which we compute the corresponding + delta. + target_delta: If not None, the delta for which we compute the corresponding + epsilon. Exactly one of target_eps and target_delta must be None. + + Returns: + eps, delta, opt_order. + + Raises: + ValueError: If target_eps and target_delta are messed up. + """ + if target_eps is None and target_delta is None: + raise ValueError( + "Exactly one out of eps and delta must be None. (Both are).") + + if target_eps is not None and target_delta is not None: + raise ValueError( + "Exactly one out of eps and delta must be None. (None is).") + + if target_eps is not None: + delta, opt_order = _compute_delta(orders, rdp, target_eps) + return target_eps, delta, opt_order + else: + eps, opt_order = _compute_eps(orders, rdp, target_delta) + return eps, target_delta, opt_order + + +def compute_rdp_from_ledger(ledger, orders): + """Compute RDP of Sampled Gaussian Mechanism from ledger. + + Args: + ledger: A formatted privacy ledger. + orders: An array (or a scalar) of RDP orders. + + Returns: + RDP at all orders, can be np.inf. + """ + total_rdp = np.zeros_like(orders, dtype=float) + for sample in ledger: + # Compute equivalent z from l2_clip_bounds and noise stddevs in sample. + # See https://arxiv.org/pdf/1812.06210.pdf for derivation of this formula. + effective_z = sum([ + (q.noise_stddev / q.l2_norm_bound)**-2 for q in sample.queries])**-0.5 + total_rdp += compute_rdp( + sample.selection_probability, effective_z, 1, orders) + return total_rdp diff --git a/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py b/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py new file mode 100644 index 00000000..acc46a8e --- /dev/null +++ b/tensorflow_privacy/privacy/analysis/rdp_accountant_test.py @@ -0,0 +1,177 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for rdp_accountant.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +from absl.testing import absltest +from absl.testing import parameterized +from mpmath import exp +from mpmath import inf +from mpmath import log +from mpmath import npdf +from mpmath import quad +import numpy as np + +from tensorflow_privacy.privacy.analysis import privacy_ledger +from tensorflow_privacy.privacy.analysis import rdp_accountant + + +class TestGaussianMoments(parameterized.TestCase): + ################################# + # HELPER FUNCTIONS: # + # Exact computations using # + # multi-precision arithmetic. # + ################################# + + def _log_float_mp(self, x): + # Convert multi-precision input to float log space. + if x >= sys.float_info.min: + return float(log(x)) + else: + return -np.inf + + def _integral_mp(self, fn, bounds=(-inf, inf)): + integral, _ = quad(fn, bounds, error=True, maxdegree=8) + return integral + + def _distributions_mp(self, sigma, q): + + def _mu0(x): + return npdf(x, mu=0, sigma=sigma) + + def _mu1(x): + return npdf(x, mu=1, sigma=sigma) + + def _mu(x): + return (1 - q) * _mu0(x) + q * _mu1(x) + + return _mu0, _mu # Closure! + + def _mu1_over_mu0(self, x, sigma): + # Closed-form expression for N(1, sigma^2) / N(0, sigma^2) at x. + return exp((2 * x - 1) / (2 * sigma**2)) + + def _mu_over_mu0(self, x, q, sigma): + return (1 - q) + q * self._mu1_over_mu0(x, sigma) + + def _compute_a_mp(self, sigma, q, alpha): + """Compute A_alpha for arbitrary alpha by numerical integration.""" + mu0, _ = self._distributions_mp(sigma, q) + a_alpha_fn = lambda z: mu0(z) * self._mu_over_mu0(z, q, sigma)**alpha + a_alpha = self._integral_mp(a_alpha_fn) + return a_alpha + + # TEST ROUTINES + def test_compute_rdp_no_data(self): + # q = 0 + self.assertEqual(rdp_accountant.compute_rdp(0, 10, 1, 20), 0) + + def test_compute_rdp_no_sampling(self): + # q = 1, RDP = alpha/2 * sigma^2 + self.assertEqual(rdp_accountant.compute_rdp(1, 10, 1, 20), 0.1) + + def test_compute_rdp_scalar(self): + rdp_scalar = rdp_accountant.compute_rdp(0.1, 2, 10, 5) + self.assertAlmostEqual(rdp_scalar, 0.07737, places=5) + + def test_compute_rdp_sequence(self): + rdp_vec = rdp_accountant.compute_rdp(0.01, 2.5, 50, + [1.5, 2.5, 5, 50, 100, np.inf]) + self.assertSequenceAlmostEqual( + rdp_vec, [0.00065, 0.001085, 0.00218075, 0.023846, 167.416307, np.inf], + delta=1e-5) + + params = ({'q': 1e-7, 'sigma': .1, 'order': 1.01}, + {'q': 1e-6, 'sigma': .1, 'order': 256}, + {'q': 1e-5, 'sigma': .1, 'order': 256.1}, + {'q': 1e-6, 'sigma': 1, 'order': 27}, + {'q': 1e-4, 'sigma': 1., 'order': 1.5}, + {'q': 1e-3, 'sigma': 1., 'order': 2}, + {'q': .01, 'sigma': 10, 'order': 20}, + {'q': .1, 'sigma': 100, 'order': 20.5}, + {'q': .99, 'sigma': .1, 'order': 256}, + {'q': .999, 'sigma': 100, 'order': 256.1}) + + # pylint:disable=undefined-variable + @parameterized.parameters(p for p in params) + def test_compute_log_a_equals_mp(self, q, sigma, order): + # Compare the cheap computation of log(A) with an expensive, multi-precision + # computation. + log_a = rdp_accountant._compute_log_a(q, sigma, order) + log_a_mp = self._log_float_mp(self._compute_a_mp(sigma, q, order)) + np.testing.assert_allclose(log_a, log_a_mp, rtol=1e-4) + + def test_get_privacy_spent_check_target_delta(self): + orders = range(2, 33) + rdp = rdp_accountant.compute_rdp(0.01, 4, 10000, orders) + eps, _, opt_order = rdp_accountant.get_privacy_spent( + orders, rdp, target_delta=1e-5) + self.assertAlmostEqual(eps, 1.258575, places=5) + self.assertEqual(opt_order, 20) + + def test_get_privacy_spent_check_target_eps(self): + orders = range(2, 33) + rdp = rdp_accountant.compute_rdp(0.01, 4, 10000, orders) + _, delta, opt_order = rdp_accountant.get_privacy_spent( + orders, rdp, target_eps=1.258575) + self.assertAlmostEqual(delta, 1e-5) + self.assertEqual(opt_order, 20) + + def test_check_composition(self): + orders = (1.25, 1.5, 1.75, 2., 2.5, 3., 4., 5., 6., 7., 8., 10., 12., 14., + 16., 20., 24., 28., 32., 64., 256.) + + rdp = rdp_accountant.compute_rdp(q=1e-4, + noise_multiplier=.4, + steps=40000, + orders=orders) + + eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp, + target_delta=1e-6) + + rdp += rdp_accountant.compute_rdp(q=0.1, + noise_multiplier=2, + steps=100, + orders=orders) + eps, _, opt_order = rdp_accountant.get_privacy_spent(orders, rdp, + target_delta=1e-5) + self.assertAlmostEqual(eps, 8.509656, places=5) + self.assertEqual(opt_order, 2.5) + + def test_compute_rdp_from_ledger(self): + orders = range(2, 33) + q = 0.1 + n = 1000 + l2_norm_clip = 3.14159 + noise_stddev = 2.71828 + steps = 3 + + query_entry = privacy_ledger.GaussianSumQueryEntry( + l2_norm_clip, noise_stddev) + ledger = [privacy_ledger.SampleEntry(n, q, [query_entry])] * steps + + z = noise_stddev / l2_norm_clip + rdp = rdp_accountant.compute_rdp(q, z, steps, orders) + rdp_from_ledger = rdp_accountant.compute_rdp_from_ledger(ledger, orders) + self.assertSequenceAlmostEqual(rdp, rdp_from_ledger) + + +if __name__ == '__main__': + absltest.main() diff --git a/tensorflow_privacy/privacy/analysis/tensor_buffer.py b/tensorflow_privacy/privacy/analysis/tensor_buffer.py new file mode 100644 index 00000000..a0cf6655 --- /dev/null +++ b/tensorflow_privacy/privacy/analysis/tensor_buffer.py @@ -0,0 +1,134 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""A lightweight buffer for maintaining tensors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + + +class TensorBuffer(object): + """A lightweight buffer for maintaining lists. + + The TensorBuffer accumulates tensors of the given shape into a tensor (whose + rank is one more than that of the given shape) via calls to `append`. The + current value of the accumulated tensor can be extracted via the property + `values`. + """ + + def __init__(self, capacity, shape, dtype=tf.int32, name=None): + """Initializes the TensorBuffer. + + Args: + capacity: Initial capacity. Buffer will double in capacity each time it is + filled to capacity. + shape: The shape (as tuple or list) of the tensors to accumulate. + dtype: The type of the tensors. + name: A string name for the variable_scope used. + + Raises: + ValueError: If the shape is empty (specifies scalar shape). + """ + shape = list(shape) + self._rank = len(shape) + self._name = name + self._dtype = dtype + if not self._rank: + raise ValueError('Shape cannot be scalar.') + shape = [capacity] + shape + + with tf.variable_scope(self._name): + # We need to use a placeholder as the initial value to allow resizing. + self._buffer = tf.Variable( + initial_value=tf.placeholder_with_default( + tf.zeros(shape, dtype), shape=None), + trainable=False, + name='buffer', + use_resource=True) + self._current_size = tf.Variable( + initial_value=0, dtype=tf.int32, trainable=False, name='current_size') + self._capacity = tf.Variable( + initial_value=capacity, + dtype=tf.int32, + trainable=False, + name='capacity') + + def append(self, value): + """Appends a new tensor to the end of the buffer. + + Args: + value: The tensor to append. Must match the shape specified in the + initializer. + + Returns: + An op appending the new tensor to the end of the buffer. + """ + + def _double_capacity(): + """Doubles the capacity of the current tensor buffer.""" + padding = tf.zeros_like(self._buffer, self._buffer.dtype) + new_buffer = tf.concat([self._buffer, padding], axis=0) + if tf.executing_eagerly(): + with tf.variable_scope(self._name, reuse=True): + self._buffer = tf.get_variable( + name='buffer', + dtype=self._dtype, + initializer=new_buffer, + trainable=False) + return self._buffer, tf.assign(self._capacity, + tf.multiply(self._capacity, 2)) + else: + return tf.assign( + self._buffer, new_buffer, + validate_shape=False), tf.assign(self._capacity, + tf.multiply(self._capacity, 2)) + + update_buffer, update_capacity = tf.cond( + tf.equal(self._current_size, self._capacity), + _double_capacity, lambda: (self._buffer, self._capacity)) + + with tf.control_dependencies([update_buffer, update_capacity]): + with tf.control_dependencies([ + tf.assert_less( + self._current_size, + self._capacity, + message='Appending past end of TensorBuffer.'), + tf.assert_equal( + tf.shape(value), + tf.shape(self._buffer)[1:], + message='Appending value of inconsistent shape.') + ]): + with tf.control_dependencies( + [tf.assign(self._buffer[self._current_size, :], value)]): + return tf.assign_add(self._current_size, 1) + + @property + def values(self): + """Returns the accumulated tensor.""" + begin_value = tf.zeros([self._rank + 1], dtype=tf.int32) + value_size = tf.concat([[self._current_size], + tf.constant(-1, tf.int32, [self._rank])], 0) + return tf.slice(self._buffer, begin_value, value_size) + + @property + def current_size(self): + """Returns the current number of tensors in the buffer.""" + return self._current_size + + @property + def capacity(self): + """Returns the current capacity of the buffer.""" + return self._capacity diff --git a/tensorflow_privacy/privacy/analysis/tensor_buffer_test_eager.py b/tensorflow_privacy/privacy/analysis/tensor_buffer_test_eager.py new file mode 100644 index 00000000..ef019104 --- /dev/null +++ b/tensorflow_privacy/privacy/analysis/tensor_buffer_test_eager.py @@ -0,0 +1,84 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tensor_buffer in eager mode.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow_privacy.privacy.analysis import tensor_buffer + +tf.enable_eager_execution() + + +class TensorBufferTest(tf.test.TestCase): + """Tests for TensorBuffer in eager mode.""" + + def test_basic(self): + size, shape = 2, [2, 3] + + my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer') + + value1 = [[1, 2, 3], [4, 5, 6]] + my_buffer.append(value1) + self.assertAllEqual(my_buffer.values.numpy(), [value1]) + + value2 = [[4, 5, 6], [7, 8, 9]] + my_buffer.append(value2) + self.assertAllEqual(my_buffer.values.numpy(), [value1, value2]) + + def test_fail_on_scalar(self): + with self.assertRaisesRegexp(ValueError, 'Shape cannot be scalar.'): + tensor_buffer.TensorBuffer(1, ()) + + def test_fail_on_inconsistent_shape(self): + size, shape = 1, [2, 3] + + my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer') + + with self.assertRaisesRegexp( + tf.errors.InvalidArgumentError, + 'Appending value of inconsistent shape.'): + my_buffer.append(tf.ones(shape=[3, 4], dtype=tf.int32)) + + def test_resize(self): + size, shape = 2, [2, 3] + + my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer') + + # Append three buffers. Third one should succeed after resizing. + value1 = [[1, 2, 3], [4, 5, 6]] + my_buffer.append(value1) + self.assertAllEqual(my_buffer.values.numpy(), [value1]) + self.assertAllEqual(my_buffer.current_size.numpy(), 1) + self.assertAllEqual(my_buffer.capacity.numpy(), 2) + + value2 = [[4, 5, 6], [7, 8, 9]] + my_buffer.append(value2) + self.assertAllEqual(my_buffer.values.numpy(), [value1, value2]) + self.assertAllEqual(my_buffer.current_size.numpy(), 2) + self.assertAllEqual(my_buffer.capacity.numpy(), 2) + + value3 = [[7, 8, 9], [10, 11, 12]] + my_buffer.append(value3) + self.assertAllEqual(my_buffer.values.numpy(), [value1, value2, value3]) + self.assertAllEqual(my_buffer.current_size.numpy(), 3) + # Capacity should have doubled. + self.assertAllEqual(my_buffer.capacity.numpy(), 4) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/analysis/tensor_buffer_test_graph.py b/tensorflow_privacy/privacy/analysis/tensor_buffer_test_graph.py new file mode 100644 index 00000000..5a66ec6e --- /dev/null +++ b/tensorflow_privacy/privacy/analysis/tensor_buffer_test_graph.py @@ -0,0 +1,72 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for tensor_buffer in graph mode.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow_privacy.privacy.analysis import tensor_buffer + + +class TensorBufferTest(tf.test.TestCase): + """Tests for TensorBuffer in graph mode.""" + + def test_noresize(self): + """Test buffer does not resize if capacity is not exceeded.""" + with self.cached_session() as sess: + size, shape = 2, [2, 3] + + my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer') + value1 = [[1, 2, 3], [4, 5, 6]] + with tf.control_dependencies([my_buffer.append(value1)]): + value2 = [[7, 8, 9], [10, 11, 12]] + with tf.control_dependencies([my_buffer.append(value2)]): + values = my_buffer.values + current_size = my_buffer.current_size + capacity = my_buffer.capacity + self.evaluate(tf.global_variables_initializer()) + + v, cs, cap = sess.run([values, current_size, capacity]) + self.assertAllEqual(v, [value1, value2]) + self.assertEqual(cs, 2) + self.assertEqual(cap, 2) + + def test_resize(self): + """Test buffer resizes if capacity is exceeded.""" + with self.cached_session() as sess: + size, shape = 2, [2, 3] + + my_buffer = tensor_buffer.TensorBuffer(size, shape, name='my_buffer') + value1 = [[1, 2, 3], [4, 5, 6]] + with tf.control_dependencies([my_buffer.append(value1)]): + value2 = [[7, 8, 9], [10, 11, 12]] + with tf.control_dependencies([my_buffer.append(value2)]): + value3 = [[13, 14, 15], [16, 17, 18]] + with tf.control_dependencies([my_buffer.append(value3)]): + values = my_buffer.values + current_size = my_buffer.current_size + capacity = my_buffer.capacity + self.evaluate(tf.global_variables_initializer()) + + v, cs, cap = sess.run([values, current_size, capacity]) + self.assertAllEqual(v, [value1, value2, value3]) + self.assertEqual(cs, 3) + self.assertEqual(cap, 4) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/bolt_on/README.md b/tensorflow_privacy/privacy/bolt_on/README.md new file mode 100644 index 00000000..1eb9a6a5 --- /dev/null +++ b/tensorflow_privacy/privacy/bolt_on/README.md @@ -0,0 +1,67 @@ +# BoltOn Subpackage + +This package contains source code for the BoltOn method, a particular +differential-privacy (DP) technique that uses output perturbations and +leverages additional assumptions to provide a new way of approaching the +privacy guarantees. + +## BoltOn Description + +This method uses 4 key steps to achieve privacy guarantees: + 1. Adds noise to weights after training (output perturbation). + 2. Projects weights to R, the radius of the hypothesis space, + after each batch. This value is configurable by the user. + 3. Limits learning rate + 4. Uses a strongly convex loss function (see compile) + +For more details on the strong convexity requirements, see: +Bolt-on Differential Privacy for Scalable Stochastic Gradient +Descent-based Analytics by Xi Wu et al. at https://arxiv.org/pdf/1606.04722.pdf + +## Why BoltOn? + +The major difference for the BoltOn method is that it injects noise post model +convergence, rather than noising gradients or weights during training. This +approach requires some additional constraints listed in the Description. +Should the use-case and model satisfy these constraints, this is another +approach that can be trained to maximize utility while maintaining the privacy. +The paper describes in detail the advantages and disadvantages of this approach +and its results compared to some other methods, namely noising at each iteration +and no noising. + +## Tutorials + +This package has a tutorial that can be found in the root tutorials directory, +under `bolton_tutorial.py`. + +## Contribution + +This package was initially contributed by Georgian Partners with the hope of +growing the tensorflow/privacy library. There are several rich use cases for +delta-epsilon privacy in machine learning, some of which can be explored here: +https://medium.com/apache-mxnet/epsilon-differential-privacy-for-machine-learning-using-mxnet-a4270fe3865e +https://arxiv.org/pdf/1811.04911.pdf + +## Stability + +As we are pegged on tensorflow2.0, this package may encounter stability +issues in the ongoing development of tensorflow2.0. + +This sub-package is currently stable for 2.0.0a0, 2.0.0b0, and 2.0.0.b1 If you +would like to use this subpackage, please do use one of these versions as we +cannot guarantee it will work for all latest releases. If you do find issues, +feel free to raise an issue to the contributors listed below. + +## Contacts + +In addition to the maintainers of tensorflow/privacy listed in the root +README.md, please feel free to contact members of Georgian Partners. In +particular, + +* Georgian Partners(@georgianpartners) +* Ji Chao Zhang(@Jichaogp) +* Christopher Choquette(@cchoquette) + +## Copyright + +Copyright 2019 - Google LLC diff --git a/tensorflow_privacy/privacy/bolt_on/__init__.py b/tensorflow_privacy/privacy/bolt_on/__init__.py new file mode 100644 index 00000000..2f87e3c3 --- /dev/null +++ b/tensorflow_privacy/privacy/bolt_on/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2019, The TensorFlow Privacy Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BoltOn Method for privacy.""" +import sys +from distutils.version import LooseVersion +import tensorflow as tf + +if LooseVersion(tf.__version__) < LooseVersion("2.0.0"): + raise ImportError("Please upgrade your version " + "of tensorflow from: {0} to at least 2.0.0 to " + "use privacy/bolt_on".format(LooseVersion(tf.__version__))) +if hasattr(sys, "skip_tf_privacy_import"): # Useful for standalone scripts. + pass +else: + from tensorflow_privacy.privacy.bolt_on.models import BoltOnModel # pylint: disable=g-import-not-at-top + from tensorflow_privacy.privacy.bolt_on.optimizers import BoltOn # pylint: disable=g-import-not-at-top + from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexHuber # pylint: disable=g-import-not-at-top + from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexBinaryCrossentropy # pylint: disable=g-import-not-at-top diff --git a/tensorflow_privacy/privacy/bolt_on/losses.py b/tensorflow_privacy/privacy/bolt_on/losses.py new file mode 100644 index 00000000..81bd0c36 --- /dev/null +++ b/tensorflow_privacy/privacy/bolt_on/losses.py @@ -0,0 +1,304 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Loss functions for BoltOn method.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python.framework import ops as _ops +from tensorflow.python.keras import losses +from tensorflow.python.keras.regularizers import L1L2 +from tensorflow.python.keras.utils import losses_utils +from tensorflow.python.platform import tf_logging as logging + + +class StrongConvexMixin: # pylint: disable=old-style-class + """Strong Convex Mixin base class. + + Strong Convex Mixin base class for any loss function that will be used with + BoltOn model. Subclasses must be strongly convex and implement the + associated constants. They must also conform to the requirements of tf losses + (see super class). + + For more details on the strong convexity requirements, see: + Bolt-on Differential Privacy for Scalable Stochastic Gradient + Descent-based Analytics by Xi Wu et. al. + """ + + def radius(self): + """Radius, R, of the hypothesis space W. + + W is a convex set that forms the hypothesis space. + + Returns: + R + """ + raise NotImplementedError("Radius not implemented for StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def gamma(self): + """Returns strongly convex parameter, gamma.""" + raise NotImplementedError("Gamma not implemented for StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def beta(self, class_weight): + """Smoothness, beta. + + Args: + class_weight: the class weights as scalar or 1d tensor, where its + dimensionality is equal to the number of outputs. + + Returns: + Beta + """ + raise NotImplementedError("Beta not implemented for StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def lipchitz_constant(self, class_weight): + """Lipchitz constant, L. + + Args: + class_weight: class weights used + + Returns: L + """ + raise NotImplementedError("lipchitz constant not implemented for " + "StrongConvex Loss" + "function: %s" % str(self.__class__.__name__)) + + def kernel_regularizer(self): + """Returns the kernel_regularizer to be used. + + Any subclass should override this method if they want a kernel_regularizer + (if required for the loss function to be StronglyConvex. + """ + return None + + def max_class_weight(self, class_weight, dtype): + """The maximum weighting in class weights (max value) as a scalar tensor. + + Args: + class_weight: class weights used + dtype: the data type for tensor conversions. + + Returns: + maximum class weighting as tensor scalar + """ + class_weight = _ops.convert_to_tensor_v2(class_weight, dtype) + return tf.math.reduce_max(class_weight) + + +class StrongConvexHuber(losses.Loss, StrongConvexMixin): + """Strong Convex version of Huber loss using l2 weight regularization.""" + + def __init__(self, + reg_lambda, + c_arg, + radius_constant, + delta, + reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, + dtype=tf.float32): + """Constructor. + + Args: + reg_lambda: Weight regularization constant + c_arg: Penalty parameter C of the loss term + radius_constant: constant defining the length of the radius + delta: delta value in huber loss. When to switch from quadratic to + absolute deviation. + reduction: reduction type to use. See super class + dtype: tf datatype to use for tensor conversions. + + Returns: + Loss values per sample. + """ + if c_arg <= 0: + raise ValueError("c: {0}, should be >= 0".format(c_arg)) + if reg_lambda <= 0: + raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) + if radius_constant <= 0: + raise ValueError("radius_constant: {0}, should be >= 0".format( + radius_constant + )) + if delta <= 0: + raise ValueError("delta: {0}, should be >= 0".format( + delta + )) + self.C = c_arg # pylint: disable=invalid-name + self.delta = delta + self.radius_constant = radius_constant + self.dtype = dtype + self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) + super(StrongConvexHuber, self).__init__( + name="strongconvexhuber", + reduction=reduction, + ) + + def call(self, y_true, y_pred): + """Computes loss. + + Args: + y_true: Ground truth values. One hot encoded using -1 and 1. + y_pred: The predicted values. + + Returns: + Loss values per sample. + """ + h = self.delta + z = y_pred * y_true + one = tf.constant(1, dtype=self.dtype) + four = tf.constant(4, dtype=self.dtype) + + if z > one + h: # pylint: disable=no-else-return + return _ops.convert_to_tensor_v2(0, dtype=self.dtype) + elif tf.math.abs(one - z) <= h: + return one / (four * h) * tf.math.pow(one + h - z, 2) + return one - z + + def radius(self): + """See super class.""" + return self.radius_constant / self.reg_lambda + + def gamma(self): + """See super class.""" + return self.reg_lambda + + def beta(self, class_weight): + """See super class.""" + max_class_weight = self.max_class_weight(class_weight, self.dtype) + delta = _ops.convert_to_tensor_v2(self.delta, + dtype=self.dtype + ) + return self.C * max_class_weight / (delta * + tf.constant(2, dtype=self.dtype)) + \ + self.reg_lambda + + def lipchitz_constant(self, class_weight): + """See super class.""" + # if class_weight is provided, + # it should be a vector of the same size of number of classes + max_class_weight = self.max_class_weight(class_weight, self.dtype) + lc = self.C * max_class_weight + \ + self.reg_lambda * self.radius() + return lc + + def kernel_regularizer(self): + """Return l2 loss using 0.5*reg_lambda as the l2 term (as desired). + + L2 regularization is required for this loss function to be strongly convex. + + Returns: + The L2 regularizer layer for this loss function, with regularizer constant + set to half the 0.5 * reg_lambda. + """ + return L1L2(l2=self.reg_lambda/2) + + +class StrongConvexBinaryCrossentropy( + losses.BinaryCrossentropy, + StrongConvexMixin +): + """Strongly Convex BinaryCrossentropy loss using l2 weight regularization.""" + + def __init__(self, + reg_lambda, + c_arg, + radius_constant, + from_logits=True, + label_smoothing=0, + reduction=losses_utils.ReductionV2.SUM_OVER_BATCH_SIZE, + dtype=tf.float32): + """StrongConvexBinaryCrossentropy class. + + Args: + reg_lambda: Weight regularization constant + c_arg: Penalty parameter C of the loss term + radius_constant: constant defining the length of the radius + from_logits: True if the input are unscaled logits. False if they are + already scaled. + label_smoothing: amount of smoothing to perform on labels + relaxation of trust in labels, e.g. (1 -> 1-x, 0 -> 0+x). Note, the + impact of this parameter's effect on privacy is not known and thus the + default should be used. + reduction: reduction type to use. See super class + dtype: tf datatype to use for tensor conversions. + """ + if label_smoothing != 0: + logging.warning("The impact of label smoothing on privacy is unknown. " + "Use label smoothing at your own risk as it may not " + "guarantee privacy.") + + if reg_lambda <= 0: + raise ValueError("reg lambda: {0} must be positive".format(reg_lambda)) + if c_arg <= 0: + raise ValueError("c: {0}, should be >= 0".format(c_arg)) + if radius_constant <= 0: + raise ValueError("radius_constant: {0}, should be >= 0".format( + radius_constant + )) + self.dtype = dtype + self.C = c_arg # pylint: disable=invalid-name + self.reg_lambda = tf.constant(reg_lambda, dtype=self.dtype) + super(StrongConvexBinaryCrossentropy, self).__init__( + reduction=reduction, + name="strongconvexbinarycrossentropy", + from_logits=from_logits, + label_smoothing=label_smoothing, + ) + self.radius_constant = radius_constant + + def call(self, y_true, y_pred): + """Computes loss. + + Args: + y_true: Ground truth values. + y_pred: The predicted values. + + Returns: + Loss values per sample. + """ + loss = super(StrongConvexBinaryCrossentropy, self).call(y_true, y_pred) + loss = loss * self.C + return loss + + def radius(self): + """See super class.""" + return self.radius_constant / self.reg_lambda + + def gamma(self): + """See super class.""" + return self.reg_lambda + + def beta(self, class_weight): + """See super class.""" + max_class_weight = self.max_class_weight(class_weight, self.dtype) + return self.C * max_class_weight + self.reg_lambda + + def lipchitz_constant(self, class_weight): + """See super class.""" + max_class_weight = self.max_class_weight(class_weight, self.dtype) + return self.C * max_class_weight + self.reg_lambda * self.radius() + + def kernel_regularizer(self): + """Return l2 loss using 0.5*reg_lambda as the l2 term (as desired). + + L2 regularization is required for this loss function to be strongly convex. + + Returns: + The L2 regularizer layer for this loss function, with regularizer constant + set to half the 0.5 * reg_lambda. + """ + return L1L2(l2=self.reg_lambda/2) diff --git a/tensorflow_privacy/privacy/bolt_on/losses_test.py b/tensorflow_privacy/privacy/bolt_on/losses_test.py new file mode 100644 index 00000000..67f3d9c9 --- /dev/null +++ b/tensorflow_privacy/privacy/bolt_on/losses_test.py @@ -0,0 +1,431 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit testing for losses.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from contextlib import contextmanager # pylint: disable=g-importing-member +from io import StringIO # pylint: disable=g-importing-member +import sys +from absl.testing import parameterized +import tensorflow as tf +from tensorflow.python.framework import test_util +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras.regularizers import L1L2 +from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexBinaryCrossentropy +from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexHuber +from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexMixin + + +@contextmanager +def captured_output(): + """Capture std_out and std_err within context.""" + new_out, new_err = StringIO(), StringIO() + old_out, old_err = sys.stdout, sys.stderr + try: + sys.stdout, sys.stderr = new_out, new_err + yield sys.stdout, sys.stderr + finally: + sys.stdout, sys.stderr = old_out, old_err + + +class StrongConvexMixinTests(keras_parameterized.TestCase): + """Tests for the StrongConvexMixin.""" + @parameterized.named_parameters([ + {'testcase_name': 'beta not implemented', + 'fn': 'beta', + 'args': [1]}, + {'testcase_name': 'gamma not implemented', + 'fn': 'gamma', + 'args': []}, + {'testcase_name': 'lipchitz not implemented', + 'fn': 'lipchitz_constant', + 'args': [1]}, + {'testcase_name': 'radius not implemented', + 'fn': 'radius', + 'args': []}, + ]) + + def test_not_implemented(self, fn, args): + """Test that the given fn's are not implemented on the mixin. + + Args: + fn: fn on Mixin to test + args: arguments to fn of Mixin + """ + with self.assertRaises(NotImplementedError): + loss = StrongConvexMixin() + getattr(loss, fn, None)(*args) + + @parameterized.named_parameters([ + {'testcase_name': 'radius not implemented', + 'fn': 'kernel_regularizer', + 'args': []}, + ]) + def test_return_none(self, fn, args): + """Test that fn of Mixin returns None. + + Args: + fn: fn of Mixin to test + args: arguments to fn of Mixin + """ + loss = StrongConvexMixin() + ret = getattr(loss, fn, None)(*args) + self.assertEqual(ret, None) + + +class BinaryCrossesntropyTests(keras_parameterized.TestCase): + """tests for BinaryCrossesntropy StrongConvex loss.""" + + @parameterized.named_parameters([ + {'testcase_name': 'normal', + 'reg_lambda': 1, + 'C': 1, + 'radius_constant': 1 + }, # pylint: disable=invalid-name + ]) + def test_init_params(self, reg_lambda, C, radius_constant): + """Test initialization for given arguments. + + Args: + reg_lambda: initialization value for reg_lambda arg + C: initialization value for C arg + radius_constant: initialization value for radius_constant arg + """ + # test valid domains for each variable + loss = StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant) + self.assertIsInstance(loss, StrongConvexBinaryCrossentropy) + + @parameterized.named_parameters([ + {'testcase_name': 'negative c', + 'reg_lambda': 1, + 'C': -1, + 'radius_constant': 1 + }, + {'testcase_name': 'negative radius', + 'reg_lambda': 1, + 'C': 1, + 'radius_constant': -1 + }, + {'testcase_name': 'negative lambda', + 'reg_lambda': -1, + 'C': 1, + 'radius_constant': 1 + }, # pylint: disable=invalid-name + ]) + def test_bad_init_params(self, reg_lambda, C, radius_constant): + """Test invalid domain for given params. Should return ValueError. + + Args: + reg_lambda: initialization value for reg_lambda arg + C: initialization value for C arg + radius_constant: initialization value for radius_constant arg + """ + # test valid domains for each variable + with self.assertRaises(ValueError): + StrongConvexBinaryCrossentropy(reg_lambda, C, radius_constant) + + @test_util.run_all_in_graph_and_eager_modes + @parameterized.named_parameters([ + # [] for compatibility with tensorflow loss calculation + {'testcase_name': 'both positive', + 'logits': [10000], + 'y_true': [1], + 'result': 0, + }, + {'testcase_name': 'positive gradient negative logits', + 'logits': [-10000], + 'y_true': [1], + 'result': 10000, + }, + {'testcase_name': 'positivee gradient positive logits', + 'logits': [10000], + 'y_true': [0], + 'result': 10000, + }, + {'testcase_name': 'both negative', + 'logits': [-10000], + 'y_true': [0], + 'result': 0 + }, + ]) + def test_calculation(self, logits, y_true, result): + """Test the call method to ensure it returns the correct value. + + Args: + logits: unscaled output of model + y_true: label + result: correct loss calculation value + """ + logits = tf.Variable(logits, False, dtype=tf.float32) + y_true = tf.Variable(y_true, False, dtype=tf.float32) + loss = StrongConvexBinaryCrossentropy(0.00001, 1, 1) + loss = loss(y_true, logits) + self.assertEqual(loss.numpy(), result) + + @parameterized.named_parameters([ + {'testcase_name': 'beta', + 'init_args': [1, 1, 1], + 'fn': 'beta', + 'args': [1], + 'result': tf.constant(2, dtype=tf.float32) + }, + {'testcase_name': 'gamma', + 'fn': 'gamma', + 'init_args': [1, 1, 1], + 'args': [], + 'result': tf.constant(1, dtype=tf.float32), + }, + {'testcase_name': 'lipchitz constant', + 'fn': 'lipchitz_constant', + 'init_args': [1, 1, 1], + 'args': [1], + 'result': tf.constant(2, dtype=tf.float32), + }, + {'testcase_name': 'kernel regularizer', + 'fn': 'kernel_regularizer', + 'init_args': [1, 1, 1], + 'args': [], + 'result': L1L2(l2=0.5), + }, + ]) + def test_fns(self, init_args, fn, args, result): + """Test that fn of BinaryCrossentropy loss returns the correct result. + + Args: + init_args: init values for loss instance + fn: the fn to test + args: the arguments to above function + result: the correct result from the fn + """ + loss = StrongConvexBinaryCrossentropy(*init_args) + expected = getattr(loss, fn, lambda: 'fn not found')(*args) + if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor + expected = expected.numpy() + result = result.numpy() + if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer + expected = expected.l2 + result = result.l2 + self.assertEqual(expected, result) + + @parameterized.named_parameters([ + {'testcase_name': 'label_smoothing', + 'init_args': [1, 1, 1, True, 0.1], + 'fn': None, + 'args': None, + 'print_res': 'The impact of label smoothing on privacy is unknown.' + }, + ]) + def test_prints(self, init_args, fn, args, print_res): + """Test logger warning from StrongConvexBinaryCrossentropy. + + Args: + init_args: arguments to init the object with. + fn: function to test + args: arguments to above function + print_res: print result that should have been printed. + """ + with captured_output() as (out, err): # pylint: disable=unused-variable + loss = StrongConvexBinaryCrossentropy(*init_args) + if fn is not None: + getattr(loss, fn, lambda *arguments: print('error'))(*args) + self.assertRegexMatch(err.getvalue().strip(), [print_res]) + + +class HuberTests(keras_parameterized.TestCase): + """tests for BinaryCrossesntropy StrongConvex loss.""" + + @parameterized.named_parameters([ + {'testcase_name': 'normal', + 'reg_lambda': 1, + 'c': 1, + 'radius_constant': 1, + 'delta': 1, + }, + ]) + def test_init_params(self, reg_lambda, c, radius_constant, delta): + """Test initialization for given arguments. + + Args: + reg_lambda: initialization value for reg_lambda arg + c: initialization value for C arg + radius_constant: initialization value for radius_constant arg + delta: the delta parameter for the huber loss + """ + # test valid domains for each variable + loss = StrongConvexHuber(reg_lambda, c, radius_constant, delta) + self.assertIsInstance(loss, StrongConvexHuber) + + @parameterized.named_parameters([ + {'testcase_name': 'negative c', + 'reg_lambda': 1, + 'c': -1, + 'radius_constant': 1, + 'delta': 1 + }, + {'testcase_name': 'negative radius', + 'reg_lambda': 1, + 'c': 1, + 'radius_constant': -1, + 'delta': 1 + }, + {'testcase_name': 'negative lambda', + 'reg_lambda': -1, + 'c': 1, + 'radius_constant': 1, + 'delta': 1 + }, + {'testcase_name': 'negative delta', + 'reg_lambda': 1, + 'c': 1, + 'radius_constant': 1, + 'delta': -1 + }, + ]) + def test_bad_init_params(self, reg_lambda, c, radius_constant, delta): + """Test invalid domain for given params. Should return ValueError. + + Args: + reg_lambda: initialization value for reg_lambda arg + c: initialization value for C arg + radius_constant: initialization value for radius_constant arg + delta: the delta parameter for the huber loss + """ + # test valid domains for each variable + with self.assertRaises(ValueError): + StrongConvexHuber(reg_lambda, c, radius_constant, delta) + + # test the bounds and test varied delta's + @test_util.run_all_in_graph_and_eager_modes + @parameterized.named_parameters([ + {'testcase_name': 'delta=1,y_true=1 z>1+h decision boundary', + 'logits': 2.1, + 'y_true': 1, + 'delta': 1, + 'result': 0, + }, + {'testcase_name': 'delta=1,y_true=1 z<1+h decision boundary', + 'logits': 1.9, + 'y_true': 1, + 'delta': 1, + 'result': 0.01*0.25, + }, + {'testcase_name': 'delta=1,y_true=1 1-z< h decision boundary', + 'logits': 0.1, + 'y_true': 1, + 'delta': 1, + 'result': 1.9**2 * 0.25, + }, + {'testcase_name': 'delta=1,y_true=1 z < 1-h decision boundary', + 'logits': -0.1, + 'y_true': 1, + 'delta': 1, + 'result': 1.1, + }, + {'testcase_name': 'delta=2,y_true=1 z>1+h decision boundary', + 'logits': 3.1, + 'y_true': 1, + 'delta': 2, + 'result': 0, + }, + {'testcase_name': 'delta=2,y_true=1 z<1+h decision boundary', + 'logits': 2.9, + 'y_true': 1, + 'delta': 2, + 'result': 0.01*0.125, + }, + {'testcase_name': 'delta=2,y_true=1 1-z < h decision boundary', + 'logits': 1.1, + 'y_true': 1, + 'delta': 2, + 'result': 1.9**2 * 0.125, + }, + {'testcase_name': 'delta=2,y_true=1 z < 1-h decision boundary', + 'logits': -1.1, + 'y_true': 1, + 'delta': 2, + 'result': 2.1, + }, + {'testcase_name': 'delta=1,y_true=-1 z>1+h decision boundary', + 'logits': -2.1, + 'y_true': -1, + 'delta': 1, + 'result': 0, + }, + ]) + def test_calculation(self, logits, y_true, delta, result): + """Test the call method to ensure it returns the correct value. + + Args: + logits: unscaled output of model + y_true: label + delta: delta value for StrongConvexHuber loss. + result: correct loss calculation value + """ + logits = tf.Variable(logits, False, dtype=tf.float32) + y_true = tf.Variable(y_true, False, dtype=tf.float32) + loss = StrongConvexHuber(0.00001, 1, 1, delta) + loss = loss(y_true, logits) + self.assertAllClose(loss.numpy(), result) + + @parameterized.named_parameters([ + {'testcase_name': 'beta', + 'init_args': [1, 1, 1, 1], + 'fn': 'beta', + 'args': [1], + 'result': tf.Variable(1.5, dtype=tf.float32) + }, + {'testcase_name': 'gamma', + 'fn': 'gamma', + 'init_args': [1, 1, 1, 1], + 'args': [], + 'result': tf.Variable(1, dtype=tf.float32), + }, + {'testcase_name': 'lipchitz constant', + 'fn': 'lipchitz_constant', + 'init_args': [1, 1, 1, 1], + 'args': [1], + 'result': tf.Variable(2, dtype=tf.float32), + }, + {'testcase_name': 'kernel regularizer', + 'fn': 'kernel_regularizer', + 'init_args': [1, 1, 1, 1], + 'args': [], + 'result': L1L2(l2=0.5), + }, + ]) + def test_fns(self, init_args, fn, args, result): + """Test that fn of BinaryCrossentropy loss returns the correct result. + + Args: + init_args: init values for loss instance + fn: the fn to test + args: the arguments to above function + result: the correct result from the fn + """ + loss = StrongConvexHuber(*init_args) + expected = getattr(loss, fn, lambda: 'fn not found')(*args) + if hasattr(expected, 'numpy') and hasattr(result, 'numpy'): # both tensor + expected = expected.numpy() + result = result.numpy() + if hasattr(expected, 'l2') and hasattr(result, 'l2'): # both l2 regularizer + expected = expected.l2 + result = result.l2 + self.assertEqual(expected, result) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/bolt_on/models.py b/tensorflow_privacy/privacy/bolt_on/models.py new file mode 100644 index 00000000..efea5cda --- /dev/null +++ b/tensorflow_privacy/privacy/bolt_on/models.py @@ -0,0 +1,303 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BoltOn model for Bolt-on method of differentially private ML.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import tensorflow as tf +from tensorflow.python.framework import ops as _ops +from tensorflow.python.keras import optimizers +from tensorflow.python.keras.models import Model +from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexMixin +from tensorflow_privacy.privacy.bolt_on.optimizers import BoltOn + + +class BoltOnModel(Model): # pylint: disable=abstract-method + """BoltOn episilon-delta differential privacy model. + + The privacy guarantees are dependent on the noise that is sampled. Please + see the paper linked below for more details. + + Uses 4 key steps to achieve privacy guarantees: + 1. Adds noise to weights after training (output perturbation). + 2. Projects weights to R after each batch + 3. Limits learning rate + 4. Use a strongly convex loss function (see compile) + + For more details on the strong convexity requirements, see: + Bolt-on Differential Privacy for Scalable Stochastic Gradient + Descent-based Analytics by Xi Wu et al. + """ + + def __init__(self, + n_outputs, + seed=1, + dtype=tf.float32): + """Private constructor. + + Args: + n_outputs: number of output classes to predict. + seed: random seed to use + dtype: data type to use for tensors + """ + super(BoltOnModel, self).__init__(name='bolton', dynamic=False) + if n_outputs <= 0: + raise ValueError('n_outputs = {0} is not valid. Must be > 0.'.format( + n_outputs + )) + self.n_outputs = n_outputs + self.seed = seed + self._layers_instantiated = False + self._dtype = dtype + + def call(self, inputs): # pylint: disable=arguments-differ + """Forward pass of network. + + Args: + inputs: inputs to neural network + + Returns: + Output logits for the given inputs. + + """ + return self.output_layer(inputs) + + def compile(self, + optimizer, + loss, + kernel_initializer=tf.initializers.GlorotUniform, + **kwargs): # pylint: disable=arguments-differ + """See super class. Default optimizer used in BoltOn method is SGD. + + Args: + optimizer: The optimizer to use. This will be automatically wrapped + with the BoltOn Optimizer. + loss: The loss function to use. Must be a StrongConvex loss (extend the + StrongConvexMixin). + kernel_initializer: The kernel initializer to use for the single layer. + **kwargs: kwargs to keras Model.compile. See super. + """ + if not isinstance(loss, StrongConvexMixin): + raise ValueError('loss function must be a Strongly Convex and therefore ' + 'extend the StrongConvexMixin.') + if not self._layers_instantiated: # compile may be called multiple times + # for instance, if the input/outputs are not defined until fit. + self.output_layer = tf.keras.layers.Dense( + self.n_outputs, + kernel_regularizer=loss.kernel_regularizer(), + kernel_initializer=kernel_initializer(), + ) + self._layers_instantiated = True + if not isinstance(optimizer, BoltOn): + optimizer = optimizers.get(optimizer) + optimizer = BoltOn(optimizer, loss) + + super(BoltOnModel, self).compile(optimizer, loss=loss, **kwargs) + + def fit(self, + x=None, + y=None, + batch_size=None, + class_weight=None, + n_samples=None, + epsilon=2, + noise_distribution='laplace', + steps_per_epoch=None, + **kwargs): # pylint: disable=arguments-differ + """Reroutes to super fit with BoltOn delta-epsilon privacy requirements. + + Note, inputs must be normalized s.t. ||x|| < 1. + Requirements are as follows: + 1. Adds noise to weights after training (output perturbation). + 2. Projects weights to R after each batch + 3. Limits learning rate + 4. Use a strongly convex loss function (see compile) + See super implementation for more details. + + Args: + x: Inputs to fit on, see super. + y: Labels to fit on, see super. + batch_size: The batch size to use for training, see super. + class_weight: the class weights to be used. Can be a scalar or 1D tensor + whose dim == n_classes. + n_samples: the number of individual samples in x. + epsilon: privacy parameter, which trades off between utility an privacy. + See the bolt-on paper for more description. + noise_distribution: the distribution to pull noise from. + steps_per_epoch: + **kwargs: kwargs to keras Model.fit. See super. + + Returns: + Output from super fit method. + """ + if class_weight is None: + class_weight_ = self.calculate_class_weights(class_weight) + else: + class_weight_ = class_weight + if n_samples is not None: + data_size = n_samples + elif hasattr(x, 'shape'): + data_size = x.shape[0] + elif hasattr(x, '__len__'): + data_size = len(x) + else: + data_size = None + batch_size_ = self._validate_or_infer_batch_size(batch_size, + steps_per_epoch, + x) + if batch_size_ is None: + batch_size_ = 32 + # inferring batch_size to be passed to optimizer. batch_size must remain its + # initial value when passed to super().fit() + if batch_size_ is None: + raise ValueError('batch_size: {0} is an ' + 'invalid value'.format(batch_size_)) + if data_size is None: + raise ValueError('Could not infer the number of samples. Please pass ' + 'this in using n_samples.') + with self.optimizer(noise_distribution, + epsilon, + self.layers, + class_weight_, + data_size, + batch_size_) as _: + out = super(BoltOnModel, self).fit(x=x, + y=y, + batch_size=batch_size, + class_weight=class_weight, + steps_per_epoch=steps_per_epoch, + **kwargs) + return out + + def fit_generator(self, + generator, + class_weight=None, + noise_distribution='laplace', + epsilon=2, + n_samples=None, + steps_per_epoch=None, + **kwargs): # pylint: disable=arguments-differ + """Fit with a generator. + + This method is the same as fit except for when the passed dataset + is a generator. See super method and fit for more details. + + Args: + generator: Inputs generator following Tensorflow guidelines, see super. + class_weight: the class weights to be used. Can be a scalar or 1D tensor + whose dim == n_classes. + noise_distribution: the distribution to get noise from. + epsilon: privacy parameter, which trades off utility and privacy. See + BoltOn paper for more description. + n_samples: number of individual samples in x + steps_per_epoch: Number of steps per training epoch, see super. + **kwargs: **kwargs + + Returns: + Output from super fit_generator method. + """ + if class_weight is None: + class_weight = self.calculate_class_weights(class_weight) + if n_samples is not None: + data_size = n_samples + elif hasattr(generator, 'shape'): + data_size = generator.shape[0] + elif hasattr(generator, '__len__'): + data_size = len(generator) + else: + raise ValueError('The number of samples could not be determined. ' + 'Please make sure that if you are using a generator' + 'to call this method directly with n_samples kwarg ' + 'passed.') + batch_size = self._validate_or_infer_batch_size(None, steps_per_epoch, + generator) + if batch_size is None: + batch_size = 32 + with self.optimizer(noise_distribution, + epsilon, + self.layers, + class_weight, + data_size, + batch_size) as _: + out = super(BoltOnModel, self).fit_generator( + generator, + class_weight=class_weight, + steps_per_epoch=steps_per_epoch, + **kwargs) + return out + + def calculate_class_weights(self, + class_weights=None, + class_counts=None, + num_classes=None): + """Calculates class weighting to be used in training. + + Args: + class_weights: str specifying type, array giving weights, or None. + class_counts: If class_weights is not None, then an array of + the number of samples for each class + num_classes: If class_weights is not None, then the number of + classes. + Returns: + class_weights as 1D tensor, to be passed to model's fit method. + """ + # Value checking + class_keys = ['balanced'] + is_string = False + if isinstance(class_weights, str): + is_string = True + if class_weights not in class_keys: + raise ValueError('Detected string class_weights with ' + 'value: {0}, which is not one of {1}.' + 'Please select a valid class_weight type' + 'or pass an array'.format(class_weights, + class_keys)) + if class_counts is None: + raise ValueError('Class counts must be provided if using ' + 'class_weights=%s' % class_weights) + class_counts_shape = tf.Variable(class_counts, + trainable=False, + dtype=self._dtype).shape + if len(class_counts_shape) != 1: + raise ValueError('class counts must be a 1D array.' + 'Detected: {0}'.format(class_counts_shape)) + if num_classes is None: + raise ValueError('num_classes must be provided if using ' + 'class_weights=%s' % class_weights) + elif class_weights is not None: + if num_classes is None: + raise ValueError('You must pass a value for num_classes if ' + 'creating an array of class_weights') + # performing class weight calculation + if class_weights is None: + class_weights = 1 + elif is_string and class_weights == 'balanced': + num_samples = sum(class_counts) + weighted_counts = tf.dtypes.cast(tf.math.multiply(num_classes, + class_counts), + self._dtype) + class_weights = tf.Variable(num_samples, dtype=self._dtype) / \ + tf.Variable(weighted_counts, dtype=self._dtype) + else: + class_weights = _ops.convert_to_tensor_v2(class_weights) + if len(class_weights.shape) != 1: + raise ValueError('Detected class_weights shape: {0} instead of ' + '1D array'.format(class_weights.shape)) + if class_weights.shape[0] != num_classes: + raise ValueError( + 'Detected array length: {0} instead of: {1}'.format( + class_weights.shape[0], + num_classes)) + return class_weights diff --git a/tensorflow_privacy/privacy/bolt_on/models_test.py b/tensorflow_privacy/privacy/bolt_on/models_test.py new file mode 100644 index 00000000..a47e8b4f --- /dev/null +++ b/tensorflow_privacy/privacy/bolt_on/models_test.py @@ -0,0 +1,548 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit testing for models.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import tensorflow as tf +from tensorflow.python.framework import ops as _ops +from tensorflow.python.keras import keras_parameterized +from tensorflow.python.keras import losses +from tensorflow.python.keras.optimizer_v2.optimizer_v2 import OptimizerV2 +from tensorflow.python.keras.regularizers import L1L2 +from tensorflow_privacy.privacy.bolt_on import models +from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexMixin +from tensorflow_privacy.privacy.bolt_on.optimizers import BoltOn + + +class TestLoss(losses.Loss, StrongConvexMixin): + """Test loss function for testing BoltOn model.""" + + def __init__(self, reg_lambda, c_arg, radius_constant, name='test'): + super(TestLoss, self).__init__(name=name) + self.reg_lambda = reg_lambda + self.C = c_arg # pylint: disable=invalid-name + self.radius_constant = radius_constant + + def radius(self): + """Radius, R, of the hypothesis space W. + + W is a convex set that forms the hypothesis space. + + Returns: + radius + """ + return _ops.convert_to_tensor_v2(1, dtype=tf.float32) + + def gamma(self): + """Returns strongly convex parameter, gamma.""" + return _ops.convert_to_tensor_v2(1, dtype=tf.float32) + + def beta(self, class_weight): # pylint: disable=unused-argument + """Smoothness, beta. + + Args: + class_weight: the class weights as scalar or 1d tensor, where its + dimensionality is equal to the number of outputs. + + Returns: + Beta + """ + return _ops.convert_to_tensor_v2(1, dtype=tf.float32) + + def lipchitz_constant(self, class_weight): # pylint: disable=unused-argument + """Lipchitz constant, L. + + Args: + class_weight: class weights used + + Returns: + L + """ + return _ops.convert_to_tensor_v2(1, dtype=tf.float32) + + def call(self, y_true, y_pred): + """Loss function that is minimized at the mean of the input points.""" + return 0.5 * tf.reduce_sum( + tf.math.squared_difference(y_true, y_pred), + axis=1 + ) + + def max_class_weight(self, class_weight): + """the maximum weighting in class weights (max value) as a scalar tensor. + + Args: + class_weight: class weights used + + Returns: + maximum class weighting as tensor scalar + """ + if class_weight is None: + return 1 + raise ValueError('') + + def kernel_regularizer(self): + """Returns the kernel_regularizer to be used. + + Any subclass should override this method if they want a kernel_regularizer + (if required for the loss function to be StronglyConvex. + """ + return L1L2(l2=self.reg_lambda) + + +class TestOptimizer(OptimizerV2): + """Test optimizer used for testing BoltOn model.""" + + def __init__(self): + super(TestOptimizer, self).__init__('test') + + def compute_gradients(self): + return 0 + + def get_config(self): + return {} + + def _create_slots(self, var): + pass + + def _resource_apply_dense(self, grad, handle): + return grad + + def _resource_apply_sparse(self, grad, handle, indices): + return grad + + +class InitTests(keras_parameterized.TestCase): + """Tests for keras model initialization.""" + + @parameterized.named_parameters([ + {'testcase_name': 'normal', + 'n_outputs': 1, + }, + {'testcase_name': 'many outputs', + 'n_outputs': 100, + }, + ]) + def test_init_params(self, n_outputs): + """Test initialization of BoltOnModel. + + Args: + n_outputs: number of output neurons + """ + # test valid domains for each variable + clf = models.BoltOnModel(n_outputs) + self.assertIsInstance(clf, models.BoltOnModel) + + @parameterized.named_parameters([ + {'testcase_name': 'invalid n_outputs', + 'n_outputs': -1, + }, + ]) + def test_bad_init_params(self, n_outputs): + """test bad initializations of BoltOnModel that should raise errors. + + Args: + n_outputs: number of output neurons + """ + # test invalid domains for each variable, especially noise + with self.assertRaises(ValueError): + models.BoltOnModel(n_outputs) + + @parameterized.named_parameters([ + {'testcase_name': 'string compile', + 'n_outputs': 1, + 'loss': TestLoss(1, 1, 1), + 'optimizer': 'adam', + }, + {'testcase_name': 'test compile', + 'n_outputs': 100, + 'loss': TestLoss(1, 1, 1), + 'optimizer': TestOptimizer(), + }, + ]) + def test_compile(self, n_outputs, loss, optimizer): + """Test compilation of BoltOnModel. + + Args: + n_outputs: number of output neurons + loss: instantiated TestLoss instance + optimizer: instantiated TestOptimizer instance + """ + # test compilation of valid tf.optimizer and tf.loss + with self.cached_session(): + clf = models.BoltOnModel(n_outputs) + clf.compile(optimizer, loss) + self.assertEqual(clf.loss, loss) + + @parameterized.named_parameters([ + {'testcase_name': 'Not strong loss', + 'n_outputs': 1, + 'loss': losses.BinaryCrossentropy(), + 'optimizer': 'adam', + }, + {'testcase_name': 'Not valid optimizer', + 'n_outputs': 1, + 'loss': TestLoss(1, 1, 1), + 'optimizer': 'ada', + } + ]) + def test_bad_compile(self, n_outputs, loss, optimizer): + """test bad compilations of BoltOnModel that should raise errors. + + Args: + n_outputs: number of output neurons + loss: instantiated TestLoss instance + optimizer: instantiated TestOptimizer instance + """ + # test compilaton of invalid tf.optimizer and non instantiated loss. + with self.cached_session(): + with self.assertRaises((ValueError, AttributeError)): + clf = models.BoltOnModel(n_outputs) + clf.compile(optimizer, loss) + + +def _cat_dataset(n_samples, input_dim, n_classes, batch_size, generator=False): + """Creates a categorically encoded dataset. + + Creates a categorically encoded dataset (y is categorical). + returns the specified dataset either as a static array or as a generator. + Will have evenly split samples across each output class. + Each output class will be a different point in the input space. + + Args: + n_samples: number of rows + input_dim: input dimensionality + n_classes: output dimensionality + batch_size: The desired batch_size + generator: False for array, True for generator + + Returns: + X as (n_samples, input_dim), Y as (n_samples, n_outputs) + """ + x_stack = [] + y_stack = [] + for i_class in range(n_classes): + x_stack.append( + tf.constant(1*i_class, tf.float32, (n_samples, input_dim)) + ) + y_stack.append( + tf.constant(i_class, tf.float32, (n_samples, n_classes)) + ) + x_set, y_set = tf.stack(x_stack), tf.stack(y_stack) + if generator: + dataset = tf.data.Dataset.from_tensor_slices( + (x_set, y_set) + ) + dataset = dataset.batch(batch_size=batch_size) + return dataset + return x_set, y_set + + +def _do_fit(n_samples, + input_dim, + n_outputs, + epsilon, + generator, + batch_size, + reset_n_samples, + optimizer, + loss, + distribution='laplace'): + """Instantiate necessary components for fitting and perform a model fit. + + Args: + n_samples: number of samples in dataset + input_dim: the sample dimensionality + n_outputs: number of output neurons + epsilon: privacy parameter + generator: True to create a generator, False to use an iterator + batch_size: batch_size to use + reset_n_samples: True to set _samples to None prior to fitting. + False does nothing + optimizer: instance of TestOptimizer + loss: instance of TestLoss + distribution: distribution to get noise from. + + Returns: + BoltOnModel instsance + """ + clf = models.BoltOnModel(n_outputs) + clf.compile(optimizer, loss) + if generator: + x = _cat_dataset( + n_samples, + input_dim, + n_outputs, + batch_size, + generator=generator + ) + y = None + # x = x.batch(batch_size) + x = x.shuffle(n_samples//2) + batch_size = None + if reset_n_samples: + n_samples = None + clf.fit_generator(x, + n_samples=n_samples, + noise_distribution=distribution, + epsilon=epsilon) + else: + x, y = _cat_dataset( + n_samples, + input_dim, + n_outputs, + batch_size, + generator=generator) + if reset_n_samples: + n_samples = None + clf.fit(x, + y, + batch_size=batch_size, + n_samples=n_samples, + noise_distribution=distribution, + epsilon=epsilon) + return clf + + +class FitTests(keras_parameterized.TestCase): + """Test cases for keras model fitting.""" + + # @test_util.run_all_in_graph_and_eager_modes + @parameterized.named_parameters([ + {'testcase_name': 'iterator fit', + 'generator': False, + 'reset_n_samples': True, + }, + {'testcase_name': 'iterator fit no samples', + 'generator': False, + 'reset_n_samples': True, + }, + {'testcase_name': 'generator fit', + 'generator': True, + 'reset_n_samples': False, + }, + {'testcase_name': 'with callbacks', + 'generator': True, + 'reset_n_samples': False, + }, + ]) + def test_fit(self, generator, reset_n_samples): + """Tests fitting of BoltOnModel. + + Args: + generator: True for generator test, False for iterator test. + reset_n_samples: True to reset the n_samples to None, False does nothing + """ + loss = TestLoss(1, 1, 1) + optimizer = BoltOn(TestOptimizer(), loss) + n_classes = 2 + input_dim = 5 + epsilon = 1 + batch_size = 1 + n_samples = 10 + clf = _do_fit( + n_samples, + input_dim, + n_classes, + epsilon, + generator, + batch_size, + reset_n_samples, + optimizer, + loss, + ) + self.assertEqual(hasattr(clf, 'layers'), True) + + @parameterized.named_parameters([ + {'testcase_name': 'generator fit', + 'generator': True, + }, + ]) + def test_fit_gen(self, generator): + """Tests the fit_generator method of BoltOnModel. + + Args: + generator: True to test with a generator dataset + """ + loss = TestLoss(1, 1, 1) + optimizer = TestOptimizer() + n_classes = 2 + input_dim = 5 + batch_size = 1 + n_samples = 10 + clf = models.BoltOnModel(n_classes) + clf.compile(optimizer, loss) + x = _cat_dataset( + n_samples, + input_dim, + n_classes, + batch_size, + generator=generator + ) + x = x.batch(batch_size) + x = x.shuffle(n_samples // 2) + clf.fit_generator(x, n_samples=n_samples) + self.assertEqual(hasattr(clf, 'layers'), True) + + @parameterized.named_parameters([ + {'testcase_name': 'iterator no n_samples', + 'generator': True, + 'reset_n_samples': True, + 'distribution': 'laplace' + }, + {'testcase_name': 'invalid distribution', + 'generator': True, + 'reset_n_samples': True, + 'distribution': 'not_valid' + }, + ]) + def test_bad_fit(self, generator, reset_n_samples, distribution): + """Tests fitting with invalid parameters, which should raise an error. + + Args: + generator: True to test with generator, False is iterator + reset_n_samples: True to reset the n_samples param to None prior to + passing it to fit + distribution: distribution to get noise from. + """ + with self.assertRaises(ValueError): + loss = TestLoss(1, 1, 1) + optimizer = TestOptimizer() + n_classes = 2 + input_dim = 5 + epsilon = 1 + batch_size = 1 + n_samples = 10 + _do_fit( + n_samples, + input_dim, + n_classes, + epsilon, + generator, + batch_size, + reset_n_samples, + optimizer, + loss, + distribution + ) + + @parameterized.named_parameters([ + {'testcase_name': 'None class_weights', + 'class_weights': None, + 'class_counts': None, + 'num_classes': None, + 'result': 1}, + {'testcase_name': 'class weights array', + 'class_weights': [1, 1], + 'class_counts': [1, 1], + 'num_classes': 2, + 'result': [1, 1]}, + {'testcase_name': 'class weights balanced', + 'class_weights': 'balanced', + 'class_counts': [1, 1], + 'num_classes': 2, + 'result': [1, 1]}, + ]) + def test_class_calculate(self, + class_weights, + class_counts, + num_classes, + result): + """Tests the BOltonModel calculate_class_weights method. + + Args: + class_weights: the class_weights to use + class_counts: count of number of samples for each class + num_classes: number of outputs neurons + result: expected result + """ + clf = models.BoltOnModel(1, 1) + expected = clf.calculate_class_weights(class_weights, + class_counts, + num_classes) + + if hasattr(expected, 'numpy'): + expected = expected.numpy() + self.assertAllEqual( + expected, + result + ) + @parameterized.named_parameters([ + {'testcase_name': 'class weight not valid str', + 'class_weights': 'not_valid', + 'class_counts': 1, + 'num_classes': 1, + 'err_msg': 'Detected string class_weights with value: not_valid'}, + {'testcase_name': 'no class counts', + 'class_weights': 'balanced', + 'class_counts': None, + 'num_classes': 1, + 'err_msg': 'Class counts must be provided if ' + 'using class_weights=balanced'}, + {'testcase_name': 'no num classes', + 'class_weights': 'balanced', + 'class_counts': [1], + 'num_classes': None, + 'err_msg': 'num_classes must be provided if ' + 'using class_weights=balanced'}, + {'testcase_name': 'class counts not array', + 'class_weights': 'balanced', + 'class_counts': 1, + 'num_classes': None, + 'err_msg': 'class counts must be a 1D array.'}, + {'testcase_name': 'class counts array, no num classes', + 'class_weights': [1], + 'class_counts': None, + 'num_classes': None, + 'err_msg': 'You must pass a value for num_classes if ' + 'creating an array of class_weights'}, + {'testcase_name': 'class counts array, improper shape', + 'class_weights': [[1], [1]], + 'class_counts': None, + 'num_classes': 2, + 'err_msg': 'Detected class_weights shape'}, + {'testcase_name': 'class counts array, wrong number classes', + 'class_weights': [1, 1, 1], + 'class_counts': None, + 'num_classes': 2, + 'err_msg': 'Detected array length:'}, + ]) + + def test_class_errors(self, + class_weights, + class_counts, + num_classes, + err_msg): + """Tests the BOltonModel calculate_class_weights method. + + This test passes invalid params which should raise the expected errors. + + Args: + class_weights: the class_weights to use. + class_counts: count of number of samples for each class. + num_classes: number of outputs neurons. + err_msg: The expected error message. + """ + clf = models.BoltOnModel(1, 1) + with self.assertRaisesRegexp(ValueError, err_msg): # pylint: disable=deprecated-method + clf.calculate_class_weights(class_weights, + class_counts, + num_classes) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/bolt_on/optimizers.py b/tensorflow_privacy/privacy/bolt_on/optimizers.py new file mode 100644 index 00000000..eac6641d --- /dev/null +++ b/tensorflow_privacy/privacy/bolt_on/optimizers.py @@ -0,0 +1,388 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""BoltOn Optimizer for Bolt-on method.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.python.keras.optimizer_v2 import optimizer_v2 +from tensorflow.python.ops import math_ops +from tensorflow_privacy.privacy.bolt_on.losses import StrongConvexMixin + +_accepted_distributions = ['laplace'] # implemented distributions for noising + + +class GammaBetaDecreasingStep( + optimizer_v2.learning_rate_schedule.LearningRateSchedule): + """Computes LR as minimum of 1/beta and 1/(gamma * step) at each step. + + This is a required step for privacy guarantees. + """ + + def __init__(self): + self.is_init = False + self.beta = None + self.gamma = None + + def __call__(self, step): + """Computes and returns the learning rate. + + Args: + step: the current iteration number + + Returns: + decayed learning rate to minimum of 1/beta and 1/(gamma * step) as per + the BoltOn privacy requirements. + """ + if not self.is_init: + raise AttributeError('Please initialize the {0} Learning Rate Scheduler.' + 'This is performed automatically by using the ' + '{1} as a context manager, ' + 'as desired'.format(self.__class__.__name__, + BoltOn.__class__.__name__ + ) + ) + dtype = self.beta.dtype + one = tf.constant(1, dtype) + return tf.math.minimum(tf.math.reduce_min(one/self.beta), + one/(self.gamma*math_ops.cast(step, dtype)) + ) + + def get_config(self): + """Return config to setup the learning rate scheduler.""" + return {'beta': self.beta, 'gamma': self.gamma} + + def initialize(self, beta, gamma): + """Setups scheduler with beta and gamma values from the loss function. + + Meant to be used with .fit as the loss params may depend on values passed to + fit. + + Args: + beta: Smoothness value. See StrongConvexMixin + gamma: Strong Convexity parameter. See StrongConvexMixin. + """ + self.is_init = True + self.beta = beta + self.gamma = gamma + + def de_initialize(self): + """De initialize post fit, as another fit call may use other parameters.""" + self.is_init = False + self.beta = None + self.gamma = None + + +class BoltOn(optimizer_v2.OptimizerV2): + """Wrap another tf optimizer with BoltOn privacy protocol. + + BoltOn optimizer wraps another tf optimizer to be used + as the visible optimizer to the tf model. No matter the optimizer + passed, "BoltOn" enables the bolt-on model to control the learning rate + based on the strongly convex loss. + + To use the BoltOn method, you must: + 1. instantiate it with an instantiated tf optimizer and StrongConvexLoss. + 2. use it as a context manager around your .fit method internals. + + This can be accomplished by the following: + optimizer = tf.optimizers.SGD() + loss = privacy.bolt_on.losses.StrongConvexBinaryCrossentropy() + bolton = BoltOn(optimizer, loss) + with bolton(*args) as _: + model.fit() + The args required for the context manager can be found in the __call__ + method. + + For more details on the strong convexity requirements, see: + Bolt-on Differential Privacy for Scalable Stochastic Gradient + Descent-based Analytics by Xi Wu et. al. + """ + + def __init__(self, # pylint: disable=super-init-not-called + optimizer, + loss, + dtype=tf.float32, + ): + """Constructor. + + Args: + optimizer: Optimizer_v2 or subclass to be used as the optimizer + (wrapped). + loss: StrongConvexLoss function that the model is being compiled with. + dtype: dtype + """ + + if not isinstance(loss, StrongConvexMixin): + raise ValueError('loss function must be a Strongly Convex and therefore ' + 'extend the StrongConvexMixin.') + self._private_attributes = [ + '_internal_optimizer', + 'dtype', + 'noise_distribution', + 'epsilon', + 'loss', + 'class_weights', + 'input_dim', + 'n_samples', + 'layers', + 'batch_size', + '_is_init', + ] + self._internal_optimizer = optimizer + self.learning_rate = GammaBetaDecreasingStep() # use the BoltOn Learning + # rate scheduler, as required for privacy guarantees. This will still need + # to get values from the loss function near the time that .fit is called + # on the model (when this optimizer will be called as a context manager) + self.dtype = dtype + self.loss = loss + self._is_init = False + + def get_config(self): + """Reroutes to _internal_optimizer. See super/_internal_optimizer.""" + return self._internal_optimizer.get_config() + + def project_weights_to_r(self, force=False): + """Normalize the weights to the R-ball. + + Args: + force: True to normalize regardless of previous weight values. + False to check if weights > R-ball and only normalize then. + + Raises: + Exception: If not called from inside this optimizer context. + """ + if not self._is_init: + raise Exception('This method must be called from within the optimizer\'s ' + 'context.') + radius = self.loss.radius() + for layer in self.layers: + weight_norm = tf.norm(layer.kernel, axis=0) + if force: + layer.kernel = layer.kernel / (weight_norm / radius) + else: + layer.kernel = tf.cond( + tf.reduce_sum(tf.cast(weight_norm > radius, dtype=self.dtype)) > 0, + lambda k=layer.kernel, w=weight_norm, r=radius: k / (w / r), # pylint: disable=cell-var-from-loop + lambda k=layer.kernel: k # pylint: disable=cell-var-from-loop + ) + + def get_noise(self, input_dim, output_dim): + """Sample noise to be added to weights for privacy guarantee. + + Args: + input_dim: the input dimensionality for the weights + output_dim: the output dimensionality for the weights + + Returns: + Noise in shape of layer's weights to be added to the weights. + + Raises: + Exception: If not called from inside this optimizer's context. + """ + if not self._is_init: + raise Exception('This method must be called from within the optimizer\'s ' + 'context.') + loss = self.loss + distribution = self.noise_distribution.lower() + if distribution == _accepted_distributions[0]: # laplace + per_class_epsilon = self.epsilon / (output_dim) + l2_sensitivity = (2 * + loss.lipchitz_constant(self.class_weights)) / \ + (loss.gamma() * self.n_samples * self.batch_size) + unit_vector = tf.random.normal(shape=(input_dim, output_dim), + mean=0, + seed=1, + stddev=1.0, + dtype=self.dtype) + unit_vector = unit_vector / tf.math.sqrt( + tf.reduce_sum(tf.math.square(unit_vector), axis=0) + ) + + beta = l2_sensitivity / per_class_epsilon + alpha = input_dim # input_dim + gamma = tf.random.gamma([output_dim], + alpha, + beta=1 / beta, + seed=1, + dtype=self.dtype + ) + return unit_vector * gamma + raise NotImplementedError('Noise distribution: {0} is not ' + 'a valid distribution'.format(distribution)) + + def from_config(self, *args, **kwargs): # pylint: disable=arguments-differ + """Reroutes to _internal_optimizer. See super/_internal_optimizer.""" + return self._internal_optimizer.from_config(*args, **kwargs) + + def __getattr__(self, name): + """Get attr. + + return _internal_optimizer off self instance, and everything else + from the _internal_optimizer instance. + + Args: + name: Name of attribute to get from this or aggregate optimizer. + + Returns: + attribute from BoltOn if specified to come from self, else + from _internal_optimizer. + """ + if name == '_private_attributes' or name in self._private_attributes: + return getattr(self, name) + optim = object.__getattribute__(self, '_internal_optimizer') + try: + return object.__getattribute__(optim, name) + except AttributeError: + raise AttributeError( + "Neither '{0}' nor '{1}' object has attribute '{2}'" + "".format(self.__class__.__name__, + self._internal_optimizer.__class__.__name__, + name) + ) + + def __setattr__(self, key, value): + """Set attribute to self instance if its the internal optimizer. + + Reroute everything else to the _internal_optimizer. + + Args: + key: attribute name + value: attribute value + """ + if key == '_private_attributes': + object.__setattr__(self, key, value) + elif key in self._private_attributes: + object.__setattr__(self, key, value) + else: + setattr(self._internal_optimizer, key, value) + + def _resource_apply_dense(self, *args, **kwargs): # pylint: disable=arguments-differ + """Reroutes to _internal_optimizer. See super/_internal_optimizer.""" + return self._internal_optimizer._resource_apply_dense(*args, **kwargs) # pylint: disable=protected-access + + def _resource_apply_sparse(self, *args, **kwargs): # pylint: disable=arguments-differ + """Reroutes to _internal_optimizer. See super/_internal_optimizer.""" + return self._internal_optimizer._resource_apply_sparse(*args, **kwargs) # pylint: disable=protected-access + + def get_updates(self, loss, params): + """Reroutes to _internal_optimizer. See super/_internal_optimizer.""" + out = self._internal_optimizer.get_updates(loss, params) + self.project_weights_to_r() + return out + + def apply_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ + """Reroutes to _internal_optimizer. See super/_internal_optimizer.""" + out = self._internal_optimizer.apply_gradients(*args, **kwargs) + self.project_weights_to_r() + return out + + def minimize(self, *args, **kwargs): # pylint: disable=arguments-differ + """Reroutes to _internal_optimizer. See super/_internal_optimizer.""" + out = self._internal_optimizer.minimize(*args, **kwargs) + self.project_weights_to_r() + return out + + def _compute_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ,protected-access + """Reroutes to _internal_optimizer. See super/_internal_optimizer.""" + return self._internal_optimizer._compute_gradients(*args, **kwargs) # pylint: disable=protected-access + + def get_gradients(self, *args, **kwargs): # pylint: disable=arguments-differ + """Reroutes to _internal_optimizer. See super/_internal_optimizer.""" + return self._internal_optimizer.get_gradients(*args, **kwargs) + + def __enter__(self): + """Context manager call at the beginning of with statement. + + Returns: + self, to be used in context manager + """ + self._is_init = True + return self + + def __call__(self, + noise_distribution, + epsilon, + layers, + class_weights, + n_samples, + batch_size): + """Accepts required values for bolton method from context entry point. + + Stores them on the optimizer for use throughout fitting. + + Args: + noise_distribution: the noise distribution to pick. + see _accepted_distributions and get_noise for possible values. + epsilon: privacy parameter. Lower gives more privacy but less utility. + layers: list of Keras/Tensorflow layers. Can be found as model.layers + class_weights: class_weights used, which may either be a scalar or 1D + tensor with dim == n_classes. + n_samples: number of rows/individual samples in the training set + batch_size: batch size used. + + Returns: + self, to be used by the __enter__ method for context. + """ + if epsilon <= 0: + raise ValueError('Detected epsilon: {0}. ' + 'Valid range is 0 < epsilon = l2_norm_clip, tf.float32) - 0.5 + + preprocessed_clipped_fraction_record = ( + self._clipped_fraction_query.preprocess_record( + params.clipped_fraction_params, was_clipped)) + + return preprocessed_sum_record, preprocessed_clipped_fraction_record + + def accumulate_preprocessed_record( + self, sample_state, preprocessed_record, weight=1): + """See base class.""" + preprocessed_sum_record, preprocessed_clipped_fraction_record = preprocessed_record + sum_state = self._sum_query.accumulate_preprocessed_record( + sample_state.sum_state, preprocessed_sum_record) + + clipped_fraction_state = self._clipped_fraction_query.accumulate_preprocessed_record( + sample_state.clipped_fraction_state, + preprocessed_clipped_fraction_record) + return self._SampleState(sum_state, clipped_fraction_state) + + def merge_sample_states(self, sample_state_1, sample_state_2): + """See base class.""" + return self._SampleState( + self._sum_query.merge_sample_states( + sample_state_1.sum_state, + sample_state_2.sum_state), + self._clipped_fraction_query.merge_sample_states( + sample_state_1.clipped_fraction_state, + sample_state_2.clipped_fraction_state)) + + def get_noised_result(self, sample_state, global_state): + """See base class.""" + gs = global_state + + noised_vectors, sum_state = self._sum_query.get_noised_result( + sample_state.sum_state, gs.sum_state) + del sum_state # Unused. To be set explicitly later. + + clipped_fraction_result, new_clipped_fraction_state = ( + self._clipped_fraction_query.get_noised_result( + sample_state.clipped_fraction_state, + gs.clipped_fraction_state)) + + # Unshift clipped percentile by 0.5. (See comment in accumulate_record.) + clipped_quantile = clipped_fraction_result + 0.5 + unclipped_quantile = 1.0 - clipped_quantile + + # Protect against out-of-range estimates. + unclipped_quantile = tf.minimum(1.0, tf.maximum(0.0, unclipped_quantile)) + + # Loss function is convex, with derivative in [-1, 1], and minimized when + # the true quantile matches the target. + loss_grad = unclipped_quantile - global_state.target_unclipped_quantile + + new_l2_norm_clip = gs.l2_norm_clip - global_state.learning_rate * loss_grad + new_l2_norm_clip = tf.maximum(0.0, new_l2_norm_clip) + + new_sum_stddev = new_l2_norm_clip * global_state.noise_multiplier + new_sum_query_global_state = self._sum_query.make_global_state( + l2_norm_clip=new_l2_norm_clip, + stddev=new_sum_stddev) + + new_global_state = global_state._replace( + l2_norm_clip=new_l2_norm_clip, + sum_state=new_sum_query_global_state, + clipped_fraction_state=new_clipped_fraction_state) + + return noised_vectors, new_global_state + + +class QuantileAdaptiveClipAverageQuery(normalized_query.NormalizedQuery): + """DPQuery for average queries with adaptive clipping. + + Clipping norm is tuned adaptively to converge to a value such that a specified + quantile of updates are clipped. + + Note that we use "fixed-denominator" estimation: the denominator should be + specified as the expected number of records per sample. Accumulating the + denominator separately would also be possible but would be produce a higher + variance estimator. + """ + + def __init__( + self, + initial_l2_norm_clip, + noise_multiplier, + denominator, + target_unclipped_quantile, + learning_rate, + clipped_count_stddev, + expected_num_records): + """Initializes the AdaptiveClipAverageQuery. + + Args: + initial_l2_norm_clip: The initial value of clipping norm. + noise_multiplier: The multiplier of the l2_norm_clip to make the stddev of + the noise. + denominator: The normalization constant (applied after noise is added to + the sum). + target_unclipped_quantile: The desired quantile of updates which should be + clipped. + learning_rate: The learning rate for the clipping norm adaptation. A + rate of r means that the clipping norm will change by a maximum of r at + each step. The maximum is attained when |clip - target| is 1.0. + clipped_count_stddev: The stddev of the noise added to the clipped_count. + Since the sensitivity of the clipped count is 0.5, as a rule of thumb it + should be about 0.5 for reasonable privacy. + expected_num_records: The expected number of records, used to estimate the + clipped count quantile. + """ + numerator_query = QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip, + noise_multiplier, + target_unclipped_quantile, + learning_rate, + clipped_count_stddev, + expected_num_records) + super(QuantileAdaptiveClipAverageQuery, self).__init__( + numerator_query=numerator_query, + denominator=denominator) diff --git a/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py new file mode 100644 index 00000000..e7521d5f --- /dev/null +++ b/tensorflow_privacy/privacy/dp_query/quantile_adaptive_clip_sum_query_test.py @@ -0,0 +1,296 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for QuantileAdaptiveClipSumQuery.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow_privacy.privacy.analysis import privacy_ledger +from tensorflow_privacy.privacy.dp_query import quantile_adaptive_clip_sum_query +from tensorflow_privacy.privacy.dp_query import test_utils + +tf.enable_eager_execution() + + +class QuantileAdaptiveClipSumQueryTest(tf.test.TestCase): + + def test_sum_no_clip_no_noise(self): + record1 = tf.constant([2.0, 0.0]) + record2 = tf.constant([-1.0, 1.0]) + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=10.0, + noise_multiplier=0.0, + target_unclipped_quantile=1.0, + learning_rate=0.0, + clipped_count_stddev=0.0, + expected_num_records=2.0) + query_result, _ = test_utils.run_query(query, [record1, record2]) + result = query_result.numpy() + expected = [1.0, 1.0] + self.assertAllClose(result, expected) + + def test_sum_with_clip_no_noise(self): + record1 = tf.constant([-6.0, 8.0]) # Clipped to [-3.0, 4.0]. + record2 = tf.constant([4.0, -3.0]) # Not clipped. + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=5.0, + noise_multiplier=0.0, + target_unclipped_quantile=1.0, + learning_rate=0.0, + clipped_count_stddev=0.0, + expected_num_records=2.0) + + query_result, _ = test_utils.run_query(query, [record1, record2]) + result = query_result.numpy() + expected = [1.0, 1.0] + self.assertAllClose(result, expected) + + def test_sum_with_noise(self): + record1, record2 = 2.71828, 3.14159 + stddev = 1.0 + clip = 5.0 + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=clip, + noise_multiplier=stddev / clip, + target_unclipped_quantile=1.0, + learning_rate=0.0, + clipped_count_stddev=0.0, + expected_num_records=2.0) + + noised_sums = [] + for _ in xrange(1000): + query_result, _ = test_utils.run_query(query, [record1, record2]) + noised_sums.append(query_result.numpy()) + + result_stddev = np.std(noised_sums) + self.assertNear(result_stddev, stddev, 0.1) + + def test_average_no_noise(self): + record1 = tf.constant([5.0, 0.0]) # Clipped to [3.0, 0.0]. + record2 = tf.constant([-1.0, 2.0]) # Not clipped. + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipAverageQuery( + initial_l2_norm_clip=3.0, + noise_multiplier=0.0, + denominator=2.0, + target_unclipped_quantile=1.0, + learning_rate=0.0, + clipped_count_stddev=0.0, + expected_num_records=2.0) + query_result, _ = test_utils.run_query(query, [record1, record2]) + result = query_result.numpy() + expected_average = [1.0, 1.0] + self.assertAllClose(result, expected_average) + + def test_average_with_noise(self): + record1, record2 = 2.71828, 3.14159 + sum_stddev = 1.0 + denominator = 2.0 + clip = 3.0 + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipAverageQuery( + initial_l2_norm_clip=clip, + noise_multiplier=sum_stddev / clip, + denominator=denominator, + target_unclipped_quantile=1.0, + learning_rate=0.0, + clipped_count_stddev=0.0, + expected_num_records=2.0) + + noised_averages = [] + for _ in range(1000): + query_result, _ = test_utils.run_query(query, [record1, record2]) + noised_averages.append(query_result.numpy()) + + result_stddev = np.std(noised_averages) + avg_stddev = sum_stddev / denominator + self.assertNear(result_stddev, avg_stddev, 0.1) + + def test_adaptation_target_zero(self): + record1 = tf.constant([8.5]) + record2 = tf.constant([-7.25]) + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=10.0, + noise_multiplier=0.0, + target_unclipped_quantile=0.0, + learning_rate=1.0, + clipped_count_stddev=0.0, + expected_num_records=2.0) + + global_state = query.initial_global_state() + + initial_clip = global_state.l2_norm_clip + self.assertAllClose(initial_clip, 10.0) + + # On the first two iterations, nothing is clipped, so the clip goes down + # by 1.0 (the learning rate). When the clip reaches 8.0, one record is + # clipped, so the clip goes down by only 0.5. After two more iterations, + # both records are clipped, and the clip norm stays there (at 7.0). + + expected_sums = [1.25, 1.25, 0.75, 0.25, 0.0] + expected_clips = [9.0, 8.0, 7.5, 7.0, 7.0] + for expected_sum, expected_clip in zip(expected_sums, expected_clips): + actual_sum, global_state = test_utils.run_query( + query, [record1, record2], global_state) + + actual_clip = global_state.l2_norm_clip + + self.assertAllClose(actual_clip.numpy(), expected_clip) + self.assertAllClose(actual_sum.numpy(), (expected_sum,)) + + def test_adaptation_target_one(self): + record1 = tf.constant([-1.5]) + record2 = tf.constant([2.75]) + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=0.0, + noise_multiplier=0.0, + target_unclipped_quantile=1.0, + learning_rate=1.0, + clipped_count_stddev=0.0, + expected_num_records=2.0) + + global_state = query.initial_global_state() + + initial_clip = global_state.l2_norm_clip + self.assertAllClose(initial_clip, 0.0) + + # On the first two iterations, both are clipped, so the clip goes up + # by 1.0 (the learning rate). When the clip reaches 2.0, only one record is + # clipped, so the clip goes up by only 0.5. After two more iterations, + # both records are clipped, and the clip norm stays there (at 3.0). + + expected_sums = [0.0, 0.0, 0.5, 1.0, 1.25] + expected_clips = [1.0, 2.0, 2.5, 3.0, 3.0] + for expected_sum, expected_clip in zip(expected_sums, expected_clips): + actual_sum, global_state = test_utils.run_query( + query, [record1, record2], global_state) + + actual_clip = global_state.l2_norm_clip + + self.assertAllClose(actual_clip.numpy(), expected_clip) + self.assertAllClose(actual_sum.numpy(), (expected_sum,)) + + def test_adaptation_linspace(self): + # 100 records equally spaced from 0 to 10 in 0.1 increments. + # Test that with a decaying learning rate we converge to the correct + # median with error at most 0.1. + records = [tf.constant(x) for x in np.linspace( + 0.0, 10.0, num=21, dtype=np.float32)] + + learning_rate = tf.Variable(1.0) + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=0.0, + noise_multiplier=0.0, + target_unclipped_quantile=0.5, + learning_rate=learning_rate, + clipped_count_stddev=0.0, + expected_num_records=2.0) + + global_state = query.initial_global_state() + + for t in range(50): + tf.assign(learning_rate, 1.0 / np.sqrt(t+1)) + _, global_state = test_utils.run_query(query, records, global_state) + + actual_clip = global_state.l2_norm_clip + + if t > 40: + self.assertNear(actual_clip, 5.0, 0.25) + + def test_adaptation_all_equal(self): + # 100 equal records. Test that with a decaying learning rate we converge to + # that record and bounce around it. + records = [tf.constant(5.0)] * 20 + + learning_rate = tf.Variable(1.0) + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=0.0, + noise_multiplier=0.0, + target_unclipped_quantile=0.5, + learning_rate=learning_rate, + clipped_count_stddev=0.0, + expected_num_records=2.0) + + global_state = query.initial_global_state() + + for t in range(50): + tf.assign(learning_rate, 1.0 / np.sqrt(t+1)) + _, global_state = test_utils.run_query(query, records, global_state) + + actual_clip = global_state.l2_norm_clip + + if t > 40: + self.assertNear(actual_clip, 5.0, 0.25) + + def test_ledger(self): + record1 = tf.constant([8.5]) + record2 = tf.constant([-7.25]) + + population_size = tf.Variable(0) + selection_probability = tf.Variable(1.0) + + query = quantile_adaptive_clip_sum_query.QuantileAdaptiveClipSumQuery( + initial_l2_norm_clip=10.0, + noise_multiplier=1.0, + target_unclipped_quantile=0.0, + learning_rate=1.0, + clipped_count_stddev=0.0, + expected_num_records=2.0) + + query = privacy_ledger.QueryWithLedger( + query, population_size, selection_probability) + + # First sample. + tf.assign(population_size, 10) + tf.assign(selection_probability, 0.1) + _, global_state = test_utils.run_query(query, [record1, record2]) + + expected_queries = [[10.0, 10.0], [0.5, 0.0]] + formatted = query.ledger.get_formatted_ledger_eager() + sample_1 = formatted[0] + self.assertAllClose(sample_1.population_size, 10.0) + self.assertAllClose(sample_1.selection_probability, 0.1) + self.assertAllClose(sample_1.queries, expected_queries) + + # Second sample. + tf.assign(population_size, 20) + tf.assign(selection_probability, 0.2) + test_utils.run_query(query, [record1, record2], global_state) + + formatted = query.ledger.get_formatted_ledger_eager() + sample_1, sample_2 = formatted + self.assertAllClose(sample_1.population_size, 10.0) + self.assertAllClose(sample_1.selection_probability, 0.1) + self.assertAllClose(sample_1.queries, expected_queries) + + expected_queries_2 = [[9.0, 9.0], [0.5, 0.0]] + self.assertAllClose(sample_2.population_size, 20.0) + self.assertAllClose(sample_2.selection_probability, 0.2) + self.assertAllClose(sample_2.queries, expected_queries_2) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/dp_query/test_utils.py b/tensorflow_privacy/privacy/dp_query/test_utils.py new file mode 100644 index 00000000..18456b30 --- /dev/null +++ b/tensorflow_privacy/privacy/dp_query/test_utils.py @@ -0,0 +1,49 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility methods for testing private queries. + +Utility methods for testing private queries. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +def run_query(query, records, global_state=None, weights=None): + """Executes query on the given set of records as a single sample. + + Args: + query: A PrivateQuery to run. + records: An iterable containing records to pass to the query. + global_state: The current global state. If None, an initial global state is + generated. + weights: An optional iterable containing the weights of the records. + + Returns: + A tuple (result, new_global_state) where "result" is the result of the + query and "new_global_state" is the updated global state. + """ + if not global_state: + global_state = query.initial_global_state() + params = query.derive_sample_params(global_state) + sample_state = query.initial_sample_state(next(iter(records))) + if weights is None: + for record in records: + sample_state = query.accumulate_record(params, sample_state, record) + else: + for weight, record in zip(weights, records): + sample_state = query.accumulate_record( + params, sample_state, record, weight) + return query.get_noised_result(sample_state, global_state) diff --git a/tensorflow_privacy/privacy/optimizers/__init__.py b/tensorflow_privacy/privacy/optimizers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer.py new file mode 100644 index 00000000..fecfd5b2 --- /dev/null +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer.py @@ -0,0 +1,239 @@ +# Copyright 2018, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Differentially private optimizers for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from distutils.version import LooseVersion +import tensorflow as tf + +from tensorflow_privacy.privacy.analysis import privacy_ledger +from tensorflow_privacy.privacy.dp_query import gaussian_query + +if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + nest = tf.contrib.framework.nest +else: + nest = tf.nest + + +def make_optimizer_class(cls): + """Constructs a DP optimizer class from an existing one.""" + if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + parent_code = tf.train.Optimizer.compute_gradients.__code__ + child_code = cls.compute_gradients.__code__ + GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name + else: + parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access + child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access + GATE_OP = None # pylint: disable=invalid-name + if child_code is not parent_code: + tf.logging.warning( + 'WARNING: Calling make_optimizer_class() on class %s that overrides ' + 'method compute_gradients(). Check to ensure that ' + 'make_optimizer_class() does not interfere with overridden version.', + cls.__name__) + + class DPOptimizerClass(cls): + """Differentially private subclass of given class cls.""" + + def __init__( + self, + dp_sum_query, + num_microbatches=None, + unroll_microbatches=False, + *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args + **kwargs): + """Initialize the DPOptimizerClass. + + Args: + dp_sum_query: DPQuery object, specifying differential privacy + mechanism to use. + num_microbatches: How many microbatches into which the minibatch is + split. If None, will default to the size of the minibatch, and + per-example gradients will be computed. + unroll_microbatches: If true, processes microbatches within a Python + loop instead of a tf.while_loop. Can be used if using a tf.while_loop + raises an exception. + """ + super(DPOptimizerClass, self).__init__(*args, **kwargs) + self._dp_sum_query = dp_sum_query + self._num_microbatches = num_microbatches + self._global_state = self._dp_sum_query.initial_global_state() + # TODO(b/122613513): Set unroll_microbatches=True to avoid this bug. + # Beware: When num_microbatches is large (>100), enabling this parameter + # may cause an OOM error. + self._unroll_microbatches = unroll_microbatches + + def compute_gradients(self, + loss, + var_list, + gate_gradients=GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + grad_loss=None, + gradient_tape=None): + if callable(loss): + # TF is running in Eager mode, check we received a vanilla tape. + if not gradient_tape: + raise ValueError('When in Eager mode, a tape needs to be passed.') + + vector_loss = loss() + if self._num_microbatches is None: + self._num_microbatches = tf.shape(vector_loss)[0] + sample_state = self._dp_sum_query.initial_sample_state(var_list) + microbatches_losses = tf.reshape(vector_loss, + [self._num_microbatches, -1]) + sample_params = ( + self._dp_sum_query.derive_sample_params(self._global_state)) + + def process_microbatch(i, sample_state): + """Process one microbatch (record) with privacy helper.""" + microbatch_loss = tf.reduce_mean(tf.gather(microbatches_losses, [i])) + grads = gradient_tape.gradient(microbatch_loss, var_list) + sample_state = self._dp_sum_query.accumulate_record( + sample_params, sample_state, grads) + return sample_state + + for idx in range(self._num_microbatches): + sample_state = process_microbatch(idx, sample_state) + + grad_sums, self._global_state = ( + self._dp_sum_query.get_noised_result( + sample_state, self._global_state)) + + def normalize(v): + return v / tf.cast(self._num_microbatches, tf.float32) + + final_grads = nest.map_structure(normalize, grad_sums) + + grads_and_vars = list(zip(final_grads, var_list)) + return grads_and_vars + + else: + # TF is running in graph mode, check we did not receive a gradient tape. + if gradient_tape: + raise ValueError('When in graph mode, a tape should not be passed.') + + # Note: it would be closer to the correct i.i.d. sampling of records if + # we sampled each microbatch from the appropriate binomial distribution, + # although that still wouldn't be quite correct because it would be + # sampling from the dataset without replacement. + if self._num_microbatches is None: + self._num_microbatches = tf.shape(loss)[0] + + microbatches_losses = tf.reshape(loss, [self._num_microbatches, -1]) + sample_params = ( + self._dp_sum_query.derive_sample_params(self._global_state)) + + def process_microbatch(i, sample_state): + """Process one microbatch (record) with privacy helper.""" + grads, _ = zip(*super(cls, self).compute_gradients( + tf.reduce_mean(tf.gather(microbatches_losses, + [i])), var_list, gate_gradients, + aggregation_method, colocate_gradients_with_ops, grad_loss)) + grads_list = [ + g if g is not None else tf.zeros_like(v) + for (g, v) in zip(list(grads), var_list) + ] + sample_state = self._dp_sum_query.accumulate_record( + sample_params, sample_state, grads_list) + return sample_state + + if var_list is None: + var_list = ( + tf.trainable_variables() + tf.get_collection( + tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + + sample_state = self._dp_sum_query.initial_sample_state(var_list) + + if self._unroll_microbatches: + for idx in range(self._num_microbatches): + sample_state = process_microbatch(idx, sample_state) + else: + # Use of while_loop here requires that sample_state be a nested + # structure of tensors. In general, we would prefer to allow it to be + # an arbitrary opaque type. + cond_fn = lambda i, _: tf.less(i, self._num_microbatches) + body_fn = lambda i, state: [tf.add(i, 1), process_microbatch(i, state)] # pylint: disable=line-too-long + idx = tf.constant(0) + _, sample_state = tf.while_loop(cond_fn, body_fn, [idx, sample_state]) + + grad_sums, self._global_state = ( + self._dp_sum_query.get_noised_result( + sample_state, self._global_state)) + + def normalize(v): + return tf.truediv(v, tf.cast(self._num_microbatches, tf.float32)) + + final_grads = nest.map_structure(normalize, grad_sums) + + return list(zip(final_grads, var_list)) + + return DPOptimizerClass + + +def make_gaussian_optimizer_class(cls): + """Constructs a DP optimizer with Gaussian averaging of updates.""" + + class DPGaussianOptimizerClass(make_optimizer_class(cls)): + """DP subclass of given class cls using Gaussian averaging.""" + + def __init__( + self, + l2_norm_clip, + noise_multiplier, + num_microbatches=None, + ledger=None, + unroll_microbatches=False, + *args, # pylint: disable=keyword-arg-before-vararg + **kwargs): + dp_sum_query = gaussian_query.GaussianSumQuery( + l2_norm_clip, l2_norm_clip * noise_multiplier) + + if ledger: + dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, + ledger=ledger) + + super(DPGaussianOptimizerClass, self).__init__( + dp_sum_query, + num_microbatches, + unroll_microbatches, + *args, + **kwargs) + + @property + def ledger(self): + return self._dp_sum_query.ledger + + return DPGaussianOptimizerClass + +if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + AdagradOptimizer = tf.train.AdagradOptimizer + AdamOptimizer = tf.train.AdamOptimizer + GradientDescentOptimizer = tf.train.GradientDescentOptimizer +else: + AdagradOptimizer = tf.optimizers.Adagrad + AdamOptimizer = tf.optimizers.Adam + GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name + +DPAdagradOptimizer = make_optimizer_class(AdagradOptimizer) +DPAdamOptimizer = make_optimizer_class(AdamOptimizer) +DPGradientDescentOptimizer = make_optimizer_class(GradientDescentOptimizer) + +DPAdagradGaussianOptimizer = make_gaussian_optimizer_class(AdagradOptimizer) +DPAdamGaussianOptimizer = make_gaussian_optimizer_class(AdamOptimizer) +DPGradientDescentGaussianOptimizer = make_gaussian_optimizer_class( + GradientDescentOptimizer) diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_eager_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_eager_test.py new file mode 100644 index 00000000..b2bf1b8f --- /dev/null +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_eager_test.py @@ -0,0 +1,130 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for differentially private optimizers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np +import tensorflow as tf + +from tensorflow_privacy.privacy.analysis import privacy_ledger +from tensorflow_privacy.privacy.dp_query import gaussian_query +from tensorflow_privacy.privacy.optimizers import dp_optimizer + + +class DPOptimizerEagerTest(tf.test.TestCase, parameterized.TestCase): + + def setUp(self): + tf.enable_eager_execution() + super(DPOptimizerEagerTest, self).setUp() + + def _loss_fn(self, val0, val1): + return 0.5 * tf.reduce_sum(tf.squared_difference(val0, val1), axis=1) + + @parameterized.named_parameters( + ('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1, + [-2.5, -2.5]), + ('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2, + [-2.5, -2.5]), + ('DPGradientDescent 4', dp_optimizer.DPGradientDescentOptimizer, 4, + [-2.5, -2.5]), + ('DPAdagrad 1', dp_optimizer.DPAdagradOptimizer, 1, [-2.5, -2.5]), + ('DPAdagrad 2', dp_optimizer.DPAdagradOptimizer, 2, [-2.5, -2.5]), + ('DPAdagrad 4', dp_optimizer.DPAdagradOptimizer, 4, [-2.5, -2.5]), + ('DPAdam 1', dp_optimizer.DPAdamOptimizer, 1, [-2.5, -2.5]), + ('DPAdam 2', dp_optimizer.DPAdamOptimizer, 2, [-2.5, -2.5]), + ('DPAdam 4', dp_optimizer.DPAdamOptimizer, 4, [-2.5, -2.5])) + def testBaseline(self, cls, num_microbatches, expected_answer): + with tf.GradientTape(persistent=True) as gradient_tape: + var0 = tf.Variable([1.0, 2.0]) + data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) + + dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0) + dp_sum_query = privacy_ledger.QueryWithLedger( + dp_sum_query, 1e6, num_microbatches / 1e6) + + opt = cls( + dp_sum_query, + num_microbatches=num_microbatches, + learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + + # Expected gradient is sum of differences divided by number of + # microbatches. + grads_and_vars = opt.compute_gradients( + lambda: self._loss_fn(var0, data0), [var0], + gradient_tape=gradient_tape) + self.assertAllCloseAccordingToType(expected_answer, grads_and_vars[0][0]) + + @parameterized.named_parameters( + ('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer), + ('DPAdagrad', dp_optimizer.DPAdagradOptimizer), + ('DPAdam', dp_optimizer.DPAdamOptimizer)) + def testClippingNorm(self, cls): + with tf.GradientTape(persistent=True) as gradient_tape: + var0 = tf.Variable([0.0, 0.0]) + data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]]) + + dp_sum_query = gaussian_query.GaussianSumQuery(1.0, 0.0) + dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6) + + opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + + # Expected gradient is sum of differences. + grads_and_vars = opt.compute_gradients( + lambda: self._loss_fn(var0, data0), [var0], + gradient_tape=gradient_tape) + self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0]) + + @parameterized.named_parameters( + ('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer), + ('DPAdagrad', dp_optimizer.DPAdagradOptimizer), + ('DPAdam', dp_optimizer.DPAdamOptimizer)) + def testNoiseMultiplier(self, cls): + with tf.GradientTape(persistent=True) as gradient_tape: + var0 = tf.Variable([0.0]) + data0 = tf.Variable([[0.0]]) + + dp_sum_query = gaussian_query.GaussianSumQuery(4.0, 8.0) + dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6) + + opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([0.0], self.evaluate(var0)) + + grads = [] + for _ in range(1000): + grads_and_vars = opt.compute_gradients( + lambda: self._loss_fn(var0, data0), [var0], + gradient_tape=gradient_tape) + grads.append(grads_and_vars[0][0]) + + # Test standard deviation is close to l2_norm_clip * noise_multiplier. + self.assertNear(np.std(grads), 2.0 * 4.0, 0.5) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py new file mode 100644 index 00000000..5237b613 --- /dev/null +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_test.py @@ -0,0 +1,241 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for differentially private optimizers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import mock +import numpy as np +import tensorflow as tf + +from tensorflow_privacy.privacy.analysis import privacy_ledger +from tensorflow_privacy.privacy.dp_query import gaussian_query +from tensorflow_privacy.privacy.optimizers import dp_optimizer + + +class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): + + def _loss(self, val0, val1): + """Loss function that is minimized at the mean of the input points.""" + return 0.5 * tf.reduce_sum(tf.squared_difference(val0, val1), axis=1) + + # Parameters for testing: optimizer, num_microbatches, expected answer. + @parameterized.named_parameters( + ('DPGradientDescent 1', dp_optimizer.DPGradientDescentOptimizer, 1, + [-2.5, -2.5]), + ('DPGradientDescent 2', dp_optimizer.DPGradientDescentOptimizer, 2, + [-2.5, -2.5]), + ('DPGradientDescent 4', dp_optimizer.DPGradientDescentOptimizer, 4, + [-2.5, -2.5]), + ('DPAdagrad 1', dp_optimizer.DPAdagradOptimizer, 1, [-2.5, -2.5]), + ('DPAdagrad 2', dp_optimizer.DPAdagradOptimizer, 2, [-2.5, -2.5]), + ('DPAdagrad 4', dp_optimizer.DPAdagradOptimizer, 4, [-2.5, -2.5]), + ('DPAdam 1', dp_optimizer.DPAdamOptimizer, 1, [-2.5, -2.5]), + ('DPAdam 2', dp_optimizer.DPAdamOptimizer, 2, [-2.5, -2.5]), + ('DPAdam 4', dp_optimizer.DPAdamOptimizer, 4, [-2.5, -2.5])) + def testBaseline(self, cls, num_microbatches, expected_answer): + with self.cached_session() as sess: + var0 = tf.Variable([1.0, 2.0]) + data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) + + dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0) + dp_sum_query = privacy_ledger.QueryWithLedger( + dp_sum_query, 1e6, num_microbatches / 1e6) + + opt = cls( + dp_sum_query, + num_microbatches=num_microbatches, + learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + + # Expected gradient is sum of differences divided by number of + # microbatches. + gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0]) + grads_and_vars = sess.run(gradient_op) + self.assertAllCloseAccordingToType(expected_answer, grads_and_vars[0][0]) + + @parameterized.named_parameters( + ('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer), + ('DPAdagrad', dp_optimizer.DPAdagradOptimizer), + ('DPAdam', dp_optimizer.DPAdamOptimizer)) + def testClippingNorm(self, cls): + with self.cached_session() as sess: + var0 = tf.Variable([0.0, 0.0]) + data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]]) + + dp_sum_query = gaussian_query.GaussianSumQuery(1.0, 0.0) + dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6) + + opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + + # Expected gradient is sum of differences. + gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0]) + grads_and_vars = sess.run(gradient_op) + self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0]) + + @parameterized.named_parameters( + ('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer), + ('DPAdagrad', dp_optimizer.DPAdagradOptimizer), + ('DPAdam', dp_optimizer.DPAdamOptimizer)) + def testNoiseMultiplier(self, cls): + with self.cached_session() as sess: + var0 = tf.Variable([0.0]) + data0 = tf.Variable([[0.0]]) + + dp_sum_query = gaussian_query.GaussianSumQuery(4.0, 8.0) + dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6) + + opt = cls(dp_sum_query, num_microbatches=1, learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([0.0], self.evaluate(var0)) + + gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0]) + grads = [] + for _ in range(1000): + grads_and_vars = sess.run(gradient_op) + grads.append(grads_and_vars[0][0]) + + # Test standard deviation is close to l2_norm_clip * noise_multiplier. + self.assertNear(np.std(grads), 2.0 * 4.0, 0.5) + + @mock.patch.object(tf, 'logging') + def testComputeGradientsOverrideWarning(self, mock_logging): + + class SimpleOptimizer(tf.train.Optimizer): + + def compute_gradients(self): + return 0 + + dp_optimizer.make_optimizer_class(SimpleOptimizer) + mock_logging.warning.assert_called_once_with( + 'WARNING: Calling make_optimizer_class() on class %s that overrides ' + 'method compute_gradients(). Check to ensure that ' + 'make_optimizer_class() does not interfere with overridden version.', + 'SimpleOptimizer') + + def testEstimator(self): + """Tests that DP optimizers work with tf.estimator.""" + + def linear_model_fn(features, labels, mode): + preds = tf.keras.layers.Dense( + 1, activation='linear', name='dense').apply(features['x']) + + vector_loss = tf.squared_difference(labels, preds) + scalar_loss = tf.reduce_mean(vector_loss) + dp_sum_query = gaussian_query.GaussianSumQuery(1.0, 0.0) + dp_sum_query = privacy_ledger.QueryWithLedger(dp_sum_query, 1e6, 1 / 1e6) + optimizer = dp_optimizer.DPGradientDescentOptimizer( + dp_sum_query, + num_microbatches=1, + learning_rate=1.0) + global_step = tf.train.get_global_step() + train_op = optimizer.minimize(loss=vector_loss, global_step=global_step) + return tf.estimator.EstimatorSpec( + mode=mode, loss=scalar_loss, train_op=train_op) + + linear_regressor = tf.estimator.Estimator(model_fn=linear_model_fn) + true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32) + true_bias = 6.0 + train_data = np.random.normal(scale=3.0, size=(200, 4)).astype(np.float32) + + train_labels = np.matmul(train_data, + true_weights) + true_bias + np.random.normal( + scale=0.1, size=(200, 1)).astype(np.float32) + + train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={'x': train_data}, + y=train_labels, + batch_size=20, + num_epochs=10, + shuffle=True) + linear_regressor.train(input_fn=train_input_fn, steps=100) + self.assertAllClose( + linear_regressor.get_variable_value('dense/kernel'), + true_weights, + atol=1.0) + + @parameterized.named_parameters( + ('DPGradientDescent', dp_optimizer.DPGradientDescentOptimizer), + ('DPAdagrad', dp_optimizer.DPAdagradOptimizer), + ('DPAdam', dp_optimizer.DPAdamOptimizer)) + def testUnrollMicrobatches(self, cls): + with self.cached_session() as sess: + var0 = tf.Variable([1.0, 2.0]) + data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) + + num_microbatches = 4 + + dp_sum_query = gaussian_query.GaussianSumQuery(1.0e9, 0.0) + dp_sum_query = privacy_ledger.QueryWithLedger( + dp_sum_query, 1e6, num_microbatches / 1e6) + + opt = cls( + dp_sum_query, + num_microbatches=num_microbatches, + learning_rate=2.0, + unroll_microbatches=True) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + + # Expected gradient is sum of differences divided by number of + # microbatches. + gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0]) + grads_and_vars = sess.run(gradient_op) + self.assertAllCloseAccordingToType([-2.5, -2.5], grads_and_vars[0][0]) + + @parameterized.named_parameters( + ('DPGradientDescent', dp_optimizer.DPGradientDescentGaussianOptimizer), + ('DPAdagrad', dp_optimizer.DPAdagradGaussianOptimizer), + ('DPAdam', dp_optimizer.DPAdamGaussianOptimizer)) + def testDPGaussianOptimizerClass(self, cls): + with self.cached_session() as sess: + var0 = tf.Variable([0.0]) + data0 = tf.Variable([[0.0]]) + + opt = cls( + l2_norm_clip=4.0, + noise_multiplier=2.0, + num_microbatches=1, + learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([0.0], self.evaluate(var0)) + + gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0]) + grads = [] + for _ in range(1000): + grads_and_vars = sess.run(gradient_op) + grads.append(grads_and_vars[0][0]) + + # Test standard deviation is close to l2_norm_clip * noise_multiplier. + self.assertNear(np.std(grads), 2.0 * 4.0, 0.5) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized.py new file mode 100644 index 00000000..7295e1dd --- /dev/null +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized.py @@ -0,0 +1,153 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Vectorized differentially private optimizers for TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from distutils.version import LooseVersion +import tensorflow as tf + +if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + nest = tf.contrib.framework.nest + AdagradOptimizer = tf.train.AdagradOptimizer + AdamOptimizer = tf.train.AdamOptimizer + GradientDescentOptimizer = tf.train.GradientDescentOptimizer + parent_code = tf.train.Optimizer.compute_gradients.__code__ + GATE_OP = tf.train.Optimizer.GATE_OP # pylint: disable=invalid-name +else: + nest = tf.nest + AdagradOptimizer = tf.optimizers.Adagrad + AdamOptimizer = tf.optimizers.Adam + GradientDescentOptimizer = tf.optimizers.SGD # pylint: disable=invalid-name + parent_code = tf.optimizers.Optimizer._compute_gradients.__code__ # pylint: disable=protected-access + GATE_OP = None # pylint: disable=invalid-name + + +def make_vectorized_optimizer_class(cls): + """Constructs a vectorized DP optimizer class from an existing one.""" + if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): + child_code = cls.compute_gradients.__code__ + else: + child_code = cls._compute_gradients.__code__ # pylint: disable=protected-access + if child_code is not parent_code: + tf.logging.warning( + 'WARNING: Calling make_optimizer_class() on class %s that overrides ' + 'method compute_gradients(). Check to ensure that ' + 'make_optimizer_class() does not interfere with overridden version.', + cls.__name__) + + class DPOptimizerClass(cls): + """Differentially private subclass of given class cls.""" + + def __init__( + self, + l2_norm_clip, + noise_multiplier, + num_microbatches=None, + *args, # pylint: disable=keyword-arg-before-vararg, g-doc-args + **kwargs): + """Initialize the DPOptimizerClass. + + Args: + l2_norm_clip: Clipping norm (max L2 norm of per microbatch gradients) + noise_multiplier: Ratio of the standard deviation to the clipping norm + num_microbatches: How many microbatches into which the minibatch is + split. If None, will default to the size of the minibatch, and + per-example gradients will be computed. + """ + super(DPOptimizerClass, self).__init__(*args, **kwargs) + self._l2_norm_clip = l2_norm_clip + self._noise_multiplier = noise_multiplier + self._num_microbatches = num_microbatches + + def compute_gradients(self, + loss, + var_list, + gate_gradients=GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + grad_loss=None, + gradient_tape=None): + if callable(loss): + # TF is running in Eager mode + raise NotImplementedError('Vectorized optimizer unavailable for TF2.') + else: + # TF is running in graph mode, check we did not receive a gradient tape. + if gradient_tape: + raise ValueError('When in graph mode, a tape should not be passed.') + + batch_size = tf.shape(loss)[0] + if self._num_microbatches is None: + self._num_microbatches = batch_size + + # Note: it would be closer to the correct i.i.d. sampling of records if + # we sampled each microbatch from the appropriate binomial distribution, + # although that still wouldn't be quite correct because it would be + # sampling from the dataset without replacement. + microbatch_losses = tf.reshape(loss, [self._num_microbatches, -1]) + + if var_list is None: + var_list = ( + tf.trainable_variables() + tf.get_collection( + tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + + def process_microbatch(microbatch_loss): + """Compute clipped grads for one microbatch.""" + microbatch_loss = tf.reduce_mean(microbatch_loss) + grads, _ = zip(*super(DPOptimizerClass, self).compute_gradients( + microbatch_loss, + var_list, + gate_gradients, + aggregation_method, + colocate_gradients_with_ops, + grad_loss)) + grads_list = [ + g if g is not None else tf.zeros_like(v) + for (g, v) in zip(list(grads), var_list) + ] + # Clip gradients to have L2 norm of l2_norm_clip. + # Here, we use TF primitives rather than the built-in + # tf.clip_by_global_norm() so that operations can be vectorized + # across microbatches. + grads_flat = nest.flatten(grads_list) + squared_l2_norms = [tf.reduce_sum(tf.square(g)) for g in grads_flat] + global_norm = tf.sqrt(tf.add_n(squared_l2_norms)) + div = tf.maximum(global_norm / self._l2_norm_clip, 1.) + clipped_flat = [g / div for g in grads_flat] + clipped_grads = nest.pack_sequence_as(grads_list, clipped_flat) + return clipped_grads + + clipped_grads = tf.vectorized_map(process_microbatch, microbatch_losses) + + def reduce_noise_normalize_batch(stacked_grads): + summed_grads = tf.reduce_sum(stacked_grads, axis=0) + noise_stddev = self._l2_norm_clip * self._noise_multiplier + noise = tf.random.normal(tf.shape(summed_grads), + stddev=noise_stddev) + noised_grads = summed_grads + noise + return noised_grads / tf.cast(self._num_microbatches, tf.float32) + + final_grads = nest.map_structure(reduce_noise_normalize_batch, + clipped_grads) + + return list(zip(final_grads, var_list)) + + return DPOptimizerClass + + +VectorizedDPAdagrad = make_vectorized_optimizer_class(AdagradOptimizer) +VectorizedDPAdam = make_vectorized_optimizer_class(AdamOptimizer) +VectorizedDPSGD = make_vectorized_optimizer_class(GradientDescentOptimizer) diff --git a/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized_test.py b/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized_test.py new file mode 100644 index 00000000..21f00e80 --- /dev/null +++ b/tensorflow_privacy/privacy/optimizers/dp_optimizer_vectorized_test.py @@ -0,0 +1,204 @@ +# Copyright 2019, The TensorFlow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for differentially private optimizers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import mock +import numpy as np +import tensorflow as tf + +from tensorflow_privacy.privacy.optimizers import dp_optimizer_vectorized +from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdagrad +from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPAdam +from tensorflow_privacy.privacy.optimizers.dp_optimizer_vectorized import VectorizedDPSGD + + +class DPOptimizerTest(tf.test.TestCase, parameterized.TestCase): + + def _loss(self, val0, val1): + """Loss function that is minimized at the mean of the input points.""" + return 0.5 * tf.reduce_sum(tf.squared_difference(val0, val1), axis=1) + + # Parameters for testing: optimizer, num_microbatches, expected answer. + @parameterized.named_parameters( + ('DPGradientDescent 1', VectorizedDPSGD, 1, [-2.5, -2.5]), + ('DPGradientDescent 2', VectorizedDPSGD, 2, [-2.5, -2.5]), + ('DPGradientDescent 4', VectorizedDPSGD, 4, [-2.5, -2.5]), + ('DPAdagrad 1', VectorizedDPAdagrad, 1, [-2.5, -2.5]), + ('DPAdagrad 2', VectorizedDPAdagrad, 2, [-2.5, -2.5]), + ('DPAdagrad 4', VectorizedDPAdagrad, 4, [-2.5, -2.5]), + ('DPAdam 1', VectorizedDPAdam, 1, [-2.5, -2.5]), + ('DPAdam 2', VectorizedDPAdam, 2, [-2.5, -2.5]), + ('DPAdam 4', VectorizedDPAdam, 4, [-2.5, -2.5])) + def testBaseline(self, cls, num_microbatches, expected_answer): + with self.cached_session() as sess: + var0 = tf.Variable([1.0, 2.0]) + data0 = tf.Variable([[3.0, 4.0], [5.0, 6.0], [7.0, 8.0], [-1.0, 0.0]]) + + opt = cls( + l2_norm_clip=1.0e9, + noise_multiplier=0.0, + num_microbatches=num_microbatches, + learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([1.0, 2.0], self.evaluate(var0)) + + # Expected gradient is sum of differences divided by number of + # microbatches. + gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0]) + grads_and_vars = sess.run(gradient_op) + self.assertAllCloseAccordingToType(expected_answer, grads_and_vars[0][0]) + + @parameterized.named_parameters( + ('DPGradientDescent', VectorizedDPSGD), + ('DPAdagrad', VectorizedDPAdagrad), + ('DPAdam', VectorizedDPAdam)) + def testClippingNorm(self, cls): + with self.cached_session() as sess: + var0 = tf.Variable([0.0, 0.0]) + data0 = tf.Variable([[3.0, 4.0], [6.0, 8.0]]) + + opt = cls(l2_norm_clip=1.0, + noise_multiplier=0., + num_microbatches=1, + learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([0.0, 0.0], self.evaluate(var0)) + + # Expected gradient is sum of differences. + gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0]) + grads_and_vars = sess.run(gradient_op) + self.assertAllCloseAccordingToType([-0.6, -0.8], grads_and_vars[0][0]) + + @parameterized.named_parameters( + ('DPGradientDescent', VectorizedDPSGD), + ('DPAdagrad', VectorizedDPAdagrad), + ('DPAdam', VectorizedDPAdam)) + def testNoiseMultiplier(self, cls): + with self.cached_session() as sess: + var0 = tf.Variable([0.0]) + data0 = tf.Variable([[0.0]]) + + opt = cls(l2_norm_clip=4.0, + noise_multiplier=8.0, + num_microbatches=1, + learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([0.0], self.evaluate(var0)) + + gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0]) + grads = [] + for _ in range(5000): + grads_and_vars = sess.run(gradient_op) + grads.append(grads_and_vars[0][0]) + + # Test standard deviation is close to l2_norm_clip * noise_multiplier. + self.assertNear(np.std(grads), 4.0 * 8.0, 0.5) + + @mock.patch.object(tf, 'logging') + def testComputeGradientsOverrideWarning(self, mock_logging): + + class SimpleOptimizer(tf.train.Optimizer): + + def compute_gradients(self): + return 0 + + dp_optimizer_vectorized.make_vectorized_optimizer_class(SimpleOptimizer) + mock_logging.warning.assert_called_once_with( + 'WARNING: Calling make_optimizer_class() on class %s that overrides ' + 'method compute_gradients(). Check to ensure that ' + 'make_optimizer_class() does not interfere with overridden version.', + 'SimpleOptimizer') + + def testEstimator(self): + """Tests that DP optimizers work with tf.estimator.""" + + def linear_model_fn(features, labels, mode): + preds = tf.keras.layers.Dense( + 1, activation='linear', name='dense').apply(features['x']) + + vector_loss = tf.squared_difference(labels, preds) + scalar_loss = tf.reduce_mean(vector_loss) + optimizer = VectorizedDPSGD( + l2_norm_clip=1.0, + noise_multiplier=0., + num_microbatches=1, + learning_rate=1.0) + global_step = tf.train.get_global_step() + train_op = optimizer.minimize(loss=vector_loss, global_step=global_step) + return tf.estimator.EstimatorSpec( + mode=mode, loss=scalar_loss, train_op=train_op) + + linear_regressor = tf.estimator.Estimator(model_fn=linear_model_fn) + true_weights = np.array([[-5], [4], [3], [2]]).astype(np.float32) + true_bias = 6.0 + train_data = np.random.normal(scale=3.0, size=(200, 4)).astype(np.float32) + + train_labels = np.matmul(train_data, + true_weights) + true_bias + np.random.normal( + scale=0.1, size=(200, 1)).astype(np.float32) + + train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={'x': train_data}, + y=train_labels, + batch_size=20, + num_epochs=10, + shuffle=True) + linear_regressor.train(input_fn=train_input_fn, steps=100) + self.assertAllClose( + linear_regressor.get_variable_value('dense/kernel'), + true_weights, + atol=1.0) + + @parameterized.named_parameters( + ('DPGradientDescent', VectorizedDPSGD), + ('DPAdagrad', VectorizedDPAdagrad), + ('DPAdam', VectorizedDPAdam)) + def testDPGaussianOptimizerClass(self, cls): + with self.cached_session() as sess: + var0 = tf.Variable([0.0]) + data0 = tf.Variable([[0.0]]) + + opt = cls( + l2_norm_clip=4.0, + noise_multiplier=2.0, + num_microbatches=1, + learning_rate=2.0) + + self.evaluate(tf.global_variables_initializer()) + # Fetch params to validate initial values + self.assertAllClose([0.0], self.evaluate(var0)) + + gradient_op = opt.compute_gradients(self._loss(data0, var0), [var0]) + grads = [] + for _ in range(1000): + grads_and_vars = sess.run(gradient_op) + grads.append(grads_and_vars[0][0]) + + # Test standard deviation is close to l2_norm_clip * noise_multiplier. + self.assertNear(np.std(grads), 2.0 * 4.0, 0.5) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tutorials/bolton_tutorial.py b/tutorials/bolton_tutorial.py index 13181a6d..55c8682d 100644 --- a/tutorials/bolton_tutorial.py +++ b/tutorials/bolton_tutorial.py @@ -16,9 +16,9 @@ from __future__ import division from __future__ import print_function import tensorflow as tf # pylint: disable=wrong-import-position -from privacy.bolt_on import losses # pylint: disable=wrong-import-position -from privacy.bolt_on import models # pylint: disable=wrong-import-position -from privacy.bolt_on.optimizers import BoltOn # pylint: disable=wrong-import-position +from tensorflow_privacy.privacy.bolt_on import losses # pylint: disable=wrong-import-position +from tensorflow_privacy.privacy.bolt_on import models # pylint: disable=wrong-import-position +from tensorflow_privacy.privacy.bolt_on.optimizers import BoltOn # pylint: disable=wrong-import-position # ------- # First, we will create a binary classification dataset with a single output # dimension. The samples for each label are repeated data points at different diff --git a/tutorials/lm_dpsgd_tutorial.py b/tutorials/lm_dpsgd_tutorial.py index 67398ea9..d41dda3b 100644 --- a/tutorials/lm_dpsgd_tutorial.py +++ b/tutorials/lm_dpsgd_tutorial.py @@ -44,10 +44,10 @@ import tensorflow as tf import tensorflow_datasets as tfds -from privacy.analysis import privacy_ledger -from privacy.analysis.rdp_accountant import compute_rdp -from privacy.analysis.rdp_accountant import get_privacy_spent -from privacy.optimizers import dp_optimizer +from tensorflow_privacy.privacy.analysis import privacy_ledger +from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp +from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent +from tensorflow_privacy.privacy.optimizers import dp_optimizer flags.DEFINE_boolean( 'dpsgd', True, 'If True, train with DP-SGD. If False, ' diff --git a/tutorials/mnist_dpsgd_tutorial.py b/tutorials/mnist_dpsgd_tutorial.py index f5864454..64f03c3f 100644 --- a/tutorials/mnist_dpsgd_tutorial.py +++ b/tutorials/mnist_dpsgd_tutorial.py @@ -26,10 +26,10 @@ import numpy as np import tensorflow as tf -from privacy.analysis import privacy_ledger -from privacy.analysis.rdp_accountant import compute_rdp_from_ledger -from privacy.analysis.rdp_accountant import get_privacy_spent -from privacy.optimizers import dp_optimizer +from tensorflow_privacy.privacy.analysis import privacy_ledger +from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp_from_ledger +from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent +from tensorflow_privacy.privacy.optimizers import dp_optimizer if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): GradientDescentOptimizer = tf.train.GradientDescentOptimizer diff --git a/tutorials/mnist_dpsgd_tutorial_eager.py b/tutorials/mnist_dpsgd_tutorial_eager.py index 94b03d4b..07af6025 100644 --- a/tutorials/mnist_dpsgd_tutorial_eager.py +++ b/tutorials/mnist_dpsgd_tutorial_eager.py @@ -24,9 +24,9 @@ import numpy as np import tensorflow as tf -from privacy.analysis.rdp_accountant import compute_rdp -from privacy.analysis.rdp_accountant import get_privacy_spent -from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer +from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp +from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent +from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): GradientDescentOptimizer = tf.train.GradientDescentOptimizer diff --git a/tutorials/mnist_dpsgd_tutorial_keras.py b/tutorials/mnist_dpsgd_tutorial_keras.py index 865fb9f7..89ce1dc7 100644 --- a/tutorials/mnist_dpsgd_tutorial_keras.py +++ b/tutorials/mnist_dpsgd_tutorial_keras.py @@ -25,9 +25,9 @@ import numpy as np import tensorflow as tf -from privacy.analysis.rdp_accountant import compute_rdp -from privacy.analysis.rdp_accountant import get_privacy_spent -from privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer +from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp +from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent +from tensorflow_privacy.privacy.optimizers.dp_optimizer import DPGradientDescentGaussianOptimizer if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): GradientDescentOptimizer = tf.train.GradientDescentOptimizer diff --git a/tutorials/mnist_dpsgd_tutorial_vectorized.py b/tutorials/mnist_dpsgd_tutorial_vectorized.py index 2b78f82f..a075cd43 100644 --- a/tutorials/mnist_dpsgd_tutorial_vectorized.py +++ b/tutorials/mnist_dpsgd_tutorial_vectorized.py @@ -26,9 +26,9 @@ import numpy as np import tensorflow as tf -from privacy.analysis.rdp_accountant import compute_rdp -from privacy.analysis.rdp_accountant import get_privacy_spent -from privacy.optimizers import dp_optimizer_vectorized +from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp +from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent +from tensorflow_privacy.privacy.optimizers import dp_optimizer_vectorized flags.DEFINE_boolean( diff --git a/tutorials/mnist_lr_tutorial.py b/tutorials/mnist_lr_tutorial.py index 62f446d5..c8bbf04a 100644 --- a/tutorials/mnist_lr_tutorial.py +++ b/tutorials/mnist_lr_tutorial.py @@ -35,9 +35,9 @@ import numpy as np import tensorflow as tf -from privacy.analysis.rdp_accountant import compute_rdp -from privacy.analysis.rdp_accountant import get_privacy_spent -from privacy.optimizers import dp_optimizer +from tensorflow_privacy.privacy.analysis.rdp_accountant import compute_rdp +from tensorflow_privacy.privacy.analysis.rdp_accountant import get_privacy_spent +from tensorflow_privacy.privacy.optimizers import dp_optimizer if LooseVersion(tf.__version__) < LooseVersion('2.0.0'): GradientDescentOptimizer = tf.train.GradientDescentOptimizer