Zig library for streaming dequantized MXFP4 tensors from Huggingface's safetensors file format.
- On initialization, the module parses the provided safetensors file's header and identifies the MXFP4 tensors,
- The dequantization is done on the fly in streaming fashion,
- The dequantization uses SIMD instructions,
- The output provides a
readerfor each MXFP4 tensor, following the modernstd.Io.Readerinterface from Zig0.15.1.
See example.zig for a basic usage example with a provided test file.
tensorReaders.zig: Main entry point, initializes the module and provides the full set of tensor readers for a given safetensors file,tensorReader.zig: Streaming MXFP4 tensor reader implementation,dequantization.zig: Core MXFP4 dequantization logic with SIMD instructions,safetensors.zig: SafeTensors file format parser,mxfp4Config.zig: MXFP4 tensor configuration extraction.
MXFP4 (Microscaling FP4) is a floating point format that uses 4.25 bits to encode tensor values.
The format consists of:
- blocks of 32 FP4 values,
- U8 scale factors that are shared by all values in a given block.
The bit layout is the following:
- S1E2M1 for the block values,
- S0E8M0 for the scale values.
Official OCP MXFP4 specification
Safetensors is a Huggingface format that serializes tensor values in the following way:
- First 8 bytes: size of the header in u64
- Header: metadata of all tensors with value offsets
- Rest of the file: raw tensor values
- How to profile this system for performance?
- Is there a more efficient way to load the fp4 values into the SIMD vectors, i.e. with some vectorized table lookup? Right now this is done in a for loop
- Are there memory-handling subtleties that can be improved?
- There is an important distinction made in Andrew Kelley's talks of above vs below the vtable for reader interface implementations, where running logic above the vtable allows for more powerful compilation. In this reader implementation, the main logic is below the vtable in the stream method. What would above-vtable logic look like? Should it be specific methods adapted to this use-case like "streamNextBlocks(num_blocks)" or "streamFullTensor"?