This codebase is being actively maintained, please create and issue if you have issues using it
Code to use Augmented Shapiro Wilks Stopping, as well as code for the paper "Statistically Signifigant Stopping of Neural Network Training"
All data files are included under losses and each folder. The main Augmented Shapiro-Wilk Stopping criterion is implemented in analysis.py, along with several helper functions and wrappers. The other comparison heuristics are also included in analysis.py, along with their wrappers. grapher.py contains all the code for generating the graphs used in the paper, and earlystopping_calculator.py includes code for generating tables and calculating some statistics from the data. hyperparameter_search.py contains all the code used to execute the grid-search on the ASWS method, along with the grid-search for the other heuristics.
If you would like to try our code, just run
pip3 install git+
If you wanted to try to determine the ASWS stopping point of a model, you can do so using the analysis.py file. If at anypoint during model training you wanted to perform the stop criterion test, you can do
from ASWS.analysis import aswt_stopping test_acc =  # for storing model accuracies for i in training_epochs: model.train() test_accuracy = model.evaluate(test_set) test_acc.append(test_accuracy) gamma = 0.5 # fill hyperparameters as desired num_data = 20 slack_prop=0.1 count = 20 if len(test_acc) > count: aswt_stop_criterion = aswt_stopping(test_acc, gamma, count, num_data, slack_prop=slack_prop) if aswt_stop_criterion: print("Stop Training")
and if you already have finished training the model and wanted to determine the ASWS stopping point, you would need a CSV with columns Epoch, Training Loss, Training Acc, Test Loss, Test Acc. You could then use the following example
from ASWS.analysis import get_aswt_stopping_point_of_model, read_file _, _, _, test_acc = read_file("modelaccuracy.csv") gamma = 0.5 # fill hyperparameters as desired num_data = 20 slack_prop=0.1 count = 20 stop_epoch, stop_accuracy = get_aswt_stopping_point_of_model(test_acc, gamma=gamma, num_data=num_data, count=count, slack_prop=slack_prop)
The pytorch-training folder contains the driver file for training each model, along with the model files which contain each network definition. The main.py file can be run out of the box for the models listed in the paper. The model to train is specified via the
--model argument. All learning rate schedulers listed in the paper are available (via
--schedule step etc.) and the ASWS learning rate scheduler is available via
--schedule ASWT . The corresponding ASWS hyperparameters are passed in at the command line (for example
In order to recreate the GoogLeNet ASWT 1 scheduler from the paper, you can use the following command
python3 main.py --model GoogLeNet --schedule ASWT --gamma 0.76 --num_data 19 --slack_prop 0.05 --lr 0.1