README.md
setup.cfg
setup.py
jax/__init__.py
jax/abstract_arrays.py
jax/ad_checkpoint.py
jax/api_util.py
jax/cloud_tpu_init.py
jax/config.py
jax/core.py
jax/custom_derivatives.py
jax/dlpack.py
jax/dtypes.py
jax/errors.py
jax/flatten_util.py
jax/jaxpr_util.py
jax/linear_util.py
jax/prng.py
jax/profiler.py
jax/py.typed
jax/random.py
jax/test_util.py
jax/tree_util.py
jax/util.py
jax/version.py
jax.egg-info/PKG-INFO
jax.egg-info/SOURCES.txt
jax.egg-info/dependency_links.txt
jax.egg-info/not-zip-safe
jax.egg-info/requires.txt
jax.egg-info/top_level.txt
jax/_src/__init__.py
jax/_src/abstract_arrays.py
jax/_src/ad_checkpoint.py
jax/_src/ad_util.py
jax/_src/api.py
jax/_src/api_util.py
jax/_src/cloud_tpu_init.py
jax/_src/config.py
jax/_src/custom_derivatives.py
jax/_src/dlpack.py
jax/_src/dtypes.py
jax/_src/errors.py
jax/_src/flatten_util.py
jax/_src/lax_reference.py
jax/_src/pretty_printer.py
jax/_src/prng.py
jax/_src/profiler.py
jax/_src/random.py
jax/_src/source_info_util.py
jax/_src/test_util.py
jax/_src/traceback_util.py
jax/_src/tree_util.py
jax/_src/util.py
jax/_src/image/__init__.py
jax/_src/image/scale.py
jax/_src/lax/__init__.py
jax/_src/lax/control_flow.py
jax/_src/lax/fft.py
jax/_src/lax/lax.py
jax/_src/lax/linalg.py
jax/_src/lax/other.py
jax/_src/lax/parallel.py
jax/_src/lax/polar.py
jax/_src/lib/__init__.py
jax/_src/lib/xla_bridge.py
jax/_src/nn/__init__.py
jax/_src/nn/functions.py
jax/_src/nn/initializers.py
jax/_src/numpy/__init__.py
jax/_src/numpy/fft.py
jax/_src/numpy/lax_numpy.py
jax/_src/numpy/linalg.py
jax/_src/numpy/polynomial.py
jax/_src/numpy/util.py
jax/_src/numpy/vectorize.py
jax/_src/ops/__init__.py
jax/_src/ops/scatter.py
jax/_src/scipy/__init__.py
jax/_src/scipy/eigh.py
jax/_src/scipy/fft.py
jax/_src/scipy/linalg.py
jax/_src/scipy/ndimage.py
jax/_src/scipy/signal.py
jax/_src/scipy/special.py
jax/_src/scipy/optimize/__init__.py
jax/_src/scipy/optimize/_lbfgs.py
jax/_src/scipy/optimize/bfgs.py
jax/_src/scipy/optimize/line_search.py
jax/_src/scipy/optimize/minimize.py
jax/_src/scipy/sparse/__init__.py
jax/_src/scipy/sparse/linalg.py
jax/_src/scipy/stats/__init__.py
jax/_src/scipy/stats/bernoulli.py
jax/_src/scipy/stats/beta.py
jax/_src/scipy/stats/betabinom.py
jax/_src/scipy/stats/cauchy.py
jax/_src/scipy/stats/chi2.py
jax/_src/scipy/stats/dirichlet.py
jax/_src/scipy/stats/expon.py
jax/_src/scipy/stats/gamma.py
jax/_src/scipy/stats/geom.py
jax/_src/scipy/stats/laplace.py
jax/_src/scipy/stats/logistic.py
jax/_src/scipy/stats/multivariate_normal.py
jax/_src/scipy/stats/nbinom.py
jax/_src/scipy/stats/norm.py
jax/_src/scipy/stats/pareto.py
jax/_src/scipy/stats/poisson.py
jax/_src/scipy/stats/t.py
jax/_src/scipy/stats/uniform.py
jax/_src/third_party/__init__.py
jax/_src/third_party/numpy/__init__.py
jax/_src/third_party/numpy/linalg.py
jax/experimental/__init__.py
jax/experimental/callback.py
jax/experimental/djax.py
jax/experimental/host_callback.py
jax/experimental/jet.py
jax/experimental/loops.py
jax/experimental/maps.py
jax/experimental/ode.py
jax/experimental/optimizers.py
jax/experimental/pjit.py
jax/experimental/stax.py
jax/experimental/x64_context.py
jax/experimental/compilation_cache/__init__.py
jax/experimental/compilation_cache/cache_interface.py
jax/experimental/compilation_cache/compilation_cache.py
jax/experimental/compilation_cache/file_system_cache.py
jax/experimental/jax2tf/__init__.py
jax/experimental/jax2tf/call_tf.py
jax/experimental/jax2tf/impl_no_xla.py
jax/experimental/jax2tf/jax2tf.py
jax/experimental/jax2tf/shape_poly.py
jax/experimental/jax2tf/examples/__init__.py
jax/experimental/jax2tf/examples/keras_reuse_main.py
jax/experimental/jax2tf/examples/keras_reuse_main_test.py
jax/experimental/jax2tf/examples/mnist_lib.py
jax/experimental/jax2tf/examples/saved_model_lib.py
jax/experimental/jax2tf/examples/saved_model_main.py
jax/experimental/jax2tf/examples/saved_model_main_test.py
jax/experimental/jax2tf/examples/serving/__init__.py
jax/experimental/jax2tf/examples/serving/model_server_request.py
jax/experimental/jax2tf/tests/__init__.py
jax/experimental/jax2tf/tests/call_tf_test.py
jax/experimental/jax2tf/tests/control_flow_ops_test.py
jax/experimental/jax2tf/tests/jax2tf_limitations.py
jax/experimental/jax2tf/tests/jax2tf_test.py
jax/experimental/jax2tf/tests/jax_primitives_coverage_test.py
jax/experimental/jax2tf/tests/primitive_harness.py
jax/experimental/jax2tf/tests/primitives_test.py
jax/experimental/jax2tf/tests/savedmodel_test.py
jax/experimental/jax2tf/tests/shape_poly_test.py
jax/experimental/jax2tf/tests/sharding_test.py
jax/experimental/jax2tf/tests/stax_test.py
jax/experimental/jax2tf/tests/tf_test_util.py
jax/experimental/sparse/__init__.py
jax/experimental/sparse/ad.py
jax/experimental/sparse/bcoo.py
jax/experimental/sparse/ops.py
jax/experimental/sparse/transform.py
jax/image/__init__.py
jax/interpreters/__init__.py
jax/interpreters/ad.py
jax/interpreters/batching.py
jax/interpreters/invertible_ad.py
jax/interpreters/masking.py
jax/interpreters/partial_eval.py
jax/interpreters/pxla.py
jax/interpreters/sharded_jit.py
jax/interpreters/xla.py
jax/lax/__init__.py
jax/lax/linalg.py
jax/lib/__init__.py
jax/lib/xla_bridge.py
jax/nn/__init__.py
jax/nn/initializers.py
jax/numpy/__init__.py
jax/numpy/fft.py
jax/numpy/linalg.py
jax/ops/__init__.py
jax/scipy/__init__.py
jax/scipy/fft.py
jax/scipy/linalg.py
jax/scipy/ndimage.py
jax/scipy/signal.py
jax/scipy/special.py
jax/scipy/optimize/__init__.py
jax/scipy/sparse/__init__.py
jax/scipy/sparse/linalg.py
jax/scipy/stats/__init__.py
jax/scipy/stats/bernoulli.py
jax/scipy/stats/beta.py
jax/scipy/stats/betabinom.py
jax/scipy/stats/cauchy.py
jax/scipy/stats/chi2.py
jax/scipy/stats/dirichlet.py
jax/scipy/stats/expon.py
jax/scipy/stats/gamma.py
jax/scipy/stats/geom.py
jax/scipy/stats/laplace.py
jax/scipy/stats/logistic.py
jax/scipy/stats/multivariate_normal.py
jax/scipy/stats/nbinom.py
jax/scipy/stats/norm.py
jax/scipy/stats/pareto.py
jax/scipy/stats/poisson.py
jax/scipy/stats/t.py
jax/scipy/stats/uniform.py
jax/tools/__init__.py
jax/tools/colab_tpu.py
jax/tools/jax_to_hlo.py