/torchaudio-filters

High-pass and low-pass filters implemented as modules with torchaudio

Primary LanguagePythonGNU General Public License v3.0GPL-3.0

torchaudio-filters

High-pass and low-pass filters implemented as modules with torchaudio.

This small package offers a simple API to implement basic butterworth filters in PyTorch modules. Aims to maintain consistency with the PyTorch API (e.g. behaves similarly to torchaudio.transforms.Spectrogram) and uses torchaudio.functional.filtfilt under the hood.

Output has been verified to generally match the scipy.signal output up to 1e-2 units of precision in testing.

>>> from torch import nn
>>> from torchaudio_filters import LowPass, Pad
>>> sample_rate = 128  # Hz
>>> sample_secs = 10
>>> x = torch.rand(32, 21, sample_rate * sample_secs)  # batch_idx, channel_idx, timestep_idx
# Add pre-processors to your models using Sequential
# * Pad will pad input and unpad output
# * LowPass will filter to the frequencies below a given cutoff
>>> model = nn.Sequential(
...     Pad(
...         module=LowPass(sample_rate, 70)  # Low-pass filter
...         padlen=32,
...     )
...     encoder,
...     decoder,
...     ...,
... )
>>> y_hat = model(x)

Installation

Requires a relatively recent torchaudio to be installed (anything after 2.0.0 should be more than okay), and also needs scipy.

To install the latest stable version from PyPi:

pip install torchaudio-filters

Usage

This package includes the torch.nn.module objects LowPass, HighPass, BandPass, Notch, and Pad. These all assume the final dimension represents time, and apply Butterworth filters in their forward methods.

Let's generate a simple example on a sinusoidal 256 Hz signal with frequencies 5, 10, and 20 Hz:

>>> sample_rate = 256  # Hz
>>> sample_secs = 1.5  # s
>>> t = torch.arange(0, sample_secs, 1 / sample_rate)
>>> freqs = [5, 10, 20]
>>> x = sum(torch.sin(2 * torch.pi * f * t) for f in freqs)

We can design low-pass and high-pass filters and call them as any other module:

>>> lp = LowPass(15, sample_rate)
>>> x_lp = lp(x)
>>> hp = HighPass(15, sample_rate)
>>> x_hp = hp(x)
>>> notch = Notch(7, 12, sample_rate)
>>> x_notch = notch(x)
>>> bp = BandPass(7, 12, sample_rate)
>>> x_bp = bp(x)

Plotting these results:

>>> fig, (ax1, ax2, ax3) = plt.subplots(3, 1, sharex=True, figsize=(9.6, 6))
>>> ax1.plot(t, x)
>>> ax1.set_title('Input signal')
>>> ax2.plot(t, x_lp)
>>> ax2.plot(t, x_hp)
>>> ax2.set_title('LowPass / HighPass')
>>> ax3.plot(t, x_notch)
>>> ax3.plot(t, x_bp)
>>> ax3.set_title('Notch / BandPass')

fig00

To decrease the edge effect, a wrapper module Pad is also provided

>>> filt = torch.nn.Sequential(hp, notch)
>>> filt_pad = Pad(
...     filt,
...     padlen=128,
... )
>>> x_f = filt(x)
>>> x_fpad = filt_pad(x)
>>> fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, figsize=(8, 6))
>>> ax1.plot(t, x)
>>> ax1.set_title('Input signal')
>>> ax2.plot(t, x_f)
>>> ax2.plot(t, x_fpad)
>>> ax2.set_title('Without pad / With pad')
>>> plt.show()

fig01

Contributing

Pull requests are most welcome! If there's any questions, open an issue so we can discuss further.

Dependencies and environments are managed with poetry. To get started developing for this package:

$ git clone https://github.com/BlakeJC94/torchaudio_filters
$ cd torchaudio_filters
$ poetry install
$ pytest

Branch main should be stable, all the latest changes will go onto the dev branch before being released on main.

  • Code is tested with pytest
  • Code is styled using black
  • Code is linted with ruff