README.md
setup.cfg
setup.py
jax/__init__.py
jax/abstract_arrays.py
jax/ad_util.py
jax/api.py
jax/api_util.py
jax/config.py
jax/core.py
jax/custom_derivatives.py
jax/dlpack.py
jax/dtypes.py
jax/flatten_util.py
jax/lax_linalg.py
jax/lax_reference.py
jax/lazy.py
jax/linear_util.py
jax/pprint_util.py
jax/profiler.py
jax/random.py
jax/source_info_util.py
jax/test_util.py
jax/tree_util.py
jax/util.py
jax/version.py
jax.egg-info/PKG-INFO
jax.egg-info/SOURCES.txt
jax.egg-info/dependency_links.txt
jax.egg-info/requires.txt
jax.egg-info/top_level.txt
jax/experimental/__init__.py
jax/experimental/callback.py
jax/experimental/doubledouble.py
jax/experimental/host_callback.py
jax/experimental/jet.py
jax/experimental/loops.py
jax/experimental/ode.py
jax/experimental/optimizers.py
jax/experimental/optix.py
jax/experimental/stax.py
jax/experimental/vectorize.py
jax/experimental/jax2tf/__init__.py
jax/experimental/jax2tf/jax2tf.py
jax/experimental/jax2tf/tests/__init__.py
jax/experimental/jax2tf/tests/control_flow_ops_test.py
jax/experimental/jax2tf/tests/jax2tf_test.py
jax/experimental/jax2tf/tests/primitive_harness.py
jax/experimental/jax2tf/tests/primitives_test.py
jax/experimental/jax2tf/tests/savedmodel_test.py
jax/experimental/jax2tf/tests/stax_test.py
jax/experimental/jax2tf/tests/tf_test_util.py
jax/image/__init__.py
jax/image/scale.py
jax/interpreters/__init__.py
jax/interpreters/ad.py
jax/interpreters/batching.py
jax/interpreters/invertible_ad.py
jax/interpreters/masking.py
jax/interpreters/parallel.py
jax/interpreters/partial_eval.py
jax/interpreters/pxla.py
jax/interpreters/sharded_jit.py
jax/interpreters/xla.py
jax/lax/__init__.py
jax/lax/lax.py
jax/lax/lax_control_flow.py
jax/lax/lax_fft.py
jax/lax/lax_parallel.py
jax/lib/__init__.py
jax/lib/xla_bridge.py
jax/nn/__init__.py
jax/nn/functions.py
jax/nn/initializers.py
jax/numpy/__init__.py
jax/numpy/_util.py
jax/numpy/fft.py
jax/numpy/lax_numpy.py
jax/numpy/linalg.py
jax/numpy/polynomial.py
jax/numpy/vectorize.py
jax/ops/__init__.py
jax/ops/scatter.py
jax/scipy/__init__.py
jax/scipy/linalg.py
jax/scipy/ndimage.py
jax/scipy/signal.py
jax/scipy/special.py
jax/scipy/optimize/__init__.py
jax/scipy/optimize/_bfgs.py
jax/scipy/optimize/_line_search.py
jax/scipy/optimize/_minimize.py
jax/scipy/sparse/__init__.py
jax/scipy/sparse/linalg.py
jax/scipy/stats/__init__.py
jax/scipy/stats/bernoulli.py
jax/scipy/stats/beta.py
jax/scipy/stats/cauchy.py
jax/scipy/stats/dirichlet.py
jax/scipy/stats/expon.py
jax/scipy/stats/gamma.py
jax/scipy/stats/geom.py
jax/scipy/stats/laplace.py
jax/scipy/stats/logistic.py
jax/scipy/stats/multivariate_normal.py
jax/scipy/stats/norm.py
jax/scipy/stats/pareto.py
jax/scipy/stats/poisson.py
jax/scipy/stats/t.py
jax/scipy/stats/uniform.py
jax/third_party/__init__.py
jax/third_party/numpy/__init__.py
jax/third_party/numpy/linalg.py
jax/tools/__init__.py
jax/tools/jax_to_hlo.py