Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backend/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
dump_ir = os.environ.get("DLC_DUMP_IR", "0") == "1"
replace_ttshared_ir = os.environ.get("DLC_REPLACE_TTSHARED_IR_FILE", None)
replace_linked_ir = os.environ.get("DLC_REPLACE_LINKED_IR_FILE", None)
replace_commonir_ir = os.environ.get("DLC_REPLACE_COMMON_IR_FILE", None)
replace_commonir_linked_ir = os.environ.get("DLC_REPLACE_COMMONIR_LINKED_IR_FILE", None)
if (
dump_ir
Expand Down Expand Up @@ -432,6 +433,10 @@ def ttir_to_ttsharedir_ascend(mod, metadata, opt, *, named_ops=False):


def commonir_to_linkedir(commonir, metadata, opt, *, named_ops=False):
if replace_commonir_ir is not None:
print(f"[DEBUG] Replace common ir with {replace_commonir_ir}")
commonir = Path(replace_commonir_ir).read_text()

assert isinstance(commonir, str)
if opt.debug or dump_ir:
dicp_utils._dump_stage_ir(commonir, metadata["hash"], "kernel.commonir.mlir")
Expand Down
281 changes: 233 additions & 48 deletions commonir/src/target/codegen_commonir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -619,12 +619,8 @@ void CodeGenTileLangCOMMONIR::VisitExpr_(const BufferLoadNode *op,
int dim = buffer_shape.size();

String buffer_name_val = "";
if (auto memrefInfo = dynamic_cast<Memref *>(type_info[buffer_name])) {
if (memrefInfo->is_arg) {
buffer_name_val = buffer_name + "_Recast";
} else {
buffer_name_val = buffer_name;
}
if (dynamic_cast<Memref *>(type_info[buffer_name])) {
buffer_name_val = buffer_name;
} else {
LOG(FATAL) << buffer_name << " should be a memref";
}
Expand Down Expand Up @@ -702,31 +698,96 @@ String CodeGenTileLangCOMMONIR::GenSubviewFromRegion(Buffer buffer_data,
region_indeces.push_back(r.get()->min);
}
String buffer_name_val = "";
if (auto memrefInfo = dynamic_cast<Memref *>(type_info[buffer_name])) {
if (memrefInfo->is_arg) {
buffer_name_val = buffer_name + "_Recast";
} else {
buffer_name_val = buffer_name;
}
if (dynamic_cast<Memref *>(type_info[buffer_name])) {
buffer_name_val = buffer_name;
} else {
LOG(FATAL) << buffer_name << " should be a memref";
}
auto *src_memref = dynamic_cast<Memref *>(type_info[buffer_name_val]);
ICHECK(src_memref) << buffer_name_val << " should be a memref";

String new_buffer_name = buffer_name_val;
String src_data_info = GetMemrefInfo(buffer_name_val);
if (!(IsEqual(buffer_shape, region_shape) && AllZero(region_indeces))) {
bool is_full_region =
IsEqual(buffer_shape, region_shape) && AllZero(region_indeces);
if (src_memref->is_arg) {
Array<PrimExpr> curr_shape = is_full_region ? buffer_shape : region_shape;
Array<PrimExpr> curr_indices;
if (is_full_region) {
for (int i = 0; i < dim; ++i) {
curr_indices.push_back(make_const(DataType::Int(64), 0));
}
} else {
curr_indices = region_indeces;
}

Array<String> cast_shape_array = GenConvertIndex(curr_shape);
Array<PrimExpr> strides_expr;
for (int i = 0; i < src_memref->dim; ++i) {
strides_expr.push_back(src_memref->stride[i]);
}
Array<String> cast_stride_array = GenConvertIndex(strides_expr);

PrimExpr dynamic_offset = make_const(DataType::Int(64), 0);
for (int i = 0; i < src_memref->dim; ++i) {
PrimExpr idx = curr_indices[i];
PrimExpr stride = strides_expr[i];
if (idx.dtype() != stride.dtype()) {
idx = tvm::tir::Cast(stride.dtype(), idx);
}
dynamic_offset = dynamic_offset + idx * stride;
}

Array<PrimExpr> offset_exprs;
offset_exprs.push_back(arith::Analyzer().Simplify(dynamic_offset));
Array<String> cast_offset_array = GenConvertIndex(offset_exprs);
unsigned long offset = ComputeOffset(src_memref, curr_indices);

new_buffer_name = buffer_name_val + "_reinterpret";
auto tempMemref = new Memref(new_buffer_name, curr_shape, buffer_type,
src_memref->address_space, offset == -1,
src_memref->stride, offset);
String dst_data_info = GetMemrefInfo(tempMemref);

temp << "memref.reinterpret_cast %" << buffer_name_val;
temp << " to offset: [";
if (offset == -1) {
temp << cast_offset_array[0];
} else {
temp << offset;
}
temp << "], sizes: [";
for (int i = 0; i < dim; ++i) {
if (i > 0) {
temp << ", ";
}
temp << cast_shape_array[i];
}
temp << "], strides: [";
for (int i = 0; i < dim; ++i) {
if (i > 0) {
temp << ", ";
}
temp << cast_stride_array[i];
}
temp << "] : ";
temp << src_data_info;
temp << " to ";
temp << dst_data_info;

delete tempMemref;
new_buffer_name = SSAGetID(temp.str(), buffer_type);
this->type_info[new_buffer_name] = new Memref(
new_buffer_name, curr_shape, buffer_type, src_memref->address_space,
offset == -1, src_memref->stride, offset);
} else if (!is_full_region) {
Array<String> cast_offset_array = GenConvertIndex(region_indeces);
Array<String> cast_shape_array = GenConvertIndex(region_shape);
if (!dynamic_cast<Memref *>(type_info[buffer_name_val])) {
LOG(FATAL) << buffer_name_val << " should be a memref";
}
unsigned long offset = ComputeOffset(
dynamic_cast<Memref *>(type_info[buffer_name_val]), region_indeces);
unsigned long offset = ComputeOffset(src_memref, region_indeces);
new_buffer_name = buffer_name_val + "_subview";
auto tempMemref = new Memref(
new_buffer_name, region_shape, buffer_type,
dynamic_cast<Memref *>(type_info[buffer_name_val])->address_space,
offset == -1,
dynamic_cast<Memref *>(type_info[buffer_name_val])->stride, offset);
auto tempMemref = new Memref(new_buffer_name, region_shape, buffer_type,
src_memref->address_space, offset == -1,
src_memref->stride, offset);
String dst_data_info = GetMemrefInfo(tempMemref);
temp << "memref.subview \%" + buffer_name_val;
temp << "[";
Expand Down Expand Up @@ -760,10 +821,8 @@ String CodeGenTileLangCOMMONIR::GenSubviewFromRegion(Buffer buffer_data,
delete tempMemref;
new_buffer_name = SSAGetID(temp.str(), buffer_type);
this->type_info[new_buffer_name] = new Memref(
new_buffer_name, region_shape, buffer_type,
dynamic_cast<Memref *>(type_info[buffer_name_val])->address_space,
offset == -1,
dynamic_cast<Memref *>(type_info[buffer_name_val])->stride, offset);
new_buffer_name, region_shape, buffer_type, src_memref->address_space,
offset == -1, src_memref->stride, offset);
}
return new_buffer_name;
}
Expand Down Expand Up @@ -844,12 +903,39 @@ void CodeGenTileLangCOMMONIR::VisitExpr_(const CallNode *op, std::ostream &os) {
InfinityCodegen(op, os);
} else if (op->op.same_as(Op::Get("tl.tileop.reduce"))) {
ReduceCodegen(op, os);
} else if (op->op.same_as(Op::Get("tir.fabs"))) {
ICHECK_EQ(op->args.size(), 1) << "abs expects 1 argument";
std::string operand = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype);
if (op->dtype.is_float()) {
os << "math.absf %" << operand << " : ";
} else {
os << "math.absi %" << operand << " : ";
}
PrintType(op->dtype, os);
} else if (op->op.same_as(Op::Get("tir.sqrt"))) {
ICHECK_EQ(op->args.size(), 1) << "sqrt expects 1 argument";
std::string operand = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype);
os << "math.sqrt %" << operand << " : ";
PrintType(op->dtype, os);
} else if (op->op.same_as(Op::Get("tir.exp"))) {
ICHECK_EQ(op->args.size(), 1) << "exp expects 1 argument";
std::string operand = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype);
os << "math.exp %" << operand << " : ";
PrintType(op->dtype, os);
} else if (op->op.same_as(Op::Get("tir.tanh"))) {
ICHECK_EQ(op->args.size(), 1) << "tanh expects 1 argument";
std::string operand = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype);
os << "math.tanh %" << operand << " : ";
PrintType(op->dtype, os);
} else if (op->op.same_as(Op::Get("tir.log"))) {
ICHECK_EQ(op->args.size(), 1) << "log expects 1 argument";
std::string operand = SSAGetID(PrintExpr(op->args[0]), op->args[0]->dtype);
os << "math.log %" << operand << " : ";
PrintType(op->dtype, os);
} else if (op->op.same_as(Op::Get("tir.rsqrt"))) {
StubCodegen(op, os, "tir.rsqrt");
} else if (op->op.same_as(Op::Get("tir.sigmoid"))) {
StubCodegen(op, os, "tir.sigmoid");
} else if (op->op.same_as(Op::Get("tir.exp"))) {
StubCodegen(op, os, "tir.exp");
} else if (op->op.same_as(builtin::if_then_else())) {
IfThenElseCodegen(op, os);
} else {
Expand Down Expand Up @@ -1020,11 +1106,123 @@ void CodeGenTileLangCOMMONIR::InfinityCodegen(const CallNode *op,
void CodeGenTileLangCOMMONIR::ReduceCodegen(const CallNode *op,
std::ostream &os) {
tvm::tl::ReduceOp reduceop(op->args);
// todo(dkx): support other reduce type

ICHECK(reduceop->type->isSum() || reduceop->type->isMax())
<< "Currently we only support: sum or max";

String src_data_name =
GenSubviewFromRegion(reduceop->src, reduceop->srcRegion_->region);
String dst_data_name =
GenSubviewFromRegion(reduceop->dst, reduceop->dstRegion_->region);

auto *src_memref = dynamic_cast<Memref *>(type_info[src_data_name]);
auto *dst_memref = dynamic_cast<Memref *>(type_info[dst_data_name]);
ICHECK(src_memref) << src_data_name << " should be a memref";
ICHECK(dst_memref) << dst_data_name << " should be a memref";

DataType src_dtype = src_memref->dtype;
DataType dst_dtype = dst_memref->dtype;

String src_tensor_name = CreateMemrefToTensor(src_data_name);
// Always use a temporary tensor as reduce outs, then materialize to memref.
// This keeps reduction in tensor form instead of memref-backed tensor views.
String init_tensor_name = CreateNewTensor(dst_data_name, "max_vals");

if (reduceop->clear) {
PrimExpr init_value;
if (reduceop->type->isSum()) {
init_value = make_zero(dst_dtype);
} else if (dst_dtype.is_int()) {
init_value = make_const(dst_dtype, -(1LL << (dst_dtype.bits() - 1)));
} else if (dst_dtype.is_uint()) {
init_value = make_const(dst_dtype, 0);
} else if (dst_dtype.is_float() || dst_dtype.is_bfloat16()) {
init_value = make_const(dst_dtype, -INFINITY);
} else {
LOG(FATAL) << "Unsupported dtype for max reduce init: " << dst_dtype;
}
std::string init_value_name = SSAGetID(PrintExpr(init_value), dst_dtype);
std::ostringstream fill_temp;
fill_temp << "linalg.fill ins(%" << init_value_name << " : ";
PrintType(dst_dtype, fill_temp);
fill_temp << ") outs(%" << init_tensor_name << " : "
<< GetTensorInfo(init_tensor_name) << ") -> "
<< GetTensorInfo(init_tensor_name);
String filled_from_tensor_name = init_tensor_name;
init_tensor_name = SSAGetID(fill_temp.str(), dst_dtype);
auto *filled_tensor_template =
dynamic_cast<Tensor *>(type_info_tensor[filled_from_tensor_name]);
ICHECK(filled_tensor_template)
<< filled_from_tensor_name << " should be a tensor";
auto *filled_tensor =
new Tensor(init_tensor_name, filled_tensor_template->shape, dst_dtype,
filled_tensor_template->address_space);
filled_tensor->var_id = init_tensor_name;
this->type_info_tensor[init_tensor_name] = filled_tensor;
}

std::ostringstream reduce_temp;
reduce_temp << "linalg.reduce ins(%" << src_tensor_name << " : "
<< GetTensorInfo(src_tensor_name) << ") outs(%"
<< init_tensor_name << " : " << GetTensorInfo(init_tensor_name)
<< ") dimensions = [" << reduceop->dim << "]\n";
reduce_temp << " (%in: ";
PrintType(src_dtype, reduce_temp);
reduce_temp << ", %acc: ";
PrintType(dst_dtype, reduce_temp);
reduce_temp << ") {\n";

std::string rhs_name = "in";
if (src_dtype != dst_dtype) {
reduce_temp << " %in_cast = " << GetCastOp(src_dtype, dst_dtype)
<< " %in : ";
PrintType(src_dtype, reduce_temp);
reduce_temp << " to ";
PrintType(dst_dtype, reduce_temp);
reduce_temp << "\n";
rhs_name = "in_cast";
}

reduce_temp << " %reduced = ";
if (reduceop->type->isSum()) {
if (dst_dtype.is_int() || dst_dtype.is_uint()) {
reduce_temp << "arith.addi %acc, %" << rhs_name << " : ";
} else if (dst_dtype.is_float() || dst_dtype.is_bfloat16()) {
reduce_temp << "arith.addf %acc, %" << rhs_name << " : ";
} else {
LOG(FATAL) << "Unsupported dtype for sum reduce: " << dst_dtype;
}
} else if (dst_dtype.is_int()) {
reduce_temp << "arith.maxsi %acc, %" << rhs_name << " : ";
} else if (dst_dtype.is_uint()) {
reduce_temp << "arith.maxui %acc, %" << rhs_name << " : ";
} else if (dst_dtype.is_float() || dst_dtype.is_bfloat16()) {
reduce_temp << "arith.maxnumf %acc, %" << rhs_name << " : ";
} else {
LOG(FATAL) << "Unsupported dtype for max reduce: " << dst_dtype;
}
PrintType(dst_dtype, reduce_temp);
reduce_temp << "\n";
reduce_temp << " linalg.yield %reduced : ";
PrintType(dst_dtype, reduce_temp);
reduce_temp << "\n";
reduce_temp << " }";

String reduced_tensor_name = SSAGetID(reduce_temp.str(), dst_dtype);
auto *reduced_tensor_template =
dynamic_cast<Tensor *>(type_info_tensor[init_tensor_name]);
ICHECK(reduced_tensor_template) << init_tensor_name << " should be a tensor";
auto *reduced_tensor =
new Tensor(reduced_tensor_name, reduced_tensor_template->shape, dst_dtype,
reduced_tensor_template->address_space);
reduced_tensor->var_id = reduced_tensor_name;
this->type_info_tensor[reduced_tensor_name] = reduced_tensor;

this->PrintIndent();
this->stream << "linalg.reduce \n";
this->stream << "bufferization.materialize_in_destination %"
<< reduced_tensor_name << " in writable %" << dst_data_name
<< " : (" << GetTensorInfo(reduced_tensor_name) << ", "
<< GetMemrefInfo(dst_data_name) << ") -> ()\n";
}

void CodeGenTileLangCOMMONIR::VisitStmt_(const LetStmtNode *op) {
Expand All @@ -1044,12 +1242,8 @@ void CodeGenTileLangCOMMONIR::VisitStmt_(const BufferStoreNode *op) {
int dim = buffer_shape.size();

String buffer_name_val = "";
if (auto memrefInfo = dynamic_cast<Memref *>(type_info[buffer_name])) {
if (memrefInfo->is_arg) {
buffer_name_val = buffer_name + "_Recast";
} else {
buffer_name_val = buffer_name;
}
if (dynamic_cast<Memref *>(type_info[buffer_name])) {
buffer_name_val = buffer_name;
} else {
LOG(FATAL) << buffer_name << " should be a memref";
}
Expand Down Expand Up @@ -1300,7 +1494,7 @@ void CodeGenTileLangCOMMONIR::GenRecastFromArg(Buffer curr_buffer,

void CodeGenTileLangCOMMONIR::AddFunction(const GlobalVar &gvar,
const PrimFunc &f) {
this->stream << "module {\n";
this->stream << "module attributes {dicp.backend = \"ascend\"} {\n";

// If the function has already been forward-declared, this is a
// no-op.
Expand All @@ -1319,8 +1513,6 @@ void CodeGenTileLangCOMMONIR::AddFunction(const GlobalVar &gvar,

this->stream << "func.func @" << func_name << "(";

std::vector<String> recast_need_insert;

this->type_info.clear();
size_t n = f->params.size();
for (size_t i = 0; i < f->params.size(); ++i) {
Expand All @@ -1337,9 +1529,6 @@ void CodeGenTileLangCOMMONIR::AddFunction(const GlobalVar &gvar,
Memref *buffer = new Memref(arg_name, f->buffer_map[v], true);
this->type_info[arg_name] = buffer;
stream << "%" << arg_name << ": " << GetMemrefInfo(arg_name);
String recast_inst = "";
GenRecastFromArg(f->buffer_map[v], arg_name, recast_inst);
recast_need_insert.push_back(recast_inst);

if (auto *ptr = v->type_annotation.as<PointerTypeNode>()) {
if (auto *prim = ptr->element_type.as<PrimTypeNode>()) {
Expand All @@ -1361,10 +1550,6 @@ void CodeGenTileLangCOMMONIR::AddFunction(const GlobalVar &gvar,
this->PrintIndent();
stream << "{\n";
int func_body_scope = this->BeginScope();
for (String recast_inst : recast_need_insert) {
this->PrintIndent();
stream << recast_inst;
}
this->PrintStmt(f->body);
this->EndScope(func_body_scope);
this->PrintIndent();
Expand Down Expand Up @@ -1392,7 +1577,7 @@ String CodeGenTileLangCOMMONIR::GetMemrefInfo(Memref *memrefObj) {
std::ostringstream memref_type;
memref_type << "memref<";
if (memrefObj->is_arg) {
memref_type << "?x";
memref_type << "*x";
} else {
for (PrimExpr s : memrefObj->shape) {
if (auto s_int = as_const_int(s)) {
Expand Down
Loading
Loading