Bugfix for the case where epoch_num is 0 with an accompanying test.
PiperOrigin-RevId: 346072261
This commit is contained in:
parent
e7c21abb09
commit
fcac288849
2 changed files with 13 additions and 5 deletions
|
@ -134,5 +134,5 @@ def _validate_results(results: Iterable[AttackResults]):
|
||||||
for attack_results in results:
|
for attack_results in results:
|
||||||
if not attack_results or not attack_results.privacy_report_metadata:
|
if not attack_results or not attack_results.privacy_report_metadata:
|
||||||
raise ValueError('Privacy metadata is not defined.')
|
raise ValueError('Privacy metadata is not defined.')
|
||||||
if not attack_results.privacy_report_metadata.epoch_num:
|
if attack_results.privacy_report_metadata.epoch_num is None:
|
||||||
raise ValueError('epoch_num in metadata is not defined.')
|
raise ValueError('epoch_num in metadata is not defined.')
|
||||||
|
|
|
@ -52,6 +52,14 @@ class PrivacyReportTest(absltest.TestCase):
|
||||||
fpr=np.array([1.0, 1.0, 0.0]),
|
fpr=np.array([1.0, 1.0, 0.0]),
|
||||||
thresholds=np.array([0, 1, 2])))
|
thresholds=np.array([0, 1, 2])))
|
||||||
|
|
||||||
|
self.results_epoch_0 = AttackResults(
|
||||||
|
single_attack_results=[self.imperfect_classifier_result],
|
||||||
|
privacy_report_metadata=PrivacyReportMetadata(
|
||||||
|
accuracy_train=0.4,
|
||||||
|
accuracy_test=0.3,
|
||||||
|
epoch_num=0,
|
||||||
|
model_variant_label='default'))
|
||||||
|
|
||||||
self.results_epoch_10 = AttackResults(
|
self.results_epoch_10 = AttackResults(
|
||||||
single_attack_results=[self.imperfect_classifier_result],
|
single_attack_results=[self.imperfect_classifier_result],
|
||||||
privacy_report_metadata=PrivacyReportMetadata(
|
privacy_report_metadata=PrivacyReportMetadata(
|
||||||
|
@ -87,14 +95,14 @@ class PrivacyReportTest(absltest.TestCase):
|
||||||
|
|
||||||
def test_single_metric_plot_by_epochs(self):
|
def test_single_metric_plot_by_epochs(self):
|
||||||
fig = privacy_report.plot_by_epochs(
|
fig = privacy_report.plot_by_epochs(
|
||||||
AttackResultsCollection((self.results_epoch_10, self.results_epoch_15)),
|
AttackResultsCollection((self.results_epoch_0, self.results_epoch_10,
|
||||||
['AUC'])
|
self.results_epoch_15)), ['AUC'])
|
||||||
# extract data from figure.
|
# extract data from figure.
|
||||||
auc_data = fig.gca().lines[0].get_data()
|
auc_data = fig.gca().lines[0].get_data()
|
||||||
# X axis lists epoch values
|
# X axis lists epoch values
|
||||||
np.testing.assert_array_equal(auc_data[0], [10, 15])
|
np.testing.assert_array_equal(auc_data[0], [0, 10, 15])
|
||||||
# Y axis lists AUC values
|
# Y axis lists AUC values
|
||||||
np.testing.assert_array_equal(auc_data[1], [0.5, 1.0])
|
np.testing.assert_array_equal(auc_data[1], [0.5, 0.5, 1.0])
|
||||||
# Check the title
|
# Check the title
|
||||||
self.assertEqual(fig._suptitle.get_text(), 'Vulnerability per Epoch')
|
self.assertEqual(fig._suptitle.get_text(), 'Vulnerability per Epoch')
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue