fix softmax issue
This commit is contained in:
parent
f677c9c440
commit
e547a10eec
1 changed files with 392 additions and 379 deletions
|
@ -51,14 +51,14 @@
|
||||||
"id": "-B5ZvlSqqLaR"
|
"id": "-B5ZvlSqqLaR"
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
|
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
|
||||||
" \u003ctd\u003e\n",
|
" <td>\n",
|
||||||
" \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
|
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
|
||||||
" \u003c/td\u003e\n",
|
" </td>\n",
|
||||||
" \u003ctd\u003e\n",
|
" <td>\n",
|
||||||
" \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n",
|
" <a target=\"_blank\" href=\"https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/membership_inference_attack/codelab.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
|
||||||
" \u003c/td\u003e\n",
|
" </td>\n",
|
||||||
"\u003c/table\u003e"
|
"</table>"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -80,7 +80,7 @@
|
||||||
},
|
},
|
||||||
"source": [
|
"source": [
|
||||||
"## Setup\n",
|
"## Setup\n",
|
||||||
"First, set this notebook's runtime to use a GPU, under Runtime \u003e Change runtime type \u003e Hardware accelerator. Then, begin importing the necessary libraries."
|
"First, set this notebook's runtime to use a GPU, under Runtime > Change runtime type > Hardware accelerator. Then, begin importing the necessary libraries."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -173,7 +173,7 @@
|
||||||
"def small_cnn(input_shape: Tuple[int],\n",
|
"def small_cnn(input_shape: Tuple[int],\n",
|
||||||
" num_classes: int,\n",
|
" num_classes: int,\n",
|
||||||
" num_conv: int,\n",
|
" num_conv: int,\n",
|
||||||
" activation: Text = 'relu') -\u003e tf.keras.models.Sequential:\n",
|
" activation: Text = 'relu') -> tf.keras.models.Sequential:\n",
|
||||||
" \"\"\"Setup a small CNN for image classification.\n",
|
" \"\"\"Setup a small CNN for image classification.\n",
|
||||||
"\n",
|
"\n",
|
||||||
" Args:\n",
|
" Args:\n",
|
||||||
|
@ -265,8 +265,8 @@
|
||||||
"logits_test = model.predict(x_test, batch_size=batch_size)\n",
|
"logits_test = model.predict(x_test, batch_size=batch_size)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print('Apply softmax to get probabilities from logits...')\n",
|
"print('Apply softmax to get probabilities from logits...')\n",
|
||||||
"prob_train = special.softmax(logits_train)\n",
|
"prob_train = special.softmax(logits_train, axis=1)\n",
|
||||||
"prob_test = special.softmax(logits_test)\n",
|
"prob_test = special.softmax(logits_test, axis=1)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"print('Compute losses...')\n",
|
"print('Compute losses...')\n",
|
||||||
"cce = tf.keras.backend.categorical_crossentropy\n",
|
"cce = tf.keras.backend.categorical_crossentropy\n",
|
||||||
|
@ -365,8 +365,21 @@
|
||||||
},
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3",
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.6.10"
|
||||||
|
},
|
||||||
"pycharm": {
|
"pycharm": {
|
||||||
"stem_cell": {
|
"stem_cell": {
|
||||||
"cell_type": "raw",
|
"cell_type": "raw",
|
||||||
|
@ -378,5 +391,5 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 0
|
"nbformat_minor": 1
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue