Jax

Latest version: v0.4.26

Safety actively analyzes 623608 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 12 of 17

0.2.14

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...jax-v0.2.14).
* New features:
* The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`.
* A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters
tracebacks.
* A new traceback filtering mode using `__tracebackhide__` is now enabled by
default in sufficiently recent versions of IPython.
* The {func}`jax2tf.convert` supports shape polymorphism even when the
unknown dimensions are used in arithmetic operations, e.g., `jnp.reshape(-1)`
({jax-issue}`6827`).
* The {func}`jax2tf.convert` generates custom attributes with location information
in TF ops. The code that XLA generates after jax2tf
has the same location information as JAX/XLA.
* New SciPy function {py:func}`jax.scipy.special.lpmn`.

* Bug fixes:
* The {func}`jax2tf.convert` now ensures that it uses the same typing rules
for Python scalars and for choosing 32-bit vs. 64-bit computations
as JAX ({jax-issue}`6883`).
* The {func}`jax2tf.convert` now scopes the `enable_xla` conversion parameter
properly to apply only during the just-in-time conversion
({jax-issue}`6720`).
* The {func}`jax2tf.convert` now converts `lax.dot_general` using the
`XlaDot` TensorFlow op, for better fidelity w.r.t. JAX numerical precision
({jax-issue}`6717`).
* The {func}`jax2tf.convert` now has support for inequality comparisons and
min/max for complex numbers ({jax-issue}`6892`).

jaxlib 0.1.67 (May 17 2021)

jaxlib 0.1.66 (May 11 2021)

* New features:
* CUDA 11.1 wheels are now supported on all CUDA 11 versions 11.1 or higher.

NVidia now promises compatibility between CUDA minor releases starting with
CUDA 11.1. This means that JAX can release a single CUDA 11.1 wheel that
is compatible with CUDA 11.2 and 11.3.

There is no longer a separate jaxlib release for CUDA 11.2 (or higher); use
the CUDA 11.1 wheel for those versions (cuda111).
* Jaxlib now bundles `libdevice.10.bc` in CUDA wheels. There should be no need
to point JAX to a CUDA installation to find this file.
* Added automatic support for static keyword arguments to the {func}`jit`
implementation.
* Added support for pretransformation exception traces.
* Initial support for pruning unused arguments from {func}`jit` -transformed
computations.
Pruning is still a work in progress.
* Improved the string representation of {class}`PyTreeDef` objects.
* Added support for XLA's variadic ReduceWindow.
* Bug fixes:
* Fixed a bug in the remote cloud TPU support when large numbers of arguments
are passed to a computation.
* Fix a bug that meant that JAX garbage collection was not triggered by
{func}`jit` transformed functions.

0.2.13

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.12...jax-v0.2.13).
* New features:
* When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static
keyword arguments. A new `static_argnames` option has been added to specify
keyword arguments as static.
* {func}`jax.nonzero` has a new optional `size` argument that allows it to
be used within `jit` ({jax-issue}`6501`)
* {func}`jax.numpy.unique` now supports the `axis` argument ({jax-issue}`6532`).
* {func}`jax.experimental.host_callback.call` now supports `pjit.pjit` ({jax-issue}`6569`).
* Added {func}`jax.scipy.linalg.eigh_tridiagonal` that computes the
eigenvalues of a tridiagonal matrix. Only eigenvalues are supported at
present.
* The order of the filtered and unfiltered stack traces in exceptions has been
changed. The traceback attached to an exception thrown from JAX-transformed
code is now filtered, with an `UnfilteredStackTrace` exception
containing the original trace as the `__cause__` of the filtered exception.
Filtered stack traces now also work with Python 3.6.
* If an exception is thrown by code that has been transformed by reverse-mode
automatic differentiation, JAX now attempts to attach as a `__cause__` of
the exception a `JaxStackTraceBeforeTransformation` object that contains the
stack trace that created the original operation in the forward pass.
Requires jaxlib 0.1.66.

* Breaking changes:
* The following function names have changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
* `host_id` --> {func}`~jax.process_index`
* `host_count` --> {func}`~jax.process_count`
* `host_ids` --> `range(jax.process_count())`
* Similarly, the argument to {func}`~jax.local_devices` has been renamed from
`host_id` to `process_index`.
* Arguments to {func}`jax.jit` other than the function are now marked as
keyword-only. This change is to prevent accidental breakage when arguments
are added to `jit`.
* Bug fixes:
* The {func}`jax2tf.convert` now works in presence of gradients for functions
with integer inputs ({jax-issue}`6360`).
* Fixed assertion failure in {func}`jax2tf.call_tf` when used with captured
`tf.Variable` ({jax-issue}`6572`).

jaxlib 0.1.65 (April 7 2021)

0.2.12

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...v0.2.12).
* New features
* New profiling APIs: {func}`jax.profiler.start_trace`,
{func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace`
* {func}`jax.lax.reduce` is now differentiable.
* Breaking changes:
* The minimum jaxlib version is now 0.1.64.
* Some profiler APIs names have been changed. There are still aliases, so this
should not break existing code, but the aliases will eventually be removed
so please change your code.
* `TraceContext` --> {func}`~jax.profiler.TraceAnnotation`
* `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation`
* `trace_function` --> {func}`~jax.profiler.annotate_function`
* Omnistaging can no longer be disabled. See [omnistaging](https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md)
for more information.
* Python integers larger than the maximum `int64` value will now lead to an overflow
in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`6047`).
* Outside X64 mode, Python integers outside the range representable by `int32` will now lead to an
`OverflowError` rather than having their value silently truncated.
* Bug fixes:
* `host_callback` now supports empty arrays in arguments and results ({jax-issue}`6262`).
* {func}`jax.random.randint` clips rather than wraps of out-of-bounds limits, and can now generate
integers in the full range of the specified dtype ({jax-issue}`5868`)

