diff --git a/tensorflow_privacy/privacy/analysis/dp_event_builder.py b/tensorflow_privacy/privacy/analysis/dp_event_builder.py index 722a1e4..53d4cc2 100644 --- a/tensorflow_privacy/privacy/analysis/dp_event_builder.py +++ b/tensorflow_privacy/privacy/analysis/dp_event_builder.py @@ -13,8 +13,6 @@ # limitations under the License. """Builder class for ComposedDpEvent.""" -import collections - from tensorflow_privacy.privacy.analysis import dp_event @@ -28,7 +26,8 @@ class DpEventBuilder(object): """ def __init__(self): - self._events = collections.OrderedDict() + # A list of (event, count) pairs. + self._event_counts = [] self._composed_event = None def compose(self, event: dp_event.DpEvent, count: int = 1): @@ -46,33 +45,32 @@ class DpEventBuilder(object): if count < 1: raise ValueError(f'`count` must be positive. Found {count}.') - if isinstance(event, dp_event.ComposedDpEvent): - for composed_event in event.events: - self.compose(composed_event, count) + if isinstance(event, dp_event.NoOpDpEvent): + return elif isinstance(event, dp_event.SelfComposedDpEvent): self.compose(event.event, count * event.count) - elif isinstance(event, dp_event.NoOpDpEvent): - return else: - current_count = self._events.get(event, 0) - self._events[event] = current_count + count + if self._event_counts and self._event_counts[-1][0] == event: + new_event_count = (event, self._event_counts[-1][1] + count) + self._event_counts[-1] = new_event_count + else: + self._event_counts.append((event, count)) self._composed_event = None def build(self) -> dp_event.DpEvent: """Builds and returns the composed DpEvent represented by the builder.""" if not self._composed_event: - self_composed_events = [] - for event, count in self._events.items(): + events = [] + for event, count in self._event_counts: if count == 1: - self_composed_events.append(event) + events.append(event) else: - self_composed_events.append( - dp_event.SelfComposedDpEvent(event, count)) - if not self_composed_events: + events.append(dp_event.SelfComposedDpEvent(event, count)) + if not events: self._composed_event = dp_event.NoOpDpEvent() - elif len(self_composed_events) == 1: - self._composed_event = self_composed_events[0] + elif len(events) == 1: + self._composed_event = events[0] else: - self._composed_event = dp_event.ComposedDpEvent(self_composed_events) + self._composed_event = dp_event.ComposedDpEvent(events) return self._composed_event diff --git a/tensorflow_privacy/privacy/analysis/dp_event_builder_test.py b/tensorflow_privacy/privacy/analysis/dp_event_builder_test.py index a10d4bb..dd8a5f2 100644 --- a/tensorflow_privacy/privacy/analysis/dp_event_builder_test.py +++ b/tensorflow_privacy/privacy/analysis/dp_event_builder_test.py @@ -20,8 +20,6 @@ from tensorflow_privacy.privacy.analysis import dp_event_builder _gaussian_event = dp_event.GaussianDpEvent(1.0) _poisson_event = dp_event.PoissonSampledDpEvent(_gaussian_event, 0.1) _self_composed_event = dp_event.SelfComposedDpEvent(_gaussian_event, 3) -_composed_event = dp_event.ComposedDpEvent( - [_self_composed_event, _poisson_event]) class DpEventBuilderTest(absltest.TestCase): @@ -50,22 +48,27 @@ class DpEventBuilderTest(absltest.TestCase): def test_compose_heterogenous(self): builder = dp_event_builder.DpEventBuilder() + builder.compose(_poisson_event) builder.compose(_gaussian_event) - builder.compose(_poisson_event) builder.compose(_gaussian_event, 2) - self.assertEqual(_composed_event, builder.build()) + builder.compose(_poisson_event) + expected_event = dp_event.ComposedDpEvent( + [_poisson_event, _self_composed_event, _poisson_event]) + self.assertEqual(expected_event, builder.build()) - def test_compose_complex(self): + def test_compose_composed(self): builder = dp_event_builder.DpEventBuilder() - builder.compose(_gaussian_event, 2) - builder.compose(_composed_event) + composed_event = dp_event.ComposedDpEvent( + [_gaussian_event, _poisson_event, _self_composed_event]) + builder.compose(_gaussian_event) + builder.compose(composed_event) + builder.compose(composed_event, 2) + builder.compose(_poisson_event) builder.compose(_poisson_event) - builder.compose(_composed_event, 2) - expected_event = dp_event.ComposedDpEvent([ - dp_event.SelfComposedDpEvent(_gaussian_event, 11), - dp_event.SelfComposedDpEvent(_poisson_event, 4) - ]) + _gaussian_event, + dp_event.SelfComposedDpEvent(composed_event, 3), + dp_event.SelfComposedDpEvent(_poisson_event, 2)]) self.assertEqual(expected_event, builder.build())