Mammoth

An Extendible (General) Continual Learning Framework based on Pytorch.

seq_mnist

seq_cifar10

seq_tinyimg

perm_mnist

rot_mnist

mnist360

Setup

  • Use ./utils/main.py to run experiments.
  • Use argument --load_best_args to use the best hyperparameters from the paper.
  • New models can be added to the models/ folder.
  • New datasets can be added to the datasets/ folder.

Models

  • Gradient Episodic Memory (GEM)
  • A-GEM
  • A-GEM with Reservoir (A-GEM-R)
  • Experience Replay (ER)
  • Meta-Experience Replay (MER)
  • Function Distance Regularization (FDR)
  • Greedy gradient-based Sample Selection (GSS)
  • Hindsight Anchor Learning (HAL)
  • Incremental Classifier and Representation Learning (iCaRL)
  • online Elastic Weight Consolidation (oEWC)
  • Synaptic Intelligence
  • Learning without Forgetting
  • Progressive Neural Networks
  • Dark Experience Replay (DER)
  • Dark Experience Replay++ (DER++)

Datasets

Class-Il / Task-IL settings

  • Sequential MNIST
  • Sequential CIFAR-10
  • Sequential Tiny ImageNet

Domain-IL settings

  • Permuted MNIST
  • Rotated MNIST

General Continual Learning setting

  • MNIST-360

Results

Continual Learning Results
Buffer Method S-CIFAR-10 S-Tiny-ImageNet P-MNIST R-MNIST S-MNIST
Class-IL Task-IL Class-IL Task-IL Domain-IL Domain-IL Class-IL Task-IL
-JOINT92.2098.3159.9982.0494.3395.7695.5799.51
SGD19.6261.027.9218.3140.7067.6619.6094.94
oEWC19.4968.297.5819.2075.7977.3520.4698.39
SI19.4868.056.5836.3265.8671.9119.2796.00
LwF19.6163.298.4615.85--19.6294.11
PNN-95.13-67.84---99.23
200ER44.7991.198.4938.1772.3785.0180.4397.86
MER------81.4798.05
GEM25.5490.44--66.9380.8080.1197.78
A-GEM20.0483.888.0722.7766.4281.9145.7298.61
iCaRL49.0288.997.5328.19--70.5198.28
FDR30.9191.018.7040.3674.7785.2279.4397.66
GSS39.0788.80--63.7279.5038.9095.02
HAL32.3682.51--74.1584.0284.7097.96
DER61.9391.4011.8740.2281.7490.0484.5598.80
DER++64.8891.9210.9640.8783.5890.4385.6198.76
500ER57.7493.619.9948.6480.6088.9186.1299.04
MER------88.3598.43
GEM26.2092.16--76.8881.1585.9998.71
A-GEM22.6789.488.0625.3367.5680.3146.6698.93
iCaRL47.5588.229.3831.55--70.1098.32
FDR28.7193.2910.5449.8883.1889.6785.8797.54
GSS49.7391.02--76.0081.5849.7697.71
HAL41.7984.54--80.1385.0087.2198.03
DER70.5193.4017.7551.7887.2992.2490.5498.84
DER++72.7093.8819.3851.9188.2192.7791.0098.94
5120ER82.4796.9827.4067.2989.9093.4593.4099.33
MER------94.5799.27
GEM25.2695.55--87.4288.5795.1199.44
A-GEM21.9990.107.9626.2273.3280.1854.2498.93
iCaRL55.0792.2314.0840.83--70.6098.32
FDR19.7094.3228.9768.0190.8794.1987.4797.79
GSS67.2794.19--82.2285.2489.3998.33
HAL59.1288.51--89.2091.1789.5298.35
DER83.8195.4336.7369.5091.6694.1494.9099.29
DER++85.2496.1239.0269.8492.2694.6595.3099.47
MNIST-360 - General Continual Learning
JOINT SGD Buffer ER MER A-GEM-R GSS DER DER++
200 49.27 48.58 28.34 43.9255.2254.16
82.98 19.09 500 65.04 62.21 28.13 54.45 69.11 69.62
1000 75.18 70.91 29.21 63.84 75.97 76.03

GitHub