fix: clean up type handling#248
Conversation
Merging this PR will not alter performance
|
| Mode | Benchmark | BASE |
HEAD |
Efficiency | |
|---|---|---|---|---|---|
| ⚡ | WallTime | test_benchmarks_lanczos_interp[xval-conserve_dc-run] |
126.5 µs | 95.8 µs | +32.01% |
| ⚡ | WallTime | test_benchmarks_lanczos_interp[xval-no_conserve_dc-run] |
123.7 µs | 92.6 µs | +33.52% |
| ⚡ | WallTime | test_benchmarks_lanczos_interp[kval-no_conserve_dc-run] |
57.9 µs | 43.7 µs | +32.7% |
| ❌ | WallTime | test_benchmarks_metacal[run] |
20.2 ms | 34.8 ms | -41.84% |
| ❌ | WallTime | test_benchmark_spergel_conv[run] |
169.2 ms | 242.8 ms | -30.3% |
| ⚡ | WallTime | test_benchmark_moffat_init[run] |
103 µs | 58.9 µs | +75.04% |
| ❌ | WallTime | test_benchmark_moffat_conv[run] |
195.3 ms | 285.9 ms | -31.7% |
Tip
Investigate this regression by commenting @codspeedbot fix this regression on this PR, or directly use the CodSpeed MCP with your agent.
Comparing typing-inits-cleanup (1aeb448) with main (6ebfcb8)
…lopers/JAX-GalSim into typing-inits-cleanup
|
@ismael-mendoza This PR is based on top of #243. I plan to merge it after #243. Comments welcome! |
ismael-mendoza
left a comment
There was a problem hiding this comment.
Thanks Matt! I just have a few minor comments/questions and perhaps one small bug in _cast_to_static_numeric_scalar?
|
|
||
| def __mul__(self, other): | ||
| if not (has_tracers(other) or isinstance(other, NON_COMPLEX_TYPES)): | ||
| if isinstance(other, (Angle, AngleUnit)): |
There was a problem hiding this comment.
Could you explain why this is the check here? I guess more specifically, why not ensure that other is scalar instead?
There was a problem hiding this comment.
So the goal in this PR is to let arrays and tracers pass through without trying to detect which is which. We can do that for the most part (except for a few cases like BoundsI). So the goal of the check here is guard against common errors of handling Angles and AngleUnits.
| self.deltax = cast_to_float(self.deltax) | ||
| self.deltay = cast_to_float(self.deltay) | ||
| if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): | ||
| raise TypeError("BoundsI must be initialized with integer values") | ||
| self.deltax = int(cast_to_int(self.deltax)) | ||
| self.deltay = int(cast_to_int(self.deltay)) | ||
| self.deltax = cast_to_int(self.deltax) | ||
| self.deltay = cast_to_int(self.deltay) |
There was a problem hiding this comment.
so if I understand correctly, if someone tries to have non static deltax or deltay they will now get a ConcretizationTypeError, with no error raised from JAX Galsim explicitly. Is that the intention since we warn the user in the LAX docs and we want to avoid using has_tracers?
There was a problem hiding this comment.
Yes that is the intention. I am further cleaning this up in #250.
Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com>
Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com>
Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com>
Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com>
Co-authored-by: Ismael Mendoza <11745764+ismael-mendoza@users.noreply.github.com>
Clarified descriptions of numeric data types in JAX.
Add error handling for photon array size exceeding limit
|
pre-commit.ci autofix |
|
OK @ismael-mendoza this one is ready for another look! |
This PR cleans up the type handling.
The goal is to have objects do one of three things:
intorfloat) or a numpy scalar type (i.e.,isinstance(x, np.number),isinstance(x, np.integer), etc.; see https://numpy.org/doc/stable/reference/arrays.scalars.html#built-in-scalar-types), convert it to a python type of the appropriate kind.jax.numpy.astype(x, ...).This set of rules is simple to understand, consistently handles array scalars (i.e., arrays with zero dimensions) in that any numpy array scalar is converted to a jax array scalar, and transparently handles JAX tracing.
The diff on the PR results in fewer lines of code, which is a nice side effect!
TODO:
try...exceptblocks, explicitly test for typeshas_tracersfunction in favor of transparent handlingcast_to_python_float/cast_to_python_intoutside of FITS handlingcloses #246