diff --git a/tensorflow_privacy/privacy/membership_inference_attack/models.py b/tensorflow_privacy/privacy/membership_inference_attack/models.py index 00f9eb6..b778b7f 100644 --- a/tensorflow_privacy/privacy/membership_inference_attack/models.py +++ b/tensorflow_privacy/privacy/membership_inference_attack/models.py @@ -183,7 +183,7 @@ def _get_average_ranks(logits: Iterator[np.ndarray], def _get_ranks_for_sequence(logits: np.ndarray, - labels: np.ndarray) -> List: + labels: np.ndarray) -> List[float]: """Returns ranks for a sequence. Args: