Bugfix for the case where epoch_num is 0 with an accompanying test.
PiperOrigin-RevId: 346072261
This commit is contained in:
parent
e7c21abb09
commit
fcac288849
2 changed files with 13 additions and 5 deletions
|
@ -134,5 +134,5 @@ def _validate_results(results: Iterable[AttackResults]):
|
|||
for attack_results in results:
|
||||
if not attack_results or not attack_results.privacy_report_metadata:
|
||||
raise ValueError('Privacy metadata is not defined.')
|
||||
if not attack_results.privacy_report_metadata.epoch_num:
|
||||
if attack_results.privacy_report_metadata.epoch_num is None:
|
||||
raise ValueError('epoch_num in metadata is not defined.')
|
||||
|
|
|
@ -52,6 +52,14 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
fpr=np.array([1.0, 1.0, 0.0]),
|
||||
thresholds=np.array([0, 1, 2])))
|
||||
|
||||
self.results_epoch_0 = AttackResults(
|
||||
single_attack_results=[self.imperfect_classifier_result],
|
||||
privacy_report_metadata=PrivacyReportMetadata(
|
||||
accuracy_train=0.4,
|
||||
accuracy_test=0.3,
|
||||
epoch_num=0,
|
||||
model_variant_label='default'))
|
||||
|
||||
self.results_epoch_10 = AttackResults(
|
||||
single_attack_results=[self.imperfect_classifier_result],
|
||||
privacy_report_metadata=PrivacyReportMetadata(
|
||||
|
@ -87,14 +95,14 @@ class PrivacyReportTest(absltest.TestCase):
|
|||
|
||||
def test_single_metric_plot_by_epochs(self):
|
||||
fig = privacy_report.plot_by_epochs(
|
||||
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
||||
['AUC'])
|
||||
AttackResultsCollection((self.results_epoch_0, self.results_epoch_10,
|
||||
self.results_epoch_15)), ['AUC'])
|
||||
# extract data from figure.
|
||||
auc_data = fig.gca().lines[0].get_data()
|
||||
# X axis lists epoch values
|
||||
np.testing.assert_array_equal(auc_data[0], [10, 15])
|
||||
np.testing.assert_array_equal(auc_data[0], [0, 10, 15])
|
||||
# Y axis lists AUC values
|
||||
np.testing.assert_array_equal(auc_data[1], [0.5, 1.0])
|
||||
np.testing.assert_array_equal(auc_data[1], [0.5, 0.5, 1.0])
|
||||
# Check the title
|
||||
self.assertEqual(fig._suptitle.get_text(), 'Vulnerability per Epoch')
|
||||
|
||||
|
|
Loading…
Reference in a new issue