Jax

Latest version: v0.4.26

Safety actively analyzes 623541 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 10 of 17

0.2.27

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.26...jax-v0.2.27).

* Breaking changes:
* Support for NumPy 1.18 has been dropped, per the
[deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html).
Please upgrade to a supported NumPy version.
* The host_callback primitives have been simplified to drop the
special autodiff handling for hcb.id_tap and id_print.
From now on, only the primals are tapped. The old behavior can be
obtained (for a limited time) by setting the ``JAX_HOST_CALLBACK_AD_TRANSFORMS``
environment variable, or the --flax_host_callback_ad_transforms flag.
Additionally, added documentation for how to implement the old behavior
using JAX custom AD APIs ({jax-issue}`8678`).
* Sorting now matches the behavior of NumPy for ``0.0`` and ``NaN`` regardless of the
bit representation. In particular, ``0.0`` and ``-0.0`` are now treated as equivalent,
where previously ``-0.0`` was treated as less than ``0.0``. Additionally all ``NaN``
representations are now treated as equivalent and sorted to the end of the array.
Previously negative ``NaN`` values were sorted to the front of the array, and ``NaN``
values with different internal bit representations were not treated as equivalent, and
were sorted according to those bit patterns ({jax-issue}`9178`).
* {func}`jax.numpy.unique` now treats ``NaN`` values in the same way as `np.unique` in
NumPy versions 1.21 and newer: at most one ``NaN`` value will appear in the uniquified
output ({jax-issue}`9184`).

* Bug fixes:
* host_callback now supports ad_checkpoint.checkpoint ({jax-issue}`8907`).

* New features:
* add `jax.block_until_ready` ({jax-issue}`8941)
* Added a new debugging flag/environment variable `JAX_DUMP_IR_TO=/path`.
If set, JAX dumps the MHLO/HLO IR it generates for each computation to a
file under the given path.
* Added `jax.ensure_compile_time_eval` to the public api ({jax-issue}`7987`).
* jax2tf now supports a flag jax2tf_associative_scan_reductions to change
the lowering for associative reductions, e.g., jnp.cumsum, to behave
like JAX on CPU and GPU (to use an associative scan). See the jax2tf README
for more details ({jax-issue}`9189`).


jaxlib 0.1.75 (Dec 8, 2021)
* New features:
* Support for python 3.10.

0.2.26

* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.25...jax-v0.2.26).

* Bug fixes:
* Out-of-bounds indices to `jax.ops.segment_sum` will now be handled with
`FILL_OR_DROP` semantics, as documented. This primarily affects the
reverse-mode derivative, where gradients corresponding to out-of-bounds
indices will now be returned as 0. (8634).
* jax2tf will force the converted code to use XLA for the code fragments
under jax.jit, e.g., most jax.numpy functions ({jax-issue}`7839`).

jaxlib 0.1.74 (Nov 17, 2021)
* Enabled peer-to-peer copies between GPUs. Previously, GPU copies were bounced via
the host, which is usually slower.
* Added experimental MLIR Python bindings for use by JAX.

0.2.25

* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.24...jax-v0.2.25).

* New features:
* (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend.
* `jax.random.permutation` supports new `independent` keyword argument
({jax-issue}`8430`)
* Breaking changes
* Moved `jax.experimental.stax` to `jax.example_libraries.stax`
* Moved `jax.experimental.optimizers` to `jax.example_libraries.optimizers`
* New features:
* Added `jax.lax.linalg.qdwh`.

0.2.24

* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.22...jax-v0.2.24).

* New features:
* `jax.random.choice` and `jax.random.permutation` now support
multidimensional arrays and an optional `axis` argument ({jax-issue}`8158`)
* Breaking changes:
* `jax.numpy.take` and `jax.numpy.take_along_axis` now require array-like inputs
(see {jax-issue}`7737`)

jaxlib 0.1.73 (Oct 18, 2021)

* Multiple cuDNN versions are now supported for jaxlib GPU `cuda11` wheels.
* cuDNN 8.2 or newer. We recommend using the cuDNN 8.2 wheel if your cuDNN
installation is new enough, since it supports additional functionality.
* cuDNN 8.0.5 or newer.

* Breaking changes:
* The install commands for GPU jaxlib are as follows:

bash
pip install --upgrade pip

Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

Installs the wheel compatible with Cuda 11 and cudnn 8.2 or newer.
pip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html

Installs the wheel compatible with Cuda 11 and cudnn 8.0.5 or newer.
pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html

0.2.22

* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.21...jax-v0.2.22).
* Breaking Changes
* Static arguments to `jax.pmap` must now be hashable.

Unhashable static arguments have long been disallowed on `jax.jit`, but they
were still permitted on `jax.pmap`; `jax.pmap` compared unhashable static
arguments using object identity.

This behavior is a footgun, since comparing arguments using
object identity leads to recompilation each time the object identity
changes. Instead, we now ban unhashable arguments: if a user of `jax.pmap`
wants to compare static arguments by object identity, they can define
`__hash__` and `__eq__` methods on their objects that do that, or wrap their
objects in an object that has those operations with object identity
semantics. Another option is to use `functools.partial` to encapsulate the
unhashable static arguments into the function object.
* `jax.util.partial` was an accidental export that has now been removed. Use
`functools.partial` from the Python standard library instead.
* Deprecations
* The functions `jax.ops.index_update`, `jax.ops.index_add` etc. are
deprecated and will be removed in a future JAX release. Please use
[the `.at` property on JAX arrays](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html)
instead, e.g., `x.at[idx].set(y)`. For now, these functions produce a
`DeprecationWarning`.
* New features:
* An optimized C++ code-path improving the dispatch time for `pmap` is now the
default when using jaxlib 0.1.72 or newer. The feature can be disabled using
the `--experimental_cpp_pmap` flag (or `JAX_CPP_PMAP` environment variable).
* `jax.numpy.unique` now supports an optional `fill_value` argument ({jax-issue}`8121`)

jaxlib 0.1.72 (Oct 12, 2021)
* Breaking changes:
* Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports
CUDA 11.1+.
* Bug fixes:
* Fixes https://github.com/google/jax/issues/7461, which caused wrong
outputs on all platforms due to incorrect buffer aliasing inside the XLA
compiler.

0.2.21

* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.20...jax-v0.2.21).
* Breaking Changes
* `jax.api` has been removed. Functions that were available as `jax.api.*`
were aliases for functions in `jax.*`; please use the functions in
`jax.*` instead.
* `jax.partial`, and `jax.lax.partial` were accidental exports that have now
been removed. Use `functools.partial` from the Python standard library
instead.
* Boolean scalar indices now raise a `TypeError`; previously this silently
returned wrong results ({jax-issue}`7925`).
* Many more `jax.numpy` functions now require array-like inputs, and will error
if passed a list ({jax-issue}`7747` {jax-issue}`7802` {jax-issue}`7907`).
See {jax-issue}`7737` for a discussion of the rationale behind this change.
* When inside a transformation such as `jax.jit`, `jax.numpy.array` always
stages the array it produces into the traced computation. Previously
`jax.numpy.array` would sometimes produce a on-device array, even under
a `jax.jit` decorator. This change may break code that used JAX arrays to
perform shape or index computations that must be known statically; the
workaround is to perform such computations using classic NumPy arrays
instead.
* `jnp.ndarray` is now a true base-class for JAX arrays. In particular, this
means that for a standard numpy array `x`, `isinstance(x, jnp.ndarray)` will
now return `False` ({jax-issue}`7927`).
* New features:
* Added {func}`jax.numpy.insert` implementation ({jax-issue}`7936`).

Page 10 of 17

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.