-----
New features:
- Added `nn.switch` as a lifted version of `jax.lax.switch`.
- Added a method for detecting the use of "init" functions.
- Added checkpointing support for `jax.experimental.GlobalDeviceArray`, a useful array type for multiprocess/multihost computing.
- Added async option to `save_checkpoints()` on single-process scenario.
- Improved documentation pages.
Bug fixes:
- Fixed variable aliasing in put_variable
- Fixed missing passthrough of nn.scan unroll arg
- Fixed the MNIST example