(no title)
HMUNACHI | 1 year ago
- A wide array of blocks and layers, facilitating the creation of customised transformer models from scratch. - An extensive selection of models like Gemma, LlaMa, Mistral, Mixtral, GPT3, GPT4 (inferred), T5, Whisper, ViT, Mixers, GAT, CLIP, and more, catering to a variety of tasks and applications. - Data-parallel distributed trainers including RLHF so developers can efficiently train large-scale models on multiple GPUs or TPUs, without the need for manual training loops. - Dataloaders, making the data handling process for Jax/Flax more straightforward and effective. - Custom layers not found in Flax/Jax, such as RoPE, GQA, MQA, and SWin attention, allowing for more flexible model development. - GPU/TPU-accelerated classical ML models like PCA, KMeans, Regression, Gaussian Processes etc., akin to SciKit Learn on GPU. - Modular design so users can blend elements from various models, such as GPT, Mixtral, and LlaMa2, to craft unique hybrid transformer models. - True random number generators in Jax which do not need the verbose code. - A range of advanced algorithms for NLP and computer vision tasks, such as Gaussian Blur, BLEU, Tokenizer etc. - Each model is in a single file with no external dependencies, so the source code can also be easily used. - True random number generators in Jax which do not need the verbose code (examples shown in next sections). - There are experimental features (like MAMBA architecture and RLHF) in the repo which are not available via the package, pending tests.
I appreciate feedback if you have the time, a dev pre-release is available via pip, it is ideal for building models with no more than 1B params. Lots of improvements to make. If feeling generous, you can contribute or leave a star.
Thanks.
No comments yet.