- Implementation using JAX automatic differentiation (no CUDA dependency)
- Equivalent functionality to the original implementation
Run the following command in an environment with the uv project management tool installed:
uv syncRequirements: Minimum 8GB, 16GB+ VRAM recommended
Prepare a reconstructed COLMAP format dataset with the following directory structure:
<colmap dataset>
├─ images
└─ sparse
└─ 0
├─ cameras.bin
├─ images.bin
├─ points3D.bin
└─ project.ini
The T&T+DB COLMAP (650MB) dataset is manageable in size and easy to work with.
Run the following command.
uv run train.py <colmap_dataset_path>For example, to use the train dataset from T&T+DB COLMAP.
uv run train.py ./tandt_db/tandt/trainOn a GeForce RTX 5070 Ti, this takes approximately 30 minutes.
By default, the trained parameter files are saved as ./output/params_final in the repository root. Additionally, optimization parameters and progress images are saved every 500 iterations.
train.py includes several improvement methods. Please refer to the comments at the beginning of the file for details. train_original.py is the original implementation.
If you encounter Out of Memory errors, reduce memory consumption by adjusting the following:
-
Adjust tile chunk processing
- Modify
tile_chanksin./config/default.json - Divides memory-intensive batch processing into smaller chunks
- Smaller values result in faster execution due to sequential processing of divided chunks
- Modify
-
Reduce image resolution (recommended: ~1000x600 pixels)
uv run train.py <colmap_dataset_path> --image_scale 0.7
-
Lower maximum number of Gaussians
- Modify
max_gaussiansin./config/default.json
- Modify
To reduce runtime memory usage (with slower execution), set these environment variables:
XLA_PYTHON_CLIENT_PREALLOCATE=false XLA_PYTHON_CLIENT_ALLOCATOR=platform uv run train.py <colmap_dataset_path>Visualize Gaussian splatting with trained parameters:
uv run viewer_gl.py -f <parameter_directory>Example:
uv run viewer_gl.py -f ./output/params_finalTo visualize using the same rendering logic as training, run:
uv run viewer_jax.py -f ./output/params_finalMouse:
- Left Drag: Rotate
- Right Drag: Pan
- Middle Drag: Roll
- Scroll: Forward/Backward
Arrow Keys:
- Up/Down: Change parameters
- Loads files in the same directory as the specified parameter directory
- Left/Right: Change camera pose
change_view.mp4
change_params.mp4
This implementation differs from the original due to JAX JIT compilation requiring static arrays. As a result:
There's a maximum limit on Gaussians that can be registered per tile Color blending processing is limited to a fixed maximum number of Gaussians When Gaussians are densely packed in certain tiles, some may not be rendered due to the Gaussian count limit, potentially causing artifacts
During training, Gaussians that are not rendered in a particular iteration won't have their parameters updated. However, this is typically compensated by updates from other viewpoint iterations.
Compared to CUDA implementations, this implementation has slower execution speed and higher device memory usage.


