Jax

Latest version: v0.4.26

Safety actively analyzes 623541 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 13 of 17

0.2.8

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.7...jax-v0.2.8).
* New features:
* Add {func}`jax.closure_convert` for use with higher-order custom
derivative functions. ({jax-issue}`5244`)
* Add {func}`jax.experimental.host_callback.call` to call a custom Python
function on the host and return a result to the device computation.
({jax-issue}`5243`)
* Bug fixes:
* `jax.numpy.arccosh` now returns the same branch as `numpy.arccosh` for
complex inputs ({jax-issue}`5156`)
* `host_callback.id_tap` now works for `jax.pmap` also. There is an
optional parameter for `id_tap` and `id_print` to request that the
device from which the value is tapped be passed as a keyword argument
to the tap function ({jax-issue}`5182`).
* Breaking changes:
* `jax.numpy.pad` now takes keyword arguments. Positional argument `constant_values`
has been removed. In addition, passing unsupported keyword arguments raises an error.
* Changes for {func}`jax.experimental.host_callback.id_tap` ({jax-issue}`5243`):
* Removed support for `kwargs` for {func}`jax.experimental.host_callback.id_tap`.
(This support has been deprecated for a few months.)
* Changed the printing of tuples for {func}`jax.experimental.host_callback.id_print`
to use '(' instead of '['.
* Changed the {func}`jax.experimental.host_callback.id_print` in presence of JVP
to print a pair of primal and tangent. Previously, there were two separate
print operations for the primals and the tangent.
* `host_callback.outfeed_receiver` has been removed (it is not necessary,
and was deprecated a few months ago).
* New features:
* New flag for debugging `inf`, analogous to that for `NaN` ({jax-issue}`5224`).

0.2.7

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.6...jax-v0.2.7).
* New features:
* Add `jax.device_put_replicated`
* Add multi-host support to `jax.experimental.sharded_jit`
* Add support for differentiating eigenvalues computed by `jax.numpy.linalg.eig`
* Add support for building on Windows platforms
* Add support for general in_axes and out_axes in `jax.pmap`
* Add complex support for `jax.numpy.linalg.slogdet`
* Bug fixes:
* Fix higher-than-second order derivatives of `jax.numpy.sinc` at zero
* Fix some hard-to-hit bugs around symbolic zeros in transpose rules
* Breaking changes:
* `jax.experimental.optix` has been deleted, in favor of the standalone
`optax` Python package.
* indexing of JAX arrays with non-tuple sequences now raises a `TypeError`. This type of indexing
has been deprecated in Numpy since v1.16, and in JAX since v0.2.4.
See {jax-issue}`4564`.

0.2.6

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.5...jax-v0.2.6).
* New Features:
* Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter.
See [README.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).
* Breaking change cleanup

* Raise an error on non-hashable static arguments for jax.jit and
xla_computation. See [cb48f42](https://github.com/google/jax/commit/cb48f42).
* Improve consistency of type promotion behavior ({jax-issue}`4744`):
* Adding a complex Python scalar to a JAX floating point number respects the precision of
the JAX float. For example, `jnp.float32(1) + 1j` now returns `complex64`, where previously
it returned `complex128`.
* Results of type promotion with 3 or more terms involving uint64, a signed int, and a third type
are now independent of the order of arguments. For example:
`jnp.result_type(jnp.uint64, jnp.int64, jnp.float16)` and
`jnp.result_type(jnp.float16, jnp.uint64, jnp.int64)` both return `float16`, where previously
the first returned `float64` and the second returned `float16`.
* The contents of the (undocumented) `jax.lax_linalg` linear algebra module
are now exposed publicly as `jax.lax.linalg`.
* `jax.random.PRNGKey` now produces the same results in and out of JIT compilation
({jax-issue}`4877`).
This required changing the result for a given seed in a few particular cases:
* With `jax_enable_x64=False`, negative seeds passed as Python integers now return a different result
outside JIT mode. For example, `jax.random.PRNGKey(-1)` previously returned
`[4294967295, 4294967295]`, and now returns `[0, 4294967295]`. This matches the behavior in JIT.
* Seeds outside the range representable by `int64` outside JIT now result in an `OverflowError`
rather than a `TypeError`. This matches the behavior in JIT.

To recover the keys returned previously for negative integers with `jax_enable_x64=False`
outside JIT, you can use:


key = random.PRNGKey(-1).at[0].set(0xFFFFFFFF)

* DeviceArray now raises `RuntimeError` instead of `ValueError` when trying
to access its value while it has been deleted.

jaxlib 0.1.58 (January 12ish 2021)

* Fixed a bug that meant JAX sometimes return platform-specific types (e.g.,
`np.cint`) instead of standard types (e.g., `np.int32`). (4903)
* Fixed a crash when constant-folding certain int16 operations. (4971)
* Added an `is_leaf` predicate to {func}`pytree.flatten`.

jaxlib 0.1.57 (November 12 2020)

* Fixed manylinux2010 compliance issues in GPU wheels.
* Switched the CPU FFT implementation from Eigen to PocketFFT.
* Fixed a bug where the hash of bfloat16 values was not correctly initialized
and could change (4651).
* Add support for retaining ownership when passing arrays to DLPack (4636).
* Fixed a bug for batched triangular solves with sizes greater than 128 but not
a multiple of 128.
* Fixed a bug when performing concurrent FFTs on multiple GPUs (3518).
* Fixed a bug in profiler where tools are missing (4427).
* Dropped support for CUDA 10.0.

0.2.5

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.4...jax-v0.2.5).
* Improvements:
* Ensure that `check_jaxpr` does not perform FLOPS. See {jax-issue}`4650`.
* Expanded the set of JAX primitives converted by jax2tf.
See [primitives_with_limited_support.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md).

0.2.4

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.3...jax-v0.2.4).
* Improvements:
* Add support for `remat` to jax.experimental.host_callback. See {jax-issue}`4608`.
* Deprecations

* Indexing with non-tuple sequences is now deprecated, following a similar deprecation in Numpy.
In a future release, this will result in a TypeError. See {jax-issue}`4564`.

jaxlib 0.1.56 (October 14, 2020)

0.2.3

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.2...jax-v0.2.3).
* The reason for another release so soon is we need to temporarily roll back a
new jit fastpath while we look into a performance degradation

Page 13 of 17

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.