seanghay/uvr-mdx-infer

Can I run these onnx model in browser with onnxruntime-web?

Closed this issue · 1 comments

These onnx models are awsome ! I want to run UVR-MDX-NET-Inst_HQ_3.onnx in browser with onnxruntime-web like what transformer.js is doing.

I've try my best to infer in web, but the output always not correct comparing with the output running out with python.

Anyone could help me out?

Here are my code running in web

<!DOCTYPE html>
<html lang="en">
<head>
  <meta charset="UTF-8">
  <meta name="viewport" content="width=device-width, initial-scale=1.0">
  <title>Audio Processing with ONNX</title>
  <script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@latest/dist/ort.min.js"></script>
  <script src="https://cdn.jsdelivr.net/npm/fft.js@4.0.4/lib/fft.min.js"></script>
</head>
<body>
  <h1>Audio Processing with ONNX</h1>
  <input type="file" id="audioInput" accept=".wav">
  <button onclick="processAudio()">Process Audio</button>
  <div id="output"></div>

  <script>
    async function preprocessAudio(file) {
      const audioContext = new (window.AudioContext || window.webkitAudioContext)();
      const arrayBuffer = await file.arrayBuffer();
      const audioBuffer = await audioContext.decodeAudioData(arrayBuffer);


      const channelData = audioBuffer.getChannelData(0);


      const n_fft = 4096;
      const hop_length = 1024;
      const hannWin = hannWindow(n_fft);
      const stftData = stft(channelData, n_fft, hop_length, hannWin);


      const requiredSize = 4 * 3072 * 256;  // 3145728

      const inputTensorData = new Float32Array(requiredSize);

      let index = 0;
      for (let i = 0; i < stftData.length; i++) {
        const complexArray = stftData[i];
        for (let j = 0; j < complexArray.length; j++) {
          if (index < requiredSize) {
            inputTensorData[index++] = complexArray[j];
          }
        }
      }


      return new ort.Tensor('float32', inputTensorData, [1, 4, 3072, 256]);
    }

    function hannWindow(length) {
      const win = new Float32Array(length);
      for (let i = 0; i < length; i++) {
        win[i] = 0.5 * (1 - Math.cos((2 * Math.PI * i) / (length - 1)));
      }
      return win;
    }

    function stft(signal, n_fft, hop_length, win) {
      const stftData = [];
      const fft = new FFT(n_fft);
      for (let i = 0; i < signal.length - n_fft; i += hop_length) {
        const segment = signal.slice(i, i + n_fft);
        const windowed = segment.map((sample, index) => sample * win[index]);
        const out = fft.createComplexArray();
        fft.realTransform(out, windowed);
        fft.completeSpectrum(out);
        stftData.push(out);
      }
      return stftData;
    }

    async function processAudio() {
      const fileInput = document.getElementById('audioInput');
      const file = fileInput.files[0];
      if (!file) {
        alert('Please select a WAV file first.');
        return;
      }

      const inputTensor = await preprocessAudio(file);
      const session = await ort.InferenceSession.create('UVR-MDX-NET-Inst_HQ_3.onnx');
      const inputs = { input: inputTensor };
      const results = await session.run(inputs);
      const outputData = results.output.data;


      const istftData = istft(outputData);


      playAudio(istftData);
    }

    function istft(data) {
      const n_fft = 4096;
      const hop_length = 1024;
      const hannWin = hannWindow(n_fft);
      const numSegments = Math.floor(data.length / (2 * n_fft));
      const ifft = new FFT(n_fft);

      const istftData = new Float32Array(numSegments * hop_length);
      for (let i = 0; i < numSegments; i++) {
        const real = data.slice(i * 2 * n_fft, (i + 1) * 2 * n_fft);
        const imag = data.slice((i + 1) * 2 * n_fft, (i + 2) * 2 * n_fft);
        const complex = new Float32Array(real.length * 2);
        for (let j = 0; j < real.length; j++) {
          complex[j * 2] = real[j];
          complex[j * 2 + 1] = imag[j];
        }
        const out = ifft.createComplexArray();
        ifft.inverseTransform(out, complex);
        const segment = out.map((val, index) => val * hannWin[index]);

        for (let j = 0; j < segment.length; j++) {
          istftData[i * hop_length + j] += segment[j];
        }
      }
      return istftData;
    }

    function playAudio(data) {
      const audioContext = new (window.AudioContext || window.webkitAudioContext)();
      const buffer = audioContext.createBuffer(1, data.length, 44100);
      buffer.copyToChannel(data, 0);
      const source = audioContext.createBufferSource();
      source.buffer = buffer;
      source.connect(audioContext.destination);
      source.start();
    }
  </script>
</body>
</html>```

I had the same idea but I never figured out the solution for STFT/inverse STFT function for JavaScript.