Continuous Wavelet Transforms in PyTorch
This is a PyTorch implementation for the wavelet analysis outlined in Torrence and Compo (BAMS, 1998). The code builds upon the excellent implementation of Aaron O'Leary by adding a PyTorch filter bank wrapper to enable fast convolution on the GPU. Specifically, the code was written to speed-up the CWT computation for a large number of 1D signals and relies on torch.nn.Conv1d for convolution.
Citation
If you found this code useful, please cite our paper Repetition Estimation (IJCV, 2019):
@article{runia2019repetition,
title={Repetition estimation},
author={Runia, Tom FH and Snoek, Cees GM and Smeulders, Arnold WM},
journal={International Journal of Computer Vision},
volume={127},
number={9},
pages={1361--1383},
year={2019},
publisher={Springer}
}
Usage
In addition to the PyTorch implementation defined in WaveletTransformTorch
the original SciPy version is also included in WaveletTransform
for completeness. As the GPU implementation highly benefits from parallelization, the cwt
and power
methods expect signal batches of shape [num_signals,signal_length]
instead of individual signals.
import numpy as np
from wavelets_pytorch.transform import WaveletTransform # SciPy version
from wavelets_pytorch.transform import WaveletTransformTorch # PyTorch version
dt = 0.1 # sampling frequency
dj = 0.125 # scale distribution parameter
batch_size = 32 # how many signals to process in parallel
# Batch of signals to process
batch = [batch_size x signal_length]
# Initialize wavelet filter banks (scipy and torch implementation)
wa_scipy = WaveletTransform(dt, dj)
wa_torch = WaveletTransformTorch(dt, dj, cuda=True)
# Performing wavelet transform (and compute scalogram)
cwt_scipy = wa_scipy.cwt(batch)
cwt_torch = wa_torch.cwt(batch)
# For plotting, see the examples/plot.py function.
# ...
Supported Wavelets
The wavelet implementations are taken from here. Default is the Morlet wavelet.
Benchmark
Performing parallel CWT computation on the GPU using PyTorch results in a significant speed-up. Increasing the batch size will give faster runtimes. The plot below shows a comaprison between the scipy
versus torch
implementation as function of the batch size N
and input signal length. These results were obtained on a powerful Linux desktop with NVIDIA Titan X GPU.
Installation
Clone and install:
git clone https://github.com/tomrunia/PyTorchWavelets.git
cd PyTorchWavelets
pip install -r requirements.txt
python setup.py install
Requirements
- Python 2.7 or 3.6 (other versions might also work)
- Numpy (developed with 1.14.1)
- Scipy (developed with 1.0.0)
- PyTorch >= 0.4.0
The core of the PyTorch implementation relies on the torch.nn.Conv1d
module.