Jax

Latest version: v0.4.26

Safety actively analyzes 623616 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 4 of 17

0.4.8

* Breaking changes
* A major component of the Cloud TPU runtime has been upgraded. This enables
the following new features on Cloud TPU:
* {func}`jax.debug.print`, {func}`jax.debug.callback`, and
{func}`jax.debug.breakpoint()` now work on Cloud TPU
* Automatic TPU memory defragmentation

{func}`jax.experimental.host_callback` is no longer supported on Cloud TPU
with the new runtime component. Please file an issue on the [JAX issue
tracker](https://github.com/google/jax/issues) if the new `jax.debug` APIs
are insufficient for your use case.

The old runtime component will be available for at least the next three
months by setting the environment variable
`JAX_USE_PJRT_C_API_ON_TPU=false`. If you find you need to disable the new
runtime for any reason, please let us know on the [JAX issue
tracker](https://github.com/google/jax/issues).

* Changes
* The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7.

* Deprecations
* CUDA 11.4 support has been dropped. JAX GPU wheels only support
CUDA 11.8 and CUDA 12. Older CUDA versions may work if jaxlib is built
from source.
* `global_arg_shapes` argument of pmap only worked with sharded_jit and has
been removed from pmap. Please migrate to pjit and remove global_arg_shapes
from pmap.

0.4.7

* Changes
* As per https://jax.readthedocs.io/en/latest/jax_array_migration.html#jax-array-migration
`jax.config.jax_array` cannot be disabled anymore.
* `jax.config.jax_jit_pjit_api_merge` cannot be disabled anymore.
* {func}`jax.experimental.jax2tf.convert` now supports the `native_serialization`
parameter to use JAX's native lowering to StableHLO to obtain a
StableHLO module for the entire JAX function instead of lowering each JAX
primitive to a TensorFlow op. This simplifies the internals and increases
the confidence that what you serialize matches the JAX native semantics.
See [documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
As part of this change the config flag `--jax2tf_default_experimental_native_lowering`
has been renamed to `--jax2tf_native_serialization`.
* JAX now depends on `ml_dtypes`, which contains definitions of NumPy types
like bfloat16. These definitions were previously internal to JAX, but have
been split into a separate package to facilitate sharing them with other
projects.
* JAX now requires NumPy 1.21 or newer and SciPy 1.7 or newer.

* Deprecations
* The type `jax.numpy.DeviceArray` is deprecated. Use `jax.Array` instead,
for which it is an alias.
* The type `jax.interpreters.pxla.ShardedDeviceArray` is deprecated. Use
`jax.Array` instead.
* Passing additional arguments to {func}`jax.numpy.ndarray.at` by position is deprecated.
For example, instead of `x.at[i].get(True)`, use `x.at[i].get(indices_are_sorted=True)`
* `jax.interpreters.xla.device_put` is deprecated. Please use `jax.device_put`.
* `jax.interpreters.pxla.device_put` is deprecated. Please use `jax.device_put`.
* `jax.experimental.pjit.FROM_GDA` is deprecated. Please pass in sharded
jax.Arrays as input and remove the `in_shardings` argument to pjit since
it is optional.

jaxlib 0.4.7 (March 27, 2023)

Changes:
* jaxlib now depends on `ml_dtypes`, which contains definitions of NumPy types
like bfloat16. These definitions were previously internal to JAX, but have
been split into a separate package to facilitate sharing them with other
projects.

0.4.6

* Changes
* `jax.tree_util` now contain a set of APIs that allow user to define keys for their
custom pytree node. This includes:
* `tree_flatten_with_path` that flattens a tree and return not only each leaf but
also their key paths.
* `tree_map_with_path` that can map a function that takes the key path as an argument.
* `register_pytree_with_keys` to register how the key path and leaves should looks
like in a custom pytree node.
* `keystr` that pretty-prints a key path.

* {func}`jax2tf.call_tf` has a new parameter `output_shape_dtype` (default `None`)
that can be used to declare the output shape and type of the result. This enables
{func}`jax2tf.call_tf` to work in the presence of shape polymorphism. ({jax-issue}`14734`).

* Deprecations
* The old key-path APIs in `jax.tree_util` are deprecated and will be removed 3 months
from Mar 10 2023:
* `register_keypaths`: use {func}`jax.tree_util.register_pytree_with_keys` instead.
* `AttributeKeyPathEntry` : use `GetAttrKey` instead.
* `GetitemKeyPathEntry` : use `SequenceKey` or `DictKey` instead.

jaxlib 0.4.6 (Mar 9, 2023)

0.4.5

* Deprecations
* `jax.sharding.OpShardingSharding` has been renamed to `jax.sharding.GSPMDSharding`.
`jax.sharding.OpShardingSharding` will be removed in 3 months from Feb 17, 2023.
* The following `jax.Array` methods are deprecated and will be removed 3 months from
Feb 23 2023:
* `jax.Array.broadcast`: use {func}`jax.lax.broadcast` instead.
* `jax.Array.broadcast_in_dim`: use {func}`jax.lax.broadcast_in_dim` instead.
* `jax.Array.split`: use {func}`jax.numpy.split` instead.

0.4.4

* Changes
* The implementation of `jit` and `pjit` has been merged. Merging jit and pjit
changes the internals of JAX without affecting the public API of JAX.
Before, `jit` was a final style primitive. Final style means that the creation
of jaxpr was delayed as much as possible and transformations were stacked
on top of each other. With the `jit`-`pjit` implementation merge, `jit`
becomes an initial style primitive which means that we trace to jaxpr
as early as possible. For more information see
[this section in autodidax](https://jax.readthedocs.io/en/latest/autodidax.html#on-the-fly-final-style-and-staged-initial-style-processing).
Moving to initial style should simplify JAX's internals and make
development of features like dynamic shapes, etc easier.
You can disable it only via the environment variable i.e.
`os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'`.
The merge must be disabled via an environment variable since it affects JAX
at import time so it needs to be disabled before jax is imported.
* `axis_resources` argument of `with_sharding_constraint` is deprecated.
Please use `shardings` instead. There is no change needed if you were using
`axis_resources` as an arg. If you were using it as a kwarg, then please
use `shardings` instead. `axis_resources` will be removed after 3 months
from Feb 13, 2023.
* added the {mod}`jax.typing` module, with tools for type annotations of JAX
functions.
* The following names have been deprecated:
* `jax.xla.Device` and `jax.interpreters.xla.Device`: use `jax.Device`.
* `jax.experimental.maps.Mesh`. Use `jax.sharding.Mesh`
instead.
* `jax.experimental.pjit.NamedSharding`: use `jax.sharding.NamedSharding`.
* `jax.experimental.pjit.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* `jax.interpreters.pxla.Mesh`: use `jax.sharding.Mesh`.
* `jax.interpreters.pxla.PartitionSpec`: use `jax.sharding.PartitionSpec`.
* Breaking Changes
* the `initial` argument to reduction functions like :func:`jax.numpy.sum`
is now required to be a scalar, consistent with the corresponding NumPy API.
The previous behavior of broadcasting the output against non-scalar `initial`
values was an unintentional implementation detail ({jax-issue}`14446`).

jaxlib 0.4.4 (Feb 16, 2023)
* Breaking changes
* Support for NVIDIA Kepler series GPUs has been removed from the default
`jaxlib` builds. If Kepler support is needed, it is still possible to
build `jaxlib` from source with Kepler support (via the
`--cuda_compute_capabilities=sm_35` option to `build.py`), however note
that CUDA 12 has completely dropped support for Kepler GPUs.

0.4.3

* Breaking changes
* Deleted {func}`jax.scipy.linalg.polar_unitary`, which was a deprecated JAX
extension to the scipy API. Use {func}`jax.scipy.linalg.polar` instead.

* Changes
* Added {func}`jax.scipy.stats.rankdata`.

jaxlib 0.4.3 (Feb 8, 2023)
* `jax.Array` now has the non-blocking `is_ready()` method, which returns `True`
if the array is ready (see also {func}`jax.block_until_ready`).

Page 4 of 17

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.