Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 94 additions & 24 deletions inference/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,38 @@
from triton import Config


def convert_scale_to_ue8m0(scale_fp32: torch.Tensor) -> torch.Tensor:
"""
Converts a float32 scale tensor to UE8M0 format (uint8 exponent).

Args:
scale_fp32 (torch.Tensor): Scale tensor in float32 format.

Returns:
torch.Tensor: Scale tensor in uint8 format (exponent + 127).
"""
scale_fp32_clamped = torch.clamp(scale_fp32, min=1e-38)
exponent = torch.ceil(torch.log2(scale_fp32_clamped))
exponent_biased = (exponent + 127).to(torch.int32)
exponent_biased = torch.clamp(exponent_biased, 0, 255)
return exponent_biased.to(torch.uint8)


def convert_scale_from_ue8m0(scale_uint8: torch.Tensor) -> torch.Tensor:
"""
Converts a UE8M0 format scale tensor (uint8 exponent) to float32.

Args:
scale_uint8 (torch.Tensor): Scale tensor in uint8 format (exponent + 127).

Returns:
torch.Tensor: Scale tensor in float32 format (2^exponent).
"""
exponent = scale_uint8.to(torch.int32) - 127
scale_fp32 = torch.pow(2.0, exponent.to(torch.float32))
return scale_fp32


@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: tl.constexpr):
"""
Expand All @@ -23,16 +55,19 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr, scale_fmt: t
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
amax = tl.max(tl.abs(x)) # reduction
amax = tl.maximum(amax, 1e-4) # clamp to 1e-4
amax = tl.max(tl.abs(x))
amax = tl.maximum(amax, 1e-4)
s = amax / 448.
if scale_fmt == "ue8m0":
exp = tl.math.ceil(tl.math.log2(s))
s = tl.math.exp2(exp)
exp_int = exp.to(tl.int32) + 127
tl.store(s_ptr + pid, exp_int.to(s_ptr.dtype.element_ty))
else:
tl.store(s_ptr + pid, s)
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)


def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -51,14 +86,15 @@ def act_quant(x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] =
assert x.is_contiguous(), 'Input tensor must be contiguous'
assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
scale_dtype = torch.uint8 if scale_fmt == "ue8m0" else torch.float32
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=scale_dtype)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size, scale_fmt=scale_fmt)
return y, s


@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr, scale_fmt: tl.constexpr):
"""
Dequantizes weights using the provided scaling factors and stores the result.

Expand All @@ -81,32 +117,39 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
s = tl.load(s_ptr + pid_m * n + pid_n)
if scale_fmt == "ue8m0":
s_uint8 = tl.load(s_ptr + pid_m * n + pid_n)
exp = s_uint8.to(tl.int32) - 127
s = tl.math.exp2(exp.to(tl.float32))
else:
s = tl.load(s_ptr + pid_m * n + pid_n)
y = x * s
tl.store(y_ptr + offs, y, mask=mask)


def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor.
Dequantizes the input tensor `x` using the provided scaling factors `s`.

Args:
x (torch.Tensor): The quantized weight tensor of shape (M, N).
s (torch.Tensor): The scale tensor of shape (M//block_size, N//block_size).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
x (torch.Tensor): The quantized input tensor.
s (torch.Tensor): The scaling factors.
block_size (int, optional): The size of the blocks to be used for dequantization. Default is 128.
scale_fmt (Optional[str], optional): The format of the scale. Default is None.

Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `x`.

Raises:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
torch.Tensor: The dequantized tensor.
"""
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
if scale_fmt is None:
scale_fmt = "ue8m0" if s.dtype == torch.uint8 else None
if scale_fmt == "ue8m0" and s.dtype != torch.uint8:
s = convert_scale_to_ue8m0(s)
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size, scale_fmt=scale_fmt)
return y


Expand All @@ -122,22 +165,26 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
M, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr):
BLOCK_SIZE_K: tl.constexpr,
scale_fmt: tl.constexpr):
"""
Performs a matrix multiplication operation on FP8 matrices with scaling factors.

Args:
a_ptr (tl.tensor): Pointer to the first input matrix A.
b_ptr (tl.tensor): Pointer to the second input matrix B.
a_ptr (tl.tensor): Pointer to the first input matrix A (FP8).
b_ptr (tl.tensor): Pointer to the second input matrix B (FP8).
c_ptr (tl.tensor): Pointer to the output matrix C.
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
(uint8 if scale_fmt=="ue8m0", float32 otherwise)
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
(uint8 if scale_fmt=="ue8m0", float32 otherwise)
M (int): Number of rows in matrix A and C.
N (tl.constexpr): Number of columns in matrix B and C.
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
scale_fmt (tl.constexpr): Scale format ("ue8m0" for uint8 exponent format, None for float32).

Returns:
None
Expand All @@ -157,9 +204,20 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
for i in range(k):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
a_s = tl.load(a_s_ptrs)
b_s = tl.load(b_s_ptrs)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]

