eformer.ops.quantization.implicit_array_rsr#
Randomized Sparse Representation (RSR) Operators for Binary and Ternary Matrices.
This module provides efficient implicit array representations for binary {0, 1} and ternary {-1, 0, 1} matrices using the one-hot encoding RSR method. This enables fast vector-matrix multiplication without materializing the full dense matrix.
- Key Classes:
RSROperatorBinary: Implicit representation for binary matrices
RSROperatorTernary: Implicit representation for ternary matrices
- The RSR method works by:
Grouping matrix columns into blocks of size k
Converting each block row to an integer index (0 to 2^k - 1)
Creating one-hot encoded lookup tables for each block
Using matrix multiplication with lookup tables for efficient computation
This approach reduces memory and computation for sparse binary/ternary matrices commonly found in quantized neural networks.
Example
>>> import jax.numpy as jnp
>>> from eformer.ops.quantization import RSROperatorBinary, RSROperatorTernary
>>>
>>> # Binary matrix
>>> A = jnp.array([[0, 1, 1, 0], [1, 0, 0, 1]], dtype=jnp.int32)
>>> rsr = RSROperatorBinary.from_matrix(A, k=4)
>>> v = jnp.array([1.0, 2.0])
>>> result = rsr.dot(v) # Efficient v @ A
>>>
>>> # Ternary matrix
>>> B = jnp.array([[-1, 0, 1], [1, -1, 0]], dtype=jnp.int32)
>>> rsr_ternary = RSROperatorTernary.from_matrix(B, k=4)
- class eformer.ops.quantization.implicit_array_rsr.RSROperatorBinary(one_hot_maps: Array, k: int, padding: int, org_dtype: dtype, *, shape: tp.Sequence[int] | None = None, dtype: jnp.dtype = None)[source]#
Bases:
ImplicitArrayImplicit Array for binary matrices using the one-hot RSR method.
This class provides an efficient implicit representation of binary matrices (containing only 0s and 1s) using Randomized Sparse Representation (RSR). The matrix is never stored in dense form; instead, it uses one-hot encoded lookup tables for efficient vector-matrix multiplication.
- one_hot_maps#
Preprocessed one-hot encoded block representations. Shape: (num_blocks, n, 2**k).
- Type
- k#
Block size for column grouping. Larger k uses more memory but may be faster for some matrices.
- Type
int
- padding#
Number of zero columns added during preprocessing.
- Type
int
- org_dtype#
Original dtype of the source matrix.
- Type
jnp.dtype
Example
>>> import jax.numpy as jnp >>> A = jnp.array([[0, 1, 1, 0], [1, 0, 0, 1]], dtype=jnp.int32) >>> rsr = RSROperatorBinary.from_matrix(A, k=4) >>> v = jnp.array([1.0, 2.0]) >>> result = rsr.dot(v) # Computes v @ A efficiently
- commute_ops: tp.ClassVar[bool] = True#
- classmethod from_matrix(A: ~jax.jaxlib._jax.Array, k: int = 8, org_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)[source]#
Create an RSROperatorBinary from a binary integer matrix.
- Parameters
A (jax.Array) – Binary matrix with values in {0, 1}. Must have dtype int32 or int8.
k (int) – Block size for column grouping. Defaults to 8. The matrix columns are processed in groups of k.
org_dtype (jnp.dtype) – Original dtype to preserve for materialization. Defaults to float32.
- Returns
The RSR-encoded representation.
- Return type
- Raises
TypeError – If input matrix is not integer type.
- k: int#
- materialize() Array[source]#
Reconstruct the original dense binary matrix.
This method reverses the RSR preprocessing to produce the original dense matrix. Useful for debugging, verification, or when the full matrix is needed for operations not supported by RSR.
- Returns
Dense binary matrix of shape (n, m) with the stored dtype.
- Return type
- one_hot_maps: Array#
- padding: int#
- classmethod quantize(A: Array, k: int = 8)[source]#
Create an RSROperatorBinary from an array by packing to 1-bit first.
- Parameters
A (jax.Array) – Input array to quantize.
k (int) – Block size for RSR encoding. Defaults to 8.
- Returns
The RSR-encoded binary matrix.
- Return type
- warn_on_materialize: tp.ClassVar[bool] = True#
- class eformer.ops.quantization.implicit_array_rsr.RSROperatorTernary(rsr_b1: RSROperatorBinary, rsr_b2: RSROperatorBinary, *, shape: tp.Sequence[int] | None = None, dtype: jnp.dtype = None)[source]#
Bases:
ImplicitArrayImplicit Array for ternary matrices using decomposed binary RSR operators.
This class provides an efficient implicit representation of ternary matrices (containing values {-1, 0, 1}) by decomposing them into two binary matrices: one for positive values (+1) and one for negative values (-1). The ternary dot product is computed as: v @ A = v @ B1 - v @ B2 where B1 marks +1 positions and B2 marks -1 positions.
- rsr_b1#
RSR representation of the positive mask (A == 1).
- Type
- rsr_b2#
RSR representation of the negative mask (A == -1).
- Type
Example
>>> import jax.numpy as jnp >>> A = jnp.array([[-1, 0, 1], [1, -1, 0]], dtype=jnp.int32) >>> rsr = RSROperatorTernary.from_matrix(A, k=4) >>> v = jnp.array([1.0, 2.0]) >>> result = rsr.dot(v) # Computes v @ A efficiently
Note
Memory usage is approximately 2x that of RSROperatorBinary since two binary RSR operators are stored internally.
- commute_ops: tp.ClassVar[bool] = True#
- dot(v: Array) Array[source]#
Compute vector-matrix product v @ A without materializing A.
Computes the ternary dot product as the difference of two binary dot products: v @ B1 - v @ B2.
- classmethod from_matrix(A: Array, k: int = 8)[source]#
Create an RSROperatorTernary from a ternary integer matrix.
- Parameters
A (jax.Array) – Ternary matrix with values in {-1, 0, 1}. Must have dtype int32 or int8.
k (int) – Block size for column grouping in the underlying binary RSR operators. Defaults to 8.
- Returns
The decomposed RSR representation.
- Return type
- Raises
TypeError – If input matrix is not integer type.
- materialize() Array[source]#
Reconstruct the original dense ternary matrix.
Materializes both binary components and computes their difference to reconstruct the original ternary matrix.
- Returns
Dense ternary matrix of shape (n, m) with values in {-1, 0, 1}.
- Return type
- rsr_b1: RSROperatorBinary#
- rsr_b2: RSROperatorBinary#
- warn_on_materialize: tp.ClassVar[bool] = True#