Bugfix for the case where epoch_num is 0 with an accompanying test.

PiperOrigin-RevId: 346072261
This commit is contained in:
David Marn 2020-12-07 05:36:17 -08:00 committed by A. Unique TensorFlower
parent e7c21abb09
commit fcac288849
2 changed files with 13 additions and 5 deletions

View file

@ -134,5 +134,5 @@ def _validate_results(results: Iterable[AttackResults]):
for attack_results in results: for attack_results in results:
if not attack_results or not attack_results.privacy_report_metadata: if not attack_results or not attack_results.privacy_report_metadata:
raise ValueError('Privacy metadata is not defined.') raise ValueError('Privacy metadata is not defined.')
if not attack_results.privacy_report_metadata.epoch_num: if attack_results.privacy_report_metadata.epoch_num is None:
raise ValueError('epoch_num in metadata is not defined.') raise ValueError('epoch_num in metadata is not defined.')

View file

@ -52,6 +52,14 @@ class PrivacyReportTest(absltest.TestCase):
fpr=np.array([1.0, 1.0, 0.0]), fpr=np.array([1.0, 1.0, 0.0]),
thresholds=np.array([0, 1, 2]))) thresholds=np.array([0, 1, 2])))
self.results_epoch_0 = AttackResults(
single_attack_results=[self.imperfect_classifier_result],
privacy_report_metadata=PrivacyReportMetadata(
accuracy_train=0.4,
accuracy_test=0.3,
epoch_num=0,
model_variant_label='default'))
self.results_epoch_10 = AttackResults( self.results_epoch_10 = AttackResults(
single_attack_results=[self.imperfect_classifier_result], single_attack_results=[self.imperfect_classifier_result],
privacy_report_metadata=PrivacyReportMetadata( privacy_report_metadata=PrivacyReportMetadata(
@ -87,14 +95,14 @@ class PrivacyReportTest(absltest.TestCase):
def test_single_metric_plot_by_epochs(self): def test_single_metric_plot_by_epochs(self):
fig = privacy_report.plot_by_epochs( fig = privacy_report.plot_by_epochs(
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)), AttackResultsCollection((self.results_epoch_0, self.results_epoch_10,
['AUC']) self.results_epoch_15)), ['AUC'])
# extract data from figure. # extract data from figure.
auc_data = fig.gca().lines[0].get_data() auc_data = fig.gca().lines[0].get_data()
# X axis lists epoch values # X axis lists epoch values
np.testing.assert_array_equal(auc_data[0], [10, 15]) np.testing.assert_array_equal(auc_data[0], [0, 10, 15])
# Y axis lists AUC values # Y axis lists AUC values
np.testing.assert_array_equal(auc_data[1], [0.5, 1.0]) np.testing.assert_array_equal(auc_data[1], [0.5, 0.5, 1.0])
# Check the title # Check the title
self.assertEqual(fig._suptitle.get_text(), 'Vulnerability per Epoch') self.assertEqual(fig._suptitle.get_text(), 'Vulnerability per Epoch')