Flax

Latest version: v0.8.3

Safety actively analyzes 631215 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 4 of 7

0.6.1

-----
- Adds axis_name and axis_index_groups to LayerNorm and GroupNorm. by copybara-service in [2402](https://github.com/google/flax/pull/2402)
- Plumb spmd_axis_name through transforms.vmap through to JAX vmap by copybara-service in [2398](https://github.com/google/flax/pull/2398)
- Support multiple inputs in flax lifted vjp/custom_vjp by copybara-service in [2399](https://github.com/google/flax/pull/2399)
- Improve tabulate by cgarciae in [2316](https://github.com/google/flax/pull/2316)
- Add path_aware_map function by cgarciae in [2371](https://github.com/google/flax/pull/2371)
- Add static_argnums to nn.checkpoint by cgarciae in [2457](https://github.com/google/flax/pull/2457)
- Adding "count_include_pad" argument to flax.linen.pooling.avg_pool by dslisleedh in [2451](https://github.com/google/flax/pull/2451)
- Add perturb() to allow capturing intermediate gradients by IvyZX in [2476](https://github.com/google/flax/pull/2476)

0.6.0

-----

- Removed deprecated optimizers in `flax.optim` package.
- Moved `flax.optim.dynamic_scale` to `flax.training.dynamic_scale`.
- Switched to using `jax.named_scope` for all profile naming, cut some pointless
stack traces out.

0.5.3

-----
New features:
- Added `nn.switch` as a lifted version of `jax.lax.switch`.
- Added a method for detecting the use of "init" functions.
- Added checkpointing support for `jax.experimental.GlobalDeviceArray`, a useful array type for multiprocess/multihost computing.
- Added async option to `save_checkpoints()` on single-process scenario.
- Improved documentation pages.

Bug fixes:
- Fixed variable aliasing in put_variable
- Fixed missing passthrough of nn.scan unroll arg
- Fixed the MNIST example

0.5.2

-----
- Fixes missing PyYAML dependency.

0.5.1

-----
New features:
- Added `nn.tabulate` and `Module.tabulate` to generate rich representations of the network structure.

0.5.0

-----
- Added `flax.jax_utils.ad_shard_unpad()` by lucasb-eyer
- Implemented [default dtype FLIP](https://github.com/google/flax/blob/main/docs/flip/1777-default-dtype.md).
This means the default dtype is now inferred from inputs and params rather than being hard-coded to float32.
This is especially useful for dealing with complex numbers because the standard Modules will no longer truncate
complex numbers to their real component by default. Instead the complex dtype is preserved by default.


Bug fixes:
- Fix support for JAX's experimental_name_stack.

Breaking changes:
- In rare cases the dtype of a layer can change due to [default dtype FLIP](https://github.com/google/flax/blob/main/docs/flip/1777-default-dtype.md). See the "Backward compatibility" section of the proposal for more information.

Page 4 of 7

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.