diff --git a/tmva/sofie/CMakeLists.txt b/tmva/sofie/CMakeLists.txt index dc44ac0a59af2..6fdc7a46183ee 100644 --- a/tmva/sofie/CMakeLists.txt +++ b/tmva/sofie/CMakeLists.txt @@ -65,6 +65,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofie TMVA/ROperator_Einsum.hxx TMVA/ROperator_Random.hxx TMVA/ROperator_ScatterElements.hxx + TMVA/ROperator_ScatterND.hxx TMVA/ROperator_Gather.hxx TMVA/ROperator_GatherND.hxx TMVA/ROperator_NonZero.hxx diff --git a/tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx b/tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx index e9b2078bc73a1..16be75c5aa960 100644 --- a/tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx @@ -140,6 +140,7 @@ public: auto ret = UTILITY::MultidirectionalBroadcastShape(fShapeA, fShapeB); fBroadcastFlag = ret.first; fShapeY = ret.second; + auto lengthY = ConvertShapeToLength(fShapeY); if (model.IsConstantTensor(fNA) && model.IsConstantTensor(fNB)) { bool broadcast = fBroadcastFlag > 0; if (broadcast) { @@ -193,7 +194,7 @@ public: const std::string &nameB = fNBroadcastedB.empty() ? fNB : fNBroadcastedB; auto dataA = static_cast(model.GetInitializedTensorData(nameA).get()); auto dataB = static_cast(model.GetInitializedTensorData(nameB).get()); - std::vector dataY(ConvertShapeToLength(fShapeY)); + std::vector dataY(lengthY); for (size_t i = 0; i < dataY.size(); i++) { dataY[i] = BinaryOperatorTrait::Func(dataA[i], dataB[i]); } @@ -207,6 +208,59 @@ public: << " , " << fNB << " " << ConvertShapeToString(fShapeB) << " ---> " << fNY << " " << ConvertShapeToString(fShapeY) << " : " << ConvertValuesToString(dataY) << std::endl; } + } else if (((model.IsShapeTensor(fNA) && model.IsShapeTensor(fNB)) || + (model.IsShapeTensor(fNA) && model.IsConstantTensor(fNB)) || + (model.IsShapeTensor(fNB) && model.IsConstantTensor(fNA))) + && (fShapeA.size() <=1 && fShapeB.size() <=1 && model.GetTensorType(fNA) == ETensorType::INT64)) { + // case of shape tensors ( tensors are of rank 0 or 1 ) + std::vector dimValA; + std::vector dimValB; + if (model.IsShapeTensor(fNA)) + dimValA = model.GetShapeTensorValues(fNA); + if (model.IsShapeTensor(fNB)) + dimValB = model.GetShapeTensorValues(fNB); + // adjust for broadcasting - repet values until it reaches shapes of Y + if (!fShapeY.empty() && fShapeY[0] > 1) { + if (dimValA.size() == 1) dimValA = std::vector( fShapeY[0], dimValA[0]); + if (dimValB.size() == 1) dimValB = std::vector( fShapeY[0], dimValB[0]); + } + + auto convertDataToDim = [&](const std::string & name, const std::vector & shape, std::vector & dimValues) { + auto data = static_cast(model.GetInitializedTensorData(name).get()); + dimValues.resize(lengthY); + for (size_t i = 0; i < lengthY; i++) { + if (!shape.empty() && lengthY == shape[0]) + dimValues[i] = Dim{ static_cast(data[i])}; + else // case dataA is a scalar + dimValues[i] = Dim{ static_cast(data[0])}; + } + }; + if (model.IsConstantTensor(fNA)) { + convertDataToDim(fNA,fShapeA,dimValA); + } else if (model.IsConstantTensor(fNB)) { + convertDataToDim(fNB,fShapeB,dimValB); + } + + //perform binary operations on shape tensors + std::vector dimValY(lengthY); + for (size_t i = 0; i < lengthY; i++) { + if (!dimValA[i].isParam && !dimValB[i].isParam) { + size_t d = BinaryOperatorTrait::Func(dimValA[i].dim, dimValB[i].dim); + dimValY[i] = Dim{d}; + } else { + auto res = BinaryOperatorTrait::Op(dimValA[i].GetVal(), dimValB[i].GetVal()); + dimValY[i] = Dim{res, static_cast(-1)}; + } + } + model.AddShapeTensor(fNY,dimValY, fShapeY.empty()); // cannot be a scalar + if (model.Verbose()) { + std::cout << BinaryOperatorTrait::Name() << " : " << fNA << " " << ConvertShapeToString(fShapeA) + << " , " << fNB << " " << ConvertShapeToString(fShapeB) << " ---> " << fNY << " " + << ConvertShapeToString(fShapeY) << " : " << ConvertDimShapeToString(dimValY) << " (shape)" << std::endl; + } + // no code needs to be generated (flag this as a constant output tensor) + fIsOutputConstant = true; + } else { // case of defined and non-constant tensors model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fShapeY); @@ -279,9 +333,6 @@ public: opName = "op_" + opName; - if (fDimShapeY.empty()) { - throw std::runtime_error("TMVA SOFIE Binary Op called to Generate without being initialized first"); - } std::stringstream out; out << SP << "\n//------ " << opName << " " << BinaryOperatorTrait::Name() << " --> " << ConvertDimShapeToString(fDimShapeY) << "\n"; diff --git a/tmva/sofie/inc/TMVA/ROperator_Cast.hxx b/tmva/sofie/inc/TMVA/ROperator_Cast.hxx index 8267bb8a7e4f4..cace65040c772 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Cast.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Cast.hxx @@ -20,13 +20,14 @@ private: std::string fNX; std::string fNY; std::vector fShape; - std::string fAttrType = "float"; + ETensorType fType; public: ROperator_Cast(){} - ROperator_Cast(std::string attr_type,std::string nameX, std::string nameY): - fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)), - fAttrType(attr_type) { + ROperator_Cast(ETensorType type,std::string nameX, std::string nameY): + fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)), + fType(type) + { fInputTensorNames = { fNX }; fOutputTensorNames = { fNY }; } @@ -51,21 +52,21 @@ public: if (model.IsInitializedTensor(fNX)) { fIsOutputConstant = true; auto inputData = model.GetInitializedTensorData(fNX); - if (ConvertStringToType(fAttrType) == ETensorType::INT64) { + if (fType == ETensorType::INT64) { model.AddConstantTensor(fNY, ConvertShapeToInt(fShape), static_cast(inputData.get())); model.SetNotWritableInitializedTensor(fNX); } else fIsOutputConstant = false; - } else if (model.IsShapeTensor(fNX) && ConvertStringToType(fAttrType) == ETensorType::INT64) { + } else if (model.IsShapeTensor(fNX) && fType == ETensorType::INT64) { auto shapeData = model.GetShapeTensorValues(fNX); model.AddShapeTensor(fNY, shapeData, fShape.size() == 0); fIsOutputConstant = true; } if (!fIsOutputConstant) - model.AddIntermediateTensor(fNY, ConvertStringToType(fAttrType), fShape); + model.AddIntermediateTensor(fNY, fType, fShape); if (model.Verbose()) { - std::cout << "Cast : " << ConvertTypeToString(inputType) << " " << fNX << " -> " << fAttrType << " for " << fNY + std::cout << "Cast : " << ConvertTypeToString(inputType) << " " << fNX << " -> " << ConvertTypeToString(fType) << " for " << fNY << " shape " << ConvertDimShapeToString(fShape); if (fIsOutputConstant) std::cout << " (constant) "; std::cout << std::endl; @@ -86,7 +87,7 @@ public: out << SP << "for (int id = 0; id < " << length << " ; id++){\n"; - out << SP << SP << "tensor_" << fNY << "[id] = static_cast<"<< fAttrType << ">(tensor_" << fNX << "[id]);\n"; + out << SP << SP << "tensor_" << fNY << "[id] = static_cast<"<< ConvertTypeToString(fType) << ">(tensor_" << fNX << "[id]);\n"; out << SP << "}\n"; return out.str(); diff --git a/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx b/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx index ecdd0b435fe37..ae56750e5ebd3 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx @@ -165,7 +165,7 @@ namespace SOFIE{ } if (fNC != ""){ if (model.CheckIfTensorAlreadyExist(fNC) == false){ //input must be a graph input, or already initialized intermediate tensor - throw std::runtime_error("TMVA SOFIE Gemm Op Input Tensor" + fNC + " is not found in model"); + throw std::runtime_error("TMVA SOFIE Gemm Op Input Tensor " + fNC + " is not found in model"); } } if (model.IsDynamicTensor(fNA) || model.IsDimInputTensor(fNA) ) { @@ -222,7 +222,7 @@ namespace SOFIE{ if (fIsDynamic && shapeY.empty()) broadcast_needed = true; else - // consider broadcasting also if same length + // consider broadcasting also if hey have different length broadcast_needed = (fShapeC != shapeY); @@ -285,7 +285,12 @@ namespace SOFIE{ int64_t dimA = fShapeA.size(); int64_t dimB = fShapeB.size(); int64_t dimY = fShapeY.size(); - if (dimA != dimB || dimA != dimY) { + int64_t dimC = fShapeC.size(); + if (dimA != dimB || dimA != dimY || (fBroadcastBias && dimC != dimY)) { + std::cout << " shape A " << ConvertDimShapeToString(fShapeA) + << " shape B " << ConvertDimShapeToString(fShapeB) + << " shape C " << ConvertShapeToString(fShapeC) + << " shape Y " << ConvertDimShapeToString(fShapeY) << std::endl; throw std::runtime_error("TMVA SOFIE Gemm(MatMul) has invalid shape for inputs or output"); } auto m = (fAttrTransA ? fShapeA[dimA-1].GetVal() : fShapeA[dimA-2].GetVal()); @@ -357,6 +362,9 @@ namespace SOFIE{ } // do the bias broadcasting if (fBroadcastBias) { + // also shapeC has prepended 1 to be same rank of Y + std::vector sC = {fShapeC[dimC-2], fShapeC[dimC-1]}; + fAttrBeta = 1.; out << SP << "for (size_t j = 0; j < " << sY[0] << "; j++) { \n"; out << SP << SP << "size_t y_index = "; @@ -369,11 +377,11 @@ namespace SOFIE{ out << SP << SP << "for (size_t k = 0; k < " << sY[1] << "; k++) { \n"; std::string bias_index; - if (fShapeC[0] == 1 && fShapeC[1] == sY[1].dim) + if (sC[0] == 1 && sC[1] == sY[1].dim) bias_index = "k"; - else if (fShapeC[1] == 1 && fShapeC[0] == sY[0].dim) + else if (sC[1] == 1 && sC[0] == sY[0].dim) bias_index = "j"; - else if (fShapeC[0] == 1 && fShapeC[1] == 1) // scalar case + else if (sC[0] == 1 && sC[1] == 1) // scalar case bias_index = "0"; else { throw std::runtime_error("TMVA SOFIE Gemm Op - invalid shape for bias tensor " + ConvertShapeToString(fShapeC)); diff --git a/tmva/sofie/inc/TMVA/ROperator_Reduce.hxx b/tmva/sofie/inc/TMVA/ROperator_Reduce.hxx index 64778f38753cd..2012a9a40a0ed 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Reduce.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Reduce.hxx @@ -18,7 +18,7 @@ namespace SOFIE{ enum EReduceOpMode { ReduceMean, ReduceSum, ReduceSumSquare, ReduceProd, InvalidReduceOp }; -template +template class ROperator_Reduce final : public ROperator { private: @@ -76,7 +76,7 @@ public: std::sort(ax.begin(), ax.end()); for (size_t j = 0; j < ax.size(); j++) { // erase reduced dimensions, but keep last one - if (outputShape.size() > 1) { + if (outputShape.size() > 0) { outputShape.erase(outputShape.begin() + ax[j]); for (size_t k = j+1; k < ax.size(); k++) ax[k] -= 1; // decrease by one since we have removed a value @@ -120,9 +120,6 @@ public: std::string Generate(std::string opName) override { opName = "op_" + opName; - if (fShapeX.empty() || fShapeY.empty()) { - throw std::runtime_error("TMVA SOFIE Reduce Op called to Generate without being initialized first"); - } auto inputLength = TMVA::Experimental::SOFIE::ConvertDimShapeToLength(fShapeX); auto outputLength = TMVA::Experimental::SOFIE::ConvertDimShapeToLength(fShapeY); diff --git a/tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx b/tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx new file mode 100644 index 0000000000000..32272ea03c2fc --- /dev/null +++ b/tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx @@ -0,0 +1,193 @@ +#ifndef TMVA_SOFIE_ROPERATOR_ScatterND +#define TMVA_SOFIE_ROPERATOR_ScatterND + +#include "TMVA/SOFIE_common.hxx" +#include "TMVA/ROperator.hxx" +#include "TMVA/RModel.hxx" + +#include +#include +#include + +namespace TMVA{ +namespace Experimental{ +namespace SOFIE{ + +class ROperator_ScatterND final : public ROperator +{ +private: + + + std::string fNX; + std::string fNI; + std::string fNU; + std::string fNY; + std::string fReduction; + + std::vector fShapeX; + std::vector fShapeI; + std::vector fShapeY; + + + std::vector fIndices; // indices vector in case they are known at initialization + + std::string fType; + + +public: + ROperator_ScatterND(){} + ROperator_ScatterND(const std::string & nameX, const std::string & nameI, const std::string & nameU, const std::string & nameY, + std::string reduction): + fNX(UTILITY::Clean_name(nameX)), fNI(UTILITY::Clean_name(nameI)), fNU(UTILITY::Clean_name(nameU)), + fNY(UTILITY::Clean_name(nameY)), fReduction(reduction) + { + fInputTensorNames = { fNX, fNI, fNU }; + fOutputTensorNames = { fNY }; + } + + void Initialize(RModel& model) override { + + // input must be a graph input, or already initialized intermediate tensor + if (!model.CheckIfTensorAlreadyExist(fNX)){ + throw std::runtime_error(std::string("TMVA SOFIE ScatterND Op Input Tensor ") + fNX + "is not found in model"); + } + if (!model.CheckIfTensorAlreadyExist(fNI)) { + throw std::runtime_error(std::string("TMVA SOFIE ScatterND Op Input Tensor ") + fNI + "is not found in model"); + } + if (!model.CheckIfTensorAlreadyExist(fNU)) { + throw std::runtime_error(std::string("TMVA SOFIE ScatterND Op Input Tensor ") + fNU + "is not found in model"); + } + //tbd check for constant tensors + + fShapeX = model.GetDimTensorShape(fNX); + fShapeI = model.GetDimTensorShape(fNI); + auto shapeU = model.GetDimTensorShape(fNU); + + // Validate inputs if fShapeI last is not dynamic + + //if (!model.IsDynamicTensor(fNI)) { + const size_t r = fShapeX.size(); // rank of data + const size_t q = fShapeI.size(); // rank of indices + if (!(fShapeI.back().isParam) ) { + const size_t k = fShapeI.back().dim; // index depth + + if (k > r) + throw std::invalid_argument( + "ScatterND: last dim of indices (" + std::to_string(k) + + ") must be <= rank of data (" + std::to_string(r) + ")"); + + // Expected updates rank = q - 1 + r - k + int64_t expected_updates_rank = q - 1 + r - k; + if ((int64_t) shapeU.size() != expected_updates_rank) + throw std::invalid_argument("ScatterND: updates rank mismatch"); + } else { + // Assumption is that last dimension of index shape is known (is not dynamic) + throw std::runtime_error("TMVA SOFIE ScatterND : Index_shape(-1) is not known. This case is not supported"); + } + + // output shape is equal to input shape + fShapeY = fShapeX; + + model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY); + if (model.Verbose()) { + std::cout << "ScatterElements: input: " << ConvertDimShapeToString(fShapeX) + << " indices " << ConvertDimShapeToString(fShapeI) + << " update " << ConvertDimShapeToString(shapeU); + std::cout << "\t----> " << ConvertDimShapeToString(fShapeY) << std::endl; + } + } + + std::string Generate(std::string opName) override { + if (fIsOutputConstant) { + // no code to generate here for constant output. Tensor output is defined in Session constructor + return "//---------------------------------------\n"; + } + opName = "op_" + opName; + std::stringstream out; + out << "//--------- ScatterND " << opName << " --> " << ConvertDimShapeToString(fShapeY) << "\n"; + + size_t r = fShapeX.size(); + + // Strides + auto stridesX = UTILITY::ComputeStrideFromShape(fShapeX); + auto stridesY = UTILITY::ComputeStrideFromShape(fShapeY); + auto stridesI = UTILITY::ComputeStrideFromShape(fShapeI); + + // case input_index_shape == rank of input + size_t k = fShapeI.back().dim; + + // Total number of index tuples = product of indices dims except last + std::vector shapeIndFirst(fShapeI.begin(), fShapeI.begin()+ fShapeI.size()-1); + auto num_index_tuples = ConvertDimShapeToLength(shapeIndFirst); + + //slice size (is product of input from k to r) + std::vector shapeSlice(fShapeX.begin()+k, fShapeX.end()); + auto slice_size = ConvertDimShapeToLength(shapeSlice); + + auto data_length = ConvertDimShapeToLength(fShapeX); + + //step1: input->output + out << SP << "// Step 1: copy input data to output\n"; + out << SP << "std::copy(tensor_" << fNX << ", tensor_" << fNX << " + " << data_length << ", tensor_" << fNY << ");\n"; + + // Step 2: Emit strides as a static constexpr array + out << SP << "// Step 2: data strides (row-major)\n"; + //to do: use static constexpr for defined strides + out << SP << "size_t " << opName << "_data_strides[" << r << "] = {"; + for (size_t i = 0; i < r; ++i) + out << stridesX[i] << (i + 1 < r ? ", " : ""); + out << "};\n\n"; + + // Step 3: Scatter loop + out << SP << "// Step 3: scatter updates into output\n"; + out << SP << "for (int64_t idx = 0; idx < " << num_index_tuples << "; idx++) {\n"; + + // Resolve flat data offset from k-dimensional index tuple + out << SP << SP << "int64_t data_offset = 0;\n"; + for (size_t dim = 0; dim < k; ++dim) { + out << SP << SP << "{\n"; + out << SP << SP << SP << "int64_t coord = tensor_" << fNI + << "[idx * " << k << " + " << dim << "];\n"; + // Support negative indices + out << SP << SP << SP << "if (coord < 0) coord += " << fShapeX[dim] << ";\n"; + out << SP << SP << SP << "data_offset += coord * " + << opName << "_data_strides[" << dim << "];\n"; + out << SP << SP << "}\n"; + } + + // Apply updates with reduction + out << SP << SP << "for (int64_t s = 0; s < " << slice_size << "; s++) {\n"; + out << SP << SP << SP << "auto upd = tensor_" << fNU + << "[idx * " << slice_size << " + s];\n"; + + if (fReduction.empty() || fReduction == "none") { + out << SP << SP << SP << "tensor_" << fNY << "[data_offset + s] = upd;\n"; + } else if (fReduction == "add") { + out << SP << SP << SP << "tensor_" << fNY<< "[data_offset + s] += upd;\n"; + } else if (fReduction == "mul") { + out << SP << SP << SP << "tensor_" << fNY << "[data_offset + s] *= upd;\n"; + } else if (fReduction == "min") { + out << SP << SP << SP << "tensor_" << fNY<< "[data_offset + s] = " + << "std::min(tensor_" << fNY << "[data_offset + s], upd);\n"; + } else if (fReduction == "max") { + out << SP << SP << SP << "tensor_" << fNY << "[data_offset + s] = " + << "std::max(tensor_" << fNY << "[data_offset + s], upd);\n"; + } else { + throw std::runtime_error( + "TMVA SOFIE ScatterND: unsupported reduction '" + fReduction + "'"); + } + + out << SP << SP << "}\n"; // end slice loop + out << SP << "}\n"; // end index tuple loop + + return out.str(); + } + +}; + +}//SOFIE +}//Experimental +}//TMVA + + +#endif //TMVA_SOFIE_ROPERATOR_RELU diff --git a/tmva/sofie/inc/TMVA/ROperator_Slice.hxx b/tmva/sofie/inc/TMVA/ROperator_Slice.hxx index 674f8a4776520..7b7c9563db352 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Slice.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Slice.hxx @@ -337,7 +337,7 @@ public: size_t ndim = fShapeInput.size(); fIdentitySlice = fShapeOutput.size() == ndim; // check also if input data is not input to the model. In that case we copy the data since we cannot just copy from the input pointer - fIdentitySlice &= !model.IsReadyInputTensor(fNData); + fIdentitySlice &= (!model.IsReadyInputTensor(fNData) && !model.IsDimInputTensor(fNData)); for (size_t idim = 0; idim < ndim; idim++) { if (!fIdentitySlice) break; fIdentitySlice &= (fStart[idim].GetVal() == "0"); diff --git a/tmva/sofie/inc/TMVA/ROperator_Where.hxx b/tmva/sofie/inc/TMVA/ROperator_Where.hxx index 3064080507e28..494834cacc69e 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Where.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Where.hxx @@ -7,293 +7,441 @@ #include -namespace TMVA{ -namespace Experimental{ -namespace SOFIE{ +namespace TMVA { +namespace Experimental { +namespace SOFIE { - - -template -class ROperator_Where final : public ROperator{ +template +class ROperator_Where final : public ROperator { private: bool fIsInputBoolTensor = false; - - std::string fNA; - std::string fNB; - std::string fNC; - std::string fNBroadcastedA; - std::string fNBroadcastedB; + // Tensor names: C = condition, X = true branch, Y = false branch, Z = output + std::string fNC; // condition (bool) + std::string fNX; // true-branch values + std::string fNY; // false-branch values + std::string fNZ; // output std::string fNBroadcastedC; - std::string fNY; + std::string fNBroadcastedX; + std::string fNBroadcastedY; - - std::vector fShapeA; - std::vector fShapeB; + // Static shapes (used when all inputs are non-dynamic) std::vector fShapeC; + std::vector fShapeX; std::vector fShapeY; + std::vector fShapeZ; + + // Dynamic shapes (Dim-aware, used when any input is dynamic) + std::vector fDimShapeC; + std::vector fDimShapeX; + std::vector fDimShapeY; + std::vector fDimShapeZ; + // Broadcast flag: mirrors convention of BasicBinary + // bit 0: broadcast Y->X (Y needs expanding) + // bit 1: broadcast X->Y (X needs expanding) + // bit 2: broadcast C->Z (C needs expanding) + // bit 4: shapes may differ at runtime (dynamic) + int fBroadcastFlag = 0; public: - ROperator_Where(){} - ROperator_Where(const std::string & nameA, const std::string & nameB, const std::string & nameC, const std::string & nameY): - fNA(UTILITY::Clean_name(nameA)), fNB(UTILITY::Clean_name(nameB)), fNC(UTILITY::Clean_name(nameC)), fNY(UTILITY::Clean_name(nameY)){ - fInputTensorNames = { fNA, fNB, fNC }; - fOutputTensorNames = { fNY }; - } + ROperator_Where() {} + ROperator_Where(const std::string &nameC, + const std::string &nameX, + const std::string &nameY, + const std::string &nameZ) + : fNC(UTILITY::Clean_name(nameC)), + fNX(UTILITY::Clean_name(nameX)), + fNY(UTILITY::Clean_name(nameY)), + fNZ(UTILITY::Clean_name(nameZ)) + { + fInputTensorNames = { fNC, fNX, fNY }; + fOutputTensorNames = { fNZ }; + } // type of output given input - std::vector TypeInference(std::vector input) override { - return input; + std::vector TypeInference(std::vector input) override + { + // output type follows X (and Y), not C (which is bool) + return { input[1] }; } // shape of output tensors given input tensors - std::vector> ShapeInference(std::vector> input) override { - // assume now inputs have same shape (no broadcasting) - auto ret = std::vector>(1, input[0]); // return vector size 1 with first input - return ret; + std::vector> ShapeInference(std::vector> input) override + { + // conservative: assume same shape (broadcasting resolved in Initialize) + return { input[1] }; } - void Initialize(RModel& model) override { - // input must be a graph input, or already initialized intermediate tensor - if (!model.CheckIfTensorAlreadyExist(fNA)){ - throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNA + "is not found in model"); - } - if (!model.CheckIfTensorAlreadyExist(fNB)) { - throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNB + "is not found in model"); - } - if (!model.CheckIfTensorAlreadyExist(fNC)) { - throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNC + "is not found in model"); - } - // check if fNC input tensor is boolean + void Initialize(RModel &model) override + { + // ---------------------------------------------------------------- // + // Check all inputs exist + // ---------------------------------------------------------------- // + if (!model.CheckIfTensorAlreadyExist(fNC)) + throw std::runtime_error(std::string("TMVA SOFIE Where Op: condition tensor ") + fNC + " not found in model"); + if (!model.CheckIfTensorAlreadyExist(fNX)) + throw std::runtime_error(std::string("TMVA SOFIE Where Op: X tensor ") + fNX + " not found in model"); + if (!model.CheckIfTensorAlreadyExist(fNY)) + throw std::runtime_error(std::string("TMVA SOFIE Where Op: Y tensor ") + fNY + " not found in model"); + + // condition tensor is bool (uint8) - mark if it is a live input tensor if (model.IsReadyInputTensor(fNC)) fIsInputBoolTensor = true; - // check broadcast for A, B and C - fShapeA = model.GetTensorShape(fNA); - fShapeB = model.GetTensorShape(fNB); - fShapeC = model.GetTensorShape(fNC); - bool broadcast = !UTILITY::AreSameShape(fShapeA, fShapeB) || !UTILITY::AreSameShape(fShapeA, fShapeC); - if (broadcast) { - // find shape to broadcast between A,B,C looking for max length - size_t lengthA = ConvertShapeToLength(fShapeA); - size_t lengthB = ConvertShapeToLength(fShapeB); - size_t lengthC = ConvertShapeToLength(fShapeC); - bool broadcastA = false, broadcastB = false, broadcastC = false; - if (lengthA >= lengthB && lengthA >= lengthC) { - fShapeY = fShapeA; - //broadcast B and C if different than A - broadcastB = (lengthB != lengthA); - broadcastC = (lengthC != lengthA); - } - else if (lengthB >= lengthA && lengthB >= lengthC) { - fShapeY = fShapeB; - //broadcast A and C if different than B - broadcastA = (lengthA != lengthB); - broadcastC = (lengthC != lengthB); - } - else if (lengthC >= lengthA && lengthC >= lengthB) { - fShapeY = fShapeC; - //broadcast A and B if different than C - broadcastA = (lengthA != lengthC); - broadcastB = (lengthB != lengthC); - } - // Broadcast A to Y - if (broadcastA) { - fNBroadcastedA = "BC_" + fNA + "_to_" + fNY; - if (model.IsInitializedTensor(fNA)) { - auto data = model.GetInitializedTensorData(fNA); - std::shared_ptr broadcastedData( - UTILITY::UnidirectionalBroadcast(static_cast(data.get()), fShapeA, fShapeY), - std::default_delete()); - // Update the data and the shape of A - model.AddConstantTensor(fNBroadcastedA, model.GetTensorType(fNA), fShapeY, broadcastedData); - fShapeA = fShapeY; - } else { - // Add an intermediate tensor for broadcasting A - model.AddIntermediateTensor(fNBroadcastedA, model.GetTensorType(fNA), fShapeY); - } - } - // Broadcast B to Y - if (broadcastB) { - fNBroadcastedB = "BC_" + fNB + "_to_" + fNY; - if (model.IsInitializedTensor(fNB)) { - auto data = model.GetInitializedTensorData(fNB); - std::shared_ptr broadcastedData( - UTILITY::UnidirectionalBroadcast(static_cast(data.get()), fShapeB, fShapeY), - std::default_delete()); - // do not update tensor B but add broadcasted one (since it can be input to some other operators) - model.AddConstantTensor(fNBroadcastedB, model.GetTensorType(fNB), fShapeY, broadcastedData); - fShapeB = fShapeY; - } else { - // Add an intermediate tensor for broadcasting B - model.AddIntermediateTensor(fNBroadcastedB, model.GetTensorType(fNB), fShapeY); - } - } - // Broadcast C to Y - if (broadcastC) { - fNBroadcastedC = "BC_" + fNC + "_to_" + fNY; - if (model.IsInitializedTensor(fNC)) { - auto data = model.GetInitializedTensorData(fNC); - std::shared_ptr broadcastedData( - UTILITY::UnidirectionalBroadcast(static_cast(data.get()), fShapeC, fShapeY), - std::default_delete()); - // do not update tensor C but add broadcasted one (since it can be input to some other operators) - model.AddConstantTensor(fNBroadcastedC, model.GetTensorType(fNC), fShapeY, broadcastedData); - fShapeC = fShapeY; - } else { - // Add an intermediate tensor for broadcasting B - model.AddIntermediateTensor(fNBroadcastedC, model.GetTensorType(fNC), fShapeY); - } - } + // ---------------------------------------------------------------- // + // Collect shapes – dynamic or static + // ---------------------------------------------------------------- // + int dynamicInputs = 0; // bitmask: bit0=C, bit1=X, bit2=Y + + if (model.IsDynamicTensor(fNC)) { + fDimShapeC = model.GetDynamicTensorShape(fNC); + dynamicInputs |= 1; } else { - fShapeY = fShapeA; + fShapeC = model.GetTensorShape(fNC); + fDimShapeC = ConvertShapeToDim(fShapeC); } - // check case of constant output (if all inputs are defined) - if (model.IsInitializedTensor(fNC)) { - - std::string nameC = fNBroadcastedC.empty()? fNC : fNBroadcastedC; - auto dataC = static_cast(model.GetInitializedTensorData(nameC).get()); - model.SetNotWritableInitializedTensor(nameC); - T * dataA = nullptr; - T * dataB = nullptr; - std::vector shapeDataA; - std::vector shapeDataB; - if (model.IsInitializedTensor(fNA)) { - std::string nameA = fNBroadcastedA.empty()? fNA : fNBroadcastedA; - dataA = static_cast(model.GetInitializedTensorData(nameA).get()); - // flag tensors to not be written in a file - model.SetNotWritableInitializedTensor(nameA); - } else if (model.IsShapeTensor(fNA)) - shapeDataA = model.GetShapeTensorValues(fNA); - if (model.IsInitializedTensor(fNB)) { - std::string nameB = fNBroadcastedB.empty()? fNB : fNBroadcastedB; - dataB = static_cast(model.GetInitializedTensorData(nameB).get()); - model.SetNotWritableInitializedTensor(nameB); - } else if (model.IsShapeTensor(fNB)) - shapeDataB = model.GetShapeTensorValues(fNB); - - std::vector dataY; - std::vector shapeDataY; - - bool isOutputConstantTensor = true; - if (dataA && dataB) { - dataY.resize(ConvertShapeToLength(fShapeY)); - for (size_t i = 0; i < dataY.size(); i++) - dataY[i] = (dataC[i]) ? dataA[i] : dataB[i]; - } - else if (dataA && shapeDataB.size()>0 ) { - shapeDataY.resize(ConvertShapeToLength(fShapeY)); - for (size_t i = 0; i < shapeDataY.size(); i++) { - shapeDataY[i] = (dataC[i]) ? Dim{size_t(dataA[i])} : shapeDataB[i]; - isOutputConstantTensor &= !shapeDataY[i].isParam; - } - } - else if (dataB && shapeDataA.size()>0 ) { - shapeDataY.resize(ConvertShapeToLength(fShapeY)); - for (size_t i = 0; i < shapeDataY.size(); i++) { - shapeDataY[i] = (dataC[i]) ? shapeDataB[i] : Dim{size_t(dataB[i])}; - isOutputConstantTensor &= !shapeDataY[i].isParam; - } + if (model.IsDynamicTensor(fNX)) { + fDimShapeX = model.GetDynamicTensorShape(fNX); + dynamicInputs |= 2; + } else { + fShapeX = model.GetTensorShape(fNX); + fDimShapeX = ConvertShapeToDim(fShapeX); + } + if (model.IsDynamicTensor(fNY)) { + fDimShapeY = model.GetDynamicTensorShape(fNY); + dynamicInputs |= 4; + } else { + fShapeY = model.GetTensorShape(fNY); + fDimShapeY = ConvertShapeToDim(fShapeY); + } + + if (model.Verbose()) { + if (dynamicInputs & 1) + std::cout << "Where : condition " << fNC << " is dynamic " << ConvertDimShapeToString(fDimShapeC) << "\n"; + if (dynamicInputs & 2) + std::cout << "Where : X " << fNX << " is dynamic " << ConvertDimShapeToString(fDimShapeX) << "\n"; + if (dynamicInputs & 4) + std::cout << "Where : Y " << fNY << " is dynamic " << ConvertDimShapeToString(fDimShapeY) << "\n"; + } + + // ---------------------------------------------------------------- // + // Static path: all shapes known at code-gen time + // ---------------------------------------------------------------- // + if (dynamicInputs == 0) { + + // Multidirectional broadcast over all three tensors + auto retXY = UTILITY::MultidirectionalBroadcastShape(fShapeX, fShapeY); + fBroadcastFlag = retXY.first; + fShapeZ = retXY.second; + // also factor in C + auto retCZ = UTILITY::MultidirectionalBroadcastShape(fShapeC, fShapeZ); + fBroadcastFlag |= retCZ.first; + fShapeZ = retCZ.second; + + bool allConstant = model.IsConstantTensor(fNC) && + model.IsConstantTensor(fNX) && + model.IsConstantTensor(fNY); + + if (allConstant) { + // ---------------------------------------------------------- + // Constant folding: evaluate Where at model initialisation + // ---------------------------------------------------------- + auto broadcastIfNeeded = [&](const std::string &name, + const std::vector &shape, + std::string &bcName, + const std::string &prefix) { + if (shape != fShapeZ) { + bcName = prefix + name + "to" + fNZ; + auto data = model.GetInitializedTensorData(name); + std::shared_ptr bcData( + UTILITY::UnidirectionalBroadcast(static_cast(data.get()), shape, fShapeZ), + std::default_delete()); + model.AddConstantTensor(bcName, model.GetTensorType(name), fShapeZ, bcData); + } + }; + + broadcastIfNeeded(fNX, fShapeX, fNBroadcastedX, "BC_"); + broadcastIfNeeded(fNY, fShapeY, fNBroadcastedY, "BC_"); + broadcastIfNeeded(fNC, fShapeC, fNBroadcastedC, "BC_"); + + const std::string &nameC = fNBroadcastedC.empty() ? fNC : fNBroadcastedC; + const std::string &nameX = fNBroadcastedX.empty() ? fNX : fNBroadcastedX; + const std::string &nameY = fNBroadcastedY.empty() ? fNY : fNBroadcastedY; + + auto dataC = static_cast(model.GetInitializedTensorData(nameC).get()); + auto dataX = static_cast (model.GetInitializedTensorData(nameX).get()); + auto dataY = static_cast (model.GetInitializedTensorData(nameY).get()); + + size_t len = ConvertShapeToLength(fShapeZ); + std::vector dataZ(len); + for (size_t i = 0; i < len; ++i) + dataZ[i] = dataC[i] ? dataX[i] : dataY[i]; + + model.AddConstantTensor(fNZ, fShapeZ, dataZ.data()); + model.SetNotWritableInitializedTensor(nameC); + model.SetNotWritableInitializedTensor(nameX); + model.SetNotWritableInitializedTensor(nameY); + fIsOutputConstant = true; + fOutputTensorNames.pop_back(); + + if (model.Verbose()) + std::cout << "Where --> " << fNZ << " " << ConvertShapeToString(fShapeZ) + << " : " << ConvertValuesToString(dataZ) << " (constant)\n"; + } else { + // ---------------------------------------------------------- + // Non-constant static: register broadcasted intermediates + // ---------------------------------------------------------- + auto registerBC = [&](const std::string &name, + const std::vector &shape, + std::string &bcName, + const std::string &prefix) { + if (shape != fShapeZ) { + bcName = prefix + name + "to" + fNZ; + if (model.IsInitializedTensor(name)) { + auto data = model.GetInitializedTensorData(name); + std::shared_ptr bcData( + UTILITY::UnidirectionalBroadcast(static_cast(data.get()), shape, fShapeZ), + std::default_delete()); + model.AddConstantTensor(bcName, model.GetTensorType(name), fShapeZ, bcData); + } else { + model.AddIntermediateTensor(bcName, model.GetTensorType(name), fShapeZ); + } + } + }; + + registerBC(fNX, fShapeX, fNBroadcastedX, "BC_"); + registerBC(fNY, fShapeY, fNBroadcastedY, "BC_"); + registerBC(fNC, fShapeC, fNBroadcastedC, "BC_"); + + fDimShapeZ = ConvertShapeToDim(fShapeZ); + model.AddIntermediateTensor(fNZ, model.GetTensorType(fNX), fShapeZ); + + if (model.Verbose()) + std::cout << "Where : C=" << fNC << " " << ConvertShapeToString(fShapeC) + << " X=" << fNX << " " << ConvertShapeToString(fShapeX) + << " Y=" << fNY << " " << ConvertShapeToString(fShapeY) + << " --> Z=" << fNZ << " " << ConvertShapeToString(fShapeZ) << "\n"; } - else if (shapeDataB.size() > 0 && shapeDataA.size()>0 ) { - shapeDataY.resize(ConvertShapeToLength(fShapeY)); - for (size_t i = 0; i < shapeDataY.size(); i++) { - shapeDataY[i] = (dataC[i]) ? shapeDataA[i] : shapeDataB[i]; - isOutputConstantTensor &= !shapeDataY[i].isParam; + + } else { + // ---------------------------------------------------------------- // + // Dynamic path: at least one input has a parametric shape + // ---------------------------------------------------------------- // + auto retXY = UTILITY::MultidirectionalBroadcastShape(fDimShapeX, fDimShapeY); + fBroadcastFlag = retXY.first; + fDimShapeZ = retXY.second; + auto retCZ = UTILITY::MultidirectionalBroadcastShape(fDimShapeC, fDimShapeZ); + fBroadcastFlag |= retCZ.first; + fDimShapeZ = retCZ.second; + + // Resolve std::max params to actual input dim params (same logic as BasicBinary) + if (fBroadcastFlag & 4) { + auto IsInputDimParam = [&](const std::string &p) { + for (auto &input : model.GetInputTensorNames()) + for (auto &s : model.GetDimTensorShape(input)) + if (s.isParam && s.param == p) return true; + return false; + }; + for (size_t i = 0; i < fDimShapeZ.size(); i++) { + auto &s = fDimShapeZ[i]; + if (s.isParam && s.param.find("std::max") != std::string::npos) { + // prefer X dim over Y dim + if (i < fDimShapeX.size() && IsInputDimParam(fDimShapeX[i].param)) { + s = (fDimShapeX[i].dim != 1) ? fDimShapeX[i] : fDimShapeY[i]; + } else if (i < fDimShapeY.size() && IsInputDimParam(fDimShapeY[i].param)) { + s = (fDimShapeY[i].dim != 1) ? fDimShapeY[i] : fDimShapeX[i]; + } + } } } - fIsOutputConstant = true; // this contains both case constant tensor output ans shape tensor output - if (isOutputConstantTensor && dataY.empty()) { - dataY.resize(shapeDataY.size()); - for (size_t i = 0; i < shapeDataY.size(); i++) - dataY[i] = static_cast(shapeDataY[i].dim); - } - if (dataY.size() > 0) - model.AddConstantTensor(fNY, fShapeY, dataY.data()); - else if (shapeDataY.size() > 0 ) - model.AddShapeTensor(fNY, shapeDataY, fShapeY.size() == 0); - else { - fIsOutputConstant = false; - } - if (fIsOutputConstant && model.Verbose()) - std::cout << "Where op ---> " << fNY << " " << ConvertShapeToString(fShapeY) << " : " - << ((dataY.size() > 0) ? ConvertValuesToString(dataY) : ConvertDimShapeToString(shapeDataY) ) - << ((dataY.size() > 0) ? " (constant)" : " (shape)") << std::endl; - // output is a constant tensor - if (fIsOutputConstant) fOutputTensorNames.pop_back(); - } - if (!fIsOutputConstant) { - model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fShapeY); - if (model.Verbose()) - std::cout << "Where op " << " condition : " << fNC << " " << ConvertShapeToString(fShapeC) << - " X " << fNA << " " << ConvertShapeToString(fShapeA) << " Y " << fNB << " " << ConvertShapeToString(fShapeB) - << " ---> " << fNY << " " << ConvertShapeToString(fShapeY) << std::endl; + model.AddIntermediateTensor(fNZ, model.GetTensorType(fNX), fDimShapeZ); + + if (model.Verbose()) + std::cout << "Where (dynamic) : C=" << ConvertDimShapeToString(fDimShapeC) + << " X=" << ConvertDimShapeToString(fDimShapeX) + << " Y=" << ConvertDimShapeToString(fDimShapeY) + << " --> Z=" << ConvertDimShapeToString(fDimShapeZ) << "\n"; } } - std::string GenerateInitCode() override { + std::string GenerateInitCode() override + { std::stringstream out; return out.str(); } - std::string Generate(std::string opName) override { - + std::string Generate(std::string opName) override + { if (fIsOutputConstant) return ""; opName = "op_" + opName; - if (fShapeY.empty()) { + if (fDimShapeZ.empty()) { throw std::runtime_error("TMVA SOFIE Where Op called to Generate without being initialized first"); } + std::stringstream out; - out << SP << "\n//-------- Where " << opName << " --> " << ConvertShapeToString(fShapeY) << "\n"; - size_t length = ConvertShapeToLength(fShapeY); - std::string typeName = TensorType::Name(); - // Broadcast A if it's uninitialized - if (fShapeA != fShapeY) { - out << SP << "// Broadcasting uninitialized tensor " << fNA << "\n"; - //out << SP << "{\n"; - out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast(tensor_" << fNA << ", " << ConvertShapeToString(fShapeA) << ", " << ConvertShapeToString(fShapeY) - << ", tensor_" << fNBroadcastedA << ");\n"; + out << SP << "\n//------ WHERE " << opName << " --> " << ConvertDimShapeToString(fDimShapeZ) << "\n"; + + // ---------------------------------------------------------------- // + // Runtime broadcast validation (dynamic shapes, flag bit 4) + // ---------------------------------------------------------------- // + if (fBroadcastFlag & 4) { + auto lengthX = ConvertDimShapeToLength(fDimShapeX); + auto lengthY = ConvertDimShapeToLength(fDimShapeY); + auto lengthC = ConvertDimShapeToLength(fDimShapeC); + out << SP << "if (" << lengthX << " != " << lengthY << " || " + << lengthX << " != " << lengthC << ") {\n"; + for (size_t i = 0; i < fDimShapeZ.size(); i++) { + // validate X vs Z + if (i < fDimShapeX.size() && fDimShapeX[i].isParam) { + out << SP << SP << "if (" << fDimShapeX[i] << " != 1 && " + << fDimShapeX[i] << " != " << fDimShapeZ[i] << ")\n"; + out << SP << SP << SP + << "throw std::runtime_error(\"SOFIE Where: cannot broadcast X dim " << i << " in " << opName << "\");\n"; + } + // validate Y vs Z + if (i < fDimShapeY.size() && fDimShapeY[i].isParam) { + out << SP << SP << "if (" << fDimShapeY[i] << " != 1 && " + << fDimShapeY[i] << " != " << fDimShapeZ[i] << ")\n"; + out << SP << SP << SP + << "throw std::runtime_error(\"SOFIE Where: cannot broadcast Y dim " << i << " in " << opName << "\");\n"; + } + // validate C vs Z + if (i < fDimShapeC.size() && fDimShapeC[i].isParam) { + out << SP << SP << "if (" << fDimShapeC[i] << " != 1 && " + << fDimShapeC[i] << " != " << fDimShapeZ[i] << ")\n"; + out << SP << SP << SP + << "throw std::runtime_error(\"SOFIE Where: cannot broadcast C dim " << i << " in " << opName << "\");\n"; + } + } + out << SP << "}\n"; + } + + // ---------------------------------------------------------------- // + // Runtime broadcasting for non-constant, non-initialised tensors + // ---------------------------------------------------------------- // + // Broadcast X if needed + if (!fNBroadcastedX.empty()) { + out << SP << "// Broadcast X tensor " << fNX << "\n"; + out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" + << TensorType::Name() << ">(tensor_" << fNX << ", " + << ConvertDimShapeToString(fDimShapeX) << ", " + << ConvertDimShapeToString(fDimShapeZ) << ", tensor_" << fNBroadcastedX << ");\n"; } - // Broadcast B if it's uninitialized - if (fShapeB != fShapeY) { - out << SP << "// Broadcasting uninitialized tensor " << fNB << "\n"; - //out << SP << "{\n"; - out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast(tensor_" << fNB << ", " << ConvertShapeToString(fShapeB) << ", " << ConvertShapeToString(fShapeY) - << ", tensor_" << fNBroadcastedB << ");\n"; + // Broadcast Y if needed + if (!fNBroadcastedY.empty()) { + out << SP << "// Broadcast Y tensor " << fNY << "\n"; + out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" + << TensorType::Name() << ">(tensor_" << fNY << ", " + << ConvertDimShapeToString(fDimShapeY) << ", " + << ConvertDimShapeToString(fDimShapeZ) << ", tensor_" << fNBroadcastedY << ");\n"; } - // Broadcast C if it's uninitialized - if (fShapeC != fShapeY) { - // special case if C is an input tensor + // Broadcast C (condition) if needed + if (!fNBroadcastedC.empty()) { if (fIsInputBoolTensor) { + // live bool input: need a temporary std::vector for the broadcast utility size_t inputLength = ConvertShapeToLength(fShapeC); - out << SP << "std::vector tmp_tensor_" << fNC << "(tensor_" << fNC << ", tensor_" << fNC << " + " << inputLength << ");\n"; + out << SP << "std::vector tmp_tensor_" << fNC + << "(tensor_" << fNC << ", tensor_" << fNC << " + " << inputLength << ");\n"; + out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast" + << "(tmp_tensor_" << fNC << ".data(), " + << ConvertDimShapeToString(fDimShapeC) << ", " + << ConvertDimShapeToString(fDimShapeZ) << ", tensor_" << fNBroadcastedC << ");\n"; + } else { + out << SP << "// Broadcast condition tensor " << fNC << "\n"; + out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast" + << "(tensor_" << fNC << ", " + << ConvertDimShapeToString(fDimShapeC) << ", " + << ConvertDimShapeToString(fDimShapeZ) << ", tensor_" << fNBroadcastedC << ");\n"; } - out << SP << "// Broadcasting uninitialized tensor " << fNC << "\n"; - //out << SP << "{\n"; - out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast(tmp_tensor_" << fNC << ".data(), " << ConvertShapeToString(fShapeC) << ", " << ConvertShapeToString(fShapeY) - << ", tensor_" << fNBroadcastedC << ");\n"; } - std::string nameA = fNBroadcastedA.empty()? fNA : fNBroadcastedA; - std::string nameB = fNBroadcastedB.empty()? fNB : fNBroadcastedB; - std::string nameC = fNBroadcastedC.empty()? fNC : fNBroadcastedC; - out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n"; - // get output tensor applying condition - out << SP << SP << "tensor_" << fNY << "[id] = " << "tensor_" << nameC << "[id] ? tensor_" - << nameA << "[id] : tensor_" + nameB + "[id];\n"; - out << SP << "}\n"; + + // Final (possibly broadcasted) tensor names + const std::string nameX = fNBroadcastedX.empty() ? fNX : fNBroadcastedX; + const std::string nameY = fNBroadcastedY.empty() ? fNY : fNBroadcastedY; + const std::string nameC = fNBroadcastedC.empty() ? fNC : fNBroadcastedC; + + // ---------------------------------------------------------------- // + // Generate loop(s) with per-dimension stride-based index arithmetic + // (same pattern as BasicBinary) + // ---------------------------------------------------------------- // + auto stridesX = UTILITY::ComputeStrideFromShape(fDimShapeX); + auto stridesY = UTILITY::ComputeStrideFromShape(fDimShapeY); + auto stridesC = UTILITY::ComputeStrideFromShape(fDimShapeC); + auto stridesZ = UTILITY::ComputeStrideFromShape(fDimShapeZ); + + auto buildIdxExpr = [&](const std::vector &dimShape, + const std::vector &strides, + size_t rankZ) -> std::string { + if (dimShape.empty() || + std::all_of(dimShape.begin(), dimShape.end(), + [](Dim d) { return d.dim == 1 || d.GetVal() == "1"; })) + return "0"; + std::string expr; + size_t offset = rankZ - dimShape.size(); + for (size_t i = 0; i < dimShape.size(); ++i) { + if (dimShape[i].dim == 1 || dimShape[i].GetVal() == "1") continue; + expr += "idx_" + std::to_string(i + offset); + if (strides[i].GetVal() != "1") + expr += " * " + strides[i].GetVal(); + expr += " + "; + } + if (expr.size() >= 3) + for (int j = 0; j < 3; j++) expr.pop_back(); // remove trailing " + " + return expr.empty() ? "0" : expr; + }; + + std::string idxX = buildIdxExpr(fDimShapeX, stridesX, fDimShapeZ.size()); + std::string idxY = buildIdxExpr(fDimShapeY, stridesY, fDimShapeZ.size()); + std::string idxC = buildIdxExpr(fDimShapeC, stridesC, fDimShapeZ.size()); + + // Emit nested loops over output shape + int nloop = 0; + std::string idxZ; + if (fDimShapeZ.empty() || + std::all_of(fDimShapeZ.begin(), fDimShapeZ.end(), + [](Dim d) { return d.dim == 1 || d.GetVal() == "1"; })) { + idxZ = "0"; + } else { + for (size_t i = 0; i < fDimShapeZ.size(); ++i) { + if (fDimShapeZ[i].dim != 1 && fDimShapeZ[i].GetVal() != "1") { + nloop++; + for (int j = 0; j < nloop; j++) out << SP; + out << "for (size_t idx_" << i << " = 0; idx_" << i + << " < " << fDimShapeZ[i] << "; ++idx_" << i << ") {\n"; + idxZ += "idx_" + std::to_string(i); + if (stridesZ[i].GetVal() != "1") + idxZ += " * " + stridesZ[i].GetVal(); + idxZ += " + "; + } + } + if (idxZ.size() >= 3) + for (int j = 0; j < 3; j++) idxZ.pop_back(); + } + + // Inner assignment + for (int j = 0; j < nloop + 1; j++) out << SP; + out << "tensor_" << fNZ << "[" << idxZ << "] = " + << "tensor_" << nameC << "[" << idxC << "] ? " + << "tensor_" << nameX << "[" << idxX << "] : " + << "tensor_" << nameY << "[" << idxY << "];\n"; + + // Close loops + for (int i = nloop; i > 0; i--) { + for (int j = 0; j < i; j++) out << SP; + out << "}\n"; + } + return out.str(); } - }; -}//SOFIE -}//Experimental -}//TMVA - +} // namespace SOFIE +} // namespace Experimental +} // namespace TMVA -#endif //TMVA_SOFIE_ROperator_Where +#endif // TMVA_SOFIE_ROperator_Where diff --git a/tmva/sofie/src/SOFIE_common.cxx b/tmva/sofie/src/SOFIE_common.cxx index cad95a159f58a..aca8afabc4281 100644 --- a/tmva/sofie/src/SOFIE_common.cxx +++ b/tmva/sofie/src/SOFIE_common.cxx @@ -99,6 +99,8 @@ std::string ConvertTypeToString(ETensorType type){ } } +// invert function might now work correctly for booleans +// prefer avoid using it if possible ETensorType ConvertStringToType(std::string type){ if(type == "float32" || type == "float" || type == "Float"){ return ETensorType::FLOAT; @@ -106,10 +108,13 @@ ETensorType ConvertStringToType(std::string type){ else if(type == "int64" || type == "int64_t"){ return ETensorType::INT64; } + else if(type == "int32" || type == "int32_t"){ + return ETensorType::INT32; + } else if (type == "double" || type == "float64"){ return ETensorType::DOUBLE; } - else if (type == "bool" ){ + else if (type == "bool" || type == "uint8_t" ){ return ETensorType::BOOL; } else{ diff --git a/tmva/sofie/test/TestCustomModelsFromONNX.cxx b/tmva/sofie/test/TestCustomModelsFromONNX.cxx index 94993f601a3c4..25bf2350c5a61 100644 --- a/tmva/sofie/test/TestCustomModelsFromONNX.cxx +++ b/tmva/sofie/test/TestCustomModelsFromONNX.cxx @@ -3006,3 +3006,57 @@ TEST(ONNX, NotIsNaN) } } +TEST(ONNX, ScatterND_1) +{ + // test 1-D scatter (k=1, scalar slice) + std::vector input = {1.,2.,3.,4.,5.}; // shape {5} + std::vector indices = { 0, 2, 4}; // shape {3,1} + std::vector updates = { 10.,30.,50.}; // shape {3} + std::vector correct_output = {10., 2., 30., 4., 50.}; + + ASSERT_INCLUDE_AND_RUN(std::vector, "ScatterND_1", input, indices, updates); + + // Checking output size + EXPECT_EQ(output.size(), correct_output.size()); + // Checking output + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - correct_output[i]), DEFAULT_TOLERANCE); + } +} + +TEST(ONNX, ScatterND_2) +{ + // test 2-d Scatter - scatter rows - reduction = 'add + std::vector input = {1.,1.,2.,2.,3.,3.}; // shape {3,2} + std::vector indices = { 0, 1}; // shape {2,1} + std::vector updates = { 10.,10.,20.,20.}; // shape { 2,2} + std::vector correct_output = {11., 11., 22., 22., 3., 3.}; + + ASSERT_INCLUDE_AND_RUN(std::vector, "ScatterND_2", input, indices, updates); + + // Checking output size + EXPECT_EQ(output.size(), correct_output.size()); + // Checking output + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - correct_output[i]), DEFAULT_TOLERANCE); + } +} + +TEST(ONNX, ScatterND_3) +{ + // test element wise scatter (k==rank input) reduction = 'mul' + std::vector input = {1.,2.,3.,4.}; // shape {2,2} + std::vector indices = { 0,0, 1,1}; // shape {2,2} + std::vector updates = { 11.,22.}; // shape { 2} + std::vector correct_output = {11., 2., 3., 88.}; + + ASSERT_INCLUDE_AND_RUN(std::vector, "ScatterND_3", input, indices, updates); + + // Checking output size + EXPECT_EQ(output.size(), correct_output.size()); + // Checking output + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - correct_output[i]), DEFAULT_TOLERANCE); + } +} + diff --git a/tmva/sofie/test/input_models/ScatterND_1.onnx b/tmva/sofie/test/input_models/ScatterND_1.onnx new file mode 100644 index 0000000000000..6e6bd2b58c0f7 --- /dev/null +++ b/tmva/sofie/test/input_models/ScatterND_1.onnx @@ -0,0 +1,21 @@ +  onnx-example:” ++ +data +indices +updatesoutput" ScatterND TestGraphZ +data + + +Z +indices +  + +Z +updates + + +b +output + + +B \ No newline at end of file diff --git a/tmva/sofie/test/input_models/ScatterND_2.onnx b/tmva/sofie/test/input_models/ScatterND_2.onnx new file mode 100644 index 0000000000000..9211d555dffda --- /dev/null +++ b/tmva/sofie/test/input_models/ScatterND_2.onnx @@ -0,0 +1,22 @@ +  onnx-example:µ +@ +data +indices +updatesoutput" ScatterND* + reduction"add  TestGraphZ +data +  + +Z +indices +  + +Z +updates +  + +b +output +  + +B \ No newline at end of file diff --git a/tmva/sofie/test/input_models/ScatterND_3.onnx b/tmva/sofie/test/input_models/ScatterND_3.onnx new file mode 100644 index 0000000000000..20d83a7dd1715 --- /dev/null +++ b/tmva/sofie/test/input_models/ScatterND_3.onnx @@ -0,0 +1,22 @@ +  onnx-example:± +@ +data +indices +updatesoutput" ScatterND* + reduction"mul  TestGraphZ +data +  + +Z +indices +  + +Z +updates + + +b +output +  + +B \ No newline at end of file diff --git a/tmva/sofie_parsers/CMakeLists.txt b/tmva/sofie_parsers/CMakeLists.txt index 80069c44c6929..5dc49688dfa6f 100644 --- a/tmva/sofie_parsers/CMakeLists.txt +++ b/tmva/sofie_parsers/CMakeLists.txt @@ -76,6 +76,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofieParser src/ParseEinsum.cxx src/ParseRandom.cxx src/ParseScatterElements.cxx + src/ParseScatterND.cxx src/ParseNonZero.cxx src/ParseNot.cxx ${PROTO_SRCS} diff --git a/tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx b/tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx index 1efb901d8e791..ed42f7381693b 100644 --- a/tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx +++ b/tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx @@ -29,6 +29,8 @@ class RModelParser_ONNX { public: struct OperatorsMapImpl; + enum EFusedOp { kMatMulAdd, kConvAdd, kConvTransAdd, kGemmRelu, kBatchnormRelu}; + private: bool fVerbose = false; @@ -36,8 +38,9 @@ private: std::unique_ptr fOperatorsMapImpl; // Type of the tensors std::unordered_map fTensorTypeMap; - // flag list of fused operators - std::vector fFusedOperators; + + // List of fused operators storing as key the second operator and a value a pair of fusion type and parent operator + std::map> fFusedOperators; public: diff --git a/tmva/sofie_parsers/src/ParseBasicNary.cxx b/tmva/sofie_parsers/src/ParseBasicNary.cxx index 0da3d493e095d..13e5d1ce5e5b4 100644 --- a/tmva/sofie_parsers/src/ParseBasicNary.cxx +++ b/tmva/sofie_parsers/src/ParseBasicNary.cxx @@ -43,19 +43,22 @@ std::unique_ptr ParseBasicNary(RModelParser_ONNX& parser, const onnx: return op; } - +// Max ParserFuncSignature ParseMax = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { return ParseBasicNary(parser, nodeproto); }; +// Min ParserFuncSignature ParseMin= [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { return ParseBasicNary(parser, nodeproto); }; +//Mean ParserFuncSignature ParseMean = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { return ParseBasicNary(parser, nodeproto); }; +// Sum ParserFuncSignature ParseSum = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { return ParseBasicNary(parser, nodeproto); }; diff --git a/tmva/sofie_parsers/src/ParseCast.cxx b/tmva/sofie_parsers/src/ParseCast.cxx index 0e07fd7d164fc..6970fc2fcbdce 100644 --- a/tmva/sofie_parsers/src/ParseCast.cxx +++ b/tmva/sofie_parsers/src/ParseCast.cxx @@ -13,21 +13,23 @@ ParserFuncSignature ParseCast = [](RModelParser_ONNX &parser, const onnx::NodePr " but its type is not yet registered"); } - std::unique_ptr op; - std::string attr_type; + ETensorType output_type = ETensorType::UNDEFINED; for (int_t i = 0; i < nodeproto.attribute_size(); i++) { std::string attribute_name = nodeproto.attribute(i).name(); - if (attribute_name == "to") - attr_type = ConvertTypeToString(static_cast(nodeproto.attribute(i).i())); + if (attribute_name == "to") { + output_type = static_cast(nodeproto.attribute(i).i()); + } } + if (output_type == ETensorType::UNDEFINED) + throw std::runtime_error("TMVA::SOFIE ONNX Parser Cast op has invalid output type"); std::string output_name = nodeproto.output(0); - op.reset(new ROperator_Cast(attr_type, nodeproto.input(0), output_name)); + auto op = std::make_unique(output_type, nodeproto.input(0), output_name); if (!parser.IsRegisteredTensorType(output_name)) { - ETensorType output_type = ConvertStringToType(attr_type); parser.RegisterTensorType(output_name, output_type); + std::cout << "Cast -> " << output_name << " type " << (int) output_type << " " << ConvertTypeToString(output_type) << std::endl; } return op; diff --git a/tmva/sofie_parsers/src/ParseComparision.cxx b/tmva/sofie_parsers/src/ParseComparision.cxx index 06128e6fd00d6..bd1598bb5142b 100644 --- a/tmva/sofie_parsers/src/ParseComparision.cxx +++ b/tmva/sofie_parsers/src/ParseComparision.cxx @@ -65,17 +65,17 @@ ParserFuncSignature ParseLess = [](RModelParser_ONNX &parser, const onnx::NodePr return ParseComparision(parser, nodeproto); }; -// Parse Mul +// Parse LessEq ParserFuncSignature ParseLessEq = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { return ParseComparision(parser, nodeproto); }; -// Parse Div +// Parse Greater ParserFuncSignature ParseGreater = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { return ParseComparision(parser, nodeproto); }; -// Parse Pow +// Parse GreaterEq ParserFuncSignature ParseGreaterEq = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { return ParseComparision(parser, nodeproto); }; diff --git a/tmva/sofie_parsers/src/ParseConstant.cxx b/tmva/sofie_parsers/src/ParseConstant.cxx index a4ba71d038389..c48b0507769ca 100644 --- a/tmva/sofie_parsers/src/ParseConstant.cxx +++ b/tmva/sofie_parsers/src/ParseConstant.cxx @@ -37,7 +37,7 @@ ParserFuncSignature ParseConstant = [](RModelParser_ONNX &parser, const onnx::No std::string output_name = nodeproto.output(0); ETensorType output_type = ETensorType::FLOAT; std::vector shape; // output shape (use in case of constant operator) - // it should be only one attribute (Constant or 1 or 0 COnstant of Shape) + // it should be only one attribute (Constant or 1 or 0 Constant of Shape) if (nodeproto.attribute_size() > 1) throw std::runtime_error("TMVA::SOFIE ONNX Parser Constant or ConstantOfShape and attribute size is larger than 1"); if (nodeproto.attribute_size() > 0) { @@ -176,6 +176,7 @@ ParserFuncSignature ParseConstant = [](RModelParser_ONNX &parser, const onnx::No if (!parser.IsRegisteredTensorType(output_name)) { parser.RegisterTensorType(output_name, output_type); + std::cout << "Constant -> " << output_name << " " << (int) output_type << std::endl; } if (parser.Verbose()) diff --git a/tmva/sofie_parsers/src/ParseExpand.cxx b/tmva/sofie_parsers/src/ParseExpand.cxx index e55e791aaee0c..8eae2de4e16e2 100644 --- a/tmva/sofie_parsers/src/ParseExpand.cxx +++ b/tmva/sofie_parsers/src/ParseExpand.cxx @@ -13,6 +13,7 @@ ParserFuncSignature ParseExpand = [](RModelParser_ONNX &parser, const onnx::Node const std::string input_name = nodeproto.input(0); if (parser.IsRegisteredTensorType(input_name)) { input_type = parser.GetTensorType(input_name); + std::cout << "input type is " << static_cast(input_type) << " " << ConvertTypeToString(input_type) << std::endl; } else { throw std::runtime_error( "TMVA::SOFIE ONNX Parser Expand op has input tensor " + input_name + @@ -39,6 +40,9 @@ ParserFuncSignature ParseExpand = [](RModelParser_ONNX &parser, const onnx::Node case ETensorType::INT64: op.reset(new ROperator_Expand(input_name, shape_name, output_name)); break; + case ETensorType::BOOL: + op.reset(new ROperator_Expand(input_name, shape_name, output_name)); + break; default: throw std::runtime_error("TMVA::SOFIE - Unsupported - Expand Operator does " "not support input type " + diff --git a/tmva/sofie_parsers/src/ParseNot.cxx b/tmva/sofie_parsers/src/ParseNot.cxx index c1bc7b027a5ad..cd03b00d0c6be 100644 --- a/tmva/sofie_parsers/src/ParseNot.cxx +++ b/tmva/sofie_parsers/src/ParseNot.cxx @@ -17,8 +17,9 @@ ParserFuncSignature ParseNot = [](RModelParser_ONNX &parser, const onnx::NodePro if (parser.IsRegisteredTensorType(input_name)) { input_type = parser.GetTensorType(input_name); + std::cout << "NOT op input type is " << static_cast(input_type) << " " << ConvertTypeToString(input_type) << std::endl; if (input_type !=ETensorType::BOOL && input_type !=ETensorType::UINT8 ) - std::runtime_error("TMVA::SOFIE ONNX Parser Not op has invalid input type " + ConvertTypeToString(input_type)); + throw std::runtime_error("TMVA::SOFIE ONNX Parser Not op has invalid input type " + ConvertTypeToString(input_type)); } else { throw std::runtime_error("TMVA::SOFIE ONNX Parser Not op has input tensor " + input_name + diff --git a/tmva/sofie_parsers/src/ParseReduce.cxx b/tmva/sofie_parsers/src/ParseReduce.cxx index 6c18a4371c342..1753979cfa7bb 100644 --- a/tmva/sofie_parsers/src/ParseReduce.cxx +++ b/tmva/sofie_parsers/src/ParseReduce.cxx @@ -57,14 +57,15 @@ std::unique_ptr ParseReduce(RModelParser_ONNX &parser, const onnx::No std::vector({nodeproto.attribute(i).ints().begin(), nodeproto.attribute(i).ints().end()}); } } - switch (input_type) { - case ETensorType::FLOAT: - op.reset(new ROperator_Reduce(attr_keepdims, attr_axes, input_name, axes_name, output_name)); - break; - default: - throw std::runtime_error("TMVA::SOFIE - Unsupported - Reduce Operator does not yet support input type " + - std::to_string(static_cast(input_type))); - } + op.reset(new ROperator_Reduce(attr_keepdims, attr_axes, input_name, axes_name, output_name)); + // switch (input_type) { + // case ETensorType::FLOAT: + // op.reset(new ROperator_Reduce(attr_keepdims, attr_axes, input_name, axes_name, output_name)); + // break; + // default: + // throw std::runtime_error("TMVA::SOFIE - Unsupported - Reduce Operator does not yet support input type " + + // std::to_string(static_cast(input_type))); + // } if (!parser.IsRegisteredTensorType(output_name)) { parser.RegisterTensorType(output_name, input_type); diff --git a/tmva/sofie_parsers/src/ParseReshape.cxx b/tmva/sofie_parsers/src/ParseReshape.cxx index ddb316ca837a4..5db12d9aac847 100644 --- a/tmva/sofie_parsers/src/ParseReshape.cxx +++ b/tmva/sofie_parsers/src/ParseReshape.cxx @@ -26,6 +26,7 @@ ParserFuncSignature ParseReshape = [](RModelParser_ONNX &parser, const onnx::Nod ? nodeproto.input(1) : ""; if (parser.IsRegisteredTensorType(input_name)) { input_type = parser.GetTensorType(input_name); + std::cout << "Reshape/Un/Squueze op input type is " << static_cast(input_type) << " " << ConvertTypeToString(input_type) << std::endl; } else { throw std::runtime_error("TMVA::SOFIE ONNX Parser Reshape op has input tensor" + input_name + " but its type is not yet registered"); @@ -57,6 +58,7 @@ ParserFuncSignature ParseReshape = [](RModelParser_ONNX &parser, const onnx::Nod if (!parser.IsRegisteredTensorType(output_name)) { parser.RegisterTensorType(output_name, input_type); + std::cout << "Reshape/Un/Squueze register output " << output_name << " with type is " << static_cast(input_type) << " " << ConvertTypeToString(input_type) << std::endl; } return op; diff --git a/tmva/sofie_parsers/src/ParseScatterND.cxx b/tmva/sofie_parsers/src/ParseScatterND.cxx new file mode 100644 index 0000000000000..feda091182c63 --- /dev/null +++ b/tmva/sofie_parsers/src/ParseScatterND.cxx @@ -0,0 +1,58 @@ +#include "TMVA/RModelParser_ONNX.hxx" +#include "TMVA/ROperator_ScatterND.hxx" +#include "onnx_proto3.pb.h" + +namespace TMVA { +namespace Experimental { +namespace SOFIE { + +ParserFuncSignature ParseScatterND = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + + if (nodeproto.input_size() != 3) { + throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has invalid input size"); + } + // data is input 0 + if (!parser.IsRegisteredTensorType(nodeproto.input(0))){ + throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has input tensor " + nodeproto.input(0) + + " but its type is not yet registered"); + } + if (!parser.IsRegisteredTensorType(nodeproto.input(1))){ + throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has input tensor " + nodeproto.input(1) + + " but its type is not yet registered"); + } + if (!parser.IsRegisteredTensorType(nodeproto.input(2))){ + throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has input tensor " + nodeproto.input(2) + + " but its type is not yet registered"); + } + ETensorType input_type = parser.GetTensorType(nodeproto.input(0)); + if (parser.GetTensorType(nodeproto.input(2)) != input_type) { + throw std::runtime_error("TMVA::SOFIE ONNX parser ScatterND op has input tensors of different types: " + + nodeproto.input(2) + " : " + ConvertTypeToString(parser.GetTensorType(nodeproto.input(2))) + + " and " + nodeproto.input(0) + " : " + ConvertTypeToString(input_type)); + } + + std::string reduction; + for (int i = 0; i < nodeproto.attribute_size(); i++) { + std::string attribute_name = nodeproto.attribute(i).name(); + if (attribute_name == "reduction") + reduction = nodeproto.attribute(i).s(); + } + + std::unique_ptr op; + std::string output_name = nodeproto.output(0); + + op.reset(new ROperator_ScatterND(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2), + output_name, reduction)); + + // Infer the output type + if (!parser.IsRegisteredTensorType(output_name)) { + parser.RegisterTensorType(output_name, input_type); + } + + return op; +}; + + +} // namespace SOFIE +} // namespace Experimental +} // namespace TMVA diff --git a/tmva/sofie_parsers/src/ParseTranspose.cxx b/tmva/sofie_parsers/src/ParseTranspose.cxx index 198f57ba90d46..8178f37dac0b3 100644 --- a/tmva/sofie_parsers/src/ParseTranspose.cxx +++ b/tmva/sofie_parsers/src/ParseTranspose.cxx @@ -40,6 +40,20 @@ ParserFuncSignature ParseTranspose = [](RModelParser_ONNX &parser, const onnx::N op.reset(new ROperator_Transpose(nodeproto.input(0), nodeproto.output(0))); } break; + case ETensorType::BOOL: + if (!attr_perm.empty()) { + op.reset(new ROperator_Transpose(attr_perm, nodeproto.input(0), nodeproto.output(0))); + } else { + op.reset(new ROperator_Transpose(nodeproto.input(0), nodeproto.output(0))); + } + break; + case ETensorType::UINT8: + if (!attr_perm.empty()) { + op.reset(new ROperator_Transpose(attr_perm, nodeproto.input(0), nodeproto.output(0))); + } else { + op.reset(new ROperator_Transpose(nodeproto.input(0), nodeproto.output(0))); + } + break; default: throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Transpose does not yet support input type " + std::to_string(static_cast(input_type))); diff --git a/tmva/sofie_parsers/src/ParseWhere.cxx b/tmva/sofie_parsers/src/ParseWhere.cxx index c072ad14cb956..6ebcf161e5012 100644 --- a/tmva/sofie_parsers/src/ParseWhere.cxx +++ b/tmva/sofie_parsers/src/ParseWhere.cxx @@ -32,10 +32,10 @@ ParserFuncSignature ParseWhere = [](RModelParser_ONNX &parser, const onnx::NodeP switch (input_type) { case ETensorType::FLOAT: - op.reset(new ROperator_Where(nodeproto.input(1), nodeproto.input(2), nodeproto.input(0), output_name)); + op.reset(new ROperator_Where(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2), output_name)); break; case ETensorType::INT64: - op.reset(new ROperator_Where(nodeproto.input(1), nodeproto.input(2), nodeproto.input(0), output_name)); + op.reset(new ROperator_Where(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2), output_name)); break; default: throw std::runtime_error("TMVA::SOFIE - Unsupported - Where Operator does not yet support input type " + diff --git a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx index b77451da25c5b..038afc8df1b74 100644 --- a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx +++ b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx @@ -94,6 +94,7 @@ extern ParserFuncSignature ParseWhere; extern ParserFuncSignature ParseEinsum; extern ParserFuncSignature ParseRandom; extern ParserFuncSignature ParseScatterElements; +extern ParserFuncSignature ParseScatterND; extern ParserFuncSignature ParseNonZero; // Declaration of fused operators extern ParserFuseFuncSignature ParseFuseConvAdd; @@ -250,6 +251,7 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un RegisterOperator("RandomUniform", ParseRandom); RegisterOperator("RandomUniformLike", ParseRandom); RegisterOperator("ScatterElements", ParseScatterElements); + RegisterOperator("ScatterND", ParseScatterND); RegisterOperator("NonZero", ParseNonZero); } @@ -305,48 +307,60 @@ RModelParser_ONNX::ParseOperator(const size_t i, const onnx::GraphProto &graphpr if (fVerbose) std::cout << "Parsing operator " << op_type << std::endl; - // skip already fused operators - if (fFusedOperators[idx]) return nullptr; + // perform the fusion of operators + if (fFusedOperators.count(idx) == 1) { + int idx1 = fFusedOperators[idx].second; + if (fVerbose) { + std::cout << "\tFusing operators " << graphproto.node(idx1).name() + << " with " << graphproto.node(idx1).name() << std::endl; + } + if (fFusedOperators[idx].first == EFusedOp::kMatMulAdd) { + return ParseFuseMatMulAdd(*this, graphproto.node(idx1), graphproto.node(idx)); + } else if (fFusedOperators[idx].first == EFusedOp::kConvAdd) { + return ParseFuseConvAdd(*this, graphproto.node(idx1), graphproto.node(idx)); + } else if (fFusedOperators[idx].first == EFusedOp::kConvTransAdd) { + return ParseFuseConvTransposeAdd(*this, graphproto.node(idx1), graphproto.node(idx)); + } else if (fFusedOperators[idx].first == EFusedOp::kGemmRelu) { + return ParseFuseGemmRelu(*this, graphproto.node(idx1), graphproto.node(idx)); + } else if (fFusedOperators[idx].first == EFusedOp::kBatchnormRelu) { + return ParseFuseBatchnormRelu(*this, graphproto.node(idx1), graphproto.node(idx)); + } + } - // try to fuse with following operator in case it is not last one + // try to fuse with following operator in case it is not last one and having only a single child if (children.size() == 1) { int idx2 = children.front(); if (op_type == "MatMul") { // Fuse MatMul and Add if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Add") { - fFusedOperators[idx2] = true; - return ParseFuseMatMulAdd(*this, graphproto.node(idx), graphproto.node(idx2)); - } - else { - return ParseMatMul(*this, graphproto.node(idx)); + fFusedOperators[idx2] = {EFusedOp::kMatMulAdd, idx}; + return nullptr; } } else if (nodeproto.op_type() == "Conv" || nodeproto.op_type() == "ConvTranspose") { // Fuse Conv or ConvTranspose without bias and Add if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Add") { if (nodeproto.op_type() == "Conv") { - fFusedOperators[idx2] = true; - return ParseFuseConvAdd(*this, graphproto.node(idx), graphproto.node(idx2)); + fFusedOperators[idx2] = { EFusedOp::kConvAdd, idx}; + return nullptr; } else { - fFusedOperators[idx2] = true; - return ParseFuseConvTransposeAdd(*this, graphproto.node(idx), graphproto.node(idx2)); + fFusedOperators[idx2] = { EFusedOp::kConvTransAdd, idx}; + return nullptr; } } } else if (nodeproto.op_type() == "Gemm") { // Fuse Gemm with activation operators if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Relu") { - fFusedOperators[idx2] = true; - return ParseFuseGemmRelu(*this, graphproto.node(idx), graphproto.node(idx2)); + fFusedOperators[idx2] = {EFusedOp::kGemmRelu, idx}; + return nullptr; } } else if (nodeproto.op_type() == "BatchNormalization") { if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Relu") { - fFusedOperators[idx2] = true; - return ParseFuseBatchnormRelu(*this, graphproto.node(idx), graphproto.node(idx2)); + fFusedOperators[idx2] = {EFusedOp::kBatchnormRelu, idx}; + return nullptr; } } } - - auto it = fOperatorsMapImpl->fOperatorsMap.find(op_type); if (it == fOperatorsMapImpl->fOperatorsMap.end()) { std::cout << "operator " << op_type << " is not supported" << std::endl; @@ -769,7 +783,6 @@ void RModelParser_ONNX::ParseONNXGraph(RModel & rmodel, const onnx::GraphProto & // we have to record order of node execution separately to // account for fused operators size_t node_order_exec = 0; - fFusedOperators = std::vector(graph.node_size(), false); for (int i = 0; i < graph.node_size(); i++) { std::string op_type = graph.node(nodesOrder[i]).op_type();