(no title)
mobicham | 1 year ago
All the linear-quantization methods have meta-data, including the 1.58bit paper. You can control the quality vs. memory usage by reducing the group-size. However, the meta-data is the not the same thing as the quantized weights for many reasons:
> The meta-data size doesn't change the fact that you can do binary/ternary matmul, which the most important thing in this story.
> The meta-data size doesn't increase the actual compute: these are point-wise operations and even if you have 1 scalar you still need to multiply the same amount of weights.
> Meta-data is offloaded to the CPU with pinned-memory, which allows non-blocking transfers. Technically, you can trigger the copy in the layer before and synchronize and will make it almost seamless. I did some experiments with cuda streams that worked very well on an older machine, but then I tried a better machine and the transfer was much faster. Obviously if you are trying it on Google colab it's very slow for this reason.
> Smaller models like Llama2-7B are very hard to directly quantize at very low bits, so they need a lower group-size to function well. Larger models (like what we showed for Mixtral), can be quantized to 2-bit on the fly, without any data, and still work very well. So basically larger models are less sensitive to extreme quantization and you could use a much larger group-size. I still think that the meta-data size is really not a big deal for the reasons I have explained above.
> There are many other ways to increase the group-size or even get rid of it all together, many ideas available but needs lots of experimentation.
> Binary/ternary CUDA matmul kernels don't exist yet. The current code is implementing the dequantization step in CUDA but then uses torch.matmul as fp16. I tried doing matmul at low-bits with CUDA but it is very difficult to even beat cuBLAS with fp16, especially for a novice CUDA coder like me :)
Please note: this is early experimental work. Since it showed promising results, we wanted to share it with the community first as we progress. There's still a lot of things to be done and we are actively working on it, despite the very limited resources we have.
Happy to answer any questions here!
vladf|1 year ago
1 Could you post the full memory use of the methods? E.g. you include quip metadata in its GB but not hqq metadata in its GB.
2 If you have to go to cpu to shift and scale, how did you get latency lower than pure on device? Was this bsz1? No speculative decoding?
3 how can lora absorb shifts with only increasing rank by 1 if you have a shift per group?
mobicham|1 year ago
1- The answer is still ~1.7GB. You only need meta-data of a single nn.Linear at a time. There are 32x(4+3) = 224 layers quantized, so you need an additional (3GB - 1.7GB)/224 = 1.3GB/224 ~ 5.8MB, which is negligible.
2- As the table says, that's the forward pass. (Batch-size=1, context-size=1024). Forward pass means there's no caching and no decoding logic. The actual model generation speed should be much faster with caching + decoding logics like speculative decoding, and using VLLM instead of HF. And even with all of that, a much larger model like Mixtral with the same group-size of 8 offloaded to the CPU works quite well on a 4090.
You mean why it's faster than Quip# despite being all on-device? Because dequantization with HQQ is a simple linear operation. It's not even using a fused kernel, only the dequantization part is done on CUDA.
3- LoRA absorbs -zero x shift, not the shift, the shift is still there, including the BitNet/1.58 work. As the paragraph explains, the math is ignoring reshaping to make the math simple and easy to read.
Let's say you have a matrix of 4096x4096, with grouping done channel-wise (but no reshaping), the -zero x shift part is a rank-1 matrix (4096x1 .dot 1x4096), the lora data will be (4096xr .dot rx4096), you can merge them exactly into (4096x(r+1) .dot (r+1)x4096).
The point of that paragraph is to show two things: - Compared to the BitNet formulation, the additional zero-point (which is necessary to get good quantization results on pre-trained models with minimum calibration), has a negligible overhead. -More importantly, it explains how we even got the idea of adding low-rank adapters: it's not because LoRA is popular, it's because the zero-point alone results in a rank-1 matrix error which is not enough to express the quantization error. As the rank tends to min(num_rows, num_cols), the error goes down. So if we increase the rank by r via low-rank adapters, we would expect better results.
Now, if we include the reshape with a lower-group size than the num_rows, the -zero x shift part is a rank-n matrix (4096xn dot nx4096), but it's not possible to properly estimate the rank n because that would highly depend on the nature of the weights matrix, but in the end, the LoRA part will be (4096x(n+r) .dot (n+r)x4096). We only use a lora rank of 8 for MLPs which are the larger matrices, so even if you double or even 4x to let's say n+r=32, it's still just 1/128=0.78% of the original matrix.
Merging -zero x scale with the low-rank adapters or not doesn't matter much, that would highly depending on which fused kernel implementation performs the best.
mikeravkine|1 year ago
It's getting tougher to use older, cheaper GPUs (Pascal/Maxwell) with modern quantization schemes so anything you can do to keep kernels compatible with SM52 and SM61 would be greatly appreciated.
mobicham|1 year ago