Jax

Latest version: v0.4.26

Safety actively analyzes 623541 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 5 of 17

0.4.2

* Breaking changes
* Deleted `jax.experimental.callback`
* Operations with dimensions in presence of jax2tf shape polymorphism have
been generalized to work in more scenarios, by converting the symbolic
dimension to JAX arrays. Operations involving symbolic dimensions and
`np.ndarray` now can raise errors when the result is used as a shape value
({jax-issue}`14106`).
* jaxpr objects now raise an error on attribute setting in order to avoid
problematic mutations ({jax-issue}`14102`)

* Changes
* {func}`jax2tf.call_tf` has a new parameter `has_side_effects` (default `True`)
that can be used to declare whether an instance can be removed or replicated
by JAX optimizations such as dead-code elimination ({jax-issue}`13980`).
* Added more support for floordiv and mod for jax2tf shape polymorphism. Previously,
certain division operations resulted in errors in presence of symbolic dimensions
({jax-issue}`14108`).

jaxlib 0.4.2 (Jan 24, 2023)

* Changes
* Set JAX_USE_PJRT_C_API_ON_TPU=1 to enable new Cloud TPU runtime, featuring
automatic device memory defragmentation.

0.4.1

* Changes
* Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}`version-support-policy`.
* We introduce `jax.Array` which is a unified array type that subsumes
`DeviceArray`, `ShardedDeviceArray`, and `GlobalDeviceArray` types in JAX.
The `jax.Array` type helps make parallelism a core feature of JAX,
simplifies and unifies JAX internals, and allows us to unify `jit` and
`pjit`. `jax.Array` has been enabled by default in JAX 0.4 and makes some
breaking change to the `pjit` API. The [jax.Array migration
guide](https://jax.readthedocs.io/en/latest/jax_array_migration.html) can
help you migrate your codebase to `jax.Array`. You can also look at the
[Distributed arrays and automatic parallelization](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html)
tutorial to understand the new concepts.
* `PartitionSpec` and `Mesh` are now out of experimental. The new API endpoints
are `jax.sharding.PartitionSpec` and `jax.sharding.Mesh`.
`jax.experimental.maps.Mesh` and `jax.experimental.PartitionSpec` are
deprecated and will be removed in 3 months.
* `with_sharding_constraint`s new public endpoint is
`jax.lax.with_sharding_constraint`.
* If using ABSL flags together with `jax.config`, the ABSL flag values are no
longer read or written after the JAX configuration options are initially
populated from the ABSL flags. This change improves performance of reading
`jax.config` options, which are used pervasively in JAX.
* The jax2tf.call_tf function now uses for TF lowering the first TF
device of the same platform as used by the embedding JAX computation.
Before, it was using the 0th device for the JAX-default backend.
* A number of `jax.numpy` functions now have their arguments marked as
positional-only, matching NumPy.
* `jnp.msort` is now deprecated, following the deprecation of `np.msort` in numpy 1.24.
It will be removed in a future release, in accordance with the {ref}`api-compatibility`
policy. It can be replaced with `jnp.sort(a, axis=0)`.

jaxlib 0.4.1 (Dec 13, 2022)

* Changes
* Support for Python 3.7 has been dropped, in accordance with JAX's
{ref}`version-support-policy`.
* The behavior of `XLA_PYTHON_CLIENT_MEM_FRACTION=.XX` has been changed to allocate XX% of
the total GPU memory instead of the previous behavior of using currently available GPU memory
to calculate preallocation. Please refer to
[GPU memory allocation](https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html) for
more details.
* The deprecated method `.block_host_until_ready()` has been removed. Use
`.block_until_ready()` instead.

0.4.0

* The release was yanked.

jaxlib 0.4.0 (Dec 12, 2022)

* The release was yanked.

0.3.25

* Changes
* {func}`jax.numpy.linalg.pinv` now supports the `hermitian` option.
* {func}`jax.scipy.linalg.hessenberg` is now supported on CPU only. Requires
jaxlib > 0.3.24.
* New functions {func}`jax.lax.linalg.hessenberg`,
{func}`jax.lax.linalg.tridiagonal`, and
{func}`jax.lax.linalg.householder_product` were added. Householder reduction
is currently CPU-only and tridiagonal reductions are supported on CPU and
GPU only.
* The gradients of `svd` and `jax.numpy.linalg.pinv` are now computed more
economically for non-square matrices.
* Breaking Changes
* Deleted the `jax_experimental_name_stack` config option.
* Convert a string `axis_names` arguments to the
{class}`jax.experimental.maps.Mesh` constructor into a singleton tuple
instead of unpacking the string into a sequence of character axis names.

jaxlib 0.3.25 (Nov 15, 2022)
* Changes
* Added support for tridiagonal reductions on CPU and GPU.
* Added support for upper Hessenberg reductions on CPU.
* Bugs
* Fixed a bug that meant that frames in tracebacks captured by JAX were
incorrectly mapped to source lines under Python 3.10+

0.3.24

* Changes
* JAX should be faster to import. We now import scipy lazily, which accounted
for a significant fraction of JAX's import time.
* Setting the env var `JAX_PERSISTENT_CACHE_MIN_COMPILE_TIME_SECS=$N` can be
used to limit the number of cache entries written to the persistent cache.
By default, computations that take 1 second or more to compile will be
cached.
* Added {func}`jax.scipy.stats.mode`.
* The default device order used by `pmap` on TPU if no order is specified now
matches `jax.devices()` for single-process jobs. Previously the
two orderings differed, which could lead to unnecessary copies or
out-of-memory errors. Requiring the orderings to agree simplifies matters.
* Breaking Changes
* {func}`jax.numpy.gradient` now behaves like most other functions in {mod}`jax.numpy`,
and forbids passing lists or tuples in place of arrays ({jax-issue}`12958`)
* Functions in {mod}`jax.numpy.linalg` and {mod}`jax.numpy.fft` now uniformly
require inputs to be array-like: i.e. lists and tuples cannot be used in place
of arrays. Part of {jax-issue}`7737`.
* Deprecations
* `jax.sharding.MeshPspecSharding` has been renamed to `jax.sharding.NamedSharding`.
`jax.sharding.MeshPspecSharding` name will be removed in 3 months.

jaxlib 0.3.24 (Nov 4, 2022)
* Changes
* Buffer donation now works on CPU. This may break code that marked buffers
for donation on CPU but relied on donation not being implemented.

0.3.23

* Changes
* Update Colab TPU driver version for new jaxlib release.

Page 5 of 17

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.