Skip to content

Commit 68ff2c1

Browse files
liruilong940607rtabriziRuilong Li
authored
Abstracting out the GS heuristics into a Strategy class (#278)
* [DRAFT] gsplat refactor (#269) * initial commit, rough design of modules * moved strategy.py into gsplat dir * latest commit for 1:1 meeting * removed ops, losses, and strategies now inherit from an abstract 'Strategy' class * GSs init handled in private method * format * class Gaussian3D * WIP: _strategy * simple_trainer no longer uses Gaussians object and instead uses ParameterDict with newly-added 'activations' dict * simple_trainer no longer uses Gaussians object and instead uses ParameterDict with newly-added 'activations' dict * created mcmc and revised strategy, simplified training scripts * update _strategy * minor * implemented ruilongs' suggested changes, no longer using separate class for revised ADS heuristics * added docstrings, removed overriding activations attribute, fixed device argument passing * mcmc and default are all trainable * cleanup * minor * minor * minor * minor * hardcode activations into strategy * minor cleanup * minor * seperate the strategy in to individual files * change optimizers from List to Dict * switch to dataclass * stateless strategy * minor cleanup * revised_opacity * _update_param_with_optimizer * individual ops * means3d -> means * starting docs * docs * docs * docs * docs * docs * minor fix to pass test * move BINOMS to strategy state * nits * nit fix * viewcode * relocation func cleanup * nits in docs * cleanup imports * minor cleanup * support screen size pruning / splitting --------- Co-authored-by: Ryan Tabrizi <[email protected]> Co-authored-by: Ruilong Li <[email protected]> Co-authored-by: rtabrizi <[email protected]>
1 parent e1de4c3 commit 68ff2c1

File tree

16 files changed

+1167
-427
lines changed

16 files changed

+1167
-427
lines changed

docs/source/apis/strategy.rst

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
Densification Strategy
2+
===================================
3+
4+
.. currentmodule:: gsplat
5+
6+
In `gsplat`, we abstract out the densification and pruning process of the Gaussian
7+
training into a strategy. A strategy is a class that defines how the Gaussian parameters
8+
(along with their optimizers) should be updated (splitting, pruning, etc.) during the training.
9+
10+
An example of the training workflow using :class:`DefaultStrategy` is like:
11+
12+
.. code-block:: python
13+
14+
from gsplat import DefaultStrategy, rasterization
15+
16+
# Define Gaussian parameters and optimizers
17+
params: Dict[str, torch.nn.Parameter] | torch.nn.ParameterDict = ...
18+
optimizers: Dict[str, torch.optim.Optimizer] = ...
19+
20+
# Initialize the strategy
21+
strategy = DefaultStrategy()
22+
23+
# Check the sanity of the parameters and optimizers
24+
strategy.check_sanity(params, optimizers)
25+
26+
# Initialize the strategy state
27+
strategy_state = strategy.initialize_state()
28+
29+
# Training loop
30+
for step in range(1000):
31+
# Forward pass
32+
render_image, render_alpha, info = rasterization(...)
33+
34+
# Pre-backward step
35+
strategy.step_pre_backward(params, optimizers, strategy_state, step, info)
36+
37+
# Compute the loss
38+
loss = ...
39+
40+
# Backward pass
41+
loss.backward()
42+
43+
# Post-backward step
44+
strategy.step_post_backward(params, optimizers, strategy_state, step, info)
45+
46+
A strategy will inplacely update the Gaussian parameters as well as the optimizers,
47+
so it has a specific expectation on the format of the parameters and the optimizers.
48+
It is designed to work with the Guassians defined as either a Dict of
49+
`torch.nn.Parameter <https://pytorch.org/docs/stable/generated/torch.nn.parameter.Parameter.html>`_
50+
or a
51+
`torch.nn.ParameterDict <https://pytorch.org/docs/stable/generated/torch.nn.ParameterDict.html>`_
52+
with at least the following keys: {"means", "scales", "quats", "opacities"}. On top of these attributes,
53+
an arbitrary number of extra attributes are supported. Besides the parameters, it also
54+
expects a Dict of `torch.optim.Optimizer <https://pytorch.org/docs/stable/optim.html>`_
55+
with the same keys as the parameters, and each optimizer should correspond to only
56+
one learnable parameter.
57+
58+
For example, the following is a valid format for the parameters and the optimizers
59+
that can be used with our strategies:
60+
61+
.. code-block:: python
62+
63+
N = 100
64+
params = torch.nn.ParameterDict{
65+
"means": Tensor(N, 3), "scales": Tensor(N), "quats": Tensor(N, 4), "opacities": Tensor(N),
66+
"colors": Tensor(N, 25, 3), "features1": Tensor(N, 128), "features2": Tensor(N, 64),
67+
}
68+
optimizers = {k: torch.optim.Adam([p], lr=1e-3) for k, p in params.keys()}
69+
70+
Below are the strategies that are currently implemented in `gsplat`:
71+
72+
.. autoclass:: DefaultStrategy
73+
:members:
74+
75+
.. autoclass:: MCMCStrategy
76+
:members:

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"sphinx.ext.intersphinx",
3030
"sphinxcontrib.bibtex",
3131
"sphinxcontrib.video",
32+
"sphinx.ext.viewcode",
3233
]
3334

3435
intersphinx_mapping = {

0 commit comments

Comments
 (0)