Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient of SO3.log() gives NaN when w=0 #9

Closed
Ending2015a opened this issue Mar 7, 2023 · 2 comments
Closed

Gradient of SO3.log() gives NaN when w=0 #9

Ending2015a opened this issue Mar 7, 2023 · 2 comments

Comments

@Ending2015a
Copy link
Contributor

Hello, thank you very much for this amazing library.

I found that there is an NaN issue occurred in line 381 when calculating the gradient of SO3.log().

jaxlie/jaxlie/_so3.py

Lines 379 to 387 in ad93513

atan_factor = jnp.where(
use_taylor,
2.0 / w - 2.0 / 3.0 * norm_sq / w**3,
jnp.where(
jnp.abs(w) < get_epsilon(w.dtype),
jnp.where(w > 0, 1.0, -1.0) * jnp.pi / norm_safe,
2.0 * atan_n_over_w / norm_safe,
),
)

The following is a small example to cause NaN:

a = jnp.array([jnp.pi, 0, 0], dtype=jnp.float32)
def func(x):
  return SO3.exp(x).log().sum()
print(jax.grad(func)(a))  # ===> [nan nan nan]

I think the reason might be that jnp.where actually does not block any unsafe gradients (e.g. x/0) as described in the official FAQ. And this also appears in line 381 when the rotation angle approaches pi and -pi, the w in line 381 will be 0 and thus cause the bad gradient. To fix this issue, I suggest adding a safe_w=1.0 if use_taylor is False before calculating the atan_factor:

safe_w = jnp.where(use_taylor, w, 1.0)
atan_factor = jnp.where(
    use_taylor,
    2.0 / safe_w - 2.0 / 3.0 * norm_sq / safe_w**3,
    jnp.where(
        jnp.abs(w) < get_epsilon(w.dtype),
        jnp.where(w > 0, 1.0, -1.0) * jnp.pi / norm_safe,
        2.0 * atan_n_over_w / norm_safe,
    ),
)
@brentyi
Copy link
Owner

brentyi commented Mar 7, 2023

Thanks!

Yeah, that looks like a clear oversight on my end. If you make a PR I'd be happy to merge it, otherwise I can make the fix+add to tests later this week.

Ending2015a added a commit to Ending2015a/jaxlie that referenced this issue Mar 7, 2023
@Ending2015a
Copy link
Contributor Author

@brentyi
I have opened a PR for it. Thank you.

brentyi added a commit that referenced this issue Mar 10, 2023
* fixed nan #9

* Add test + isort / black

---------

Co-authored-by: Brent Yi <yibrenth@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants