Fast (Differentiable) Soft DTW for PyTorch using CUDA

Dynamic time warping (DTW) is a dynamic programming algorithm which aims to find the dissimilarity between two time-series. This algorithm was originally applied towards speech recognition.

In ICML 2017, Marco Cuturi and Mathieu Blondel proposed a differentiable formulation of this algorithm that’s very helpful in optimization problems involving temporal sequences. They call this differentiable formulation soft DTW and I’ve been using this algorithm quite extensively for a project that I’m working on.

My primary deep learning framework is PyTorch, and although multiple implementations exist already (e.g. this or this), they were a bit slow for my use-case, and couldn’t do as many experiments as I wanted due to speed constraints. Considering that soft DTW is very similar to the original DTW, and many efficient implementations exist for it already, I set out to come up with my own implementation that was faster than the existing ones. Naturally, a CUDA implementation was the first thing that I thought of.

One obvious approach of parallelizing DTW computations is to compute multiple DTW(x, y) queries in parallel, but that’s a very obvious approach 😀 Interestingly, there already exist related work in the literature that demonstrate the parallelization of the seemingly sequential work flow of dynamic programming. Without going too much into the details, the basic idea is to parallelize the dynamic programming computation across multiple threads, by processing the diagonals of the cost matrix in parallel, via multiple threads. This, coupled by computing multiple DTW(x, y) queries in parallel yields some serious speed ups.

Today, I decided to publicly release my implementation, in the hope of helping others who may be interested in using this algorithm in their PyTorch projects. The project is available at GitHub.

1 comment

  1. Thank you!
    Will be nice to try it as a loss for various TTS related things (i.e. an alternative loss for a vocoder or for a FastSpeech like network).

Leave a Reply

Your email address will not be published.