if scale_fmt == "ue8m0":
a_s_uint8 = tl.load(a_s_ptrs)
b_s_uint8 = tl.load(b_s_ptrs)
a_exp = a_s_uint8.to(tl.int32) - 127
b_exp = b_s_uint8.to(tl.int32) - 127
combined_exp = a_exp[:, None] + b_exp[None, :]
scale = tl.math.exp2(combined_exp.to(tl.float32))
accumulator += tl.dot(a, b) * scale
else:
a_s = tl.load(a_s_ptrs)
b_s = tl.load(b_s_ptrs)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]

a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
a_s_ptrs += 1
Expand All @@ -172,7 +230,7 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
tl.store(c_ptrs, c, mask=mask)


def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor, scale_fmt: Optional[str] = None):
"""
Perform a matrix multiplication using FP8 precision.

Expand All @@ -181,16 +239,28 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
b (torch.Tensor): The second input matrix, must be contiguous.
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
scale_fmt (Optional[str], optional): The format of the scale. Default is None.

Returns:
torch.Tensor: The result of the matrix multiplication.
"""
assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
if scale_fmt is None:
if a_s.dtype == torch.uint8 or b_s.dtype == torch.uint8:
scale_fmt = "ue8m0"
else:
scale_fmt = None
if scale_fmt == "ue8m0":
if a_s.dtype != torch.uint8:
a_s = convert_scale_to_ue8m0(a_s)
if b_s.dtype != torch.uint8:
b_s = convert_scale_to_ue8m0(b_s)

K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K, scale_fmt=scale_fmt)
return c
21 changes: 6 additions & 15 deletions inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,33 +131,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, scale_fmt: Optional[str] = None) -> torch.Tensor:
"""
Applies a linear transformation to the incoming data: y = xA^T + b.
This function supports specialized implementations based on quantization
and tensor formats.

Args:
x (torch.Tensor): The input tensor.
weight (torch.Tensor): The weight tensor. It may be quantized and
requires dequantization for certain cases.
weight (torch.Tensor): The weight tensor.
bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
scale_fmt (Optional[str], optional): The format of the scale. Default is None.

Returns:
torch.Tensor: The result of the linear transformation, which may involve
quantization-aware computations depending on the input parameters.

Notes:
- If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
is used for computation.
- If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
torch.Tensor: The result of the linear transformation.
"""
if weight.element_size() > 1:
return F.linear(x, weight, bias)
elif gemm_impl == "bf16":
weight = weight_dequant(weight, weight.scale)
weight = weight_dequant(weight, weight.scale, scale_fmt=scale_fmt)
return F.linear(x, weight, bias)
else:
x, scale = act_quant(x, block_size, scale_fmt)
y = fp8_gemm(x, scale, weight, weight.scale)
y = fp8_gemm(x, scale, weight, weight.scale, scale_fmt)
if bias is not None:
y += bias
return y
Expand Down Expand Up @@ -478,7 +469,7 @@ def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask
self.v_cache[:bsz, start_pos:end_pos] = v
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
else:
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size, Linear.scale_fmt)
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
Expand Down