mesh-transformer-jax


  • A haiku library using the `xmap` operator in Jax for model parallelism of transformers.