-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gradient of SO3.log() gives NaN when w=0 #9
Comments
Thanks! Yeah, that looks like a clear oversight on my end. If you make a PR I'd be happy to merge it, otherwise I can make the fix+add to tests later this week. |
Ending2015a
added a commit
to Ending2015a/jaxlie
that referenced
this issue
Mar 7, 2023
@brentyi |
brentyi
added a commit
that referenced
this issue
Mar 10, 2023
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello, thank you very much for this amazing library.
I found that there is an NaN issue occurred in line 381 when calculating the gradient of SO3.log().
jaxlie/jaxlie/_so3.py
Lines 379 to 387 in ad93513
The following is a small example to cause NaN:
I think the reason might be that
jnp.where
actually does not block any unsafe gradients (e.g. x/0) as described in the official FAQ. And this also appears in line 381 when the rotation angle approachespi
and-pi
, thew
in line 381 will be0
and thus cause the bad gradient. To fix this issue, I suggest adding asafe_w=1.0
ifuse_taylor
isFalse
before calculating theatan_factor
:The text was updated successfully, but these errors were encountered: