
Nd4j.ScatterUpdates has a large overhead

ebeaufay opened this issue · 0 comments

Issue Description

My use-case is a type of EmbeddingLayer where the dictionary size is small (<1M).
I'm implementing a HashGridEncoding layer for training NERFs

Nd4j.ScatterUpdates is slow. In fact, even slower than a naive CPU implementation.

Nd4j.ScatterUpdates seems to run in constant time relative to the size of the dictionary and seems to run in linear time relative to the number of updates.

This is problematic for an EmbeddingLayer with a small Dictionary size and a large batch size

expected behavior

Nd4j.ScatterUpdates is quick with minimal overhead.
I would expect it to be on par with Nd4j.pullRows which is a lot faster (but also a bottleneck in the forward pass).

encountered behavior

Nd4j.ScatterUpdates runs almost in constant time relative to the size of the array to update but it has a large overhead.

until around 1M "weights", this:

    public static void doScatterUpdate(INDArray weights, INDArray indices, INDArray updates){
        Nd4j.scatterUpdate(ScatterUpdate.UpdateOp.ADD, weights, indices, updates, 1);

is slower than this

    public static void doScatterUpdateCPUWorkaround(INDArray weights, int[] indexes, INDArray updates){
        float[][] tempWeights = new float[dictionarySize[0]][dictionarySize[1]];
        float[][] tempUpdates = new float[dictionarySize[0]][dictionarySize[1]];
        updates = updates.dup('f');

        for (int j = 0; j < dictionarySize[1]; j++) {
            INDArray column = updates.getColumn(j);
            tempUpdates[j] =;

        for (int i = 0; i < indexes.length; i++) {
            for (int j = 0; j < dictionarySize[1]; j++) {

        INDArray reshape = Nd4j.create(tempWeights);

by a factor of up to 50x (see the test class below)

The CPU workaround doesn't scale.
Nd4j.scatterUpdates runs in linear time relative to the number of updates but even for a smaller amount of updates, it takes an abnormal about of time, perhaps some optimization is possible.

Since Nd4j.pullRows does essentially the same job, it might be worth comparing. It's much faster.

Version Information


OS: windows
Nvidia RTX 3060 laptop GPU (driver:
cudnn is installed too if it's used at all

Test class

public class test {

    static int[] dictionarySize = {10000,2};
    static int indexesSize = 1000000;
    public static void main(String[] args) {

        INDArray indices = Nd4j.create(indexesSize).assign(1).castTo(DataType.INT32);
        int[] indexes =;

        INDArray weights = Nd4j.create(dictionarySize).assign(6);
        INDArray updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
        doScatterUpdate(weights, indices, updates);
        assert weights.getFloat(1,0) == 40006;

        weights = Nd4j.create(dictionarySize).assign(6);
        updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
        doScatterUpdateCPUWorkaround(weights, indexes, updates);
        assert weights.getFloat(1,0) == 40006;

        // warmup
        for (int i = 0; i < 10; i++) {
            weights = Nd4j.create(dictionarySize).assign(6);
            updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
            doScatterUpdate(weights, indices, updates);

            weights = Nd4j.create(dictionarySize).assign(6);
            updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
            doScatterUpdateCPUWorkaround(weights, indexes, updates);

        long start = System.currentTimeMillis();
        for (int i = 0; i < 10; i++) {
            weights = Nd4j.create(dictionarySize).assign(6);
            updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
            doScatterUpdate(weights, indices, updates);
        System.out.println("scatterUpdates Nd4j : "+(System.currentTimeMillis()-start)+ " ms");

        start = System.currentTimeMillis();
        for (int i = 0; i < 10; i++) {
            weights = Nd4j.create(dictionarySize).assign(6);
            updates = Nd4j.create(indexesSize, weights.size(1)).assign(4);
            doScatterUpdateCPUWorkaround(weights, indexes, updates);
        System.out.println("scatterUpdates CPU : "+(System.currentTimeMillis()-start)+ " ms");

    public static void doScatterUpdate(INDArray weights, INDArray indices, INDArray updates){
        Nd4j.scatterUpdate(ScatterUpdate.UpdateOp.ADD, weights, indices, updates, 1);

    public static void doScatterUpdateCPUWorkaround(INDArray weights, int[] indexes, INDArray updates){
        float[][] tempWeights = new float[dictionarySize[0]][dictionarySize[1]];
        float[][] tempUpdates = new float[dictionarySize[0]][dictionarySize[1]];
        updates = updates.dup('f');

        for (int j = 0; j < dictionarySize[1]; j++) {
            INDArray column = updates.getColumn(j);
            tempUpdates[j] =;

        for (int i = 0; i < indexes.length; i++) {
            for (int j = 0; j < dictionarySize[1]; j++) {

        INDArray reshape = Nd4j.create(tempWeights);


I have little experience with C++ so it's a risk to let me contribute on that although if you tell me there's no demand and no time to improve this, I will have a look.