-----
Breaking changes:
- Migrating Flax from returning FrozenDicts to returning regular dicts. More details can be found in this [announcement](https://github.com/google/flax/discussions/3191)
New features:
- Use pyink
- added dict migration guide to index
- add scan over layers section
- Expose options to customize rich.Table
- add support for initializing carry variables in scan
- Let Flax-Orbax to not port the shape of `target` arrays when they port the `target` shardings.
Bug fixes:
- Use import `orbax.checkpoint` which is a better import pattern.
- Use import `orbax.checkpoint as ocp` to avoid the verbosity of using 'orbax.checkpoint` every time.
- [linen] Add alternative, more numerically stable, variance calculation to `LayerNorm`.
- [linen] Minor cleanup to normalization code.
- Fix norm calculation bug for 0-rank arrays.
- [JAX] Remove references to jax.config.jax_array.
- [linen] Use `stack` instead of `concatenate` in `compute_stats`, to handle scalar stats case.
- [linen] More minor cleanup in normalization `compute_stats`.
- Fix warnings from atari gym.
- Refactor TypeHandler to operate over batches of values, rather than individual ones. This allows more flexibility for implementations that may operate more efficiently on batches.
- Fix carry slice logic
- make flax_basics guide use utility fns
- Fix checkpointing guide error at head
- Improve scan docs