/tiny-cuda-nn-32

A fork of the lightning fast C++/CUDA neural network framework which uses FLOAT32 all the time.

Primary LanguageC++OtherNOASSERTION

Tiny CUDA Neural Networks

THIS IS A FORKED VERSION OF TCNN WHICH USES FLOAT32 ALL THE TIME.

This fork is renamed to tinycudann32 and please use the following command to install:

pip install git+https://github.com/yousiki/tiny-cuda-nn-32/#subdirectory=bindings/torch

It shares the same usage with the original tinycudann except use FLOAT32 instead of FLOAT16. It may lead to lower performance but can avoid resulting in NaN in some cases.

BELOW IS THE ORIGINAL README

This is a small, self-contained framework for training and querying neural networks. Most notably, it contains a lightning fast "fully fused" multi-layer perceptron (technical paper), a versatile multiresolution hash encoding (technical paper), as well as support for various other input encodings, losses, and optimizers.

Performance

Image Fully fused networks vs. TensorFlow v2.5.0 w/ XLA. Measured on 64 (solid line) and 128 (dashed line) neurons wide multi-layer perceptrons on an RTX 3090. Generated by benchmarks/bench_ours.cu and benchmarks/bench_tensorflow.py using data/config_oneblob.json.

Usage

Tiny CUDA neural networks have a simple C++/CUDA API:

#include <tiny-cuda-nn/common.h>

// Configure the model
nlohmann::json config = {
	{"loss", {
		{"otype", "L2"}
	}},
	{"optimizer", {
		{"otype", "Adam"},
		{"learning_rate", 1e-3},
	}},
	{"encoding", {
		{"otype", "HashGrid"},
		{"n_levels", 16},
		{"n_features_per_level", 2},
		{"log2_hashmap_size", 19},
		{"base_resolution", 16},
		{"per_level_scale", 2.0},
	}},
	{"network", {
		{"otype", "FullyFusedMLP"},
		{"activation", "ReLU"},
		{"output_activation", "None"},
		{"n_neurons", 64},
		{"n_hidden_layers", 2},
	}},
};

using namespace tcnn;

auto model = create_from_config(n_input_dims, n_output_dims, config);

// Train the model
GPUMatrix<float> training_batch_inputs(n_input_dims, batch_size);
GPUMatrix<float> training_batch_targets(n_output_dims, batch_size);

for (int i = 0; i < n_training_steps; ++i) {
	generate_training_batch(&training_batch_inputs, &training_batch_targets); // <-- your code

	float loss;
	model.trainer->training_step(training_batch_inputs, training_batch_targets, &loss);
	std::cout << "iteration=" << i << " loss=" << loss << std::endl;
}

// Use the model
GPUMatrix<float> inference_inputs(n_input_dims, batch_size);
generate_inputs(&inference_inputs); // <-- your code

GPUMatrix<float> inference_outputs(n_output_dims, batch_size);
model.network->inference(inference_inputs, inference_outputs);

Example: learning a 2D image

We provide a sample application where an image function (x,y) -> (R,G,B) is learned. It can be run via

tiny-cuda-nn/build$ ./mlp_learning_an_image ../data/images/albert.jpg ../data/config_hash.json

producing an image every 1000 training steps. Each 1000 steps should take roughly 0.42 seconds with the default configuration on an RTX 3090.

10 steps (4.2 ms) 100 steps (42 ms) 1000 steps (420 ms) Reference image
10steps 100steps 1000steps reference

Requirements

  • An NVIDIA GPU; tensor cores increase performance when available. All shown results come from an RTX 3090.
  • A C++14 capable compiler. The following choices are recommended and have been tested:
    • Windows: Visual Studio 2019
    • Linux: GCC/G++ 7.5 or higher
  • CUDA v10.2 or higher and CMake v3.21 or higher.
  • The fully fused MLP component of this framework requires a very large amount of shared memory in its default configuration. It will likely only work on an RTX 3090, an RTX 2080 Ti, or high-end enterprise GPUs. Lower end cards must reduce the n_neurons parameter or use the CutlassMLP (better compatibility but slower) instead.

If you are using Linux, install the following packages

