Compare commits
5 commits
master
...
akemi/lira
Author | SHA1 | Date | |
---|---|---|---|
72fcf6732a | |||
1a61c2a36a | |||
9a0611fb37 | |||
9c67e0abef | |||
f3af0a3a5c |
4 changed files with 62 additions and 31 deletions
|
@ -8,14 +8,19 @@ by Nicholas Carlini, Steve Chien, Milad Nasr, Shuang Song, Andreas Terzis, and F
|
|||
|
||||
### INSTALLING
|
||||
|
||||
You will need to install fairly standard dependencies
|
||||
You will need to install fairly standard dependencies and python 3.11 minimum.
|
||||
|
||||
`pip install scipy, sklearn, numpy, matplotlib`
|
||||
```
|
||||
pip install scipy scikit-learn numpy matplotlib tensorflow tensorflow_datasets
|
||||
|
||||
and also some machine learning framework to train models. We train our models
|
||||
with JAX + ObJAX so you will need to follow build instructions for that
|
||||
https://github.com/google/objax
|
||||
https://objax.readthedocs.io/en/latest/installation_setup.html
|
||||
# This needs to be separate
|
||||
pip install objax
|
||||
|
||||
RELEASE_URL="https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
|
||||
JAX_VERSION=`python3 -c 'import jax; print(jax.__version__)'`
|
||||
pip uninstall -y jaxlib
|
||||
pip install -f $RELEASE_URL jax[cuda]==$JAX_VERSION
|
||||
```
|
||||
|
||||
### RUNNING THE CODE
|
||||
|
||||
|
|
|
@ -11,20 +11,41 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
mkdir -p logs
|
||||
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15
|
||||
SECONDS=0
|
||||
|
||||
echo '======== 1 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 2>&1 | tee logs/log_0
|
||||
echo '======== 2 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 2>&1 | tee logs/log_1
|
||||
echo '======== 3 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 2>&1 | tee logs/log_2
|
||||
echo '======== 4 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 2>&1 | tee logs/log_3
|
||||
echo '======== 5 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 2>&1 | tee logs/log_4
|
||||
echo '======== 6 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 2>&1 | tee logs/log_5
|
||||
echo '======== 7 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 2>&1 | tee logs/log_6
|
||||
echo '======== 8 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 2>&1 | tee logs/log_7
|
||||
echo '======== 9 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 2>&1 | tee logs/log_8
|
||||
echo '======== 10 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 2>&1 | tee logs/log_9
|
||||
echo '======== 11 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 2>&1 | tee logs/log_10
|
||||
echo '======== 12 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 2>&1 | tee logs/log_11
|
||||
echo '======== 13 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 2>&1 | tee logs/log_12
|
||||
echo '======== 14 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 2>&1 | tee logs/log_13
|
||||
echo '======== 15 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 2>&1 | tee logs/log_14
|
||||
echo '======== 16 of 16 ========'
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 2>&1 | tee logs/log_15
|
||||
|
||||
echo "COMPLETE: Took ${SECONDS} seconds"
|
||||
|
|
|
@ -11,22 +11,25 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
mkdir -p logs
|
||||
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 0 --logdir exp/cifar10 &> logs/log_0 &
|
||||
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 1 --logdir exp/cifar10 &> logs/log_1 &
|
||||
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 2 --logdir exp/cifar10 &> logs/log_2 &
|
||||
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 3 --logdir exp/cifar10 &> logs/log_3 &
|
||||
CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 &
|
||||
CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 &
|
||||
CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 &
|
||||
CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 &
|
||||
wait;
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 4 --logdir exp/cifar10 &> logs/log_4 &
|
||||
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 5 --logdir exp/cifar10 &> logs/log_5 &
|
||||
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 6 --logdir exp/cifar10 &> logs/log_6 &
|
||||
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 7 --logdir exp/cifar10 &> logs/log_7 &
|
||||
wait;
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 8 --logdir exp/cifar10 &> logs/log_8 &
|
||||
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 9 --logdir exp/cifar10 &> logs/log_9 &
|
||||
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 10 --logdir exp/cifar10 &> logs/log_10 &
|
||||
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 11 --logdir exp/cifar10 &> logs/log_11 &
|
||||
CUDA_VISIBLE_DEVICES='4' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 &
|
||||
CUDA_VISIBLE_DEVICES='5' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 &
|
||||
CUDA_VISIBLE_DEVICES='6' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 &
|
||||
CUDA_VISIBLE_DEVICES='7' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 &
|
||||
wait;
|
||||
CUDA_VISIBLE_DEVICES='0' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 12 --logdir exp/cifar10 &> logs/log_12 &
|
||||
CUDA_VISIBLE_DEVICES='1' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 13 --logdir exp/cifar10 &> logs/log_13 &
|
||||
CUDA_VISIBLE_DEVICES='2' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 14 --logdir exp/cifar10 &> logs/log_14 &
|
||||
CUDA_VISIBLE_DEVICES='3' python3 -u train.py --dataset=cifar10 --epochs=100 --save_steps=20 --arch wrn28-2 --num_experiments 16 --expid 15 --logdir exp/cifar10 &> logs/log_15 &
|
||||
wait;
|
||||
|
|
|
@ -66,7 +66,9 @@ class TrainLoop(objax.Module):
|
|||
for k, v in kv.items():
|
||||
if jn.isnan(v):
|
||||
raise ValueError('NaN, try reducing learning rate', k)
|
||||
if summary is not None:
|
||||
if summary is not None and v.ndim == 1:
|
||||
summary.scalar(k, float(v[0]))
|
||||
elif summary is not None:
|
||||
summary.scalar(k, float(v))
|
||||
|
||||
def train(self, num_train_epochs: int, train_size: int, train: DataSet, test: DataSet, logdir: str, save_steps=100, patience=None):
|
||||
|
|
Loading…
Reference in a new issue