Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/array/cuda/gather_mm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ __global__ void GatherMMScatterKernel(
for (unsigned int outloop = 0; outloop < out_len; outloop += 32) {
DType out_reg = static_cast<DType>(0.0f); // thread private
const unsigned int l = laneId;
if (l < out_len) {
if (l + outloop < out_len) {
// iterate over elements of a row of A
for (unsigned int i = 0; i < a_tile; i++) {
const DType a_val = sh_A[local_row * sh_a_tile + i];
Expand Down Expand Up @@ -170,7 +170,7 @@ __global__ void GatherMMScatterKernel2(
for (unsigned int outloop = 0; outloop < out_len; outloop += 32) {
DType out_reg = static_cast<DType>(0.0f); // thread private
const unsigned int l = laneId;
if (l < out_len) {
if (l + outloop < out_len) {
const DType b_val = B[row_b * out_len + (outloop + l)];
/* iterate over elements of a row of A */
for (unsigned int i = 0; i < a_tile; i++) {
Expand Down