sudo apt-get install build-essential git

We also recommend installing CUDA in /usr/local/ and adding the CUDA installation to your PATH. For example, if you have CUDA 11.4, add the following to your ~/.bashrc

export PATH="/usr/local/cuda-11.4/bin:$PATH"
export LD_LIBRARY_PATH="/usr/local/cuda-11.4/lib64:$LD_LIBRARY_PATH"

Compilation (Windows & Linux)

Begin by cloning this repository and all its submodules using the following command:

$ git clone --recursive https://github.com/nvlabs/tiny-cuda-nn
$ cd tiny-cuda-nn

Then, use CMake to build the project: (on Windows, this must be in a developer command prompt)

tiny-cuda-nn$ cmake . -B build
tiny-cuda-nn$ cmake --build build --config RelWithDebInfo -j

PyTorch extension

tiny-cuda-nn comes with a PyTorch extension that allows using the fast MLPs and input encodings from within a Python context. These bindings can be significantly faster than full Python implementations; in particular for the multiresolution hash encoding.

The overheads of Python/PyTorch can nonetheless be extensive. For example, the bundled mlp_learning_an_image example is ~2x slower through PyTorch than native CUDA.

Begin by setting up a Python 3.X environment with a recent, CUDA-enabled version of PyTorch. Then, invoke

pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch

Alternatively, if you would like to install from a local clone of tiny-cuda-nn, invoke

tiny-cuda-nn$ cd bindings/torch
tiny-cuda-nn/bindings/torch$ python setup.py install

Upon success, you can use tiny-cuda-nn models as in the following example:

import commentjson as json
import tinycudann as tcnn
import torch

with open("data/config_hash.json") as f:
	config = json.load(f)

# Option 1: efficient Encoding+Network combo.
model = tcnn.NetworkWithInputEncoding(
	n_input_dims, n_output_dims,
	config["encoding"], config["network"]
)

# Option 2: separate modules. Slower but more flexible.
encoding = tcnn.Encoding(n_input_dims, config["encoding"])
network = tcnn.Network(encoding.n_output_dims, n_output_dims, config["network"])
model = torch.nn.Sequential(encoding, network)

See samples/mlp_learning_an_image_pytorch.py for an example.

Components

Following is a summary of the components of this framework. The JSON documentation lists configuration options.