0.2.11

* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.10...jax-v0.2.11).
* New features:
* [6112](https://github.com/google/jax/pull/6112) added context managers:
`jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`,
`jax.debug_infs`, `jax.log_compiles`.
* [6085](https://github.com/google/jax/pull/6085) added `jnp.delete`

* Bug fixes:
* [6136](https://github.com/google/jax/pull/6136) generalized
`jax.flatten_util.ravel_pytree` to handle integer dtypes.
* [6129](https://github.com/google/jax/issues/6129) fixed a bug with handling
some constants like `enum.IntEnums`
* [6145](https://github.com/google/jax/pull/6145) fixed batching issues with
incomplete beta functions
* [6014](https://github.com/google/jax/pull/6014) fixed H2D transfers during
tracing
* [6165](https://github.com/google/jax/pull/6165) avoids OverflowErrors when
converting some large Python integers to floats
* Breaking changes:
* The minimum jaxlib version is now 0.1.62.


jaxlib 0.1.64 (March 18 2021)

jaxlib 0.1.63 (March 17 2021)

0.2.10

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.9...jax-v0.2.10).
* New features:
* {func}`jax.scipy.stats.chi2` is now available as a distribution with logpdf and pdf methods.
* {func}`jax.scipy.stats.betabinom` is now available as a distribution with logpmf and pmf methods.
* Added {func}`jax.experimental.jax2tf.call_tf` to call TensorFlow functions
from JAX ({jax-issue}`5627`)
and [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)).
* Extended the batching rule for `lax.pad` to support batching of the padding values.
* Bug fixes:
* {func}`jax.numpy.take` properly handles negative indices ({jax-issue}`5768`)
* Breaking changes:
* JAX's promotion rules were adjusted to make promotion more consistent and
invariant to JIT. In particular, binary operations can now result in weakly-typed
values when appropriate. The main user-visible effect of the change is that
some operations result in outputs of different precision than before; for
example the expression `jnp.bfloat16(1) + 0.1 * jnp.arange(10)`
previously returned a `float64` array, and now returns a `bfloat16` array.
JAX's type promotion behavior is described at {ref}`type-promotion`.
* {func}`jax.numpy.linspace` now computes the floor of integer values, i.e.,
rounding towards -inf rather than 0. This change was made to match NumPy
1.20.0.
* {func}`jax.numpy.i0` no longer accepts complex numbers. Previously the
function computed the absolute value of complex arguments. This change was
made to match the semantics of NumPy 1.20.0.
* Several {mod}`jax.numpy` functions no longer accept tuples or lists in place
of array arguments: {func}`jax.numpy.pad`, :func`jax.numpy.ravel`,
{func}`jax.numpy.repeat`, {func}`jax.numpy.reshape`.
In general, {mod}`jax.numpy` functions should be used with scalars or array arguments.

jaxlib 0.1.62 (March 9 2021)

* New features:
* jaxlib wheels are now built to require AVX instructions on x86-64 machines
by default. If you want to use JAX on a machine that doesn't support AVX,
you can build a jaxlib from source using the `--target_cpu_features` flag
to `build.py`. `--target_cpu_features` also replaces
`--enable_march_native`.

jaxlib 0.1.61 (February 12 2021)

jaxlib 0.1.60 (February 3 2021)

* Bug fixes:
* Fixed a memory leak when converting CPU DeviceArrays to NumPy arrays. The
memory leak was present in jaxlib releases 0.1.58 and 0.1.59.
* `bool`, `int8`, and `uint8` are now considered safe to cast to
`bfloat16` NumPy extension type.

0.2.9

* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.8...jax-v0.2.9).
* New features:
* Extend the {mod}`jax.experimental.loops` module with support for pytrees. Improved
error checking and error messages.
* Add {func}`jax.experimental.enable_x64` and {func}`jax.experimental.disable_x64`.
These are context managers which allow X64 mode to be temporarily enabled/disabled
within a session.
* Breaking changes:
* {func}`jax.ops.segment_sum` now drops segment IDs that are out of range rather
than wrapping them into the segment ID space. This was done for performance
reasons.

jaxlib 0.1.59 (January 15 2021)

Page 12 of 17

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.