This is a minimalistic refactoring of the original 3D Gaussian splatting codebase that follows PyTorch conventions and allow for easy customization and extension, based on the original 3DGS official repository.
It is meant for researchers who want to experiment with 3D Gaussian splatting and need a clean and easy-to-understand codebase to start from.
Assuming you have CUDA SDK installed, you can view or directly run install.sh
to install the required dependencies and compile the CUDA kernels.
example.py
shows how to train a 3DGS model using the original training pipeline.
To customize the pipeline GaussianModel
can be used just like any other PyTorch model and the training loop can be written from scratch. Below is a minimal example:
import torch
from gs.core.GaussianModel import GaussianModel
from gs.helpers.loss import l1_loss
from gs.io.colmap import load
cameras, pointcloud = load('your_dataset/')
model = GaussianModel.from_point_cloud(pointcloud).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, eps=1e-15)
for i in range(5000):
camera = cameras[i % len(cameras)]
rendered = model.forward(camera)
loss = l1_loss(rendered, camera.image)
loss.backward()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
torch.cuda.empty_cache()
The codebase is structured as follows:
gs/
: The Gaussian splatting module.core/
: Core data structures and functions for rendering 3DGS modelsBaseCamera.py
: Base class that represents a camera used for training 3DGS modelsBasePointCloud.py
: Base class for point clouds used for initializing 3DGS modelsGaussianModel.py
: 3DGS model refactored as a nn.Module. Useforward
with a camera to render the model
io/
: Functions for importing and exporting image and point cloud datacolmap/
: Functions for importing COLMAP reconstructions intoBaseCamera
andBasePointCloud
compliant objects
trainers/
: Training scripts for 3DGS modelsbasic/
: Re-implementations of the original training script
helpers/
: General functions for rendering and training 3DGS models
- .ply import/export
- Live visualization using nerfstudio's
viser
module