Networks    
Fully fused MLP src/fully_fused_mlp.cu Lightning fast implementation of small multi-layer perceptrons (MLPs).
CUTLASS MLP src/cutlass_mlp.cu MLP based on CUTLASS' GEMM routines. Slower than fully-fused, but handles larger networks and still is reasonably fast.
Input encodings    
Composite include/tiny-cuda-nn/encodings/composite.h Allows composing multiple encodings. Can be, for example, used to assemble the Neural Radiance Caching encoding [Müller et al. 2021].
Frequency include/tiny-cuda-nn/encodings/frequency.h NeRF's [Mildenhall et al. 2020] positional encoding applied equally to all dimensions.
Grid include/tiny-cuda-nn/encodings/grid.h Encoding based on trainable multiresolution grids. Used for Instant Neural Graphics Primitives [Müller et al. 2022]. The grids can be backed by hashtables, dense storage, or tiled storage.
Identity include/tiny-cuda-nn/encodings/identity.h Leaves values untouched.
Oneblob include/tiny-cuda-nn/encodings/oneblob.h From Neural Importance Sampling [Müller et al. 2019] and Neural Control Variates [Müller et al. 2020].
SphericalHarmonics include/tiny-cuda-nn/encodings/spherical_harmonics.h A frequency-space encoding that is more suitable to direction vectors than component-wise ones.
TriangleWave include/tiny-cuda-nn/encodings/triangle_wave.h Low-cost alternative to the NeRF's encoding. Used in Neural Radiance Caching [Müller et al. 2021].
Losses    
L1 include/tiny-cuda-nn/losses/l1.h Standard L1 loss.
Relative L1 include/tiny-cuda-nn/losses/l1.h Relative L1 loss normalized by the network prediction.
MAPE include/tiny-cuda-nn/losses/mape.h Mean absolute percentage error (MAPE). The same as Relative L1, but normalized by the target.
SMAPE include/tiny-cuda-nn/losses/smape.h Symmetric mean absolute percentage error (SMAPE). The same as Relative L1, but normalized by the mean of the prediction and the target.
L2 include/tiny-cuda-nn/losses/l2.h Standard L2 loss.
Relative L2 include/tiny-cuda-nn/losses/relative_l2.h Relative L2 loss normalized by the network prediction [Lehtinen et al. 2018].
Relative L2 Luminance include/tiny-cuda-nn/losses/relative_l2_luminance.h Same as above, but normalized by the luminance of the network prediction. Only applicable when network prediction is RGB. Used in Neural Radiance Caching [Müller et al. 2021].
Cross Entropy include/tiny-cuda-nn/losses/cross_entropy.h Standard cross entropy loss. Only applicable when the network prediction is a PDF.
Variance include/tiny-cuda-nn/losses/variance_is.h Standard variance loss. Only applicable when the network prediction is a PDF.
Optimizers    
Adam include/tiny-cuda-nn/optimizers/adam.h Implementation of Adam [Kingma and Ba 2014], generalized to AdaBound [Luo et al. 2019].
Novograd include/tiny-cuda-nn/optimizers/lookahead.h Implementation of Novograd [Ginsburg et al. 2019].
SGD include/tiny-cuda-nn/optimizers/sgd.h Standard stochastic gradient descent (SGD).
Shampoo include/tiny-cuda-nn/optimizers/shampoo.h Implementation of the 2nd order Shampoo optimizer [Gupta et al. 2018] with home-grown optimizations as well as those by Anil et al. [2020].
Average include/tiny-cuda-nn/optimizers/average.h Wraps another optimizer and computes a linear average of the weights over the last N iterations. The average is used for inference only (does not feed back into training).
Batched include/tiny-cuda-nn/optimizers/batched.h Wraps another optimizer, invoking the nested optimizer once every N steps on the averaged gradient. Has the same effect as increasing the batch size but requires only a constant amount of memory.
EMA include/tiny-cuda-nn/optimizers/average.h Wraps another optimizer and computes an exponential moving average of the weights. The average is used for inference only (does not feed back into training).
Exponential Decay include/tiny-cuda-nn/optimizers/exponential_decay.h Wraps another optimizer and performs piecewise-constant exponential learning-rate decay.
Lookahead include/tiny-cuda-nn/optimizers/lookahead.h Wraps another optimizer, implementing the lookahead algorithm [Zhang et al. 2019].

License and Citation

This framework is licensed under the BSD 3-clause license. Please see LICENSE.txt for details.

If you use it in your research, we would appreciate a citation via

@misc{tiny-cuda-nn,
    Author = {Thomas M\"uller},
    Year = {2021},
    Note = {https://github.com/nvlabs/tiny-cuda-nn},
    Title = {Tiny {CUDA} Neural Network Framework}
}

For business inquiries, please visit our website and submit the form: NVIDIA Research Licensing

Publications

This framework powers the following publications:

Instant Neural Graphics Primitives with a Multiresolution Hash Encoding
Thomas Müller, Alex Evans, Christoph Schied, Alexander Keller
ACM Transactions on Graphics (SIGGRAPH), July 2022
Website / Paper / Code / Video / BibTeX

Extracting Triangular 3D Models, Materials, and Lighting From Images
Jacob Munkberg, Jon Hasselgren, Tianchang Shen, Jun Gao, Wenzheng Chen, Alex Evans, Thomas Müller, Sanja Fidler
CVPR (Oral), June 2022
Website / Paper / Video / BibTeX

Real-time Neural Radiance Caching for Path Tracing
Thomas Müller, Fabrice Rousselle, Jan Novák, Alexander Keller
ACM Transactions on Graphics (SIGGRAPH), August 2021
Paper / GTC talk / Video / Interactive results viewer / BibTeX

Acknowledgments

Special thanks go to the NRC authors for helpful discussions and to Nikolaus Binder for providing part of the infrastructure of this framework, as well as for help with utilizing TensorCores from within CUDA.