From 573b025b7af299202355dfb23bc4e6893b44c853 Mon Sep 17 00:00:00 2001 From: vshevchenko Date: Sun, 23 Mar 2025 09:59:05 +0000 Subject: [PATCH 1/2] CE restricted & CCE loss --- replay/models/nn/optimizer_utils/__init__.py | 1 + .../optimizer_utils/fused_linear_ce_loss.py | 542 ++++++++++++++++++ .../nn/sequential/bert4rec/lightning.py | 36 +- replay/models/nn/sequential/bert4rec/model.py | 37 +- .../models/nn/sequential/sasrec/lightning.py | 231 +++++++- replay/models/nn/sequential/sasrec/model.py | 12 + replay_benchmarks/configs/model/bert4rec.yaml | 4 +- .../configs/model/bert4rec_beauty.yaml | 4 +- .../configs/model/bert4rec_megamarket.yaml | 4 +- .../configs/model/bert4rec_movielens_1m.yaml | 4 +- .../configs/model/bert4rec_movielens_20m.yaml | 4 +- .../configs/model/bert4rec_netflix.yaml | 4 +- .../configs/model/bert4rec_zvuk.yaml | 4 +- replay_benchmarks/configs/model/sasrec.yaml | 2 +- .../configs/model/sasrec_beauty.yaml | 4 +- .../configs/model/sasrec_games.yaml | 4 +- .../configs/model/sasrec_gowalla.yaml | 4 +- .../configs/model/sasrec_megamarket.yaml | 4 +- .../configs/model/sasrec_movielens_1m.yaml | 4 +- .../configs/model/sasrec_movielens_20m.yaml | 6 +- .../configs/model/sasrec_netflix.yaml | 4 +- .../configs/model/sasrec_sports.yaml | 4 +- .../configs/model/sasrec_yelp.yaml | 4 +- .../configs/model/sasrec_zvuk.yaml | 2 +- replay_benchmarks/train_runner.py | 23 + 25 files changed, 904 insertions(+), 48 deletions(-) create mode 100644 replay/models/nn/optimizer_utils/fused_linear_ce_loss.py diff --git a/replay/models/nn/optimizer_utils/__init__.py b/replay/models/nn/optimizer_utils/__init__.py index bc07124aa..20e38a8bc 100644 --- a/replay/models/nn/optimizer_utils/__init__.py +++ b/replay/models/nn/optimizer_utils/__init__.py @@ -2,3 +2,4 @@ if TORCH_AVAILABLE: from .optimizer_factory import FatLRSchedulerFactory, FatOptimizerFactory, LRSchedulerFactory, OptimizerFactory + from .fused_linear_ce_loss import LigerFusedLinearCrossEntropyFunction diff --git a/replay/models/nn/optimizer_utils/fused_linear_ce_loss.py b/replay/models/nn/optimizer_utils/fused_linear_ce_loss.py new file mode 100644 index 000000000..67c9cdc40 --- /dev/null +++ b/replay/models/nn/optimizer_utils/fused_linear_ce_loss.py @@ -0,0 +1,542 @@ +import torch +import triton +import triton.language as tl +try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh +except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh + + + + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning + + +@triton.jit +def liger_cross_entropy_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + weight_ptr, + loss_ptr, + z_loss_ptr, + loss_stride, + n_cols, + n_non_ignore, + sum_non_ignore_weight, + weight_sum, + ignore_index, + lse_square_scale: tl.constexpr, + label_smoothing: tl.constexpr, + reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time + softcap, + RETURN_Z_LOSS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_SOFTCAPPING: tl.constexpr, +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + weight_ptr: Pointer to weight tensor. + loss_ptr: Pointer to tensor to store the loss. + z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. + loss_stride (int): The stride of the loss tensor. + n_cols (int): The number of columns in the input tensor. + n_non_ignore (flaot): The number of non-ignored elements in the batch. + sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch. + weight_sum (float): The sum of weight tensor. + ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + reduction (str): The string for the reduction to apply + softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). + RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. + BLOCK_SIZE (int): The block size for Triton operations. + HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. + HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. + """ + + # https://github.com/triton-lang/triton/issues/1058 + # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64 + program_id = tl.program_id(0).to(tl.int64) + + # 1. Load Y_ptr first because if the target is ignore_index, we can return right away + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + # 2. locate the start index + X_ptr += program_id * X_stride + + if y == ignore_index: + # set all X_ptr as 0 + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) + return + + loss_ptr += program_id * loss_stride + if RETURN_Z_LOSS: + z_loss_ptr += program_id * loss_stride + + if HAS_WEIGHT: + weight_y = tl.load(weight_ptr + y).cast(tl.float32) + + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) + # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 + + # 3. [Online softmax] first pass: find max + sum + m = float("-inf") # m is the max value. use the notation from the paper + d = 0.0 # d is the sum. use the notation from the paper + ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation + if HAS_SOFTCAPPING: + ori_X_y = softcap * tanh(ori_X_y / softcap) + + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 + scaled_x_sum = 0.0 + eps = label_smoothing / n_cols + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + X_block = softcap * tanh(X_block / softcap) + block_max = tl.max(X_block) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + if HAS_WEIGHT: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0)) + else: + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X))))) + # = log (e^(max(X)) * sum(e ^ (X_i - max(X)))) + # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d + lse = m + tl.log(d) + + # 4. [Online Softmax] Second pass: compute gradients + # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # For label smoothing: + # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + # With Z loss: + # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y + # dx_y = dx_i - (1 - label_smoothing) / N + # For 'sum' reduction, no normalization is applied: + # dx_y = softmax(x_y) - 1 + # dx_i = softmax(x_i), for i ≠ y + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + intermediate = tanh(X_block / softcap) + X_block = softcap * intermediate + + if not HAS_WEIGHT: + # softmax(x_i) + X_block = tl.exp(X_block - m) / d + # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) + X_block += 2 * lse_square_scale * lse * X_block + # smoothing term + X_block += -eps + # special handle dx_y + X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) + # reduction scale + if reduction == "mean": + X_block = X_block / n_non_ignore + else: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + softmax_X = tl.exp(X_block - m) / d + # derivative of original_loss + dloss_ori = (1 - label_smoothing) * softmax_X + # specially handle dx_y + dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)) + dloss_ori = dloss_ori * weight_y + # derivative of smooth_loss + dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) + # derivative of z-loss + dz_loss = 2 * lse_square_scale * lse * softmax_X + # reduction scale + if reduction == "mean": + dloss_ori = dloss_ori / sum_non_ignore_weight + dloss_smooth = dloss_smooth / sum_non_ignore_weight + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + dz_loss = dz_loss / n_non_ignore + # derivative of total_loss + X_block = dloss_ori + dloss_smooth + dz_loss + + # chain rule softcapping + # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) + if HAS_SOFTCAPPING: + X_block = X_block * (1 - intermediate * intermediate) + + tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in + # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + # = X_y - m - log d = X_y - lse + # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 + # So we can safely calculate log (softmax(X_y)) without overflow + loss = lse - ori_X_y + if HAS_WEIGHT: + loss = weight_y * loss + + # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 + if label_smoothing > 0: + if HAS_WEIGHT: + smooth_loss = scaled_x_sum + eps * lse * weight_sum + else: + smooth_loss = scaled_x_sum + label_smoothing * lse + loss = loss * (1 - label_smoothing) + smooth_loss + + # An auxiliary loss, z_loss + # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html + z_loss = lse_square_scale * lse * lse + # Normalize the loss by the number of non-ignored elements if reduction is "mean" + if reduction == "mean": + if HAS_WEIGHT: + loss = loss / sum_non_ignore_weight + else: + loss = loss / n_non_ignore + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + z_loss = z_loss / n_non_ignore + loss += z_loss + + tl.store(loss_ptr, loss) + if RETURN_Z_LOSS: + tl.store(z_loss_ptr, z_loss) + + +def fused_linear_cross_entropy_forward( + _input, + weight, + target, + ce_weight=None, + bias=None, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + softcap=None, + return_z_loss=False, + triton_backend=True +): + assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + device = _input.device + + # inputs have shape: BT x H + # materialized activations will have shape: BT x V + # the increase in memory = BT x V + # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + BT, H = _input.shape + V = weight.shape[0] + # BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + # inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + # chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor + # num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + + # inc_factor = (V + H - 1) // H + # chunk_size = (BT + inc_factor - 1) // inc_factor + # num_chunks = (BT + chunk_size - 1) // chunk_size + + chunk_size = 1024 + if triton_backend: + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + num_chunks = triton.cdiv(BT, chunk_size) + else: + num_chunks = (BT + chunk_size - 1) // chunk_size + + grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None + grad_input = torch.zeros_like(_input, device=device) + grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None + # we use fp32 for loss accumulator + loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) + z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None + + # TODO: evaluate how CUDA synchronization caused by .item() affects the speed + target_mask = target != ignore_index + total_n_non_ignore = target_mask.sum().item() + total_sum_non_ignore_ce_weight = total_n_non_ignore + ce_weight_sum = 0.0 + if ce_weight is not None: + assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}" + assert torch.is_floating_point(ce_weight), ( + f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}" + ) + total_sum_non_ignore_ce_weight = ( + torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item() + ) + ce_weight_sum = ce_weight.sum().item() + if ce_weight.stride(-1) != 1: + ce_weight = ce_weight.contiguous() + + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + _input_chunk = _input[start_idx:end_idx] # chunk_size x H + + # when doing matmul, use the original precision + logits_chunk = _input_chunk @ weight.t() # chunk_size x V + if bias is not None: + logits_chunk = logits_chunk + bias + + target_chunk = target[start_idx:end_idx] # chunk_size, + + n_rows = logits_chunk.shape[0] + + + + # ensure _input and target are contiguous + logits_chunk = logits_chunk.contiguous() + target_chunk = target_chunk.contiguous() + + if triton_backend: + # unreduced loss + loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, + z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None + + # Here we calculate the gradient of logits_chunk in place so we can save memory. + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=logits_chunk, + X_stride=logits_chunk.stride(-2), + Y_ptr=target_chunk, + Y_stride=target_chunk.stride(-1), # always 1 + weight_ptr=ce_weight, + loss_ptr=loss_1d_slice, + z_loss_ptr=z_loss_1d_slice, + loss_stride=loss_1d_slice.stride(-1), # always 1 + n_cols=V, + n_non_ignore=total_n_non_ignore, + sum_non_ignore_weight=total_sum_non_ignore_ce_weight, + weight_sum=ce_weight_sum, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + RETURN_Z_LOSS=return_z_loss, + HAS_WEIGHT=True if ce_weight is not None else False, + HAS_SOFTCAPPING=True if softcap is not None else False, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + grad_logits_chunk = logits_chunk # chunk_size x V + else: + y_chunk = torch.nn.functional.softmax(logits_chunk, dim=1) + loss_1d_slice = -torch.log(y_chunk).gather(1, target_chunk.view(-1, 1)) + loss_1d_slice = loss_1d_slice.squeeze(1) + logits_chunk = y_chunk - torch.nn.functional.one_hot(target_chunk, num_classes=V) + logits_chunk = (logits_chunk * (chunk_size / BT)) + grad_logits_chunk = logits_chunk + + loss_1d[start_idx:end_idx] = loss_1d_slice + if return_z_loss: + z_loss_1d[start_idx:end_idx] = z_loss_1d_slice + + grad_input[start_idx:end_idx] = grad_logits_chunk @ weight + + if grad_weight is not None: + torch.addmm( + input=grad_weight, + mat1=logits_chunk.t().to( + _input_chunk.dtype + ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error. + mat2=_input_chunk, + out=grad_weight, + alpha=1.0, + beta=1.0, + ) + + if bias is not None: + torch.add( + input=grad_bias, + other=logits_chunk.sum(dim=0), + out=grad_bias, + alpha=1.0, + ) + + if reduction == "none": + loss = loss_1d + z_loss = z_loss_1d if return_z_loss else None + + else: + loss = torch.sum(loss_1d) if triton_backend else torch.mean(loss_1d) + z_loss = torch.sum(z_loss_1d) if return_z_loss else None + + return loss, z_loss, grad_input, grad_weight, grad_bias + +def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + # BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + # element_mul_kernel[(n_rows,)]( + # grad_input, + # grad_input.stride(-2), + # grad_output, + # H, + # BLOCK_SIZE=BLOCK_SIZE, + # num_warps=32 if not is_hip() else 16, + # ) + + # handle grad_weight + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V + + # element_mul_kernel[(n_rows,)]( + # grad_weight, + # grad_weight.stride(-2), + # grad_output, + # H, + # BLOCK_SIZE=BLOCK_SIZE, + # num_warps=32 if not is_hip() else 16, + # ) + + if grad_bias is not None: + V = grad_bias.shape[0] + n_rows = V + + # element_mul_kernel[(n_rows,)]( + # grad_bias, + # grad_bias.stride(-1), + # grad_output, + # 1, + # BLOCK_SIZE=BLOCK_SIZE, + # num_warps=32 if not is_hip() else 16, + # ) + return grad_input, grad_weight, grad_bias + + +class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + _input, + weight, + target, + bias=None, + ce_weight=None, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + softcap=None, + return_z_loss: bool = False, + ): + """ + Fusing the last linear layer with cross-entropy loss + Reference: https://github.com/mgmalek/efficient_cross_entropy + + Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding + the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can + compute the gradient at the forward pass. By doing so, we don't have to store the _input and target + for the backward pass. + + _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension. + target: (B*T) where each value is in [0, V-1] + weight: (V, H) where V is the number of classes + bias: (V) where V is the number of classes + ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype + ignore_index: the index to ignore in the target + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction: reduction to apply + """ + + loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( + _input=_input, + weight=weight, + target=target, + bias=bias, + ce_weight=ce_weight, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + return_z_loss=return_z_loss, + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + grad_input.detach(), + grad_weight.detach() if grad_weight is not None else None, + grad_bias.detach() if bias is not None else None, + ) + ctx.return_z_loss = return_z_loss + # return loss, z_loss + return loss + + @staticmethod + # def backward(ctx, grad_output, grad_output2): + def backward(ctx, grad_output): + if ctx.return_z_loss: + del grad_output2 # z_loss is only for logging + (grad_input, grad_weight, grad_bias) = ctx.saved_tensors + grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( + grad_output, grad_input, grad_weight, grad_bias + ) + return ( + grad_input, + grad_weight, + None, + grad_bias, + None, + None, + None, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/replay/models/nn/sequential/bert4rec/lightning.py b/replay/models/nn/sequential/bert4rec/lightning.py index 1823ec4b8..3715e4fb8 100644 --- a/replay/models/nn/sequential/bert4rec/lightning.py +++ b/replay/models/nn/sequential/bert4rec/lightning.py @@ -230,6 +230,8 @@ def _compute_loss(self, batch: Bert4RecTrainingBatch) -> torch.Tensor: loss_func = self._compute_loss_ce if self._loss_sample_count is None else self._compute_loss_ce_sampled elif self._loss_type == "SCE": loss_func = self._compute_loss_scalable_ce + elif self._loss_type == "CE_restricted": + loss_func = self._compute_loss_ce_restricted else: msg = f"Not supported loss type: {self._loss_type}" raise ValueError(msg) @@ -405,6 +407,20 @@ def _compute_loss_scalable_ce( return loss + def _compute_loss_ce_restricted( + self, + feature_tensors: TensorMap, + positive_labels: torch.LongTensor, + padding_mask: torch.BoolTensor, + tokens_mask: torch.BoolTensor, + ) -> torch.Tensor: + (logits, labels) = self._get_restricted_logits_for_ce_loss( + feature_tensors, positive_labels, padding_mask, tokens_mask + ) + + loss = self._loss(logits, labels) + return loss + def _get_sampled_logits( self, feature_tensors: TensorMap, @@ -487,11 +503,27 @@ def _get_sampled_logits( vocab_size, ) + def _get_restricted_logits_for_ce_loss( + self, + feature_tensors: TensorMap, + positive_labels: torch.LongTensor, + padding_mask: torch.BoolTensor, + tokens_mask: torch.BoolTensor + ): + labels_mask = (~padding_mask) + tokens_mask + masked_tokens = ~labels_mask + positive_labels = cast( + torch.LongTensor, torch.masked_select(positive_labels, masked_tokens) + ) # (masked_batch_seq_size,) + output_emb = self._model.forward_step(feature_tensors, padding_mask, tokens_mask)[masked_tokens] + logits = self._model.get_logits_for_restricted_loss(output_emb) + return (logits, positive_labels) + def _create_loss(self) -> Union[torch.nn.BCEWithLogitsLoss, torch.nn.CrossEntropyLoss]: if self._loss_type == "BCE": return torch.nn.BCEWithLogitsLoss(reduction="sum") - if self._loss_type == "CE" or self._loss_type == "SCE": + if self._loss_type == "CE" or self._loss_type == "SCE" or self._loss_type == "CE_restricted": return torch.nn.CrossEntropyLoss() msg = "Not supported loss_type" @@ -676,4 +708,4 @@ def _prepare_prediction_batch( padding_mask = torch.nn.functional.pad(padding_mask, (max_len - sequence_item_count, 0), value=0) shifted_features, shifted_padding_mask, tokens_mask = _shift_features(schema, features, padding_mask) batch = Bert4RecPredictionBatch(query_id, shifted_padding_mask, shifted_features, tokens_mask) - return batch + return batch \ No newline at end of file diff --git a/replay/models/nn/sequential/bert4rec/model.py b/replay/models/nn/sequential/bert4rec/model.py index 44218c1df..54d3d3ef5 100644 --- a/replay/models/nn/sequential/bert4rec/model.py +++ b/replay/models/nn/sequential/bert4rec/model.py @@ -6,9 +6,12 @@ import torch import torch.nn as nn -from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize -from bitsandbytes.triton.quantize_rowwise import quantize_rowwise -from bitsandbytes.triton.triton_utils import is_triton_available +try: + from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize + from bitsandbytes.triton.quantize_rowwise import quantize_rowwise + from bitsandbytes.triton.triton_utils import is_triton_available +except ModuleNotFoundError: + print("bitsandbytes is not installed. SwitchBack cannot be used.") from replay.data.nn import TensorFeatureInfo, TensorMap, TensorSchema @@ -184,6 +187,18 @@ def get_logits(self, out_embeddings: torch.Tensor, item_ids: Optional[torch.Long """ return self._head(out_embeddings, item_ids) + def get_logits_for_restricted_loss(self, out_embeddings: torch.Tensor, item_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: + """ + Apply head to output embeddings of `forward_step`. + + :param out_embeddings: Embeddings after `forward step`. + :param item_ids: Item ids to calculate scores. + Default: ``None``. + + :returns: Logits for each element in `item_ids`. + """ + return self._head.forward_for_restricted_loss(out_embeddings, item_ids) + def get_query_embeddings(self, inputs: TensorMap, pad_mask: torch.BoolTensor, token_mask: torch.BoolTensor): """ :param inputs: Batch of features. @@ -411,6 +426,22 @@ def forward( logits = torch.matmul(out_embeddings, item_embeddings.t()) + bias return logits + def forward_for_restricted_loss( + self, + out_embeddings: torch.Tensor, + item_ids: Optional[torch.LongTensor] = None, + ) -> torch.Tensor: + + item_embeddings = self.get_item_embeddings() + bias = self.get_bias() + if item_ids is not None: + item_embeddings = item_embeddings[item_ids] + bias = bias[item_ids] + + logits = torch.nn.functional.linear(out_embeddings, item_embeddings, bias) + return logits + + @abstractmethod def get_item_embeddings(self) -> torch.Tensor: # pragma: no cover """ diff --git a/replay/models/nn/sequential/sasrec/lightning.py b/replay/models/nn/sequential/sasrec/lightning.py index 6aca6d7cf..b15486715 100644 --- a/replay/models/nn/sequential/sasrec/lightning.py +++ b/replay/models/nn/sequential/sasrec/lightning.py @@ -10,6 +10,26 @@ from .dataset import SasRecPredictionBatch, SasRecTrainingBatch, SasRecValidationBatch from .model import SasRecModel +from replay.models.nn.optimizer_utils import LigerFusedLinearCrossEntropyFunction + + +try: + import sys + sys.path.append("/home/jovyan/zhmax/cce_loss/") + from cut_cross_entropy.cce import CCEParams, LinearCrossEntropyFunction, _build_flat_valids + from cut_cross_entropy.cce_backward import cce_backward_kernel + from cut_cross_entropy.cce_lse_forward import cce_lse_forward_kernel + from cut_cross_entropy.constants import IGNORE_INDEX + from cut_cross_entropy.doc import CCE_OPTS_DOC, LINEAR_CROSS_ENTROPY_DOC, add_doc_start + from cut_cross_entropy.indexed_dot import indexed_neg_dot_forward_kernel + from cut_cross_entropy.utils import ( + _build_flat_valids, + _handle_eps, + handle_reduction_none, + ) +except ModuleNotFoundError: + print("cut_cross_entropy is not installed. CCE / CCE_minus loss cannot be used.") + class SasRec(lightning.LightningModule): """ @@ -215,6 +235,10 @@ def _compute_loss(self, batch: SasRecTrainingBatch) -> torch.Tensor: loss_func = self._compute_loss_ce if self._loss_sample_count is None else self._compute_loss_ce_sampled elif self._loss_type == "SCE": loss_func = self._compute_loss_scalable_ce + elif self._loss_type == "CE_restricted": + loss_func = self._compute_loss_ce_restricted + elif self._loss_type == "CCE": + loss_func = self._compute_loss_cce else: msg = f"Not supported loss type: {self._loss_type}" raise ValueError(msg) @@ -291,18 +315,13 @@ def _compute_loss_ce( padding_mask: torch.BoolTensor, target_padding_mask: torch.BoolTensor, ) -> torch.Tensor: - # logits: [B x L x V] - logits = self._model.forward( - feature_tensors, - padding_mask, - ) - + # [B x L x V] + logits = self._model.forward(feature_tensors, padding_mask) # labels: [B x L] labels = positive_labels.masked_fill(mask=(~target_padding_mask), value=-100) logits_flat = logits.view(-1, logits.size(-1)) # [(B * L) x V] labels_flat = labels.view(-1) # [(B * L)] - loss = self._loss(logits_flat, labels_flat) return loss @@ -385,6 +404,163 @@ def _compute_loss_scalable_ce( return loss + def _compute_loss_ce_restricted( + self, + feature_tensors: TensorMap, + positive_labels: torch.LongTensor, + padding_mask: torch.BoolTensor, + target_padding_mask: torch.BoolTensor, + ) -> torch.Tensor: + """ + Calculate the Cross-Entropy (CE) loss restricting size of + out_emb and positive labels according to the target padding mask. + """ + + (logits, labels) = self._get_restricted_logits_for_ce_loss( + feature_tensors, positive_labels, padding_mask, target_padding_mask + ) + logits_flat = logits.view(-1, logits.size(-1)) # [(B * L) x V] + labels_flat = labels.view(-1) # [(B * L)] + loss = self._loss(logits_flat, labels_flat) + + return loss + + def _compute_loss_fused_linear_CE( + self, + feature_tensors: TensorMap, + positive_labels: torch.LongTensor, + padding_mask: torch.BoolTensor, + target_padding_mask: torch.BoolTensor + ) -> torch.Tensor: + + output_emb = self._model.forward_step(feature_tensors, padding_mask)[target_padding_mask] + positive_labels = cast( + torch.LongTensor, torch.masked_select(positive_labels, target_padding_mask) + ) + + # Next token prediction + # output_emb = self._model.forward_step(feature_tensors, target_padding_mask) + # output_emb = output_emb[:, :-1, :][target_padding_mask[:, :-1]] + + # padding_mask[:, 0] = False + # positive_labels = cast(torch.LongTensor, torch.masked_select(positive_labels, padding_mask)) + + loss = self._loss.apply( + output_emb.view(-1, self._model.hidden_size), + self._model._head._item_embedder.get_all_item_weights(), + positive_labels + ) + + return loss + + def _compute_loss_cce( + self, + feature_tensors: TensorMap, + positive_labels: torch.LongTensor, + padding_mask: torch.BoolTensor, + target_padding_mask: torch.BoolTensor + ) -> torch.Tensor: + """ + Cut Cross-Entropy (CCE), a method that computes the cross-entropy loss + without materializing the logits for all tokens into global memory. + The method is implemented in a custom kernel that performs the matrix multiplications + and the log-sum-exp reduction over the vocabulary in flash memory, + making global memory consumption for the cross-entropy computation negligible. + + https://arxiv.org/abs/2411.09009 + https://github.com/apple/ml-cross-entropy + """ + + bias = None + ignore_index = -100 + softcap = None + reduction = "mean" + shift = False + filter_eps = "auto" + use_kahan = False + item_inds = None + + e = self._model.forward_step(feature_tensors, padding_mask) + e = e.to(torch.float16) + targets = cast(torch.LongTensor, positive_labels) + c = self._model._head._item_embedder.get_all_item_weights() + + # Next token prediction + # e = self._model.forward_step(feature_tensors, target_padding_mask) + # e = e[:, :-1, :][target_padding_mask[:, :-1]] + # e = e.to(torch.float16) + # c = self._model._head._item_embedder.get_all_item_weights() + # padding_mask[:, 0] = False + # targets = cast(torch.LongTensor, torch.masked_select(positive_labels, padding_mask)) + + e = e.contiguous() + padding_mask = padding_mask.contiguous() + + assert e.size()[0:-1] == targets.size() + assert e.size(-1) == c.size(1) + if not torch.cuda.is_bf16_supported(): + raise RuntimeError( + "Cut Cross Entropy requires an ampere GPU or newer. " + "Consider using torch_compile_linear_cross_entropy for scenarios where one is not available." + ) + + batch_shape = targets.size() + + shift = int(shift) + valids = _build_flat_valids(targets, ignore_index, shift) + + e = e.flatten(0, -2) + targets = targets.flatten() + + if (targets.data_ptr() % 16) != 0: + targets = torch.nn.functional.pad(targets, (0, 1))[:-1] + + assert (targets.data_ptr() % 16) == 0 + + if self._loss_sample_count is not None: + filter_eps = None + n_negative_samples = self._loss_sample_count + vocab_size = self._vocab_size + device = padding_mask.device + + masked_batch_seq_size = targets.size(0) + # probs = torch.ones((masked_batch_seq_size, vocab_size), device=device) + # ids = torch.arange(masked_batch_seq_size, dtype=torch.long, device=device) + # probs[ids, targets] = 0.0 + # negative_labels = torch.multinomial(probs, num_samples=n_negative_samples, replacement=False) + + negative_labels = torch.randint( + low=0, + high=vocab_size, + size=(masked_batch_seq_size, n_negative_samples), + dtype=torch.long, + device=device, + ) + + reject_labels_mask = targets.view(-1, 1) == negative_labels + negative_labels[reject_labels_mask] = vocab_size - 1 + + + item_inds = torch.hstack([targets.view(-1, 1), negative_labels]) + + + params = CCEParams( + targets, + valids, + softcap, + reduction, + _handle_eps(filter_eps, e.dtype), + shift, + batch_shape, + use_kahan, + item_inds + ) + + loss = self._loss.apply(e, c.to(e.dtype), bias, params) + assert isinstance(loss, torch.Tensor) + + return loss + def _get_sampled_logits( self, feature_tensors: TensorMap, @@ -401,6 +577,15 @@ def _get_sampled_logits( device = padding_mask.device output_emb = self._model.forward_step(feature_tensors, padding_mask)[target_padding_mask] + # Next token prediction + # output_emb = self._model.forward_step(feature_tensors, target_padding_mask) + # output_emb = output_emb[:, :-1, :][target_padding_mask[:, :-1]] + # padding_mask[:, 0] = False + # positive_labels = cast(torch.LongTensor, torch.masked_select(positive_labels, padding_mask)) + # masked_batch_seq_size = positive_labels.size(0) + # device = padding_mask.device + + positive_labels = cast(torch.LongTensor, positive_labels.view(-1, 1)) ids = torch.arange(masked_batch_seq_size, dtype=torch.long, device=device) unique_positive_labels, positive_labels_indices = positive_labels.unique(return_inverse=True) @@ -465,13 +650,43 @@ def _get_sampled_logits( return (positive_logits, negative_logits, positive_labels, negative_labels, vocab_size) + def _get_restricted_logits_for_ce_loss( + self, + feature_tensors: TensorMap, + positive_labels: torch.LongTensor, + padding_mask: torch.BoolTensor, + target_padding_mask: torch.BoolTensor + ): + device = padding_mask.device + positive_labels = cast( + torch.LongTensor, torch.masked_select(positive_labels, target_padding_mask) + ) # (masked_batch_seq_size,) + output_emb = self._model.forward_step(feature_tensors, padding_mask) + output_emb = output_emb[target_padding_mask] + + # Next token prediction + # output_emb = self._model.forward_step(feature_tensors, target_padding_mask) + # output_emb = output_emb[:, :-1, :][target_padding_mask[:, :-1]] + + # padding_mask[:, 0] = False + # positive_labels = cast(torch.LongTensor, torch.masked_select(positive_labels, padding_mask)) + + logits = self._model.get_logits_for_restricted_loss(output_emb) + return (logits, positive_labels) + def _create_loss(self) -> Union[torch.nn.BCEWithLogitsLoss, torch.nn.CrossEntropyLoss]: if self._loss_type == "BCE": return torch.nn.BCEWithLogitsLoss(reduction="sum") - if self._loss_type == "CE" or self._loss_type == "SCE": + if self._loss_type == "CE" or self._loss_type == "SCE" or self._loss_type == "CE_restricted": return torch.nn.CrossEntropyLoss() + if self._loss_type == "fused_linear_CE": + return LigerFusedLinearCrossEntropyFunction() + + if self._loss_type == "CCE": + return LinearCrossEntropyFunction() + msg = "Not supported loss_type" raise NotImplementedError(msg) diff --git a/replay/models/nn/sequential/sasrec/model.py b/replay/models/nn/sequential/sasrec/model.py index 4ccd6ef37..763869bd2 100644 --- a/replay/models/nn/sequential/sasrec/model.py +++ b/replay/models/nn/sequential/sasrec/model.py @@ -185,6 +185,18 @@ def get_logits(self, out_embeddings: torch.Tensor, item_ids: Optional[torch.Long :returns: Logits for each element in `item_ids`. """ return self._head(out_embeddings, item_ids) + + def get_logits_for_restricted_loss(self, out_embeddings: torch.Tensor, item_ids: Optional[torch.LongTensor] = None) -> torch.Tensor: + """ + Apply head to output embeddings of `forward_step`. + + :param out_embeddings: Embeddings after `forward step`. + :param item_ids: Item ids to calculate scores. + Default: ``None``. + + :returns: Logits for each element in `item_ids`. + """ + return self._head.forward_for_restricted_loss(out_embeddings, item_ids) def _init(self) -> None: for _, param in self.named_parameters(): diff --git a/replay_benchmarks/configs/model/bert4rec.yaml b/replay_benchmarks/configs/model/bert4rec.yaml index 269ebf206..4e144dbbd 100755 --- a/replay_benchmarks/configs/model/bert4rec.yaml +++ b/replay_benchmarks/configs/model/bert4rec.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CE_restricted loss_sample_count: Null n_buckets: 443 bucket_size_x: 443 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: null #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/bert4rec_beauty.yaml b/replay_benchmarks/configs/model/bert4rec_beauty.yaml index 269ebf206..4e144dbbd 100755 --- a/replay_benchmarks/configs/model/bert4rec_beauty.yaml +++ b/replay_benchmarks/configs/model/bert4rec_beauty.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CE_restricted loss_sample_count: Null n_buckets: 443 bucket_size_x: 443 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: null #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/bert4rec_megamarket.yaml b/replay_benchmarks/configs/model/bert4rec_megamarket.yaml index 269ebf206..4e144dbbd 100755 --- a/replay_benchmarks/configs/model/bert4rec_megamarket.yaml +++ b/replay_benchmarks/configs/model/bert4rec_megamarket.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CE_restricted loss_sample_count: Null n_buckets: 443 bucket_size_x: 443 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: null #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/bert4rec_movielens_1m.yaml b/replay_benchmarks/configs/model/bert4rec_movielens_1m.yaml index 340463a2e..1420e7e9f 100755 --- a/replay_benchmarks/configs/model/bert4rec_movielens_1m.yaml +++ b/replay_benchmarks/configs/model/bert4rec_movielens_1m.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CE_restricted loss_sample_count: Null n_buckets: 443 bucket_size_x: 443 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: null #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/bert4rec_movielens_20m.yaml b/replay_benchmarks/configs/model/bert4rec_movielens_20m.yaml index 269ebf206..4e144dbbd 100755 --- a/replay_benchmarks/configs/model/bert4rec_movielens_20m.yaml +++ b/replay_benchmarks/configs/model/bert4rec_movielens_20m.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CE_restricted loss_sample_count: Null n_buckets: 443 bucket_size_x: 443 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: null #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/bert4rec_netflix.yaml b/replay_benchmarks/configs/model/bert4rec_netflix.yaml index 269ebf206..4e144dbbd 100755 --- a/replay_benchmarks/configs/model/bert4rec_netflix.yaml +++ b/replay_benchmarks/configs/model/bert4rec_netflix.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CE_restricted loss_sample_count: Null n_buckets: 443 bucket_size_x: 443 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: null #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/bert4rec_zvuk.yaml b/replay_benchmarks/configs/model/bert4rec_zvuk.yaml index 269ebf206..4e144dbbd 100755 --- a/replay_benchmarks/configs/model/bert4rec_zvuk.yaml +++ b/replay_benchmarks/configs/model/bert4rec_zvuk.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CE_restricted loss_sample_count: Null n_buckets: 443 bucket_size_x: 443 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: null #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/sasrec.yaml b/replay_benchmarks/configs/model/sasrec.yaml index cf5ce9b59..f69596848 100755 --- a/replay_benchmarks/configs/model/sasrec.yaml +++ b/replay_benchmarks/configs/model/sasrec.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 100 hidden_size: 256 dropout_rate: 0.1 - loss_type: BCE # CE, BCE, SCE + loss_type: BCE # CE, BCE, SCE, CCE loss_sample_count: 500 n_buckets: 443 bucket_size_x: 443 diff --git a/replay_benchmarks/configs/model/sasrec_beauty.yaml b/replay_benchmarks/configs/model/sasrec_beauty.yaml index d156a29fc..4118305dc 100755 --- a/replay_benchmarks/configs/model/sasrec_beauty.yaml +++ b/replay_benchmarks/configs/model/sasrec_beauty.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CCE loss_sample_count: Null n_buckets: 443 bucket_size_x: 443 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: null #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/sasrec_games.yaml b/replay_benchmarks/configs/model/sasrec_games.yaml index d156a29fc..4118305dc 100755 --- a/replay_benchmarks/configs/model/sasrec_games.yaml +++ b/replay_benchmarks/configs/model/sasrec_games.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CCE loss_sample_count: Null n_buckets: 443 bucket_size_x: 443 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: null #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/sasrec_gowalla.yaml b/replay_benchmarks/configs/model/sasrec_gowalla.yaml index 1d17b12ef..d63aaf4be 100755 --- a/replay_benchmarks/configs/model/sasrec_gowalla.yaml +++ b/replay_benchmarks/configs/model/sasrec_gowalla.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CCE loss_sample_count: 500 n_buckets: 443 bucket_size_x: 443 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: null #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/sasrec_megamarket.yaml b/replay_benchmarks/configs/model/sasrec_megamarket.yaml index 1d17b12ef..d63aaf4be 100755 --- a/replay_benchmarks/configs/model/sasrec_megamarket.yaml +++ b/replay_benchmarks/configs/model/sasrec_megamarket.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CCE loss_sample_count: 500 n_buckets: 443 bucket_size_x: 443 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: null #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/sasrec_movielens_1m.yaml b/replay_benchmarks/configs/model/sasrec_movielens_1m.yaml index 5effdbd6a..4365aaa3c 100755 --- a/replay_benchmarks/configs/model/sasrec_movielens_1m.yaml +++ b/replay_benchmarks/configs/model/sasrec_movielens_1m.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CCE loss_sample_count: Null n_buckets: 443 bucket_size_x: 443 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: null #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/sasrec_movielens_20m.yaml b/replay_benchmarks/configs/model/sasrec_movielens_20m.yaml index d156a29fc..5f5f5e0cd 100755 --- a/replay_benchmarks/configs/model/sasrec_movielens_20m.yaml +++ b/replay_benchmarks/configs/model/sasrec_movielens_20m.yaml @@ -9,8 +9,8 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE - loss_sample_count: Null + loss_type: CCE # CE, BCE, SCE, CCE + loss_sample_count: 199 n_buckets: 443 bucket_size_x: 443 bucket_size_y: 512 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: 16-mixed #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/sasrec_netflix.yaml b/replay_benchmarks/configs/model/sasrec_netflix.yaml index d156a29fc..4118305dc 100755 --- a/replay_benchmarks/configs/model/sasrec_netflix.yaml +++ b/replay_benchmarks/configs/model/sasrec_netflix.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CCE loss_sample_count: Null n_buckets: 443 bucket_size_x: 443 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: null #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/sasrec_sports.yaml b/replay_benchmarks/configs/model/sasrec_sports.yaml index d156a29fc..4118305dc 100755 --- a/replay_benchmarks/configs/model/sasrec_sports.yaml +++ b/replay_benchmarks/configs/model/sasrec_sports.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CCE loss_sample_count: Null n_buckets: 443 bucket_size_x: 443 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: null #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/sasrec_yelp.yaml b/replay_benchmarks/configs/model/sasrec_yelp.yaml index d156a29fc..4118305dc 100755 --- a/replay_benchmarks/configs/model/sasrec_yelp.yaml +++ b/replay_benchmarks/configs/model/sasrec_yelp.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CCE loss_sample_count: Null n_buckets: 443 bucket_size_x: 443 @@ -23,4 +23,4 @@ model: num_workers: 4 patience: 4 max_epochs: 20 - precision: null #default 32, bf16-mixed + precision: null #default 32, 16-mixed, bf16-mixed diff --git a/replay_benchmarks/configs/model/sasrec_zvuk.yaml b/replay_benchmarks/configs/model/sasrec_zvuk.yaml index 1d17b12ef..023425ac6 100755 --- a/replay_benchmarks/configs/model/sasrec_zvuk.yaml +++ b/replay_benchmarks/configs/model/sasrec_zvuk.yaml @@ -9,7 +9,7 @@ model: max_seq_len: 128 hidden_size: 256 dropout_rate: 0.1 - loss_type: CE # CE, BCE, SCE + loss_type: CE # CE, BCE, SCE, CCE loss_sample_count: 500 n_buckets: 443 bucket_size_x: 443 diff --git a/replay_benchmarks/train_runner.py b/replay_benchmarks/train_runner.py index 851d32485..fdff57d94 100755 --- a/replay_benchmarks/train_runner.py +++ b/replay_benchmarks/train_runner.py @@ -5,6 +5,7 @@ from pathlib import Path import optuna +import pandas as pd import torch import lightning as L from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger @@ -357,6 +358,27 @@ def objective(trial): logging.info(f"Best hyperparameters: {study.best_params}") + def _save_allocated_memory(self): + devices = [int(self.config["env"]["CUDA_VISIBLE_DEVICES"])] + torch.cuda.synchronize() + allocated = torch.cuda.memory_allocated(device=devices[0]) / 1024**3 # GB + max_allocated = torch.cuda.max_memory_allocated(device=devices[0]) / 1024**3 # GB + torch.cuda.reset_peak_memory_stats() + + data = { + 'allocated_memory': [allocated], + 'max_allocated_memory': [max_allocated] + } + df = pd.DataFrame(data) + + df.to_csv(os.path.join( + self.csv_logger.log_dir, + "memory_stats.csv" + ), index=False) + + logging.info(f"Allocated memory: {allocated} GB") + logging.info(f"Max allocated memory: {max_allocated} GB") + def run(self): """Execute the training pipeline.""" train_dataloader, val_dataloader, val_pred_dataloader, prediction_dataloader = ( @@ -435,6 +457,7 @@ def run(self): ) else: trainer.fit(model, train_dataloader, val_dataloader) + self._save_allocated_memory() if self.model_name.lower() == "sasrec": best_model = SasRec.load_from_checkpoint( From 8812eaf3823f1743d568821fdd4f667f51f2808f Mon Sep 17 00:00:00 2001 From: ZhMax Date: Mon, 24 Mar 2025 15:08:54 +0000 Subject: [PATCH 2/2] kernels --- kernels/__init__.py | 0 kernels/cut_cross_entropy/__init__.py | 12 + kernels/cut_cross_entropy/cce.py | 204 ++++++ kernels/cut_cross_entropy/cce_backward.py | 669 ++++++++++++++++++ kernels/cut_cross_entropy/cce_lse_forward.py | 370 ++++++++++ kernels/cut_cross_entropy/constants.py | 2 + kernels/cut_cross_entropy/doc.py | 58 ++ kernels/cut_cross_entropy/indexed_dot.py | 158 +++++ .../cut_cross_entropy/linear_cross_entropy.py | 120 ++++ kernels/cut_cross_entropy/tl_autotune.py | 595 ++++++++++++++++ kernels/cut_cross_entropy/tl_utils.py | 90 +++ kernels/cut_cross_entropy/torch_compile.py | 82 +++ kernels/cut_cross_entropy/utils.py | 55 ++ .../fused_linear_cross_entropy/__init__.py | 1 + .../fused_linear_ce_loss.py | 542 ++++++++++++++ .../models/nn/sequential/sasrec/lightning.py | 23 +- replay_benchmarks/configs/config.yaml | 4 +- .../configs/model/sasrec_movielens_20m.yaml | 1 - 18 files changed, 2972 insertions(+), 14 deletions(-) create mode 100644 kernels/__init__.py create mode 100644 kernels/cut_cross_entropy/__init__.py create mode 100644 kernels/cut_cross_entropy/cce.py create mode 100644 kernels/cut_cross_entropy/cce_backward.py create mode 100644 kernels/cut_cross_entropy/cce_lse_forward.py create mode 100644 kernels/cut_cross_entropy/constants.py create mode 100644 kernels/cut_cross_entropy/doc.py create mode 100644 kernels/cut_cross_entropy/indexed_dot.py create mode 100644 kernels/cut_cross_entropy/linear_cross_entropy.py create mode 100644 kernels/cut_cross_entropy/tl_autotune.py create mode 100644 kernels/cut_cross_entropy/tl_utils.py create mode 100644 kernels/cut_cross_entropy/torch_compile.py create mode 100644 kernels/cut_cross_entropy/utils.py create mode 100644 kernels/fused_linear_cross_entropy/__init__.py create mode 100644 kernels/fused_linear_cross_entropy/fused_linear_ce_loss.py diff --git a/kernels/__init__.py b/kernels/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/kernels/cut_cross_entropy/__init__.py b/kernels/cut_cross_entropy/__init__.py new file mode 100644 index 000000000..046057d13 --- /dev/null +++ b/kernels/cut_cross_entropy/__init__.py @@ -0,0 +1,12 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +from cut_cross_entropy.linear_cross_entropy import ( + LinearCrossEntropy, + LinearCrossEntropyImpl, + linear_cross_entropy, +) + +__all__ = [ + "LinearCrossEntropy", + "LinearCrossEntropyImpl", + "linear_cross_entropy", +] \ No newline at end of file diff --git a/kernels/cut_cross_entropy/cce.py b/kernels/cut_cross_entropy/cce.py new file mode 100644 index 000000000..ade740a8b --- /dev/null +++ b/kernels/cut_cross_entropy/cce.py @@ -0,0 +1,204 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +from dataclasses import dataclass +from typing import cast + +import torch + +from cut_cross_entropy.cce_backward import cce_backward_kernel +from cut_cross_entropy.cce_lse_forward import cce_lse_forward_kernel +from cut_cross_entropy.constants import IGNORE_INDEX +from cut_cross_entropy.doc import CCE_OPTS_DOC, LINEAR_CROSS_ENTROPY_DOC, add_doc_start +from cut_cross_entropy.indexed_dot import indexed_neg_dot_forward_kernel +from cut_cross_entropy.utils import ( + _build_flat_valids, + _handle_eps, + handle_reduction_none, +) + + +@dataclass +class CCEParams: + targets: torch.Tensor + valids: torch.Tensor | None + softcap: float | None + reduction: str + filter_eps: float | None + shift: int + batch_shape: torch.Size + use_kahan: bool + item_inds: torch.Tensor | None + + +@torch.compile(fullgraph=True, dynamic=True) +def sort_logit_avg(logit_avg: torch.Tensor) -> torch.Tensor: + return torch.argsort(logit_avg).to(torch.int32) + + +class LinearCrossEntropyFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + e: torch.Tensor, + c: torch.Tensor, + bias: torch.Tensor | None, + params: CCEParams, + ) -> torch.Tensor: + needs_grad = e.requires_grad or c.requires_grad + return_logit_avg = needs_grad and params.filter_eps is not None + + ret = cce_lse_forward_kernel( + e=e, + c=c, + bias=bias, + valids=params.valids, + softcap=params.softcap, + return_logit_avg=return_logit_avg, + item_inds=params.item_inds + ) + if return_logit_avg: + assert isinstance(ret, tuple) + lse, logit_avg = ret + else: + assert isinstance(ret, torch.Tensor) + lse = ret + logit_avg = None + + neg_dot = indexed_neg_dot_forward_kernel( + e=e, + c=c, + inds=params.targets, + bias=bias, + shift=params.shift, + valids=params.valids, + softcap=params.softcap, + out_dtype=lse.dtype, + ) + + nll = neg_dot.add_(lse) + + reduction = params.reduction + if reduction == "mean": + loss = nll.mean() + elif reduction == "sum": + loss = nll.sum() + elif reduction == "none": + loss = handle_reduction_none(params.batch_shape, params.valids, params.shift, nll) + else: + raise ValueError(f"Unknown reduction {reduction}") + + ctx.save_for_backward(e, c, bias, lse, params.targets, params.valids, logit_avg) + ctx.params = params + + return loss + + @staticmethod + def backward( + ctx, grad_out: torch.Tensor + ) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None, None]: + e, c, bias, lse, targets, valids, logit_avg = ctx.saved_tensors + + if logit_avg is not None: + vocab_ordering = sort_logit_avg(logit_avg) + else: + vocab_ordering = None + + params = cast(CCEParams, ctx.params) + reduction = params.reduction + if reduction == "mean": + grad_scale = 1 / lse.numel() + elif reduction == "sum": + grad_scale = 1.0 + elif reduction == "none": + grad_scale = 1.0 + grad_out = grad_out.view(-1) + else: + raise ValueError(f"Unknown reduction {reduction}") + + de, dc, dbias = cce_backward_kernel( + do=grad_out, + e=e, + c=c, + bias=bias, + lse=lse, + valids=valids, + softcap=params.softcap, + filter_eps=params.filter_eps, + targets=targets, + shift=params.shift, + vocab_ordering=vocab_ordering, + grad_scale=grad_scale, + use_kahan=params.use_kahan, + item_inds=params.item_inds + ) + + return de, dc, dbias, None + + +def linear_cross_entropy_apply( + e: torch.Tensor, + c: torch.Tensor, + bias: torch.Tensor | None, + params: CCEParams, +) -> torch.Tensor: + loss = LinearCrossEntropyFunction.apply(e, c, bias, params) + assert isinstance(loss, torch.Tensor) + + if params.shift != 0 and params.reduction == "none": + loss = loss[..., params.shift :] + + return loss + + +@add_doc_start(LINEAR_CROSS_ENTROPY_DOC) +@add_doc_start(*(doc_str + "\n" for doc_str in CCE_OPTS_DOC)) +def cce_linear_cross_entropy( + e: torch.Tensor, + c: torch.Tensor, + targets: torch.Tensor, + bias: torch.Tensor | None = None, + ignore_index: int = IGNORE_INDEX, + softcap: float | None = None, + reduction: str = "mean", + shift: bool | int = 0, + filter_eps: float | str | None = "auto", + use_kahan: bool = False, +) -> torch.Tensor: + assert e.size()[0:-1] == targets.size() + assert e.size(-1) == c.size(1) + if not torch.cuda.is_bf16_supported(): + raise RuntimeError( + "Cut Cross Entropy requires an ampere GPU or newer. " + "Consider using torch_compile_linear_cross_entropy for scenarios where one is not available." + ) + + batch_shape = targets.size() + + e = e.contiguous() + targets = targets.contiguous() + + shift = int(shift) + valids = _build_flat_valids(targets, ignore_index, shift) + + e = e.flatten(0, -2) + targets = targets.flatten() + + if (targets.data_ptr() % 16) != 0: + targets = torch.nn.functional.pad(targets, (0, 1))[:-1] + + assert (targets.data_ptr() % 16) == 0 + + return linear_cross_entropy_apply( + e, + c, + bias, + CCEParams( + targets, + valids, + softcap, + reduction, + _handle_eps(filter_eps, e.dtype), + shift, + batch_shape, + use_kahan, + ), + ) \ No newline at end of file diff --git a/kernels/cut_cross_entropy/cce_backward.py b/kernels/cut_cross_entropy/cce_backward.py new file mode 100644 index 000000000..a06a1a1ee --- /dev/null +++ b/kernels/cut_cross_entropy/cce_backward.py @@ -0,0 +1,669 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +import torch +import triton +import triton.language as tl + +from cut_cross_entropy.tl_autotune import cce_backward_autotune, cce_sampled_backward_autotune +from cut_cross_entropy.tl_utils import ( + b_bin_fn, + tl_and_reduce_fn, + tl_lock_add, + tl_lock_kahan_sum, + tl_softcapping, + tl_softcapping_grad, +) + + +@triton.jit +def _mm_backward( + do, + da_ptrs, + dac_ptrs, + partial_mask_a, + da_lock_ptr, + n_locks, + b_ptrs, + partial_mask_b, + stride_ad, + stride_bd, + D, + BLOCK_D: tl.constexpr, + EVEN_D: tl.constexpr, + USE_KAHAN: tl.constexpr, +): + d_inds = tl.arange(0, BLOCK_D)[None, :] + + b_ptrs = b_ptrs + d_inds * stride_bd + da_ptrs = da_ptrs + d_inds * stride_ad + if USE_KAHAN: + dac_ptrs = dac_ptrs + d_inds * stride_ad + + for d in range(0, tl.cdiv(D, BLOCK_D)): + if EVEN_D: + mask = partial_mask_b + else: + mask = partial_mask_b & (d_inds < (D - d * BLOCK_D)) + + b = tl.load(b_ptrs, mask=mask, other=0.0) + + da_i = tl.dot(do, b).to(da_ptrs.dtype.element_ty) + + if EVEN_D: + mask = partial_mask_a + else: + mask = partial_mask_a & (d_inds < (D - d * BLOCK_D)) + + lock_offset = d // tl.cdiv(D, BLOCK_D * n_locks) + this_da_lock_ptr = da_lock_ptr + lock_offset + + if USE_KAHAN: + tl_lock_kahan_sum(da_ptrs, dac_ptrs, da_i, mask, this_da_lock_ptr) + else: + tl_lock_add(da_ptrs, da_i, mask, this_da_lock_ptr) + + b_ptrs += BLOCK_D * stride_bd + da_ptrs += BLOCK_D * stride_ad + if USE_KAHAN: + dac_ptrs += BLOCK_D * stride_ad + + +@triton.jit +def _block_is_filtered(check_val: tl.tensor, filter_eps: tl.tensor) -> tl.tensor: + return tl.reduce(check_val < filter_eps, None, tl_and_reduce_fn) + + +def _cce_backward_kernel( + E, + C, + Bias, + LSE, + dOut, + grad_scale, + Valids, + VocabOrdering, + softcap, + Targets, + dE, + dEC, + dELocks, + dC, + dCC, + dCLocks, + dBias, + B, + D, + V, + BMax, + n_de_locks_0, + n_de_locks_1, + n_dc_locks_0, + n_dc_locks_1, + stride_eb, + stride_ed, + stride_cv, + stride_cd, + stride_biasv, + stride_vb, + filter_eps, + shift, + B_BIN, + BLOCK_B: tl.constexpr, + BLOCK_V: tl.constexpr, + BLOCK_D: tl.constexpr, + MM_BACK_BLOCK_D: tl.constexpr, + GROUP_B: tl.constexpr, + EVEN_D: tl.constexpr, + MM_BACK_EVEN_D: tl.constexpr, + ITEM_DO: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_VALIDS: tl.constexpr, + HAS_VOCAB_ORDERING: tl.constexpr, + FILTER_GRAD: tl.constexpr, + HAS_TARGETS: tl.constexpr, + HAS_SOFTCAP: tl.constexpr, + HAS_SHIFT: tl.constexpr, + USE_KAHAN: tl.constexpr, + COMPUTE_DC: tl.constexpr, + COMPUTE_DE: tl.constexpr, + COMPUTE_DBIAS: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_b_chunks = tl.cdiv(B, BLOCK_B) + num_v_chunks = tl.cdiv(V, BLOCK_V) + num_v_in_group = GROUP_B * num_v_chunks + group_id = pid // num_v_in_group + first_pid_b = group_id * GROUP_B + group_size_b = min(num_b_chunks - first_pid_b, GROUP_B) + pid_b = first_pid_b + ((pid % num_v_in_group) % group_size_b) + pid_v = (pid % num_v_in_group) // group_size_b + + offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) + if HAS_VALIDS: + offs_b = tl.load(Valids + stride_vb * offs_b, mask=offs_b < B, other=BMax) + + offs_v = pid_v * BLOCK_V + tl.arange(0, BLOCK_V) + if HAS_VOCAB_ORDERING: + offs_v = tl.load(VocabOrdering + offs_v, mask=offs_v < V, other=V) + + offs_d = tl.arange(0, BLOCK_D) + e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) + c_ptrs = C + (offs_v[None, :] * stride_cv + offs_d[:, None] * stride_cd) + + accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32) + for d in range(0, tl.cdiv(D, BLOCK_D)): + e_mask = offs_b[:, None] < BMax + if not EVEN_D: + e_mask = e_mask & (offs_d[None, :] < (D - d * BLOCK_D)) + + e = tl.load(e_ptrs, mask=e_mask, other=0.0) + + c_mask = offs_v[None, :] < V + if not EVEN_D: + c_mask = c_mask & (offs_d[:, None] < (D - d * BLOCK_D)) + + c = tl.load(c_ptrs, mask=c_mask, other=0.0) + + accum = tl.dot(e, c, accum) + + e_ptrs += BLOCK_D * stride_ed + c_ptrs += BLOCK_D * stride_cd + + tl.debug_barrier() + + if HAS_BIAS: + bias = tl.load(Bias + offs_v * stride_biasv, mask=offs_v < V, other=0.0) + bias = bias.to(dtype=accum.dtype) + accum += bias[None, :] + + if HAS_SOFTCAP: + accum = tl_softcapping(accum, softcap) + + if HAS_VALIDS: + direct_offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) + lse = tl.load(LSE + direct_offs_b, mask=direct_offs_b < B, other=float("inf")) + else: + lse = tl.load(LSE + offs_b, mask=offs_b < B, other=float("inf")) + + d_accum = tl.exp(accum - lse[:, None]) + d_accum = tl.where(offs_v[None, :] < V, d_accum, 0.0) + + if HAS_TARGETS: + if HAS_SHIFT: + target_offs_b = offs_b + shift + else: + target_offs_b = offs_b + + targets = tl.load(Targets + target_offs_b, mask=target_offs_b < BMax, other=V + 1) + is_target = targets[:, None] == offs_v[None, :] + d_accum += tl.where(is_target, -1.0, 0.0) + else: + is_target = None + + if FILTER_GRAD: + if _block_is_filtered(tl.abs(d_accum), filter_eps): + return + + if HAS_SOFTCAP: + d_accum = tl_softcapping_grad(d_accum, accum, softcap) + + if ITEM_DO: + d_out = tl.load(dOut) + else: + if HAS_SHIFT: + d_out_offs_b = offs_b + shift + else: + d_out_offs_b = offs_b + + d_out = tl.load(dOut + d_out_offs_b, mask=d_out_offs_b < BMax, other=0.0)[:, None] + + d_out = grad_scale * d_out + + d_accum = d_accum * d_out + + if COMPUTE_DBIAS: + tl.atomic_add(dBias + offs_v * stride_biasv, tl.sum(d_accum, 0), mask=offs_v < V) + + d_accum = d_accum.to(e_ptrs.dtype.element_ty) + + if COMPUTE_DE: + lock_offset = (pid_b // tl.cdiv(B, BLOCK_B * n_de_locks_0)) * n_de_locks_1 + + _mm_backward( + d_accum, + dE + (offs_b[:, None] * stride_eb), + dEC + (offs_b[:, None] * stride_eb) if USE_KAHAN else None, + offs_b[:, None] < BMax, + dELocks + lock_offset, + n_de_locks_1, + C + offs_v[:, None] * stride_cv, + offs_v[:, None] < V, + stride_ed, + stride_cd, + D, + MM_BACK_BLOCK_D, + MM_BACK_EVEN_D, + USE_KAHAN, + ) + + if COMPUTE_DC: + lock_offset = (pid_v // tl.cdiv(V, BLOCK_V * n_dc_locks_0)) * n_dc_locks_1 + + _mm_backward( + tl.trans(d_accum), + dC + (offs_v[:, None] * stride_cv), + dCC + (offs_v[:, None] * stride_cv) if USE_KAHAN else None, + offs_v[:, None] < V, + dCLocks + lock_offset, + n_dc_locks_1, + E + (offs_b[:, None] * stride_eb), + offs_b[:, None] < BMax, + stride_cd, + stride_ed, + D, + MM_BACK_BLOCK_D, + MM_BACK_EVEN_D, + USE_KAHAN, + ) + + +def _cce_back_block_d(args) -> int: + block_d = args["BLOCK_D"] + return 2 * block_d + + +_cce_backward_kernel = triton.jit(_cce_backward_kernel) +_cce_backward_kernel = triton.heuristics( # type: ignore + { + "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0, + "MM_BACK_BLOCK_D": lambda args: _cce_back_block_d(args), + "MM_BACK_EVEN_D": lambda args: (args["D"] % _cce_back_block_d(args)) == 0, + "HAS_VALIDS": lambda args: args["Valids"] is not None, + "HAS_BIAS": lambda args: args["Bias"] is not None, + "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None, + "FILTER_GRAD": lambda args: args["filter_eps"] is not None, + "HAS_TARGETS": lambda args: args["Targets"] is not None, + "HAS_SOFTCAP": lambda args: args["softcap"] is not None, + "HAS_SHIFT": lambda args: args["shift"] != 0, + "ITEM_DO": lambda args: args["dOut"].numel() == 1, + "GROUP_B": lambda args: 8, + "COMPUTE_DC": lambda args: args["dC"] is not None, + "COMPUTE_DE": lambda args: args["dE"] is not None, + "COMPUTE_DBIAS": lambda args: args["dBias"] is not None, + } +)(_cce_backward_kernel) +_cce_backward_kernel = cce_backward_autotune()(_cce_backward_kernel) # type: ignore + + +def _cce_sampled_backward_kernel( + E, + C, + Inds, + Bias, + LSE, + dOut, + grad_scale, + Valids, + VocabOrdering, + softcap, + Targets, + dE, + dEC, + dELocks, + dC, + dCC, + dCLocks, + dBias, + B, + D, + V, + SAMPLE_NUMS, + BMax, + n_de_locks_0, + n_de_locks_1, + n_dc_locks_0, + n_dc_locks_1, + stride_eb, + stride_ed, + stride_cv, + stride_cd, + stride_ib, + stride_is, + stride_biasv, + stride_vb, + filter_eps, + shift, + B_BIN, + BLOCK_B: tl.constexpr, + BLOCK_V: tl.constexpr, + BLOCK_D: tl.constexpr, + MM_BACK_BLOCK_D: tl.constexpr, + GROUP_B: tl.constexpr, + EVEN_D: tl.constexpr, + MM_BACK_EVEN_D: tl.constexpr, + ITEM_DO: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_VALIDS: tl.constexpr, + HAS_VOCAB_ORDERING: tl.constexpr, + FILTER_GRAD: tl.constexpr, + HAS_TARGETS: tl.constexpr, + HAS_SOFTCAP: tl.constexpr, + HAS_SHIFT: tl.constexpr, + USE_KAHAN: tl.constexpr, + COMPUTE_DC: tl.constexpr, + COMPUTE_DE: tl.constexpr, + COMPUTE_DBIAS: tl.constexpr, +): + pid = tl.program_id(axis=0) + idx = tl.program_id(axis=1) + offs_b = pid * BLOCK_B + tl.arange(0, BLOCK_B) + offs_d = tl.arange(0, BLOCK_D) + + # de_accum = tl.zeros((BLOCK_B, BLOCK_D), dtype=tl.float16) + # dc_accum = tl.zeros((BLOCK_B, BLOCK_D), dtype=tl.float16) + # inds_accum = tl.zeros((BLOCK_B, ), dtype=tl.int32) + # for idx in range(0, SAMPLE_NUMS): + e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) + e_mask = (offs_b[:, None] < BMax) & (offs_d[None, :] < D) + + inds_ptrs = Inds + offs_b * stride_ib + idx + inds_mask = offs_b < BMax + inds = tl.load(inds_ptrs, mask=inds_mask, other=V) + c_ptrs = C + (inds[:, None] * stride_cv + offs_d[None, :] * stride_cd) + c_mask = (inds[:, None] < V) & (offs_d[None, :] < D) + + e = tl.load(e_ptrs, mask=e_mask, other=0.0) + c = tl.load(c_ptrs, mask=c_mask, other=0.0) + + dot_sum = tl.sum(e.to(tl.float32) * c.to(tl.float32), axis=1) + + if idx > 0: + dot_sum += tl.log(V - 1.0) + dot_sum -= tl.log(1.0 * SAMPLE_NUMS) + + lse = tl.load(LSE + offs_b, mask=offs_b < B, other=float("inf")) + + d_accum = tl.exp(dot_sum - lse) + d_accum = tl.where(inds < V, d_accum, 0.0) + + if HAS_TARGETS: + if HAS_SHIFT: + target_offs_b = offs_b + shift + else: + target_offs_b = offs_b + + targets = tl.load(Targets + target_offs_b, mask=target_offs_b < BMax, other=V + 1) + is_target = targets == idx + d_accum += tl.where(is_target, -1.0, 0.0) + else: + is_target = None + + + if ITEM_DO: + d_out = tl.load(dOut) + else: + if HAS_SHIFT: + d_out_offs_b = offs_b + shift + else: + d_out_offs_b = offs_b + + d_out = tl.load(dOut + d_out_offs_b, mask=d_out_offs_b < BMax, other=0.0) + + d_out = grad_scale * d_out + d_accum = d_accum * d_out + d_accum = d_accum.to(e_ptrs.dtype.element_ty) + + if COMPUTE_DE: + de_ptrs = dE + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) + de_mask = (offs_b[:, None] < BMax) & (offs_d[None, :] < D) + de_out = d_accum[:, None] * c + tl.atomic_add(de_ptrs, de_out, mask=de_mask) + + if COMPUTE_DC: + dc_ptrs = dC + (inds[:, None] * stride_cv + offs_d[None, :] * stride_cd) + dc_mask = (inds[:, None] < V) & (offs_d[None, :] < D) + dc_out = d_accum[:, None] * e + tl.atomic_add(dc_ptrs, dc_out, mask=dc_mask) + + # if COMPUTE_DE: + # de_accum += d_accum[:, None] * c + + # if COMPUTE_DC: + # dc_accum += d_accum[:, None] * e + + + # if COMPUTE_DE: + # de_ptrs = dE + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) + # de_mask = (offs_b[:, None] < BMax) & (offs_d[None, :] < D) + # tl.store(de_ptrs, de_accum, mask=de_mask) + + # if COMPUTE_DC: + # dc_ptrs = dC + (inds[:, None] * stride_cv + offs_d[None, :] * stride_cd) + # dc_mask = (inds[:, None] < V) & (offs_d[None, :] < D) + # tl.store(dc_ptrs, dc_accum, mask=dc_mask) + + + +_cce_sampled_backward_kernel = triton.jit(_cce_sampled_backward_kernel) +_cce_sampled_backward_kernel = triton.heuristics( # type: ignore + { + "EVEN_D": lambda args: (args["D"] % args["BLOCK_D"]) == 0, + "MM_BACK_BLOCK_D": lambda args: _cce_back_block_d(args), + "MM_BACK_EVEN_D": lambda args: (args["D"] % _cce_back_block_d(args)) == 0, + "HAS_VALIDS": lambda args: args["Valids"] is not None, + "HAS_BIAS": lambda args: args["Bias"] is not None, + "HAS_VOCAB_ORDERING": lambda args: args["VocabOrdering"] is not None, + "FILTER_GRAD": lambda args: args["filter_eps"] is not None, + "HAS_TARGETS": lambda args: args["Targets"] is not None, + "HAS_SOFTCAP": lambda args: args["softcap"] is not None, + "HAS_SHIFT": lambda args: args["shift"] != 0, + "ITEM_DO": lambda args: args["dOut"].numel() == 1, + "GROUP_B": lambda args: 8, + "COMPUTE_DC": lambda args: args["dC"] is not None, + "COMPUTE_DE": lambda args: args["dE"] is not None, + "COMPUTE_DBIAS": lambda args: args["dBias"] is not None, + } +)(_cce_sampled_backward_kernel) +_cce_sampled_backward_kernel = cce_sampled_backward_autotune()(_cce_sampled_backward_kernel) # type: ignore + +def cce_backward_kernel( + do: torch.Tensor, + e: torch.Tensor, + c: torch.Tensor, + bias: torch.Tensor | None, + lse: torch.Tensor, + valids: torch.Tensor | None, + softcap: float | None, + filter_eps: float | None, + targets: torch.Tensor | None = None, + shift: int = 0, + vocab_ordering: torch.Tensor | None = None, + grad_scale: float = 1.0, + use_kahan: bool = False, + item_inds: torch.Tensor | None = None, +) -> tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]: + assert do.numel() in (e.size(0), 1) + assert c.size(1) == e.size(1) + assert lse.size(0) == e.size(0) or (valids is not None and lse.size(0) == valids.size(0)) + assert e.dtype in ( + torch.float16, + torch.bfloat16, + ), "Backwards requires embeddings to be bf16 or fp16" + assert c.dtype in ( + torch.float16, + torch.bfloat16, + ), "Backwards requires classifier to be bf16 or fp16" + + do = do.contiguous() + lse = lse.contiguous() + + de = torch.zeros_like(e) if e.requires_grad else None + dc = torch.zeros_like(c) if c.requires_grad else None + + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) if bias.requires_grad else None + else: + dbias = None + + if de is not None: + assert de.stride() == e.stride() + + if dc is not None: + assert dc.stride() == c.stride() + + if dbias is not None: + assert bias is not None + assert dbias.stride() == bias.stride() + + if use_kahan: + dec = torch.zeros_like(e) if de is not None else None + dcc = torch.zeros_like(c) if dc is not None else None + else: + dec = None + dcc = None + + if dec is not None: + assert dec.stride() == e.stride() + + if dcc is not None: + assert dcc.stride() == e.stride() + + if valids is not None: + assert valids.ndim == 1 + B = valids.size(0) + else: + B = e.size(0) + + if do.numel() > 1: + do = do.contiguous() + lse = lse.contiguous() + assert do.stride(0) == lse.stride(0), f"{do.stride()=}, {lse.stride()=}" + + if item_inds is None: + def grid(META): + return (triton.cdiv(B, META["BLOCK_B"]) * triton.cdiv(c.size(0), META["BLOCK_V"]),) + + if vocab_ordering is not None: + assert vocab_ordering.ndim == 1 + assert vocab_ordering.numel() == c.size(0) + assert vocab_ordering.stride(0) == 1 + + nd_locks = triton.cdiv(c.size(1), 64) + if de is not None: + de_locks = e.new_zeros((triton.cdiv(B, 128), nd_locks), dtype=torch.int32) + de_lock_sizes = de_locks.size() + else: + de_locks = None + de_lock_sizes = (None, None) + + if dc is not None: + dc_locks = c.new_zeros((triton.cdiv(c.size(0), 128), nd_locks), dtype=torch.int32) + dc_lock_sizes = dc_locks.size() + else: + dc_locks = None + dc_lock_sizes = (None, None) + + _cce_backward_kernel[grid]( + e, + c, + bias, + lse, + do, + grad_scale, + valids, + vocab_ordering, + softcap, + targets, + de, + dec, + de_locks, + dc, + dcc, + dc_locks, + dbias, + B, + e.size(1), + c.size(0), + e.size(0), + *de_lock_sizes, + *dc_lock_sizes, + e.stride(0), + e.stride(1), + c.stride(0), + c.stride(1), + 1 if bias is None else bias.stride(0), + 1 if valids is None else valids.stride(0), + filter_eps, + shift=shift, + B_BIN=b_bin_fn(B), + USE_KAHAN=use_kahan, + ) + else: + SAMPLE_NUMS = item_inds.size(1) + def grid(META): + return (triton.cdiv(B, META["BLOCK_B"]), SAMPLE_NUMS) + D = e.size(1) + BLOCK_D = int(2**torch.ceil(torch.log2(torch.tensor(D)))) + + # nd_locks = triton.cdiv(c.size(1), 64) + # if de is not None: + # de_locks = e.new_zeros((triton.cdiv(B, 128), nd_locks), dtype=torch.int32) + # de_lock_sizes = de_locks.size() + # else: + # de_locks = None + # de_lock_sizes = (None, None) + + # if dc is not None: + # dc_locks = c.new_zeros((triton.cdiv(c.size(0), 128), nd_locks), dtype=torch.int32) + # dc_lock_sizes = dc_locks.size() + # else: + # dc_locks = None + # dc_lock_sizes = (None, None) + + targets_cce_sampled_loss = torch.zeros_like(lse) + + _cce_sampled_backward_kernel[grid]( + e, + c, + item_inds, + bias, + lse, + do, + grad_scale, + valids, + vocab_ordering, + softcap, + targets_cce_sampled_loss, + de, + None, #dec, + None,#de_locks, + dc, + None, #dcc, + None, #dc_locks, + dbias, + B, + D, + c.size(0), + SAMPLE_NUMS, + e.size(0), + *(None, None), #*de_lock_sizes, + *(None, None), #*dc_lock_sizes, + e.stride(0), + e.stride(1), + c.stride(0), + c.stride(1), + item_inds.stride(0), + item_inds.stride(1), + 1 if bias is None else bias.stride(0), + 1 if valids is None else valids.stride(0), + filter_eps, + shift=shift, + B_BIN=b_bin_fn(B), + USE_KAHAN=use_kahan, + BLOCK_D=BLOCK_D + ) + + if dbias is not None: + assert bias is not None + dbias = dbias.to(dtype=bias.dtype) + + return de, dc, dbias \ No newline at end of file diff --git a/kernels/cut_cross_entropy/cce_lse_forward.py b/kernels/cut_cross_entropy/cce_lse_forward.py new file mode 100644 index 000000000..48443a446 --- /dev/null +++ b/kernels/cut_cross_entropy/cce_lse_forward.py @@ -0,0 +1,370 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +from typing import Literal, overload + +import torch +import triton +import triton.language as tl + +from cut_cross_entropy.tl_autotune import cce_forward_autotune, cce_sampled_forward_autotune +from cut_cross_entropy.tl_utils import b_bin_fn, tl_logaddexp, tl_softcapping + + +def _cce_lse_forward_kernel( + E, + C, + Bias, + LSE, + LA, + Locks, + Valids, + softcap, + B, + V, + D, + BMax, + stride_eb, + stride_ed, + stride_cv, + stride_cd, + stride_biasv, + stride_lse_b, + stride_vb, + num_locks, + # Meta-parameters + B_BIN, + HAS_BIAS: tl.constexpr, + HAS_VALIDS: tl.constexpr, + BLOCK_B: tl.constexpr, + BLOCK_V: tl.constexpr, + BLOCK_D: tl.constexpr, # + GROUP_B: tl.constexpr, # + EVEN_D: tl.constexpr, + HAS_SOFTCAP: tl.constexpr, + HAS_LA: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_b = tl.cdiv(B, BLOCK_B) + num_pid_v = tl.cdiv(V, BLOCK_V) + num_pid_in_group = GROUP_B * num_pid_v + group_id = pid // num_pid_in_group + first_pid_b = group_id * GROUP_B + group_size_b = min(num_pid_b - first_pid_b, GROUP_B) + pid_b = first_pid_b + ((pid % num_pid_in_group) % group_size_b) + pid_v = (pid % num_pid_in_group) // group_size_b + + offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) + if HAS_VALIDS: + offs_b = tl.load(Valids + stride_vb * offs_b, mask=offs_b < B, other=BMax) + + offs_v = pid_v * BLOCK_V + tl.arange(0, BLOCK_V) + offs_d = tl.arange(0, BLOCK_D) + e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) + c_ptrs = C + (offs_v[None, :] * stride_cv + offs_d[:, None] * stride_cd) + + accum = tl.zeros((BLOCK_B, BLOCK_V), dtype=tl.float32) + for d in range(0, tl.cdiv(D, BLOCK_D)): + e_mask = offs_b[:, None] < BMax + if not EVEN_D: + e_mask = e_mask & (offs_d[None, :] < (D - d * BLOCK_D)) + + e = tl.load(e_ptrs, mask=e_mask, other=0.0) + + c_mask = offs_v[None, :] < V + if not EVEN_D: + c_mask = c_mask & (offs_d[:, None] < (D - d * BLOCK_D)) + + c = tl.load(c_ptrs, mask=c_mask, other=0.0) + + accum = tl.dot(e, c, accum, input_precision=DOT_PRECISION) + + e_ptrs += BLOCK_D * stride_ed + c_ptrs += BLOCK_D * stride_cd + + tl.debug_barrier() + + if HAS_BIAS: + bias = tl.load(Bias + offs_v * stride_biasv, mask=offs_v < V, other=0.0) + bias = bias.to(dtype=accum.dtype) + accum += bias[None, :] + + logits = tl.where(offs_v[None, :] < V, accum, -float("inf")) + if HAS_SOFTCAP: + logits = tl_softcapping(logits, softcap) + + if HAS_LA: + this_avg_logit = tl.sum(logits, 0) / B + tl.atomic_add(LA + offs_v, this_avg_logit, mask=offs_v < V) + + this_mx = tl.max(logits, axis=1) + e = tl.exp(logits - this_mx[:, None]) + this_lse = this_mx + tl.log(tl.sum(e, axis=1)) + + offs_b = pid_b * BLOCK_B + tl.arange(0, BLOCK_B) + o_mask = offs_b < B + + lse_ptrs = LSE + (stride_lse_b * offs_b) + + this_locks = Locks + (pid_b // tl.cdiv(B, BLOCK_B * num_locks)) + while tl.atomic_cas(this_locks, 0, 1) == 1: + pass + + lse = tl.load(lse_ptrs, mask=o_mask, other=0.0, eviction_policy="evict_last") + lse = tl_logaddexp(lse, this_lse) + tl.store(lse_ptrs, lse, mask=o_mask, eviction_policy="evict_last") + + tl.debug_barrier() + tl.atomic_xchg(this_locks, 0) + + +_cce_lse_forward_kernel = triton.jit(_cce_lse_forward_kernel) +_cce_lse_forward_kernel = triton.heuristics( # type: ignore + { + "EVEN_D": lambda args: args["D"] % args["BLOCK_D"] == 0, + "HAS_BIAS": lambda args: args["Bias"] is not None, + "HAS_VALIDS": lambda args: args["Valids"] is not None, + "HAS_SOFTCAP": lambda args: args["softcap"] is not None, + "HAS_LA": lambda args: args["LA"] is not None, + "GROUP_B": lambda args: 8, + "DOT_PRECISION": lambda args: "tf32" + if torch.get_float32_matmul_precision() == "high" + else "ieee", + } +)(_cce_lse_forward_kernel) +_cce_lse_forward_kernel = cce_forward_autotune()(_cce_lse_forward_kernel) # type: ignore + + +def _cce_lse_sampled_forward_kernel( + E, + C, + Inds, + Bias, + LSE, + LA, + Locks, + Valids, + softcap, + B, + V, + D, + SAMPLE_NUMS, + BMax, + stride_eb, + stride_ed, + stride_cv, + stride_cd, + stride_ib, + stride_is, + stride_biasv, + stride_lse_b, + stride_vb, + num_locks, + # Meta-parameters + B_BIN, + HAS_BIAS: tl.constexpr, + HAS_VALIDS: tl.constexpr, + BLOCK_B: tl.constexpr, + BLOCK_V: tl.constexpr, + BLOCK_D: tl.constexpr, # + GROUP_B: tl.constexpr, # + EVEN_D: tl.constexpr, + HAS_SOFTCAP: tl.constexpr, + HAS_LA: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + pid = tl.program_id(axis=0) + m = tl.full((BLOCK_B, ), float("-inf"), dtype=tl.float32) + d = tl.zeros((BLOCK_B, ), dtype=tl.float32) + + offs_b = pid * BLOCK_B + tl.arange(0, BLOCK_B) + offs_d = tl.arange(0, BLOCK_D) + + for idx in range(0, SAMPLE_NUMS): + e_ptrs = E + (offs_b[:, None] * stride_eb + offs_d[None, :] * stride_ed) + e_mask = (offs_b[:, None] < BMax) & (offs_d[None, :] < D) + + inds_ptrs = Inds + offs_b * stride_ib + idx + inds_mask = offs_b < BMax + inds = tl.load(inds_ptrs, mask=inds_mask, other=V) + c_ptrs = C + (inds[:, None] * stride_cv + offs_d[None, :] * stride_cd) + c_mask = (inds[:, None] < V) & (offs_d[None, :] < D) + + e = tl.load(e_ptrs, mask=e_mask, other=0.0) + c = tl.load(c_ptrs, mask=c_mask, other=0.0) + + dot_sum = tl.sum(e.to(tl.float32) * c.to(tl.float32), axis=1) + + if idx > 0: + dot_sum += tl.log(V - 1.0) + dot_sum -= tl.log(1.0 * SAMPLE_NUMS) + + block_max = dot_sum + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.exp(dot_sum - m_new) + m = m_new + + lse = m + tl.log(d) + lse_ptrs = LSE + offs_b + out_mask = (offs_b < BMax) + tl.store(lse_ptrs, lse, mask = out_mask) + +_cce_lse_sampled_forward_kernel = triton.jit(_cce_lse_sampled_forward_kernel) +_cce_lse_sampled_forward_kernel = triton.heuristics( # type: ignore + { + "EVEN_D": lambda args: args["D"] % args["BLOCK_D"] == 0, + "HAS_BIAS": lambda args: args["Bias"] is not None, + "HAS_VALIDS": lambda args: args["Valids"] is not None, + "HAS_SOFTCAP": lambda args: args["softcap"] is not None, + "HAS_LA": lambda args: args["LA"] is not None, + "GROUP_B": lambda args: 8, + "DOT_PRECISION": lambda args: "tf32" + if torch.get_float32_matmul_precision() == "high" + else "ieee", + } +)(_cce_lse_sampled_forward_kernel) +_cce_lse_sampled_forward_kernel = cce_sampled_forward_autotune()(_cce_lse_sampled_forward_kernel) # type: ignore + + +@overload +def cce_lse_forward_kernel( + e: torch.Tensor, + c: torch.Tensor, + bias: torch.Tensor | None = None, + valids: torch.Tensor | None = None, + softcap: float | None = None, + return_logit_avg: Literal[False] = False, +) -> torch.Tensor: ... + + +@overload +def cce_lse_forward_kernel( + e: torch.Tensor, + c: torch.Tensor, + bias: torch.Tensor | None = None, + valids: torch.Tensor | None = None, + softcap: float | None = None, + return_logit_avg: Literal[True] = True, +) -> tuple[torch.Tensor, torch.Tensor]: ... + + +@overload +def cce_lse_forward_kernel( + e: torch.Tensor, + c: torch.Tensor, + bias: torch.Tensor | None = None, + valids: torch.Tensor | None = None, + softcap: float | None = None, + return_logit_avg: bool = False, +) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor: ... + + +def cce_lse_forward_kernel( + e: torch.Tensor, + c: torch.Tensor, + bias: torch.Tensor | None = None, + valids: torch.Tensor | None = None, + softcap: float | None = None, + return_logit_avg: bool = False, + item_inds: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor: + # Check constraints. + assert e.shape[1] == c.shape[1], "Incompatible dimensions" + assert e.is_contiguous(), "Matrix A must be contiguous" + if valids is not None: + assert valids.ndim == 1 + B = valids.numel() + else: + B, _ = e.shape + + if bias is not None: + assert bias.ndim == 1 + assert c.shape[0] == bias.shape[0] + + V, D = c.shape + # Allocates output. + lse = e.new_full((B,), -float("inf"), dtype=torch.float32) + + + if item_inds is None: + locks = e.new_full( + (triton.cdiv(B, 128),), + 0, + dtype=torch.uint32, + ) + + if return_logit_avg: + logit_avg = e.new_full((V,), 0.0, dtype=torch.float32) + else: + logit_avg = None + + # 1D launch kernel where each block gets its own program. + def grid(META) -> tuple[int]: + return (triton.cdiv(B, META["BLOCK_B"]) * triton.cdiv(V, META["BLOCK_V"]),) + + _cce_lse_forward_kernel[grid]( + e, + c, + bias, + lse, # + logit_avg, + locks, + valids, + softcap, + B, + V, + D, # + e.size(0), + e.stride(0), + e.stride(1), # + c.stride(0), + c.stride(1), # + 1 if bias is None else bias.stride(0), + lse.stride(0), + 1 if valids is None else valids.stride(0), + num_locks=locks.size(0), + B_BIN=b_bin_fn(B), + ) + else: + SAMPLE_NUMS = item_inds.size(1) + if return_logit_avg: + logit_avg = e.new_full((SAMPLE_NUMS,), 0.0, dtype=torch.float32) + else: + logit_avg = None + # 1D launch kernel where each block gets its own program. + def grid(META) -> tuple[int]: + return (triton.cdiv(B, META['BLOCK_B']), ) + BLOCK_D = int(2**torch.ceil(torch.log2(torch.tensor(D)))) + _cce_lse_sampled_forward_kernel[grid]( + e, + c, + item_inds, + bias, + lse, # + logit_avg, + None, #locks + valids, + softcap, + B, + V, + D, # + SAMPLE_NUMS, + e.size(0), + e.stride(0), + e.stride(1), # + c.stride(0), + c.stride(1), # + item_inds.stride(0), + item_inds.stride(1), + 1 if bias is None else bias.stride(0), + lse.stride(0), + 1 if valids is None else valids.stride(0), + num_locks=None, # num_locks=locks.size(0), + B_BIN=b_bin_fn(B), + BLOCK_D=BLOCK_D + ) + + if return_logit_avg: + assert logit_avg is not None + return lse, logit_avg + else: + return lse \ No newline at end of file diff --git a/kernels/cut_cross_entropy/constants.py b/kernels/cut_cross_entropy/constants.py new file mode 100644 index 000000000..2ba4670ac --- /dev/null +++ b/kernels/cut_cross_entropy/constants.py @@ -0,0 +1,2 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +IGNORE_INDEX: int = -100 \ No newline at end of file diff --git a/kernels/cut_cross_entropy/doc.py b/kernels/cut_cross_entropy/doc.py new file mode 100644 index 000000000..4b3d2803a --- /dev/null +++ b/kernels/cut_cross_entropy/doc.py @@ -0,0 +1,58 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +LINEAR_CROSS_ENTROPY_DOC = """Computes cross-entropy loss using the logits generated by performing + the matrix multiplication between the embeddings (e) and classifier (c). + + This method saves GPU memory by not materializing the logits into GPU + main memory. + + + Specifically, this computes + + ```python + + loss = F.cross_entropy((e @ c.T).float(), targets) + ``` + + without allocating the intermediary (e @ c.T).float() matrix. + + :param e: Embedding of the inputs used to compute the logits. Shape (..., D) + :param c: Classifier matrix. Shape (NumClasses, D) + :param targets: The target class for each input. Values must be in [0, NumClasses). Shape (...) + :param ignore_index: If an input as a target of this value, it is ignored in the loss computation. + :param softcap: The value for logit softcapping. + :param reduction: The reduction to perform over the loss. Supports "mean", "sum", and "none". + :param shift: When non-zero, the embedding and targets will be shifted along the temporal axis to perform nth-next token prediction. + Specifically, this is used to efficiently compute the following + + ```python + shift_e = e[..., :-shift, :].flatten(0, -2) + shift_targets = targets[..., shift:].flatten() + + loss = F.cross_entropy((shift_e @ c.T), targets) + ``` + + If given a boolean value, False will be treated as zero and True will be treated as one. + + When this value is non-zero or True, e and targets must have shape (..., T, D) and (..., T), respectively. + + Integer values must be in [0, T) +""" + +CCE_OPTS_DOC = [ + """ + :param filter_eps: The threshold value used to determine which locations can be safely ignored + in gradient computation. The default value of "auto" will automatically choose a value + based on the input dtype.""", + """ + :param use_kahan: Uses Kahan summation to increase the precision of CCE's reduction along the vocab axis. This only + makes sense to set to True when filter_eps is None (or is a very very small value).""", +] + + +def add_doc_start(*docstr: str): + def add_doc(fn): + fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "") + + return fn + + return add_doc \ No newline at end of file diff --git a/kernels/cut_cross_entropy/indexed_dot.py b/kernels/cut_cross_entropy/indexed_dot.py new file mode 100644 index 000000000..314d00864 --- /dev/null +++ b/kernels/cut_cross_entropy/indexed_dot.py @@ -0,0 +1,158 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +import torch +import triton +import triton.language as tl + +from cut_cross_entropy.tl_autotune import indexed_dot_autotune +from cut_cross_entropy.tl_utils import b_bin_fn +from cut_cross_entropy.utils import softcapping + + +def _indexed_neg_dot_forward_kernel( + E, + C, + Inds, + Bias, + Valids, + Out, + B, + D, + V, + BMax, + stride_eb, + stride_ed, + stride_cv, + stride_cd, + stride_ib, + stride_biasv, + stride_vb, + shift, + B_BIN, + BLOCK_B: tl.constexpr, + BLOCK_D: tl.constexpr, + GROUP_B: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_VALIDS: tl.constexpr, + EVEN_D: tl.constexpr, + HAS_SHIFT: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_b_chunks = tl.cdiv(B, BLOCK_B) + num_d_chunks = tl.cdiv(D, BLOCK_D) + num_d_in_group = GROUP_B * num_d_chunks + group_id = pid // num_d_in_group + first_pid_b = group_id * GROUP_B + group_size_b = min(num_b_chunks - first_pid_b, GROUP_B) + pid_b = first_pid_b + ((pid % num_d_in_group) % group_size_b) + pid_d = (pid % num_d_in_group) // group_size_b + + offs_b = tl.arange(0, BLOCK_B) + pid_b * BLOCK_B + if HAS_VALIDS: + offs_b = tl.load(Valids + stride_vb * offs_b, mask=offs_b < B, other=BMax) + + offs_d = tl.arange(0, BLOCK_D) + pid_d * BLOCK_D + e_ptrs = E + (stride_eb * offs_b[:, None] + stride_ed * offs_d[None, :]) + + e_mask = offs_b[:, None] < BMax + if not EVEN_D: + e_mask = e_mask & (offs_d[None, :] < D) + + e = tl.load(e_ptrs, mask=e_mask, other=0.0) + + if HAS_SHIFT: + offs_b = offs_b + shift + + inds = tl.load(Inds + stride_ib * offs_b, mask=offs_b < BMax, other=V) + + c_ptrs = C + (inds[:, None] * stride_cv + offs_d[None, :] * stride_cd) + + c_mask = inds[:, None] < V + if not EVEN_D: + c_mask = c_mask & (offs_d[None, :] < D) + + c = tl.load(c_ptrs, mask=c_mask, other=0.0) + + offs_b = tl.arange(0, BLOCK_B) + pid_b * BLOCK_B + out_ptrs = Out + offs_b + dot = e.to(tl.float32) * c.to(tl.float32) + neg_dot = -tl.sum(dot, 1) + + if HAS_BIAS: + bias = tl.load(Bias + inds * stride_biasv, mask=inds < V, other=0.0) + bias = bias.to(tl.float32) + neg_dot -= bias + + tl.atomic_add(out_ptrs, neg_dot.to(out_ptrs.dtype.element_ty), mask=offs_b < B) + + +_indexed_neg_dot_forward_kernel = triton.jit(_indexed_neg_dot_forward_kernel) +_indexed_neg_dot_forward_kernel = triton.heuristics( # type: ignore + { + "EVEN_D": lambda args: args["D"] % args["BLOCK_D"] == 0, + "HAS_BIAS": lambda args: args["Bias"] is not None, + "HAS_VALIDS": lambda args: args["Valids"] is not None, + "HAS_SHIFT": lambda args: args["shift"] != 0, + "GROUP_B": lambda args: 8, + } +)(_indexed_neg_dot_forward_kernel) +_indexed_neg_dot_forward_kernel = indexed_dot_autotune()(_indexed_neg_dot_forward_kernel) # type: ignore + + +def indexed_neg_dot_forward_kernel( + e: torch.Tensor, + c: torch.Tensor, + inds: torch.Tensor, + bias: torch.Tensor | None = None, + shift: int = 0, + valids: torch.Tensor | None = None, + softcap: float | None = None, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + assert inds.ndim == 1 + assert e.ndim == 2 + assert c.ndim == 2 + assert inds.size(0) == e.size(0) + assert c.size(1) == e.size(1) + + if valids is not None: + assert valids.ndim == 1 + B = valids.size(0) + else: + B = e.size(0) + + out = e.new_zeros((B,), dtype=torch.float32) + + def grid(META) -> tuple[int]: + return (triton.cdiv(B, META["BLOCK_B"]) * triton.cdiv(e.size(1), META["BLOCK_D"]),) + + _indexed_neg_dot_forward_kernel[grid]( + e, + c, + inds, + bias, + valids, + out, + B, + e.size(1), + c.size(0), + e.size(0), + e.stride(0), + e.stride(1), + c.stride(0), + c.stride(1), + inds.stride(0), + 1 if bias is None else bias.stride(0), + 1 if valids is None else valids.stride(0), + shift=shift, + B_BIN=b_bin_fn(B), + ) + + if softcap is not None: + out = softcapping(out, softcap) + + if out_dtype is None: + out_dtype = e.dtype + + out = out.to(out_dtype) + + return out \ No newline at end of file diff --git a/kernels/cut_cross_entropy/linear_cross_entropy.py b/kernels/cut_cross_entropy/linear_cross_entropy.py new file mode 100644 index 000000000..4a84b25e9 --- /dev/null +++ b/kernels/cut_cross_entropy/linear_cross_entropy.py @@ -0,0 +1,120 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +import enum +import platform +from enum import auto +from typing import TYPE_CHECKING + +import torch +import torch.nn as nn + +from cut_cross_entropy.constants import IGNORE_INDEX +from cut_cross_entropy.doc import CCE_OPTS_DOC, LINEAR_CROSS_ENTROPY_DOC, add_doc_start +from cut_cross_entropy.torch_compile import torch_compile_linear_cross_entropy + + +class LinearCrossEntropyImpl(enum.IntEnum): + CCE = auto() + TORCH_COMPILE = auto() + CCE_EXACT = auto() + + +PLATFORM_SYSTEM = platform.system() + +if TYPE_CHECKING or PLATFORM_SYSTEM != "Darwin": + from cut_cross_entropy.cce import cce_linear_cross_entropy + + LCE_IMPL_DEFAULT = LinearCrossEntropyImpl.CCE +else: + cce_linear_cross_entropy = None + LCE_IMPL_DEFAULT = LinearCrossEntropyImpl.TORCH_COMPILE + + +@add_doc_start(LINEAR_CROSS_ENTROPY_DOC) +@add_doc_start(*(doc_str + " Only valid for the cce implementation.\n" for doc_str in CCE_OPTS_DOC)) +def linear_cross_entropy( + e: torch.Tensor, + c: torch.Tensor, + targets: torch.Tensor, + bias: torch.Tensor | None = None, + ignore_index: int = IGNORE_INDEX, + softcap: float | None = None, + reduction: str = "mean", + shift: bool | int = 0, + filter_eps: float | str | None = "auto", + use_kahan: bool = False, + impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT, +) -> torch.Tensor: + """ + :param impl: The linear cross entropy implementation to use. Currently supports cce, torch_compile, and cce_exact. + """ + + if isinstance(impl, LinearCrossEntropyImpl): + impl = impl.name.lower() + + if isinstance(shift, int) and (shift < 0 or shift >= targets.size(-1)): + raise ValueError(f"Shift must be in the range [0, {targets.size(-1)}). Got {shift}.") + + match impl: + case "cce" | "cce_exact": + if platform.system() == "Darwin": + raise RuntimeError( + "CCE does not support MacOS. Please use torch_compile when running on MacOS instead." + ) + + if impl == "cce_exact": + filter_eps = None + use_kahan = True + + assert cce_linear_cross_entropy is not None + return cce_linear_cross_entropy( + e, c, targets, bias, ignore_index, softcap, reduction, shift, filter_eps, use_kahan + ) + case "torch_compile": + return torch_compile_linear_cross_entropy( + e, c, targets, bias, ignore_index, softcap, reduction, shift + ) + case _: + raise NotImplementedError(f"{impl} is not implemented.") + + +class LinearCrossEntropy(nn.Module): + def __init__( + self, + ignore_index: int = IGNORE_INDEX, + softcap: float | None = None, + reduction: str = "mean", + shift: bool | int = 0, + filter_eps: float | str | None = "auto", + use_kahan: bool = False, + impl: str | LinearCrossEntropyImpl = LCE_IMPL_DEFAULT, + ): + super().__init__() + self.ignore_index = ignore_index + self.softcap = softcap + self.reduction = reduction + self.filter_eps = filter_eps + self.shift = shift + self.use_kahan = use_kahan + + self.impl = impl + + def forward( + self, + e: torch.Tensor, + c: torch.Tensor, + targets: torch.Tensor, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: + return linear_cross_entropy( + e, + c, + targets, + bias=bias, + ignore_index=self.ignore_index, + softcap=self.softcap, + reduction=self.reduction, + shift=self.shift, + filter_eps=self.filter_eps, + use_kahan=self.use_kahan, + impl=self.impl, + ) \ No newline at end of file diff --git a/kernels/cut_cross_entropy/tl_autotune.py b/kernels/cut_cross_entropy/tl_autotune.py new file mode 100644 index 000000000..50b909650 --- /dev/null +++ b/kernels/cut_cross_entropy/tl_autotune.py @@ -0,0 +1,595 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +import functools +import heapq +import os +from dataclasses import dataclass, field +from typing import Any, Callable + +import torch +import triton +from triton import Config, cdiv +from triton.runtime import autotuner, driver +from triton.testing import ( + get_dram_gbps, + get_max_simd_tflops, + get_max_tensorcore_tflops, + nvsmi, +) + +_AUTOTUNE: bool = os.getenv("CCE_AUTOTUNE", "0") != "0" + + +@dataclass +class NoneSupportRestorer: + reset_idx: list[int] + restore_idx: list[int] + _restore_copies: list[torch.Tensor | None] = field(default_factory=list, init=False) + + def pre_hook(self, args: list[torch.Tensor | None | Any]) -> None: + for i in self.reset_idx: + v = args[i] + if v is not None: + assert isinstance(v, torch.Tensor) + v.zero_() + + for i in self.reset_idx: + v = args[i] + if v is not None: + assert isinstance(v, torch.Tensor) + self._restore_copies.append(v.clone()) + else: + self._restore_copies.append(None) + + def post_hook(self, args: list[torch.Tensor | None | Any], _exception) -> None: + for j, i in enumerate(self.reset_idx): + v = args[i] + if v is not None: + old_v = self._restore_copies[j] + assert isinstance(v, torch.Tensor) + assert old_v is not None + + v.copy_(old_v) + + self._restore_copies = [] + + +@functools.wraps(triton.autotune) +def _cce_autotune(*args, **kwargs) -> Callable[..., autotuner.Autotuner]: + def decorator(fn): + reset_idx = [] + restore_idx = [] + arg_names = fn.arg_names + reset_to_zero = kwargs.pop("reset_to_zero", None) + if reset_to_zero is not None: + reset_idx = [arg_names.index(k) for k in reset_to_zero] + + restore_value = kwargs.pop("restore_value", None) + if restore_value is not None: + restore_idx = [arg_names.index(k) for k in restore_value] + + restorer = NoneSupportRestorer(reset_idx, restore_idx) + if len(reset_idx) > 0: + kwargs["pre_hook"] = restorer.pre_hook + + if len(restore_idx) > 0: + kwargs["post_hook"] = restorer.post_hook + + return triton.autotune(*args, **kwargs)(fn) + + return decorator + + +@functools.lru_cache() +def get_clock_rate_in_khz(): + try: + return nvsmi(["clocks.max.sm"])[0] * 1e3 + except FileNotFoundError: + import pynvml + + pynvml.nvmlInit() + handle = pynvml.nvmlDeviceGetHandleByIndex(0) + return pynvml.nvmlDeviceGetMaxClockInfo(handle, pynvml.NVML_CLOCK_SM) * 1e3 + + +def get_tensorcore_tflops(device, num_ctas, num_warps, dtype): + """return compute throughput in TOPS""" + total_warps = num_ctas * min(num_warps, 4) + num_subcores = ( + driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + ) # on recent GPUs + tflops = ( + min(num_subcores, total_warps) + / num_subcores + * get_max_tensorcore_tflops(dtype, get_clock_rate_in_khz(), device) + ) + return tflops + + +def get_simd_tflops(device, num_ctas, num_warps, dtype): + """return compute throughput in TOPS""" + total_warps = num_ctas * min(num_warps, 4) + num_subcores = ( + driver.active.utils.get_device_properties(device)["multiprocessor_count"] * 4 + ) # on recent GPUs + tflops = ( + min(num_subcores, total_warps) + / num_subcores + * get_max_simd_tflops(dtype, get_clock_rate_in_khz(), device) + ) + return tflops + + +def get_tflops(device, num_ctas, num_warps, dtype): + capability = torch.cuda.get_device_capability(device) + if capability[0] < 8 and dtype == torch.float32: + return get_simd_tflops(device, num_ctas, num_warps, dtype) + return get_tensorcore_tflops(device, num_ctas, num_warps, dtype) + + +def early_config_prune( + configs, + named_args, + *, + shared_memory_factor: float = 1.0, + max_num_warps: int | None = None, + **kwargs, +): + device = torch.cuda.current_device() + capability = torch.cuda.get_device_capability() + # BLOCK_B, BLOCK_V, BLOCK_D, SPLIT_K, num_warps, num_stages + dtsize = named_args["E"].element_size() + + if max_num_warps is not None: + configs = [config for config in configs if config.num_warps <= max_num_warps] + + # 1. make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_B, BLOCK_V, BLOCK_D, num_stages = ( + kw["BLOCK_B"], + kw["BLOCK_V"], + kw["BLOCK_D"], + config.num_stages, + ) + + max_shared_memory = driver.active.utils.get_device_properties(device)["max_shared_mem"] + required_shared_memory = ( + shared_memory_factor * (BLOCK_B + BLOCK_V) * BLOCK_D * num_stages * dtsize + ) + if required_shared_memory > max_shared_memory: + continue + + pruned_configs.append(config) + + configs = pruned_configs + + # group configs by (BLOCK_B,_N,_K, num_warps) + configs_map = {} + for config in configs: + kw = config.kwargs + BLOCK_B, BLOCK_V, BLOCK_D, num_warps, num_stages = ( + kw["BLOCK_B"], + kw["BLOCK_V"], + kw["BLOCK_D"], + config.num_warps, + config.num_stages, + ) + + key = (BLOCK_B, BLOCK_V, BLOCK_D, num_warps) + if key in configs_map: + configs_map[key].append((config, num_stages)) + else: + configs_map[key] = [(config, num_stages)] + + pruned_configs = [] + for k, v in configs_map.items(): + BLOCK_B, BLOCK_V, BLOCK_D, num_warps = k + if capability[0] >= 8: + # compute cycles (only works for ampere GPUs) + mmas = BLOCK_B * BLOCK_V * BLOCK_D / (16 * 8 * 16) + mma_cycles = mmas / min(4, num_warps) * 8 + + ldgsts_latency = 300 # Does this matter? + optimal_num_stages = ldgsts_latency / mma_cycles + + # nearest stages, prefer large #stages + nearest = heapq.nsmallest( + 2, + v, + key=lambda x: 10 + abs(x[1] - optimal_num_stages) + if (x[1] - optimal_num_stages) < 0 + else x[1] - optimal_num_stages, + ) + + for n in nearest: + pruned_configs.append(n[0]) + else: # Volta & Turing only supports num_stages <= 2 + random_config = v[0][0] + random_config.num_stages = 2 + pruned_configs.append(random_config) + return pruned_configs + + +def _total_ops_fn(B, V, D) -> float: + return 2 * B * V * D + 10 * B * V + + +def _total_store_fn(B, V, D, dtsize, num_cta_b, num_cta_v): + return B * dtsize + + +def estimate_matmul_time( + # backend, device, + num_warps, + num_stages, # + E, + B, + V, + D, # + BLOCK_B, + BLOCK_V, + BLOCK_D, + debug=False, + total_ops_fn=_total_ops_fn, + total_store_fn=_total_store_fn, + **kwargs, # +): + """return estimated running time in ms + = max(compute, loading) + store""" + device = torch.cuda.current_device() + dtype = E.dtype + dtsize = E.element_size() + + num_cta_b = cdiv(B, BLOCK_B) + num_cta_v = cdiv(V, BLOCK_V) + num_ctas = num_cta_b * num_cta_v + + # If the input is smaller than the block size + B, V = max(B, BLOCK_B), max(V, BLOCK_V) + + # time to compute + total_ops = total_ops_fn(B, V, D) + total_ops = total_ops / (1024 * 1024 * 1024) # GOPS + tput = get_tflops(device, num_ctas, num_warps, dtype) + compute_ms = total_ops / tput + + # time to load data + num_sm = driver.active.utils.get_device_properties(device)["multiprocessor_count"] + active_cta_ratio = min(1, num_ctas / num_sm) + active_cta_ratio_bw1 = min(1, num_ctas / 32) # 32 active ctas are enough to saturate + active_cta_ratio_bw2 = max(min(1, (num_ctas - 32) / (108 - 32)), 0) # 32-108, remaining 5% + dram_bw = get_dram_gbps(device) * ( + active_cta_ratio_bw1 * 0.95 + active_cta_ratio_bw2 * 0.05 + ) # in GB/s + l2_bw = dram_bw * 4 # rough estimation (should be 4.7 for A100?) + # assume 80% of (following) loads are in L2 cache + load_a_dram = B * D * dtsize * (1 + 0.2 * (num_cta_v - 1)) + load_a_l2 = B * D * dtsize * 0.8 * (num_cta_v - 1) + load_b_dram = V * D * dtsize * (1 + 0.2 * (num_cta_b - 1)) + load_b_l2 = V * D * dtsize * 0.8 * (num_cta_b - 1) + # total + total_dram = (load_a_dram + load_b_dram) / (1024 * 1024) # MB + total_l2 = (load_a_l2 + load_b_l2) / (1024 * 1024) + # loading time in ms + load_ms = total_dram / dram_bw + total_l2 / l2_bw + + # estimate storing time + store_bw = dram_bw * 0.4 # :o + store_dram = total_store_fn(B, V, D, dtsize, num_cta_b, num_cta_v) / (1024 * 1024) + store_ms = store_dram / store_bw + + total_time_ms = max(compute_ms, load_ms) + store_ms + if debug: + print( + f"{BLOCK_B=}, {BLOCK_V=}, {BLOCK_D=}, {num_warps=}, {num_stages=}, " + f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, " + f"loading time: {load_ms}ms, store time: {store_ms}ms, " + f"Activate CTAs: {active_cta_ratio*100}%" + ) + return total_time_ms + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + Config( + { + "BLOCK_B": block_m, + "BLOCK_V": block_n, + "BLOCK_D": block_k, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + return configs + + +def get_autotune_config(): + return [ + # basic configs for compute-bound matmuls + Config( + {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 128}, + num_stages=2, + num_warps=4, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 256, "BLOCK_D": 32}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_B": 256, "BLOCK_V": 128, "BLOCK_D": 32}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_B": 256, "BLOCK_V": 64, "BLOCK_D": 32}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 64, "BLOCK_V": 256, "BLOCK_D": 32}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 32}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 32}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 32}, + num_stages=4, + num_warps=8, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 64, "BLOCK_D": 32}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 64, "BLOCK_V": 128, "BLOCK_D": 32}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 32, "BLOCK_D": 32}, + num_stages=4, + num_warps=4, + ), + Config({"BLOCK_B": 64, "BLOCK_V": 32, "BLOCK_D": 32}, num_stages=5, num_warps=2), + # good for int8 + Config( + {"BLOCK_B": 128, "BLOCK_V": 256, "BLOCK_D": 128}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 256, "BLOCK_D": 128}, + num_stages=3, + num_warps=16, + ), + Config( + {"BLOCK_B": 256, "BLOCK_V": 128, "BLOCK_D": 128}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_B": 256, "BLOCK_V": 128, "BLOCK_D": 128}, + num_stages=3, + num_warps=16, + ), + Config( + {"BLOCK_B": 256, "BLOCK_V": 64, "BLOCK_D": 128}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 64, "BLOCK_V": 256, "BLOCK_D": 128}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 128, "BLOCK_D": 128}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 64, "BLOCK_D": 64}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 64, "BLOCK_V": 128, "BLOCK_D": 64}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_B": 128, "BLOCK_V": 32, "BLOCK_D": 64}, + num_stages=4, + num_warps=4, + ), + Config({"BLOCK_B": 64, "BLOCK_V": 32, "BLOCK_D": 64}, num_stages=5, num_warps=2), + ] + get_configs_io_bound() + + +def _heuristics_from_config(config: Config) -> Callable[..., autotuner.Heuristics]: + return triton.heuristics({k: (lambda args, _v=v: _v) for k, v in config.all_kwargs().items()}) + + +def _cce_forward_best_config() -> Config: + return Config(dict(BLOCK_B=256, BLOCK_V=128, BLOCK_D=32), num_warps=8, num_stages=3) + + +def _cce_sampled_forward_best_config() -> Config: + # return Config(dict(BLOCK_B=32, BLOCK_V=128), num_warps=2, num_stages=3) + # return Config(dict(BLOCK_B=128, BLOCK_V=128), num_warps=16, num_stages=4) + return Config(dict(BLOCK_B=32, BLOCK_V=128), num_warps=16, num_stages=4) + + + +def cce_forward_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]: + if _AUTOTUNE: + return _cce_autotune( + configs=get_autotune_config(), + key=["V", "D", "B_BIN"], + prune_configs_by={ + "early_config_prune": early_config_prune, + "perf_model": estimate_matmul_time, + "top_k": 10, + }, + restore_value=["LSE"], + reset_to_zero=["LA"], + ) + else: + return _heuristics_from_config(_cce_forward_best_config()) + + +def cce_sampled_forward_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]: + if _AUTOTUNE: + return _cce_autotune( + configs=get_autotune_config(), + key=["V", "D", "B_BIN"], + prune_configs_by={ + "early_config_prune": early_config_prune, + "perf_model": estimate_matmul_time, + "top_k": 10, + }, + restore_value=["LSE"], + reset_to_zero=["LA"], + ) + else: + return _heuristics_from_config(_cce_sampled_forward_best_config()) + + +def _bw_total_ops_fn(B, V, D) -> float: + return 2 * B * V * D + 6 * B * V + 0.2 * (2 * B * V * D + 2 * B * V * D) + + +def _bw_total_store_fn(B, V, D, dtsize, num_cta_b, num_cta_v): + return 0.2 * (num_cta_v * B * D * dtsize + num_cta_b * D * V * dtsize) + + +def _cce_backward_best_config() -> Config: + return Config(dict(BLOCK_B=128, BLOCK_V=128, BLOCK_D=32), num_warps=4, num_stages=4) + + +def _cce_sampled_backward_best_config() -> Config: + # return Config(dict(BLOCK_B=32, BLOCK_V=128), num_warps=2, num_stages=5) + # return Config(dict(BLOCK_B=128, BLOCK_V=128), num_warps=16, num_stages=4) + return Config(dict(BLOCK_B=32, BLOCK_V=128), num_warps=16, num_stages=4) + +def cce_backward_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]: + if _AUTOTUNE: + return _cce_autotune( + configs=get_autotune_config(), + key=["V", "D", "B_BIN"], + prune_configs_by={ + "early_config_prune": functools.partial( + early_config_prune, shared_memory_factor=2.0 + ), + "perf_model": functools.partial( + estimate_matmul_time, + total_ops_fn=_bw_total_ops_fn, + total_store_fn=_bw_total_store_fn, + ), + "top_k": 5, + }, + reset_to_zero=["dE", "dC", "dEC", "dCC", "dBias"], + ) + else: + return _heuristics_from_config(_cce_backward_best_config()) + + +def cce_sampled_backward_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]: + if _AUTOTUNE: + return _cce_autotune( + configs=get_autotune_config(), + key=["V", "D", "B_BIN"], + prune_configs_by={ + "early_config_prune": functools.partial( + early_config_prune, shared_memory_factor=2.0 + ), + "perf_model": functools.partial( + estimate_matmul_time, + total_ops_fn=_bw_total_ops_fn, + total_store_fn=_bw_total_store_fn, + ), + "top_k": 5, + }, + reset_to_zero=["dE", "dC", "dEC", "dCC", "dBias"], + ) + else: + return _heuristics_from_config(_cce_sampled_backward_best_config()) + + +def _indexed_dot_best_config() -> Config: + return Config(dict(BLOCK_B=128, BLOCK_D=256), num_warps=16, num_stages=4) + + +def _indexed_dot_all_configs() -> list[Config]: + return [ + Config( + dict( + BLOCK_B=128, + BLOCK_D=128, + ), + num_warps=4, + num_stages=4, + ), + Config( + dict( + BLOCK_B=128, + BLOCK_D=128, + ), + num_warps=8, + num_stages=4, + ), + Config( + dict( + BLOCK_B=256, + BLOCK_D=256, + ), + num_warps=16, + num_stages=4, + ), + Config( + dict( + BLOCK_B=256, + BLOCK_D=128, + ), + num_warps=16, + num_stages=4, + ), + Config( + dict( + BLOCK_B=128, + BLOCK_D=256, + ), + num_warps=16, + num_stages=4, + ), + ] + + +def indexed_dot_autotune() -> Callable[..., autotuner.Autotuner | autotuner.Heuristics]: + if _AUTOTUNE: + return _cce_autotune( + configs=_indexed_dot_all_configs(), + key=["D", "B_BIN"], + reset_to_zero=["Out"], + ) + else: + return _heuristics_from_config(_indexed_dot_best_config()) \ No newline at end of file diff --git a/kernels/cut_cross_entropy/tl_utils.py b/kernels/cut_cross_entropy/tl_utils.py new file mode 100644 index 000000000..e35411c6f --- /dev/null +++ b/kernels/cut_cross_entropy/tl_utils.py @@ -0,0 +1,90 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +import triton +import triton.language as tl +from triton.language.extra import libdevice as tl_libdevice + + +@triton.jit +def tl_and_reduce_fn(a, b): + return a & b + + +@triton.jit +def tl_tanh(a: tl.tensor) -> tl.tensor: + return tl_libdevice.tanh(a) + + +@triton.jit +def tl_log1p(a: tl.tensor) -> tl.tensor: + return tl_libdevice.log1p(a) + + +@triton.jit +def tl_softcapping(v: tl.tensor, softcap: float) -> tl.tensor: + return tl_tanh(v / softcap) * softcap + + +@triton.jit +def tl_softcapping_grad(dv: tl.tensor, v: tl.tensor, softcap: float) -> tl.tensor: + v = v / softcap + return dv * (1 - v * v) + + +@triton.jit +def tl_logaddexp(a, b) -> tl.tensor: + minx = tl.minimum(a, b) + mx = tl.maximum(a, b) + return tl_log1p(tl.exp(minx - mx)) + mx + + +@triton.jit +def tl_2sum(a: tl.tensor, b: tl.tensor) -> tuple[tl.tensor, tl.tensor]: + s = a + b + + a_prime = s - b + b_prime = s - a_prime + + delta_a = a - a_prime + delta_b = b - b_prime + + t = delta_a + delta_b + return s, t + + +@triton.jit +def tl_lock_kahan_sum(ptrs, c_ptrs, v, mask, lock_ptr): + while tl.atomic_cas(lock_ptr, 0, 1) == 1: + pass + + s = tl.load(ptrs, mask=mask, other=0.0, eviction_policy="evict_last") + c = tl.load(c_ptrs, mask=mask, other=0.0, eviction_policy="evict_last") + + s, c = tl_2sum(s, c + v) + + tl.store(ptrs, s, mask=mask, eviction_policy="evict_last") + tl.store(c_ptrs, c, mask=mask, eviction_policy="evict_last") + + tl.debug_barrier() + tl.atomic_xchg(lock_ptr, 0) + + +@triton.jit +def tl_lock_add(ptrs, v, mask, lock_ptr): + while tl.atomic_cas(lock_ptr, 0, 1) == 1: + pass + + cur_v = tl.load(ptrs, mask=mask, other=0.0, eviction_policy="evict_last") + new_v = v + cur_v + tl.store(ptrs, new_v, mask=mask, eviction_policy="evict_last") + + tl.debug_barrier() + tl.atomic_xchg(lock_ptr, 0) + + +def b_bin_fn(b: int) -> int: + if b >= 1024: + return 1024 + elif b <= 128: + return 128 + else: + return 512 \ No newline at end of file diff --git a/kernels/cut_cross_entropy/torch_compile.py b/kernels/cut_cross_entropy/torch_compile.py new file mode 100644 index 000000000..a2e7d124b --- /dev/null +++ b/kernels/cut_cross_entropy/torch_compile.py @@ -0,0 +1,82 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +import torch +import torch.nn.functional as F + +from cut_cross_entropy.constants import IGNORE_INDEX +from cut_cross_entropy.doc import LINEAR_CROSS_ENTROPY_DOC, add_doc_start +from cut_cross_entropy.utils import ( + _build_flat_valids, + handle_reduction_none, + softcapping, +) + + +@torch.compile(fullgraph=True, dynamic=True) +def torch_compile_linear_cross_entropy_apply( + e: torch.Tensor, + c: torch.Tensor, + targets: torch.Tensor, + bias: torch.Tensor | None = None, + softcap: float | None = None, + *, + ignore_index: int = IGNORE_INDEX, + reduction: str = "mean", +) -> torch.Tensor: + logits = e @ c.T + + if bias is not None: + logits = logits + bias + + if softcap is not None: + logits = softcapping(logits, softcap) + + loss = F.cross_entropy(logits.float(), targets, ignore_index=ignore_index, reduction=reduction) + + return loss + + +@add_doc_start(LINEAR_CROSS_ENTROPY_DOC) +def torch_compile_linear_cross_entropy( + e: torch.Tensor, + c: torch.Tensor, + targets: torch.Tensor, + bias: torch.Tensor | None = None, + ignore_index: int = IGNORE_INDEX, + softcap: float | None = None, + reduction: str = "mean", + shift: bool | int = 0, +) -> torch.Tensor: + assert e.size()[0:-1] == targets.size() + assert e.size(-1) == c.size(1) + + orig_b_size = targets.size() + e = e.contiguous() + targets = targets.contiguous() + + shift = int(shift) + valids = _build_flat_valids(targets, ignore_index, shift) + + e = e.flatten(0, -2) + targets = targets.flatten() + + if valids is not None: + e = e[valids] + targets = targets[(valids + shift) if shift != 0 else valids] + + loss = torch_compile_linear_cross_entropy_apply( + e, + c, + targets, + bias, + softcap, + ignore_index=ignore_index, + reduction=reduction, + ) + + if reduction == "none": + loss = handle_reduction_none(orig_b_size, valids, shift, loss) + + if shift != 0: + loss = loss[..., shift:] + + return loss \ No newline at end of file diff --git a/kernels/cut_cross_entropy/utils.py b/kernels/cut_cross_entropy/utils.py new file mode 100644 index 000000000..4e95e1a0e --- /dev/null +++ b/kernels/cut_cross_entropy/utils.py @@ -0,0 +1,55 @@ +# Copyright (C) 2024 Apple Inc. All Rights Reserved. +import torch + + +@torch.compile(fullgraph=True, dynamic=True) +def softcapping(logits: torch.Tensor, softcap: float) -> torch.Tensor: + return torch.tanh(logits / softcap) * softcap + + +def _handle_eps(filter_eps: float | str | None, dtype: torch.dtype) -> float | None: + match filter_eps: + case None: + return None + case float(): + return filter_eps + case "auto": + return torch.finfo(dtype).eps / 32 + case _: + raise RuntimeError(f"Unknown eps {filter_eps=}") + + +def _build_flat_valids( + targets: torch.Tensor, + ignore_index: int, + shift: int, +) -> torch.Tensor | None: + if shift != 0: + targets = targets[..., shift:] + else: + targets = targets.flatten() + + valids = (targets != ignore_index).nonzero().to(torch.int32) + + if shift == 0: + assert valids.size(1) == 1 + return valids.squeeze(1) if valids.numel() != targets.numel() else None + + for i in range(targets.ndim - 1): + valids[:, i] *= targets.stride(i) + + assert targets.stride(-1) == 1 + + return valids.sum(1) + + +def handle_reduction_none( + batch_shape: torch.Size, valids: torch.Tensor | None, shift: int, loss: torch.Tensor +) -> torch.Tensor: + if valids is None: + return loss.view(batch_shape) + + full_loss = loss.new_zeros((batch_shape.numel(),)) + full_loss[(valids + shift) if shift != 0 else valids] = loss + + return full_loss.view(batch_shape) \ No newline at end of file diff --git a/kernels/fused_linear_cross_entropy/__init__.py b/kernels/fused_linear_cross_entropy/__init__.py new file mode 100644 index 000000000..917dfc305 --- /dev/null +++ b/kernels/fused_linear_cross_entropy/__init__.py @@ -0,0 +1 @@ +from fused_linear_cross_entropy.fused_linear_ce_loss import LigerFusedLinearCrossEntropyFunction \ No newline at end of file diff --git a/kernels/fused_linear_cross_entropy/fused_linear_ce_loss.py b/kernels/fused_linear_cross_entropy/fused_linear_ce_loss.py new file mode 100644 index 000000000..67c9cdc40 --- /dev/null +++ b/kernels/fused_linear_cross_entropy/fused_linear_ce_loss.py @@ -0,0 +1,542 @@ +import torch +import triton +import triton.language as tl +try: + # typical import path with dispatch available + from triton.language.extra.libdevice import tanh +except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import tanh + + + + +# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19 +# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling +# The optimal maximum block size depends on your hardware, your kernel, and your dtype +MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning + + +@triton.jit +def liger_cross_entropy_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + weight_ptr, + loss_ptr, + z_loss_ptr, + loss_stride, + n_cols, + n_non_ignore, + sum_non_ignore_weight, + weight_sum, + ignore_index, + lse_square_scale: tl.constexpr, + label_smoothing: tl.constexpr, + reduction: tl.constexpr, # set it as constexpr since reduction is always known at compile time + softcap, + RETURN_Z_LOSS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HAS_WEIGHT: tl.constexpr, + HAS_SOFTCAPPING: tl.constexpr, +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + We only consider hard label + mean reduction for now. Please refer to https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html for the math. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + weight_ptr: Pointer to weight tensor. + loss_ptr: Pointer to tensor to store the loss. + z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0. + loss_stride (int): The stride of the loss tensor. + n_cols (int): The number of columns in the input tensor. + n_non_ignore (flaot): The number of non-ignored elements in the batch. + sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch. + weight_sum (float): The sum of weight tensor. + ignore_index (int): The index to ignore in the target. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training. + reduction (str): The string for the reduction to apply + softcap (float): The upper threshold for scaling logits to the range (-softcap, +softcap). + RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1. + BLOCK_SIZE (int): The block size for Triton operations. + HAS_WEIGHT (bool): The boolean value to determine whether assigning weight to each of the classes. + HAS_SOFTCAPPING (bool): The boolean value to determine whether applying soft-capping or not. + """ + + # https://github.com/triton-lang/triton/issues/1058 + # If B*T*V is too large, program_id * stride will overflow out of int32, so we convert to int64 + program_id = tl.program_id(0).to(tl.int64) + + # 1. Load Y_ptr first because if the target is ignore_index, we can return right away + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + # 2. locate the start index + X_ptr += program_id * X_stride + + if y == ignore_index: + # set all X_ptr as 0 + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) + return + + loss_ptr += program_id * loss_stride + if RETURN_Z_LOSS: + z_loss_ptr += program_id * loss_stride + + if HAS_WEIGHT: + weight_y = tl.load(weight_ptr + y).cast(tl.float32) + + # Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax) + # Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867 + + # 3. [Online softmax] first pass: find max + sum + m = float("-inf") # m is the max value. use the notation from the paper + d = 0.0 # d is the sum. use the notation from the paper + ori_X_y = tl.load(X_ptr + y).cast(tl.float32) # we need to store the original value of X_y for the loss calculation + if HAS_SOFTCAPPING: + ori_X_y = softcap * tanh(ori_X_y / softcap) + + # Label smoothing is a general case of normal cross entropy + # See the full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issue-2503665310 + scaled_x_sum = 0.0 + eps = label_smoothing / n_cols + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + X_block = softcap * tanh(X_block / softcap) + block_max = tl.max(X_block) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + if HAS_WEIGHT: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block * weight_block, 0.0)) + else: + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + # log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X))))) + # = log (e^(max(X)) * sum(e ^ (X_i - max(X)))) + # = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d + lse = m + tl.log(d) + + # 4. [Online Softmax] Second pass: compute gradients + # For 'mean' reduction, gradients are normalized by number of non-ignored elements (N) + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # For label smoothing: + # dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + # With Z loss: + # dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y + # dx_y = dx_i - (1 - label_smoothing) / N + # For 'sum' reduction, no normalization is applied: + # dx_y = softmax(x_y) - 1 + # dx_i = softmax(x_i), for i ≠ y + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load( + X_ptr + X_offsets, + mask=X_offsets < n_cols, + other=float("-inf"), + # Ensure float32 precision for softmax calculation + ).cast(tl.float32) + if HAS_SOFTCAPPING: + intermediate = tanh(X_block / softcap) + X_block = softcap * intermediate + + if not HAS_WEIGHT: + # softmax(x_i) + X_block = tl.exp(X_block - m) / d + # derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i) + X_block += 2 * lse_square_scale * lse * X_block + # smoothing term + X_block += -eps + # special handle dx_y + X_block = tl.where(X_offsets != y, X_block, X_block - (1 - label_smoothing)) + # reduction scale + if reduction == "mean": + X_block = X_block / n_non_ignore + else: + weight_block = tl.load(weight_ptr + X_offsets, mask=X_offsets < n_cols) + softmax_X = tl.exp(X_block - m) / d + # derivative of original_loss + dloss_ori = (1 - label_smoothing) * softmax_X + # specially handle dx_y + dloss_ori = tl.where(X_offsets != y, dloss_ori, dloss_ori - (1 - label_smoothing)) + dloss_ori = dloss_ori * weight_y + # derivative of smooth_loss + dloss_smooth = eps * (-weight_block + softmax_X * weight_sum) + # derivative of z-loss + dz_loss = 2 * lse_square_scale * lse * softmax_X + # reduction scale + if reduction == "mean": + dloss_ori = dloss_ori / sum_non_ignore_weight + dloss_smooth = dloss_smooth / sum_non_ignore_weight + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + dz_loss = dz_loss / n_non_ignore + # derivative of total_loss + X_block = dloss_ori + dloss_smooth + dz_loss + + # chain rule softcapping + # d(softcap * tanh(x / softcap)) = (1 - tanh^2(x / softcap)) + if HAS_SOFTCAPPING: + X_block = X_block * (1 - intermediate * intermediate) + + tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in + # https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/ops/cross_entropy.py#L34 + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + # = X_y - m - log d = X_y - lse + # sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1 + # So we can safely calculate log (softmax(X_y)) without overflow + loss = lse - ori_X_y + if HAS_WEIGHT: + loss = weight_y * loss + + # Original loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (sum(-eps * x_i) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + # pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516 + # See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087 + if label_smoothing > 0: + if HAS_WEIGHT: + smooth_loss = scaled_x_sum + eps * lse * weight_sum + else: + smooth_loss = scaled_x_sum + label_smoothing * lse + loss = loss * (1 - label_smoothing) + smooth_loss + + # An auxiliary loss, z_loss + # Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html + z_loss = lse_square_scale * lse * lse + # Normalize the loss by the number of non-ignored elements if reduction is "mean" + if reduction == "mean": + if HAS_WEIGHT: + loss = loss / sum_non_ignore_weight + else: + loss = loss / n_non_ignore + # TODO: Implement weighted z_loss. Currently, z_loss is not scaled by weight. + z_loss = z_loss / n_non_ignore + loss += z_loss + + tl.store(loss_ptr, loss) + if RETURN_Z_LOSS: + tl.store(z_loss_ptr, z_loss) + + +def fused_linear_cross_entropy_forward( + _input, + weight, + target, + ce_weight=None, + bias=None, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + softcap=None, + return_z_loss=False, + triton_backend=True +): + assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" + device = _input.device + + # inputs have shape: BT x H + # materialized activations will have shape: BT x V + # the increase in memory = BT x V + # reduction can be achieved by partitioning the number of tokens BT into smaller chunks. + # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be: + # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor + # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048 + BT, H = _input.shape + V = weight.shape[0] + # BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + + # inc_factor = triton.cdiv(V, H) # (V + H - 1) // H + # chunk_size = triton.next_power_of_2(triton.cdiv(BT, inc_factor)) # (BT + inc_factor - 1) // inc_factor + # num_chunks = triton.cdiv(BT, chunk_size) # (BT + chunk_size - 1) // chunk_size + + # inc_factor = (V + H - 1) // H + # chunk_size = (BT + inc_factor - 1) // inc_factor + # num_chunks = (BT + chunk_size - 1) // chunk_size + + chunk_size = 1024 + if triton_backend: + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V)) + num_chunks = triton.cdiv(BT, chunk_size) + else: + num_chunks = (BT + chunk_size - 1) // chunk_size + + grad_weight = torch.zeros_like(weight, device=device) if weight.requires_grad else None + grad_input = torch.zeros_like(_input, device=device) + grad_bias = torch.zeros_like(bias, device=device) if bias is not None else None + # we use fp32 for loss accumulator + loss_1d = torch.zeros(BT, dtype=torch.float32, device=device) + z_loss_1d = torch.zeros(BT, dtype=_input.dtype, device=_input.device) if return_z_loss else None + + # TODO: evaluate how CUDA synchronization caused by .item() affects the speed + target_mask = target != ignore_index + total_n_non_ignore = target_mask.sum().item() + total_sum_non_ignore_ce_weight = total_n_non_ignore + ce_weight_sum = 0.0 + if ce_weight is not None: + assert ce_weight.shape[0] == V, f"If given, weight has to be a Tensor of size V. Got: {ce_weight.shape}" + assert torch.is_floating_point(ce_weight), ( + f"If given, weight has to be a Tensor of floating point dtype. Got: {ce_weight.dtype}" + ) + total_sum_non_ignore_ce_weight = ( + torch.gather(ce_weight, dim=0, index=target.masked_select(target_mask)).sum().item() + ) + ce_weight_sum = ce_weight.sum().item() + if ce_weight.stride(-1) != 1: + ce_weight = ce_weight.contiguous() + + for chunk_id in range(num_chunks): + start_idx = chunk_id * chunk_size + end_idx = min((chunk_id + 1) * chunk_size, BT) + _input_chunk = _input[start_idx:end_idx] # chunk_size x H + + # when doing matmul, use the original precision + logits_chunk = _input_chunk @ weight.t() # chunk_size x V + if bias is not None: + logits_chunk = logits_chunk + bias + + target_chunk = target[start_idx:end_idx] # chunk_size, + + n_rows = logits_chunk.shape[0] + + + + # ensure _input and target are contiguous + logits_chunk = logits_chunk.contiguous() + target_chunk = target_chunk.contiguous() + + if triton_backend: + # unreduced loss + loss_1d_slice = loss_1d[start_idx:end_idx] # chunk_size, + z_loss_1d_slice = z_loss_1d[start_idx:end_idx] if return_z_loss else None + + # Here we calculate the gradient of logits_chunk in place so we can save memory. + liger_cross_entropy_kernel[(n_rows,)]( + X_ptr=logits_chunk, + X_stride=logits_chunk.stride(-2), + Y_ptr=target_chunk, + Y_stride=target_chunk.stride(-1), # always 1 + weight_ptr=ce_weight, + loss_ptr=loss_1d_slice, + z_loss_ptr=z_loss_1d_slice, + loss_stride=loss_1d_slice.stride(-1), # always 1 + n_cols=V, + n_non_ignore=total_n_non_ignore, + sum_non_ignore_weight=total_sum_non_ignore_ce_weight, + weight_sum=ce_weight_sum, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + RETURN_Z_LOSS=return_z_loss, + HAS_WEIGHT=True if ce_weight is not None else False, + HAS_SOFTCAPPING=True if softcap is not None else False, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=32, + ) + grad_logits_chunk = logits_chunk # chunk_size x V + else: + y_chunk = torch.nn.functional.softmax(logits_chunk, dim=1) + loss_1d_slice = -torch.log(y_chunk).gather(1, target_chunk.view(-1, 1)) + loss_1d_slice = loss_1d_slice.squeeze(1) + logits_chunk = y_chunk - torch.nn.functional.one_hot(target_chunk, num_classes=V) + logits_chunk = (logits_chunk * (chunk_size / BT)) + grad_logits_chunk = logits_chunk + + loss_1d[start_idx:end_idx] = loss_1d_slice + if return_z_loss: + z_loss_1d[start_idx:end_idx] = z_loss_1d_slice + + grad_input[start_idx:end_idx] = grad_logits_chunk @ weight + + if grad_weight is not None: + torch.addmm( + input=grad_weight, + mat1=logits_chunk.t().to( + _input_chunk.dtype + ), # In an autocast scenario without bias, differing logits_chunk data types will cause an addmm operation error. + mat2=_input_chunk, + out=grad_weight, + alpha=1.0, + beta=1.0, + ) + + if bias is not None: + torch.add( + input=grad_bias, + other=logits_chunk.sum(dim=0), + out=grad_bias, + alpha=1.0, + ) + + if reduction == "none": + loss = loss_1d + z_loss = z_loss_1d if return_z_loss else None + + else: + loss = torch.sum(loss_1d) if triton_backend else torch.mean(loss_1d) + z_loss = torch.sum(z_loss_1d) if return_z_loss else None + + return loss, z_loss, grad_input, grad_weight, grad_bias + +def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias): + # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time + if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place + # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton. + BT, H = grad_input.shape + n_rows = BT + # BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(H)) + + # element_mul_kernel[(n_rows,)]( + # grad_input, + # grad_input.stride(-2), + # grad_output, + # H, + # BLOCK_SIZE=BLOCK_SIZE, + # num_warps=32 if not is_hip() else 16, + # ) + + # handle grad_weight + if grad_weight is not None: + V, H = grad_weight.shape + n_rows = V + + # element_mul_kernel[(n_rows,)]( + # grad_weight, + # grad_weight.stride(-2), + # grad_output, + # H, + # BLOCK_SIZE=BLOCK_SIZE, + # num_warps=32 if not is_hip() else 16, + # ) + + if grad_bias is not None: + V = grad_bias.shape[0] + n_rows = V + + # element_mul_kernel[(n_rows,)]( + # grad_bias, + # grad_bias.stride(-1), + # grad_output, + # 1, + # BLOCK_SIZE=BLOCK_SIZE, + # num_warps=32 if not is_hip() else 16, + # ) + return grad_input, grad_weight, grad_bias + + +class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + _input, + weight, + target, + bias=None, + ce_weight=None, + ignore_index=-100, + lse_square_scale=0.0, + label_smoothing=0.0, + reduction="mean", + softcap=None, + return_z_loss: bool = False, + ): + """ + Fusing the last linear layer with cross-entropy loss + Reference: https://github.com/mgmalek/efficient_cross_entropy + + Handle the forward and backward pass of the final linear layer via cross-entropy loss by avoiding + the materialization of the large logits tensor. Since Cross Entropy Loss is the last layer, we can + compute the gradient at the forward pass. By doing so, we don't have to store the _input and target + for the backward pass. + + _input: (B*T, H) where B is batch size, T is sequence length, H is hidden dimension. + target: (B*T) where each value is in [0, V-1] + weight: (V, H) where V is the number of classes + bias: (V) where V is the number of classes + ce_weight: a manual rescaling weight given to each class. If given, has to be a Tensor of size V and floating point dtype + ignore_index: the index to ignore in the target + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + reduction: reduction to apply + """ + + loss, z_loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward( + _input=_input, + weight=weight, + target=target, + bias=bias, + ce_weight=ce_weight, + ignore_index=ignore_index, + lse_square_scale=lse_square_scale, + label_smoothing=label_smoothing, + reduction=reduction, + softcap=softcap, + return_z_loss=return_z_loss, + ) + # downcast to dtype and store for backward + ctx.save_for_backward( + grad_input.detach(), + grad_weight.detach() if grad_weight is not None else None, + grad_bias.detach() if bias is not None else None, + ) + ctx.return_z_loss = return_z_loss + # return loss, z_loss + return loss + + @staticmethod + # def backward(ctx, grad_output, grad_output2): + def backward(ctx, grad_output): + if ctx.return_z_loss: + del grad_output2 # z_loss is only for logging + (grad_input, grad_weight, grad_bias) = ctx.saved_tensors + grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward( + grad_output, grad_input, grad_weight, grad_bias + ) + return ( + grad_input, + grad_weight, + None, + grad_bias, + None, + None, + None, + None, + None, + None, + None, + ) \ No newline at end of file diff --git a/replay/models/nn/sequential/sasrec/lightning.py b/replay/models/nn/sequential/sasrec/lightning.py index b15486715..a28ad56b9 100644 --- a/replay/models/nn/sequential/sasrec/lightning.py +++ b/replay/models/nn/sequential/sasrec/lightning.py @@ -10,19 +10,20 @@ from .dataset import SasRecPredictionBatch, SasRecTrainingBatch, SasRecValidationBatch from .model import SasRecModel -from replay.models.nn.optimizer_utils import LigerFusedLinearCrossEntropyFunction + +import sys +sys.path.append("./kernels") + +try: + from kernels.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction + +except ModuleNotFoundError: + print("fused linear cross entropy is not installed. fused_linear_CE loss cannot be used.") try: - import sys - sys.path.append("/home/jovyan/zhmax/cce_loss/") - from cut_cross_entropy.cce import CCEParams, LinearCrossEntropyFunction, _build_flat_valids - from cut_cross_entropy.cce_backward import cce_backward_kernel - from cut_cross_entropy.cce_lse_forward import cce_lse_forward_kernel - from cut_cross_entropy.constants import IGNORE_INDEX - from cut_cross_entropy.doc import CCE_OPTS_DOC, LINEAR_CROSS_ENTROPY_DOC, add_doc_start - from cut_cross_entropy.indexed_dot import indexed_neg_dot_forward_kernel - from cut_cross_entropy.utils import ( + from kernels.cut_cross_entropy.cce import CCEParams, LinearCrossEntropyFunction, _build_flat_valids + from kernels.cut_cross_entropy.utils import ( _build_flat_valids, _handle_eps, handle_reduction_none, @@ -538,7 +539,7 @@ def _compute_loss_cce( ) reject_labels_mask = targets.view(-1, 1) == negative_labels - negative_labels[reject_labels_mask] = vocab_size - 1 + negative_labels[reject_labels_mask] = vocab_size item_inds = torch.hstack([targets.view(-1, 1), negative_labels]) diff --git a/replay_benchmarks/configs/config.yaml b/replay_benchmarks/configs/config.yaml index 4d2736877..2151b18cb 100755 --- a/replay_benchmarks/configs/config.yaml +++ b/replay_benchmarks/configs/config.yaml @@ -1,6 +1,6 @@ defaults: - - dataset: megamarket - - model: sasrec_megamarket + - dataset: movielens_20m + - model: sasrec_movielens_20m - mode: train - acceleration: null diff --git a/replay_benchmarks/configs/model/sasrec_movielens_20m.yaml b/replay_benchmarks/configs/model/sasrec_movielens_20m.yaml index 5f5f5e0cd..b827532db 100755 --- a/replay_benchmarks/configs/model/sasrec_movielens_20m.yaml +++ b/replay_benchmarks/configs/model/sasrec_movielens_20m.yaml @@ -18,7 +18,6 @@ model: training_params: embedding_dim: 256 learning_rate: 0.001 - weight_decay: 0.00001 batch_size: 128 num_workers: 4 patience: 4