diff --git a/tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py index e05774a..a2cb56f 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report.py @@ -134,5 +134,5 @@ def _validate_results(results: Iterable[AttackResults]): for attack_results in results: if not attack_results or not attack_results.privacy_report_metadata: raise ValueError('Privacy metadata is not defined.') - if not attack_results.privacy_report_metadata.epoch_num: + if attack_results.privacy_report_metadata.epoch_num is None: raise ValueError('epoch_num in metadata is not defined.') diff --git a/tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py index 8b3bdf2..326cad0 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/privacy_report_test.py @@ -52,6 +52,14 @@ class PrivacyReportTest(absltest.TestCase): fpr=np.array([1.0, 1.0, 0.0]), thresholds=np.array([0, 1, 2]))) + self.results_epoch_0 = AttackResults( + single_attack_results=[self.imperfect_classifier_result], + privacy_report_metadata=PrivacyReportMetadata( + accuracy_train=0.4, + accuracy_test=0.3, + epoch_num=0, + model_variant_label='default')) + self.results_epoch_10 = AttackResults( single_attack_results=[self.imperfect_classifier_result], privacy_report_metadata=PrivacyReportMetadata( @@ -87,14 +95,14 @@ class PrivacyReportTest(absltest.TestCase): def test_single_metric_plot_by_epochs(self): fig = privacy_report.plot_by_epochs( - AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)), - ['AUC']) + AttackResultsCollection((self.results_epoch_0, self.results_epoch_10, + self.results_epoch_15)), ['AUC']) # extract data from figure. auc_data = fig.gca().lines[0].get_data() # X axis lists epoch values - np.testing.assert_array_equal(auc_data[0], [10, 15]) + np.testing.assert_array_equal(auc_data[0], [0, 10, 15]) # Y axis lists AUC values - np.testing.assert_array_equal(auc_data[1], [0.5, 1.0]) + np.testing.assert_array_equal(auc_data[1], [0.5, 0.5, 1.0]) # Check the title self.assertEqual(fig._suptitle.get_text(), 'Vulnerability per Epoch')