eformer.ops.quantization.implicit_array_rsr#

Randomized Sparse Representation (RSR) Operators for Binary and Ternary Matrices.

This module provides efficient implicit array representations for binary {0, 1} and ternary {-1, 0, 1} matrices using the one-hot encoding RSR method. This enables fast vector-matrix multiplication without materializing the full dense matrix.

Key Classes:
  • RSROperatorBinary: Implicit representation for binary matrices

  • RSROperatorTernary: Implicit representation for ternary matrices

The RSR method works by:
  1. Grouping matrix columns into blocks of size k

  2. Converting each block row to an integer index (0 to 2^k - 1)

  3. Creating one-hot encoded lookup tables for each block

  4. Using matrix multiplication with lookup tables for efficient computation

This approach reduces memory and computation for sparse binary/ternary matrices commonly found in quantized neural networks.

Example

>>> import jax.numpy as jnp
>>> from eformer.ops.quantization import RSROperatorBinary, RSROperatorTernary
>>>
>>> # Binary matrix
>>> A = jnp.array([[0, 1, 1, 0], [1, 0, 0, 1]], dtype=jnp.int32)
>>> rsr = RSROperatorBinary.from_matrix(A, k=4)
>>> v = jnp.array([1.0, 2.0])
>>> result = rsr.dot(v)  # Efficient v @ A
>>>
>>> # Ternary matrix
>>> B = jnp.array([[-1, 0, 1], [1, -1, 0]], dtype=jnp.int32)
>>> rsr_ternary = RSROperatorTernary.from_matrix(B, k=4)
class eformer.ops.quantization.implicit_array_rsr.RSROperatorBinary(one_hot_maps: Array, k: int, padding: int, org_dtype: dtype, *, shape: tp.Sequence[int] | None = None, dtype: jnp.dtype = None)[source]#

Bases: ImplicitArray

Implicit Array for binary matrices using the one-hot RSR method.

This class provides an efficient implicit representation of binary matrices (containing only 0s and 1s) using Randomized Sparse Representation (RSR). The matrix is never stored in dense form; instead, it uses one-hot encoded lookup tables for efficient vector-matrix multiplication.

one_hot_maps#

Preprocessed one-hot encoded block representations. Shape: (num_blocks, n, 2**k).

Type

jax.Array

k#

Block size for column grouping. Larger k uses more memory but may be faster for some matrices.

Type

int

padding#

Number of zero columns added during preprocessing.

Type

int

org_dtype#

Original dtype of the source matrix.

Type

jnp.dtype

Example

>>> import jax.numpy as jnp
>>> A = jnp.array([[0, 1, 1, 0], [1, 0, 0, 1]], dtype=jnp.int32)
>>> rsr = RSROperatorBinary.from_matrix(A, k=4)
>>> v = jnp.array([1.0, 2.0])
>>> result = rsr.dot(v)  # Computes v @ A efficiently
commute_ops: tp.ClassVar[bool] = True#
dot(v: Array) Array[source]#

Compute vector-matrix product v @ A without materializing A.

Parameters

v (jax.Array) – Input vector of shape (n,) where n is the number of rows in the original matrix.

Returns

Result vector of shape (m,) where m is the number

of columns in the original matrix.

Return type

jax.Array

Raises

ValueError – If input vector shape doesn’t match matrix dimensions.

classmethod from_matrix(A: ~jax.jaxlib._jax.Array, k: int = 8, org_dtype: ~numpy.dtype = <class 'jax.numpy.float32'>)[source]#

Create an RSROperatorBinary from a binary integer matrix.

Parameters
  • A (jax.Array) – Binary matrix with values in {0, 1}. Must have dtype int32 or int8.

  • k (int) – Block size for column grouping. Defaults to 8. The matrix columns are processed in groups of k.

  • org_dtype (jnp.dtype) – Original dtype to preserve for materialization. Defaults to float32.

Returns

The RSR-encoded representation.

Return type

RSROperatorBinary

Raises

TypeError – If input matrix is not integer type.

k: int#
materialize() Array[source]#

Reconstruct the original dense binary matrix.

This method reverses the RSR preprocessing to produce the original dense matrix. Useful for debugging, verification, or when the full matrix is needed for operations not supported by RSR.

Returns

Dense binary matrix of shape (n, m) with the stored dtype.

Return type

jax.Array

one_hot_maps: Array#
org_dtype: dtype#
padding: int#
classmethod quantize(A: Array, k: int = 8)[source]#

Create an RSROperatorBinary from an array by packing to 1-bit first.

Parameters
  • A (jax.Array) – Input array to quantize.

  • k (int) – Block size for RSR encoding. Defaults to 8.

Returns

The RSR-encoded binary matrix.

Return type

RSROperatorBinary

warn_on_materialize: tp.ClassVar[bool] = True#
class eformer.ops.quantization.implicit_array_rsr.RSROperatorTernary(rsr_b1: RSROperatorBinary, rsr_b2: RSROperatorBinary, *, shape: tp.Sequence[int] | None = None, dtype: jnp.dtype = None)[source]#

Bases: ImplicitArray

Implicit Array for ternary matrices using decomposed binary RSR operators.

This class provides an efficient implicit representation of ternary matrices (containing values {-1, 0, 1}) by decomposing them into two binary matrices: one for positive values (+1) and one for negative values (-1). The ternary dot product is computed as: v @ A = v @ B1 - v @ B2 where B1 marks +1 positions and B2 marks -1 positions.

rsr_b1#

RSR representation of the positive mask (A == 1).

Type

RSROperatorBinary

rsr_b2#

RSR representation of the negative mask (A == -1).

Type

RSROperatorBinary

Example

>>> import jax.numpy as jnp
>>> A = jnp.array([[-1, 0, 1], [1, -1, 0]], dtype=jnp.int32)
>>> rsr = RSROperatorTernary.from_matrix(A, k=4)
>>> v = jnp.array([1.0, 2.0])
>>> result = rsr.dot(v)  # Computes v @ A efficiently

Note

Memory usage is approximately 2x that of RSROperatorBinary since two binary RSR operators are stored internally.

commute_ops: tp.ClassVar[bool] = True#
dot(v: Array) Array[source]#

Compute vector-matrix product v @ A without materializing A.

Computes the ternary dot product as the difference of two binary dot products: v @ B1 - v @ B2.

Parameters

v (jax.Array) – Input vector of shape (n,) where n is the number of rows in the original matrix.

Returns

Result vector of shape (m,) where m is the number

of columns in the original matrix.

Return type

jax.Array

classmethod from_matrix(A: Array, k: int = 8)[source]#

Create an RSROperatorTernary from a ternary integer matrix.

Parameters
  • A (jax.Array) – Ternary matrix with values in {-1, 0, 1}. Must have dtype int32 or int8.

  • k (int) – Block size for column grouping in the underlying binary RSR operators. Defaults to 8.

Returns

The decomposed RSR representation.

Return type

RSROperatorTernary

Raises

TypeError – If input matrix is not integer type.

materialize() Array[source]#

Reconstruct the original dense ternary matrix.

Materializes both binary components and computes their difference to reconstruct the original ternary matrix.

Returns

Dense ternary matrix of shape (n, m) with values in {-1, 0, 1}.

Return type

jax.Array

rsr_b1: RSROperatorBinary#
rsr_b2: RSROperatorBinary#
warn_on_materialize: tp.ClassVar[bool] = True#