Skip to content

Commit

Permalink
Consistency: project_all() -> normalize_all()
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Aug 29, 2022
1 parent 794cd23 commit 0d3c208
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion examples/se3_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def step(self: State) -> Tuple[jnp.ndarray, State]:
new_params = optax.apply_updates(self.params, updates) # type: ignore

# Project back to manifold.
new_params = jaxlie.manifold.project_all(new_params)
new_params = jaxlie.manifold.normalize_all(new_params)

elif self.algorithm == "exponential_coordinates":
# If we parameterize with exponential coordinates, we can
Expand Down
4 changes: 2 additions & 2 deletions jaxlie/manifold/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ._backprop import grad, value_and_grad, zero_tangents
from ._deltas import rminus, rplus, rplus_jacobian_parameters_wrt_delta
from ._tree_utils import project_all
from ._tree_utils import normalize_all

__all__ = [
"grad",
Expand All @@ -9,5 +9,5 @@
"rminus",
"rplus",
"rplus_jacobian_parameters_wrt_delta",
"project_all",
"normalize_all",
]
2 changes: 1 addition & 1 deletion jaxlie/manifold/_tree_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _map_group_trees(
PytreeType = TypeVar("PytreeType")


def project_all(pytree: PytreeType) -> PytreeType:
def normalize_all(pytree: PytreeType) -> PytreeType:
"""Call `.normalize()` on each Lie group instance in a pytree.
Results in a naive projection of each group instance to its respective manifold.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_manifold.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def test_rminus_auto_vmap():
assert_arrays_close(deltas[0], -deltas[1])


def test_project():
def test_normalize():
container = {"key": (jaxlie.SO3(jnp.array([2.0, 0.0, 0.0, 0.0])),)}
container_valid = {"key": (jaxlie.SO3(jnp.array([1.0, 0.0, 0.0, 0.0])),)}
with pytest.raises(AssertionError):
assert_transforms_close(container["key"][0], container_valid["key"][0])
assert_transforms_close(
jaxlie.manifold.project_all(container)["key"][0], container_valid["key"][0]
jaxlie.manifold.normalize_all(container)["key"][0], container_valid["key"][0]
)

0 comments on commit 0d3c208

Please sign in to comment.