Pytorch Deep Learning Template
In this article, we present you a deep learning template based on Pytorch. This template aims to make it easier for you to start a new deep learning computer vision project with PyTorch. The main features are:
- modularity: we split each logic piece into a different python submodule
- data-augmentation: we included imgaug
- ready to go: by using poutyne a Keras-like framework you don't have to write any train loop.
- torchsummary to show a summary of your models
- reduce the learning rate on a plateau
- auto-saving the best model
- experiment tracking with comet
- logging using python logging module
- a playground notebook to quick test/play around
Clone the repo and go inside it. Then, run:
pip install -r requirements.txt
Let's face it, usually data scientists are not software engineers and they usually end up with spaghetti code, most of the time on a big unusable Jupiter-notebook. With this repo, you have proposed a clean example of how your code should be split and modularized to make scalability and sharability possible. In this example, we will try to classify Darth Vader and Luke Skywalker. We have 100 images per class gathered using google images. The dataset is here. You just have to exact it in this folder and run main.py. We are fine-tuning resnet18 and it should be able to reach > 90% accuracy in 5/10 epochs.
The template is inside
. ├── callbacks // here you can create your custom callbacks ├── checkpoint // were we store the trained models ├── data // here we define our dataset │ └── transformation // custom transformation, e.g. resize and data augmentation ├── dataset // the data │ ├── train │ └── val ├── logger.py // were we define our logger ├── losses // custom losses ├── main.py ├── models // here we create our models │ ├── MyCNN.py │ ├── resnet.py │ └── utils.py ├── playground.ipynb // a notebook that can be used to fast experiment with things ├── Project.py // a class that represents the project structure ├── README.md ├── requirements.txt ├── test // you should always perform some basic testing │ └── test_myDataset.py └── utils.py // utilities functions
We strongly encourage to play around with the template
Keep your structure clean and concise
Every deep learning project has at least three mains steps:
- data gathering/processing
One good idea is to store all the paths at an interesting location, e.g. the dataset folder, in a shared class that be accessed by anyone in the folder. You should never hardcode any paths and always define them once and import them. So, if you later change your structure you will only have to modify one file.
If we have a look at
Project.py we can see how we defined the
data_dir and the
checkpoint_dir once for all. We are using the 'new' Path APIs that support different OS out of the box, and also make it easier to join and concatenate paths.
For example, if we want to know the data location we can :
from Project import Project project = Project() print(projct.data_dir) # /foo/baa/…/dataset
data package you can define your own Dataset, as always by subclassing
torch.data.utils.Dataset, exposing transformations and utilities to work with your data.
In our example, we directly used
torchvision but we included a skeleton for a custom
You usually have to do some preprocessing on the data, e.g. resize the images and apply data augmentation. All your transformation should go inside
.data.trasformation. In our template, we included a wrapper for
As you know, you have to create a
Dataloader to feed your data into the model. In the
data.__init__.py file we expose a very simple function
get_dataloaders to automatically configure the train, val and test data loaders using few parameters
Sometimes you may need to define your custom losses, you can include them in the
./losses package. For example
Sometimes you may need to define your custom metrics. For example
We included python logging module. You can import and use it by:
from logger import logger logger.info('print() is for noobs')
All your models go inside
models, in our case, we have a very basic cnn and we override the
resnet18 function to provide a frozen model to finetune.
In our case we kept things simple, all the training and evaluation logic is inside
.main.py where we used poutyne as the main library. We already defined a useful list of callbacks:
- learning rate scheduler
- auto-save of the best model
- early stopping
Usually, this is all you need!
You may need to create custom callbacks, with poutyne is very easy since it support Keras-like API. You custom callbacks should go inside
./callbacks. For example, we have created one to update Comet every epoch.
Track your experiment
We are using comet to automatically track our models' results. This is what comet's board looks like after a few models run.
main.py produces the following output:
We also created different utilities function to plot booth dataset and dataloader. They are in
utils.py. For example, calling
show_dl on our train and val dataset produces the following outputs.
As you can see data-augmentation is correctly applied on the train set
I hope you found some useful information and hopefully it this template will help you on your next amazing project :)
Let me know if you have some ideas/suggestions to improve it.
Thank you for reading
Subscribe to Python Awesome
Get the latest posts delivered right to your inbox