Flax

Latest version: v0.8.2

Safety actively analyzes 619566 Python packages for vulnerabilities to keep your Python projects secure.

Scan your dependencies

Page 1 of 6

0.8.1

-----
- Added default collection in `make_rng`.
- Added `InstanceNorm` and renamed `channel_axes` to `feature_axes`.
- Added norm equivalence tests.
- Added `Module.module_paths` and doc.
- make `Sequential.__call__` compact.
- Added `nn.compact_name_scope` v3.
- Add explicit control over frozen/slots setting in `flax.struct.dataclass`.
- Replacing `jax.tree_util.tree_map` with mapping over leafs.
- Fixed docs and docstrings.

0.8.0

-----
- Added [NNX](https://github.com/google/flax/tree/main/flax/experimental/nnx#nnx), a neural network library for JAX that provides a simple yet powerful module system that adheres to standard Python semantics. Its aim is to combine the robustness of Linen with a simplified, Pythonic API akin to that of PyTorch.
- Added `nn.compact_name_scope` decorator that enables methods to act as compact name scopes as with regular Haiku methods. This makes porting Haiku code easier.
- Add copy() method to Module. This is a user-friendly version of the internal clone() method with better
defaults for common use cases.
- Added [`BatchApply`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#batchapply) class.
- Added `sow_weights` option in attention layer.
- Added [`MultiHeadAttention`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/_autosummary/flax.linen.MultiHeadAttention.html) alias.
- Added kwargs support for `nn.jit`.
- Deprecated `normalize` activation function, in favor of `standardize`.
- Added `GeGLU` activation function.
- Added `Enum` support for `tabulate` function.
- Added simple argument-only lifted `nn.grad` function.

0.7.5

-----
- Report forward and backward pass FLOPs of modules and submodules in `linen.Module.tabulate` and `summary.tabulate` (in new `flops` and `vjp_flops` table columns). Pass `compute_flops=True` and/or `compute_vjp_flops=True` to include these columns.
- Re-factored `MultiHeadDotProductAttention`'s call method signature, by adding
`inputs_k` and `inputs_v` args and switching `inputs_kv`, `mask` and `determistic`
to keyword arguments. See more details in [3389](https://github.com/google/flax/discussions/3389).
- Use new typed PRNG keys throughout flax: this essentially involved changing
uses of `jax.random.PRNGKey` to `jax.random.key`.
(See [JEP 9263](https://github.com/google/jax/pull/17297) for details).
If you notice dispatch performance regressions after this change, be sure
you update `jax` to version 0.4.16 or newer.
- Added `has_improved` field to EarlyStopping and changed the return signature of
`EarlyStopping.update` from returning a tuple to returning just the updated class.
See more details in [3385](https://github.com/google/flax/pull/3385)

0.7.4

-----
New features:
- Add QK-normalization to MultiHeadDotProductAttention
- Allow apply's method argument to accept submodules
- Add module path to nn.module.
- [JAX] Generate new type of PRNG keys

Bug fixes:
- Directly call original method if method interceptor stack is empty.
- fix stackoverflow when loading pickled module
- Improve kw_only_dataclass.
- Allow pass-through implementation of state dict
- Promote dot_general injections from a function to a module.

0.7.2

-----
New features:
- make `flax.core.copy` `add_or_replace` optional
- Add `use_fast_variance` option to `GroupNorm` and `BatchNorm` to allow disabling it.

Bug fixes:
- Use `field_specifiers` instead of `field_descriptors` in `dataclass_transform`.
- Fix `nn.Module` typing.
- [JAX] Replace uses of `jax.experimental.pjit.with_sharding_constraint` with `jax.lax.with_sharding_constraint`.

0.7.1

-----
Breaking changes:
- Migrating Flax from returning FrozenDicts to returning regular dicts. More details can be found in this [announcement](https://github.com/google/flax/discussions/3191)

New features:
- Use pyink
- added dict migration guide to index
- add scan over layers section
- Expose options to customize rich.Table
- add support for initializing carry variables in scan
- Let Flax-Orbax to not port the shape of `target` arrays when they port the `target` shardings.

Bug fixes:
- Use import `orbax.checkpoint` which is a better import pattern.
- Use import `orbax.checkpoint as ocp` to avoid the verbosity of using 'orbax.checkpoint` every time.
- [linen] Add alternative, more numerically stable, variance calculation to `LayerNorm`.
- [linen] Minor cleanup to normalization code.
- Fix norm calculation bug for 0-rank arrays.
- [JAX] Remove references to jax.config.jax_array.
- [linen] Use `stack` instead of `concatenate` in `compute_stats`, to handle scalar stats case.
- [linen] More minor cleanup in normalization `compute_stats`.
- Fix warnings from atari gym.
- Refactor TypeHandler to operate over batches of values, rather than individual ones. This allows more flexibility for implementations that may operate more efficiently on batches.
- Fix carry slice logic
- make flax_basics guide use utility fns
- Fix checkpointing guide error at head
- Improve scan docs

Page 1 of 6

© 2024 Safety CLI Cybersecurity Inc. All Rights Reserved.