Jax

Latest version: v0.4.26

Safety actively analyzes 623518 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 8 of 17

0.3.10

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.9...jax-v0.3.10).

jaxlib 0.3.10 (May 3, 2022)
* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.7...jaxlib-v0.3.10).
* Changes
* [TF commit](https://github.com/tensorflow/tensorflow/commit/207d50d253e11c3a3430a700af478a1d524a779a)
fixes an issue in the MHLO canonicalizer that caused constant folding to
take a long time or crash for certain programs.

0.3.9

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.8...jax-v0.3.9).
* Changes
* Added support for fully asynchronous checkpointing for GlobalDeviceArray.

0.3.8

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.7...jax-v0.3.8).
* Changes
* {func}`jax.numpy.linalg.svd` on TPUs uses a qdwh-svd solver.
* {func}`jax.numpy.linalg.cond` on TPUs now accepts complex input.
* {func}`jax.numpy.linalg.pinv` on TPUs now accepts complex input.
* {func}`jax.numpy.linalg.matrix_rank` on TPUs now accepts complex input.
* {func}`jax.scipy.cluster.vq.vq` has been added.
* `jax.experimental.maps.mesh` has been deleted.
Please use `jax.experimental.maps.Mesh`. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
for more information.
* {func}`jax.scipy.linalg.qr` now returns a length-1 tuple rather than the raw array when
`mode='r'`, in order to match the behavior of `scipy.linalg.qr` ({jax-issue}`10452`)
* {func}`jax.numpy.take_along_axis` now takes an optional `mode` parameter
that specifies the behavior of out-of-bounds indexing. By default,
invalid values (e.g., NaN) will be returned for out-of-bounds indices. In
previous versions of JAX, invalid indices were clamped into range. The
previous behavior can be restored by passing `mode="clip"`.
* {func}`jax.numpy.take` now defaults to `mode="fill"`, which returns
invalid values (e.g., NaN) for out-of-bounds indices.
* Scatter operations, such as `x.at[...].set(...)`, now have `"drop"` semantics.
This has no effect on the scatter operation itself, but it means that when
differentiated the gradient of a scatter will yield zero cotangents for
out-of-bounds indices. Previously out-of-bounds indices were clamped into
range for the gradient, which was not mathematically correct.
* {func}`jax.numpy.take_along_axis` now raises a `TypeError` if its indices
are not of an integer type, matching the behavior of
{func}`numpy.take_along_axis`. Previously non-integer indices were silently
cast to integers.
* {func}`jax.numpy.ravel_multi_index` now raises a `TypeError` if its `dims` argument
is not of an integer type, matching the behavior of
{func}`numpy.ravel_multi_index`. Previously non-integer `dims` was silently
cast to integers.
* {func}`jax.numpy.split` now raises a `TypeError` if its `axis` argument
is not of an integer type, matching the behavior of
{func}`numpy.split`. Previously non-integer `axis` was silently
cast to integers.
* {func}`jax.numpy.indices` now raises a `TypeError` if its dimensions
are not of an integer type, matching the behavior of
{func}`numpy.indices`. Previously non-integer dimensions were silently
cast to integers.
* {func}`jax.numpy.diag` now raises a `TypeError` if its `k` argument
is not of an integer type, matching the behavior of
{func}`numpy.diag`. Previously non-integer `k` was silently
cast to integers.
* Added {func}`jax.random.orthogonal`.
* Deprecations
* Many functions and objects available in {mod}`jax.test_util` are now deprecated and will raise a
warning on import. This includes `cases_from_list`, `check_close`, `check_eq`, `device_under_test`,
`format_shape_dtype_string`, `rand_uniform`, `skip_on_devices`, `with_config`, `xla_bridge`, and
`_default_tolerance` ({jax-issue}`10389`). These, along with previously-deprecated `JaxTestCase`,
`JaxTestLoader`, and `BufferDonationTestCase`, will be removed in a future JAX release.
Most of these utilities can be replaced by calls to standard python & numpy testing utilities found
in e.g. {mod}`unittest`, {mod}`absl.testing`, {mod}`numpy.testing`, etc. JAX-specific functionality
such as device checking can be replaced through the use of public APIs such as {func}`jax.devices`.
Many of the deprecated utilities will still exist in {mod}`jax._src.test_util`, but these are not
public APIs and as such may be changed or removed without notice in future releases.

0.3.7

* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.6...jax-v0.3.7).
* Changes:
* Fixed a performance problem if the indices passed to
{func}`jax.numpy.take_along_axis` were broadcasted ({jax-issue}`10281`).
* {func}`jax.scipy.special.expit` and {func}`jax.scipy.special.logit` now
require their arguments to be scalars or JAX arrays. They also now promote
integer arguments to floating point.
* The `DeviceArray.tile()` method is deprecated, because numpy arrays do not have a
`tile()` method. As a replacement for this, use {func}`jax.numpy.tile`
({jax-issue}`10266`).

jaxlib 0.3.7 (April 15, 2022)
* Changes:
* Linux wheels are now built conforming to the `manylinux2014` standard, instead
of `manylinux2010`.

0.3.6

* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.5...jax-v0.3.6).
* Changes:
* Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU
pod. Fixes [10218](https://github.com/google/jax/issues/10218).
* Deprecations:
* {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`10278`
for an alternative API.

0.3.5

* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.3.4...jax-v0.3.5).
* Changes:
* added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta`
and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`9906`).
* the private `lax_numpy` submodule is no longer exposed in the `jax.numpy` namespace ({jax-issue}`10029`).
* added array creation routines {func}`jax.numpy.frombuffer`, {func}`jax.numpy.fromfunction`,
and {func}`jax.numpy.fromstring` ({jax-issue}`10049`).
* `DeviceArray.copy()` now returns a `DeviceArray` rather than a `np.ndarray` ({jax-issue}`10069`)
* added {func}`jax.scipy.linalg.rsf2csf`
* `jax.experimental.sharded_jit` has been deprecated and will be removed soon.
* Deprecations:
* {func}`jax.nn.normalize` is being deprecated. Use {func}`jax.nn.standardize` instead ({jax-issue}`9899`).
* {func}`jax.tree_util.tree_multimap` is deprecated. Use {func}`jax.tree_util.tree_map` instead ({jax-issue}`5746`).
* `jax.experimental.sharded_jit` is deprecated. Use `pjit` instead.

jaxlib 0.3.5 (April 7, 2022)
* Bug fixes
* Fixed a bug where double-precision complex-to-real IRFFTs would mutate their
input buffers on GPU ({jax-issue}`9946`).
* Fixed incorrect constant-folding of complex scatters ({jax-issue}`10159`)

Page 8 of 17

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.