From 6bf9819330e18bfd9faa9d5fc5de1e6102c0041d Mon Sep 17 00:00:00 2001 From: moneta Date: Mon, 16 Mar 2026 17:05:48 +0100 Subject: [PATCH 1/4] [tmva][sofie] Add new ScatterND operator Add an implementation of ScatterND operator which is needed to parse the MLPF model from CMS Include also 3 tests to probe the different type of scattering wich can be performed --- tmva/sofie/CMakeLists.txt | 1 + tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx | 192 ++++++++++++++++++ tmva/sofie/test/TestCustomModelsFromONNX.cxx | 54 +++++ tmva/sofie/test/input_models/ScatterND_1.onnx | 21 ++ tmva/sofie/test/input_models/ScatterND_2.onnx | 22 ++ tmva/sofie/test/input_models/ScatterND_3.onnx | 22 ++ tmva/sofie_parsers/CMakeLists.txt | 1 + tmva/sofie_parsers/src/ParseScatterND.cxx | 58 ++++++ tmva/sofie_parsers/src/RModelParser_ONNX.cxx | 2 + 9 files changed, 373 insertions(+) create mode 100644 tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx create mode 100644 tmva/sofie/test/input_models/ScatterND_1.onnx create mode 100644 tmva/sofie/test/input_models/ScatterND_2.onnx create mode 100644 tmva/sofie/test/input_models/ScatterND_3.onnx create mode 100644 tmva/sofie_parsers/src/ParseScatterND.cxx diff --git a/tmva/sofie/CMakeLists.txt b/tmva/sofie/CMakeLists.txt index dc44ac0a59af2..6fdc7a46183ee 100644 --- a/tmva/sofie/CMakeLists.txt +++ b/tmva/sofie/CMakeLists.txt @@ -65,6 +65,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofie TMVA/ROperator_Einsum.hxx TMVA/ROperator_Random.hxx TMVA/ROperator_ScatterElements.hxx + TMVA/ROperator_ScatterND.hxx TMVA/ROperator_Gather.hxx TMVA/ROperator_GatherND.hxx TMVA/ROperator_NonZero.hxx diff --git a/tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx b/tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx new file mode 100644 index 0000000000000..570b3f7a294aa --- /dev/null +++ b/tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx @@ -0,0 +1,192 @@ +#ifndef TMVA_SOFIE_ROPERATOR_ScatterND +#define TMVA_SOFIE_ROPERATOR_ScatterND + +#include "TMVA/SOFIE_common.hxx" +#include "TMVA/ROperator.hxx" +#include "TMVA/RModel.hxx" + +#include +#include +#include + +namespace TMVA{ +namespace Experimental{ +namespace SOFIE{ + +class ROperator_ScatterND final : public ROperator +{ +private: + + + std::string fNX; + std::string fNI; + std::string fNU; + std::string fNY; + std::string fReduction; + + std::vector fShapeX; + std::vector fShapeI; + std::vector fShapeY; + + + std::vector fIndices; // indices vector in case they are known at initialization + + std::string fType; + + +public: + ROperator_ScatterND(){} + ROperator_ScatterND(const std::string & nameX, const std::string & nameI, const std::string & nameU, const std::string & nameY, + std::string reduction): + fNX(UTILITY::Clean_name(nameX)), fNI(UTILITY::Clean_name(nameI)), fNU(UTILITY::Clean_name(nameU)), + fNY(UTILITY::Clean_name(nameY)), fReduction(reduction) + { + fInputTensorNames = { fNX, fNI, fNU }; + fOutputTensorNames = { fNY }; + } + + void Initialize(RModel& model) override { + + // input must be a graph input, or already initialized intermediate tensor + if (!model.CheckIfTensorAlreadyExist(fNX)){ + throw std::runtime_error(std::string("TMVA SOFIE ScatterND Op Input Tensor ") + fNX + "is not found in model"); + } + if (!model.CheckIfTensorAlreadyExist(fNI)) { + throw std::runtime_error(std::string("TMVA SOFIE ScatterND Op Input Tensor ") + fNI + "is not found in model"); + } + if (!model.CheckIfTensorAlreadyExist(fNU)) { + throw std::runtime_error(std::string("TMVA SOFIE ScatterND Op Input Tensor ") + fNU + "is not found in model"); + } + //tbd check for constant tensors + + fShapeX = model.GetDimTensorShape(fNX); + fShapeI = model.GetDimTensorShape(fNI); + auto shapeU = model.GetDimTensorShape(fNU); + + // Validate inputs if fShapeI last is not dynamic + + //if (!model.IsDynamicTensor(fNI)) { + const size_t r = fShapeX.size(); // rank of data + const size_t q = fShapeI.size(); // rank of indices + if (!(fShapeI.back().isParam) ) { + const size_t k = fShapeI.back().dim; // index depth + + if (k > r) + throw std::invalid_argument( + "ScatterND: last dim of indices (" + std::to_string(k) + + ") must be <= rank of data (" + std::to_string(r) + ")"); + + // Expected updates rank = q - 1 + r - k + int64_t expected_updates_rank = q - 1 + r - k; + if ((int64_t) shapeU.size() != expected_updates_rank) + throw std::invalid_argument("ScatterND: updates rank mismatch"); + } else { + // Assumption is that last dimension of index shape is known (is not dynamic) + throw std::runtime_error("TMVA SOFIE ScatterND : Index_shape(-1) is not known. This case is not supported"); + } + + // output shape is equal to input shape + fShapeY = fShapeX; + + model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY); + if (model.Verbose()) { + std::cout << "ScatterElements: input: " << ConvertDimShapeToString(fShapeX) + << " indices " << ConvertDimShapeToString(fShapeI) + << " update " << ConvertDimShapeToString(shapeU); + std::cout << "\t----> " << ConvertDimShapeToString(fShapeY) << std::endl; + } + } + + std::string Generate(std::string opName) override { + if (fIsOutputConstant) { + // no code to generate here for constant output. Tensor output is defined in Session constructor + return "//---------------------------------------\n"; + } + opName = "op_" + opName; + std::stringstream out; + out << "//--------- ScatterND " << opName << " --> " << ConvertDimShapeToString(fShapeY) << "\n"; + + size_t r = fShapeX.size(); + + // Strides + auto stridesX = UTILITY::ComputeStrideFromShape(fShapeX); + auto stridesY = UTILITY::ComputeStrideFromShape(fShapeY); + auto stridesI = UTILITY::ComputeStrideFromShape(fShapeI); + + // case input_index_shape == rank of input + size_t k = fShapeI.back().dim; + + // Total number of index tuples = product of indices dims except last + std::vector shapeIndFirst(fShapeI.begin(), fShapeI.begin()+ fShapeI.size()-1); + auto num_index_tuples = ConvertDimShapeToLength(shapeIndFirst); + + //slice size (is product of input from k to r) + std::vector shapeSlice(fShapeX.begin()+k, fShapeX.end()); + auto slice_size = ConvertDimShapeToLength(shapeSlice); + + auto data_length = ConvertDimShapeToLength(fShapeX); + + //step1: input->output + out << SP << "// Step 1: copy input data to output\n"; + out << SP << "std::copy(tensor_" << fNX << ", tensor_" << fNX << " + " << data_length << ", tensor_" << fNY << ");\n"; + + // Step 2: Emit strides as a static constexpr array + out << SP << "// Step 2: data strides (row-major)\n"; + out << SP << "static constexpr int64_t " << opName << "_data_strides[" << r << "] = {"; + for (size_t i = 0; i < r; ++i) + out << stridesX[i] << (i + 1 < r ? ", " : ""); + out << "};\n\n"; + + // Step 3: Scatter loop + out << SP << "// Step 3: scatter updates into output\n"; + out << SP << "for (int64_t idx = 0; idx < " << num_index_tuples << "; idx++) {\n"; + + // Resolve flat data offset from k-dimensional index tuple + out << SP << SP << "int64_t data_offset = 0;\n"; + for (size_t dim = 0; dim < k; ++dim) { + out << SP << SP << "{\n"; + out << SP << SP << SP << "int64_t coord = tensor_" << fNI + << "[idx * " << k << " + " << dim << "];\n"; + // Support negative indices + out << SP << SP << SP << "if (coord < 0) coord += " << fShapeX[dim] << ";\n"; + out << SP << SP << SP << "data_offset += coord * " + << opName << "_data_strides[" << dim << "];\n"; + out << SP << SP << "}\n"; + } + + // Apply updates with reduction + out << SP << SP << "for (int64_t s = 0; s < " << slice_size << "; s++) {\n"; + out << SP << SP << SP << "auto upd = tensor_" << fNU + << "[idx * " << slice_size << " + s];\n"; + + if (fReduction.empty() || fReduction == "none") { + out << SP << SP << SP << "tensor_" << fNY << "[data_offset + s] = upd;\n"; + } else if (fReduction == "add") { + out << SP << SP << SP << "tensor_" << fNY<< "[data_offset + s] += upd;\n"; + } else if (fReduction == "mul") { + out << SP << SP << SP << "tensor_" << fNY << "[data_offset + s] *= upd;\n"; + } else if (fReduction == "min") { + out << SP << SP << SP << "tensor_" << fNY<< "[data_offset + s] = " + << "std::min(tensor_" << fNY << "[data_offset + s], upd);\n"; + } else if (fReduction == "max") { + out << SP << SP << SP << "tensor_" << fNY << "[data_offset + s] = " + << "std::max(tensor_" << fNY << "[data_offset + s], upd);\n"; + } else { + throw std::runtime_error( + "TMVA SOFIE ScatterND: unsupported reduction '" + fReduction + "'"); + } + + out << SP << SP << "}\n"; // end slice loop + out << SP << "}\n"; // end index tuple loop + + return out.str(); + } + +}; + +}//SOFIE +}//Experimental +}//TMVA + + +#endif //TMVA_SOFIE_ROPERATOR_RELU diff --git a/tmva/sofie/test/TestCustomModelsFromONNX.cxx b/tmva/sofie/test/TestCustomModelsFromONNX.cxx index 94993f601a3c4..25bf2350c5a61 100644 --- a/tmva/sofie/test/TestCustomModelsFromONNX.cxx +++ b/tmva/sofie/test/TestCustomModelsFromONNX.cxx @@ -3006,3 +3006,57 @@ TEST(ONNX, NotIsNaN) } } +TEST(ONNX, ScatterND_1) +{ + // test 1-D scatter (k=1, scalar slice) + std::vector input = {1.,2.,3.,4.,5.}; // shape {5} + std::vector indices = { 0, 2, 4}; // shape {3,1} + std::vector updates = { 10.,30.,50.}; // shape {3} + std::vector correct_output = {10., 2., 30., 4., 50.}; + + ASSERT_INCLUDE_AND_RUN(std::vector, "ScatterND_1", input, indices, updates); + + // Checking output size + EXPECT_EQ(output.size(), correct_output.size()); + // Checking output + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - correct_output[i]), DEFAULT_TOLERANCE); + } +} + +TEST(ONNX, ScatterND_2) +{ + // test 2-d Scatter - scatter rows - reduction = 'add + std::vector input = {1.,1.,2.,2.,3.,3.}; // shape {3,2} + std::vector indices = { 0, 1}; // shape {2,1} + std::vector updates = { 10.,10.,20.,20.}; // shape { 2,2} + std::vector correct_output = {11., 11., 22., 22., 3., 3.}; + + ASSERT_INCLUDE_AND_RUN(std::vector, "ScatterND_2", input, indices, updates); + + // Checking output size + EXPECT_EQ(output.size(), correct_output.size()); + // Checking output + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - correct_output[i]), DEFAULT_TOLERANCE); + } +} + +TEST(ONNX, ScatterND_3) +{ + // test element wise scatter (k==rank input) reduction = 'mul' + std::vector input = {1.,2.,3.,4.}; // shape {2,2} + std::vector indices = { 0,0, 1,1}; // shape {2,2} + std::vector updates = { 11.,22.}; // shape { 2} + std::vector correct_output = {11., 2., 3., 88.}; + + ASSERT_INCLUDE_AND_RUN(std::vector, "ScatterND_3", input, indices, updates); + + // Checking output size + EXPECT_EQ(output.size(), correct_output.size()); + // Checking output + for (size_t i = 0; i < output.size(); ++i) { + EXPECT_LE(std::abs(output[i] - correct_output[i]), DEFAULT_TOLERANCE); + } +} + diff --git a/tmva/sofie/test/input_models/ScatterND_1.onnx b/tmva/sofie/test/input_models/ScatterND_1.onnx new file mode 100644 index 0000000000000..6e6bd2b58c0f7 --- /dev/null +++ b/tmva/sofie/test/input_models/ScatterND_1.onnx @@ -0,0 +1,21 @@ +  onnx-example:” ++ +data +indices +updatesoutput" ScatterND TestGraphZ +data + + +Z +indices +  + +Z +updates + + +b +output + + +B \ No newline at end of file diff --git a/tmva/sofie/test/input_models/ScatterND_2.onnx b/tmva/sofie/test/input_models/ScatterND_2.onnx new file mode 100644 index 0000000000000..9211d555dffda --- /dev/null +++ b/tmva/sofie/test/input_models/ScatterND_2.onnx @@ -0,0 +1,22 @@ +  onnx-example:µ +@ +data +indices +updatesoutput" ScatterND* + reduction"add  TestGraphZ +data +  + +Z +indices +  + +Z +updates +  + +b +output +  + +B \ No newline at end of file diff --git a/tmva/sofie/test/input_models/ScatterND_3.onnx b/tmva/sofie/test/input_models/ScatterND_3.onnx new file mode 100644 index 0000000000000..20d83a7dd1715 --- /dev/null +++ b/tmva/sofie/test/input_models/ScatterND_3.onnx @@ -0,0 +1,22 @@ +  onnx-example:± +@ +data +indices +updatesoutput" ScatterND* + reduction"mul  TestGraphZ +data +  + +Z +indices +  + +Z +updates + + +b +output +  + +B \ No newline at end of file diff --git a/tmva/sofie_parsers/CMakeLists.txt b/tmva/sofie_parsers/CMakeLists.txt index 80069c44c6929..5dc49688dfa6f 100644 --- a/tmva/sofie_parsers/CMakeLists.txt +++ b/tmva/sofie_parsers/CMakeLists.txt @@ -76,6 +76,7 @@ ROOT_STANDARD_LIBRARY_PACKAGE(ROOTTMVASofieParser src/ParseEinsum.cxx src/ParseRandom.cxx src/ParseScatterElements.cxx + src/ParseScatterND.cxx src/ParseNonZero.cxx src/ParseNot.cxx ${PROTO_SRCS} diff --git a/tmva/sofie_parsers/src/ParseScatterND.cxx b/tmva/sofie_parsers/src/ParseScatterND.cxx new file mode 100644 index 0000000000000..feda091182c63 --- /dev/null +++ b/tmva/sofie_parsers/src/ParseScatterND.cxx @@ -0,0 +1,58 @@ +#include "TMVA/RModelParser_ONNX.hxx" +#include "TMVA/ROperator_ScatterND.hxx" +#include "onnx_proto3.pb.h" + +namespace TMVA { +namespace Experimental { +namespace SOFIE { + +ParserFuncSignature ParseScatterND = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { + + if (nodeproto.input_size() != 3) { + throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has invalid input size"); + } + // data is input 0 + if (!parser.IsRegisteredTensorType(nodeproto.input(0))){ + throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has input tensor " + nodeproto.input(0) + + " but its type is not yet registered"); + } + if (!parser.IsRegisteredTensorType(nodeproto.input(1))){ + throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has input tensor " + nodeproto.input(1) + + " but its type is not yet registered"); + } + if (!parser.IsRegisteredTensorType(nodeproto.input(2))){ + throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has input tensor " + nodeproto.input(2) + + " but its type is not yet registered"); + } + ETensorType input_type = parser.GetTensorType(nodeproto.input(0)); + if (parser.GetTensorType(nodeproto.input(2)) != input_type) { + throw std::runtime_error("TMVA::SOFIE ONNX parser ScatterND op has input tensors of different types: " + + nodeproto.input(2) + " : " + ConvertTypeToString(parser.GetTensorType(nodeproto.input(2))) + + " and " + nodeproto.input(0) + " : " + ConvertTypeToString(input_type)); + } + + std::string reduction; + for (int i = 0; i < nodeproto.attribute_size(); i++) { + std::string attribute_name = nodeproto.attribute(i).name(); + if (attribute_name == "reduction") + reduction = nodeproto.attribute(i).s(); + } + + std::unique_ptr op; + std::string output_name = nodeproto.output(0); + + op.reset(new ROperator_ScatterND(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2), + output_name, reduction)); + + // Infer the output type + if (!parser.IsRegisteredTensorType(output_name)) { + parser.RegisterTensorType(output_name, input_type); + } + + return op; +}; + + +} // namespace SOFIE +} // namespace Experimental +} // namespace TMVA diff --git a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx index b77451da25c5b..6090b8a0799c6 100644 --- a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx +++ b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx @@ -94,6 +94,7 @@ extern ParserFuncSignature ParseWhere; extern ParserFuncSignature ParseEinsum; extern ParserFuncSignature ParseRandom; extern ParserFuncSignature ParseScatterElements; +extern ParserFuncSignature ParseScatterND; extern ParserFuncSignature ParseNonZero; // Declaration of fused operators extern ParserFuseFuncSignature ParseFuseConvAdd; @@ -250,6 +251,7 @@ RModelParser_ONNX::RModelParser_ONNX() noexcept : fOperatorsMapImpl(std::make_un RegisterOperator("RandomUniform", ParseRandom); RegisterOperator("RandomUniformLike", ParseRandom); RegisterOperator("ScatterElements", ParseScatterElements); + RegisterOperator("ScatterND", ParseScatterND); RegisterOperator("NonZero", ParseNonZero); } From aca08005518b5d897cbcaefb590d3397592bb125 Mon Sep 17 00:00:00 2001 From: moneta Date: Tue, 17 Mar 2026 22:39:31 +0100 Subject: [PATCH 2/4] [tmva][sofie] Apply several fixes to parse MLPF model - Fix in operator Reduce to return a scalar and not a tensor of shape [1] - Fix handling of output boolean type in Cast. Do not convert type in a string, because a boolean is converted to a uint8_t which can be a native uint8_t or a bool. Avoid then calling function ConvertStrigToType if possible - Fix fusion of operators. Perform fusion not at first op encountered but at the last onem in order to parse before all operators which can provide an input to last fused one. This was the case in MLPF where there was a MatMul + Constant + Add, where COnstant is an input to Add. - remove check in Generate on empty shapes because scalars tensors have empty shapes --- tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx | 3 - tmva/sofie/inc/TMVA/ROperator_Cast.hxx | 19 +-- tmva/sofie/inc/TMVA/ROperator_Gemm.hxx | 2 +- tmva/sofie/inc/TMVA/ROperator_Reduce.hxx | 7 +- tmva/sofie/inc/TMVA/ROperator_Where.hxx | 120 +++++++++--------- tmva/sofie/src/SOFIE_common.cxx | 7 +- .../inc/TMVA/RModelParser_ONNX.hxx | 7 +- tmva/sofie_parsers/src/ParseBasicNary.cxx | 5 +- tmva/sofie_parsers/src/ParseCast.cxx | 14 +- tmva/sofie_parsers/src/ParseComparision.cxx | 6 +- tmva/sofie_parsers/src/ParseConstant.cxx | 1 + tmva/sofie_parsers/src/ParseExpand.cxx | 4 + tmva/sofie_parsers/src/ParseNot.cxx | 3 +- tmva/sofie_parsers/src/ParseReduce.cxx | 17 +-- tmva/sofie_parsers/src/ParseReshape.cxx | 2 + tmva/sofie_parsers/src/ParseTranspose.cxx | 14 ++ tmva/sofie_parsers/src/RModelParser_ONNX.cxx | 49 ++++--- 17 files changed, 161 insertions(+), 119 deletions(-) diff --git a/tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx b/tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx index e9b2078bc73a1..50bae3e04b6ec 100644 --- a/tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx @@ -279,9 +279,6 @@ public: opName = "op_" + opName; - if (fDimShapeY.empty()) { - throw std::runtime_error("TMVA SOFIE Binary Op called to Generate without being initialized first"); - } std::stringstream out; out << SP << "\n//------ " << opName << " " << BinaryOperatorTrait::Name() << " --> " << ConvertDimShapeToString(fDimShapeY) << "\n"; diff --git a/tmva/sofie/inc/TMVA/ROperator_Cast.hxx b/tmva/sofie/inc/TMVA/ROperator_Cast.hxx index 8267bb8a7e4f4..cace65040c772 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Cast.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Cast.hxx @@ -20,13 +20,14 @@ private: std::string fNX; std::string fNY; std::vector fShape; - std::string fAttrType = "float"; + ETensorType fType; public: ROperator_Cast(){} - ROperator_Cast(std::string attr_type,std::string nameX, std::string nameY): - fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)), - fAttrType(attr_type) { + ROperator_Cast(ETensorType type,std::string nameX, std::string nameY): + fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)), + fType(type) + { fInputTensorNames = { fNX }; fOutputTensorNames = { fNY }; } @@ -51,21 +52,21 @@ public: if (model.IsInitializedTensor(fNX)) { fIsOutputConstant = true; auto inputData = model.GetInitializedTensorData(fNX); - if (ConvertStringToType(fAttrType) == ETensorType::INT64) { + if (fType == ETensorType::INT64) { model.AddConstantTensor(fNY, ConvertShapeToInt(fShape), static_cast(inputData.get())); model.SetNotWritableInitializedTensor(fNX); } else fIsOutputConstant = false; - } else if (model.IsShapeTensor(fNX) && ConvertStringToType(fAttrType) == ETensorType::INT64) { + } else if (model.IsShapeTensor(fNX) && fType == ETensorType::INT64) { auto shapeData = model.GetShapeTensorValues(fNX); model.AddShapeTensor(fNY, shapeData, fShape.size() == 0); fIsOutputConstant = true; } if (!fIsOutputConstant) - model.AddIntermediateTensor(fNY, ConvertStringToType(fAttrType), fShape); + model.AddIntermediateTensor(fNY, fType, fShape); if (model.Verbose()) { - std::cout << "Cast : " << ConvertTypeToString(inputType) << " " << fNX << " -> " << fAttrType << " for " << fNY + std::cout << "Cast : " << ConvertTypeToString(inputType) << " " << fNX << " -> " << ConvertTypeToString(fType) << " for " << fNY << " shape " << ConvertDimShapeToString(fShape); if (fIsOutputConstant) std::cout << " (constant) "; std::cout << std::endl; @@ -86,7 +87,7 @@ public: out << SP << "for (int id = 0; id < " << length << " ; id++){\n"; - out << SP << SP << "tensor_" << fNY << "[id] = static_cast<"<< fAttrType << ">(tensor_" << fNX << "[id]);\n"; + out << SP << SP << "tensor_" << fNY << "[id] = static_cast<"<< ConvertTypeToString(fType) << ">(tensor_" << fNX << "[id]);\n"; out << SP << "}\n"; return out.str(); diff --git a/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx b/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx index ecdd0b435fe37..c0cbe18f11475 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx @@ -165,7 +165,7 @@ namespace SOFIE{ } if (fNC != ""){ if (model.CheckIfTensorAlreadyExist(fNC) == false){ //input must be a graph input, or already initialized intermediate tensor - throw std::runtime_error("TMVA SOFIE Gemm Op Input Tensor" + fNC + " is not found in model"); + throw std::runtime_error("TMVA SOFIE Gemm Op Input Tensor " + fNC + " is not found in model"); } } if (model.IsDynamicTensor(fNA) || model.IsDimInputTensor(fNA) ) { diff --git a/tmva/sofie/inc/TMVA/ROperator_Reduce.hxx b/tmva/sofie/inc/TMVA/ROperator_Reduce.hxx index 64778f38753cd..2012a9a40a0ed 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Reduce.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Reduce.hxx @@ -18,7 +18,7 @@ namespace SOFIE{ enum EReduceOpMode { ReduceMean, ReduceSum, ReduceSumSquare, ReduceProd, InvalidReduceOp }; -template +template class ROperator_Reduce final : public ROperator { private: @@ -76,7 +76,7 @@ public: std::sort(ax.begin(), ax.end()); for (size_t j = 0; j < ax.size(); j++) { // erase reduced dimensions, but keep last one - if (outputShape.size() > 1) { + if (outputShape.size() > 0) { outputShape.erase(outputShape.begin() + ax[j]); for (size_t k = j+1; k < ax.size(); k++) ax[k] -= 1; // decrease by one since we have removed a value @@ -120,9 +120,6 @@ public: std::string Generate(std::string opName) override { opName = "op_" + opName; - if (fShapeX.empty() || fShapeY.empty()) { - throw std::runtime_error("TMVA SOFIE Reduce Op called to Generate without being initialized first"); - } auto inputLength = TMVA::Experimental::SOFIE::ConvertDimShapeToLength(fShapeX); auto outputLength = TMVA::Experimental::SOFIE::ConvertDimShapeToLength(fShapeY); diff --git a/tmva/sofie/inc/TMVA/ROperator_Where.hxx b/tmva/sofie/inc/TMVA/ROperator_Where.hxx index 3064080507e28..59cb311f1a203 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Where.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Where.hxx @@ -20,26 +20,26 @@ private: bool fIsInputBoolTensor = false; - std::string fNA; - std::string fNB; + std::string fNX; + std::string fNY; std::string fNC; - std::string fNBroadcastedA; - std::string fNBroadcastedB; + std::string fNBroadcastedX; + std::string fNBroadcastedY; std::string fNBroadcastedC; std::string fNY; - std::vector fShapeA; - std::vector fShapeB; + std::vector fShapeX; + std::vector fShapeY; std::vector fShapeC; std::vector fShapeY; public: ROperator_Where(){} - ROperator_Where(const std::string & nameA, const std::string & nameB, const std::string & nameC, const std::string & nameY): - fNA(UTILITY::Clean_name(nameA)), fNB(UTILITY::Clean_name(nameB)), fNC(UTILITY::Clean_name(nameC)), fNY(UTILITY::Clean_name(nameY)){ - fInputTensorNames = { fNA, fNB, fNC }; + ROperator_Where(const std::string & nameX, const std::string & nameY, const std::string & nameC, const std::string & nameY): + fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)), fNC(UTILITY::Clean_name(nameC)), fNY(UTILITY::Clean_name(nameY)){ + fInputTensorNames = { fNX, fNY, fNC }; fOutputTensorNames = { fNY }; } @@ -57,11 +57,11 @@ public: void Initialize(RModel& model) override { // input must be a graph input, or already initialized intermediate tensor - if (!model.CheckIfTensorAlreadyExist(fNA)){ - throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNA + "is not found in model"); + if (!model.CheckIfTensorAlreadyExist(fNX)){ + throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNX + "is not found in model"); } - if (!model.CheckIfTensorAlreadyExist(fNB)) { - throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNB + "is not found in model"); + if (!model.CheckIfTensorAlreadyExist(fNY)) { + throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNY + "is not found in model"); } if (!model.CheckIfTensorAlreadyExist(fNC)) { throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNC + "is not found in model"); @@ -70,24 +70,24 @@ public: if (model.IsReadyInputTensor(fNC)) fIsInputBoolTensor = true; // check broadcast for A, B and C - fShapeA = model.GetTensorShape(fNA); - fShapeB = model.GetTensorShape(fNB); + fShapeX = model.GetTensorShape(fNX); + fShapeY = model.GetTensorShape(fNY); fShapeC = model.GetTensorShape(fNC); - bool broadcast = !UTILITY::AreSameShape(fShapeA, fShapeB) || !UTILITY::AreSameShape(fShapeA, fShapeC); + bool broadcast = !UTILITY::AreSameShape(fShapeX, fShapeY) || !UTILITY::AreSameShape(fShapeX, fShapeC); if (broadcast) { // find shape to broadcast between A,B,C looking for max length - size_t lengthA = ConvertShapeToLength(fShapeA); - size_t lengthB = ConvertShapeToLength(fShapeB); + size_t lengthA = ConvertShapeToLength(fShapeX); + size_t lengthB = ConvertShapeToLength(fShapeY); size_t lengthC = ConvertShapeToLength(fShapeC); bool broadcastA = false, broadcastB = false, broadcastC = false; if (lengthA >= lengthB && lengthA >= lengthC) { - fShapeY = fShapeA; + fShapeY = fShapeX; //broadcast B and C if different than A broadcastB = (lengthB != lengthA); broadcastC = (lengthC != lengthA); } else if (lengthB >= lengthA && lengthB >= lengthC) { - fShapeY = fShapeB; + fShapeY = fShapeY; //broadcast A and C if different than B broadcastA = (lengthA != lengthB); broadcastC = (lengthC != lengthB); @@ -101,34 +101,34 @@ public: // Broadcast A to Y if (broadcastA) { - fNBroadcastedA = "BC_" + fNA + "_to_" + fNY; - if (model.IsInitializedTensor(fNA)) { - auto data = model.GetInitializedTensorData(fNA); + fNBroadcastedX = "BC_" + fNX + "_to_" + fNY; + if (model.IsInitializedTensor(fNX)) { + auto data = model.GetInitializedTensorData(fNX); std::shared_ptr broadcastedData( - UTILITY::UnidirectionalBroadcast(static_cast(data.get()), fShapeA, fShapeY), + UTILITY::UnidirectionalBroadcast(static_cast(data.get()), fShapeX, fShapeY), std::default_delete()); // Update the data and the shape of A - model.AddConstantTensor(fNBroadcastedA, model.GetTensorType(fNA), fShapeY, broadcastedData); - fShapeA = fShapeY; + model.AddConstantTensor(fNBroadcastedX, model.GetTensorType(fNX), fShapeY, broadcastedData); + fShapeX = fShapeY; } else { // Add an intermediate tensor for broadcasting A - model.AddIntermediateTensor(fNBroadcastedA, model.GetTensorType(fNA), fShapeY); + model.AddIntermediateTensor(fNBroadcastedX, model.GetTensorType(fNX), fShapeY); } } // Broadcast B to Y if (broadcastB) { - fNBroadcastedB = "BC_" + fNB + "_to_" + fNY; - if (model.IsInitializedTensor(fNB)) { - auto data = model.GetInitializedTensorData(fNB); + fNBroadcastedY = "BC_" + fNY + "_to_" + fNY; + if (model.IsInitializedTensor(fNY)) { + auto data = model.GetInitializedTensorData(fNY); std::shared_ptr broadcastedData( - UTILITY::UnidirectionalBroadcast(static_cast(data.get()), fShapeB, fShapeY), + UTILITY::UnidirectionalBroadcast(static_cast(data.get()), fShapeY, fShapeY), std::default_delete()); // do not update tensor B but add broadcasted one (since it can be input to some other operators) - model.AddConstantTensor(fNBroadcastedB, model.GetTensorType(fNB), fShapeY, broadcastedData); - fShapeB = fShapeY; + model.AddConstantTensor(fNBroadcastedY, model.GetTensorType(fNY), fShapeY, broadcastedData); + fShapeY = fShapeY; } else { // Add an intermediate tensor for broadcasting B - model.AddIntermediateTensor(fNBroadcastedB, model.GetTensorType(fNB), fShapeY); + model.AddIntermediateTensor(fNBroadcastedY, model.GetTensorType(fNY), fShapeY); } } // Broadcast C to Y @@ -148,7 +148,7 @@ public: } } } else { - fShapeY = fShapeA; + fShapeY = fShapeX; } // check case of constant output (if all inputs are defined) if (model.IsInitializedTensor(fNC)) { @@ -160,19 +160,19 @@ public: T * dataB = nullptr; std::vector shapeDataA; std::vector shapeDataB; - if (model.IsInitializedTensor(fNA)) { - std::string nameA = fNBroadcastedA.empty()? fNA : fNBroadcastedA; - dataA = static_cast(model.GetInitializedTensorData(nameA).get()); + if (model.IsInitializedTensor(fNX)) { + std::string nameX = fNBroadcastedX.empty()? fNX : fNBroadcastedX; + dataA = static_cast(model.GetInitializedTensorData(nameX).get()); // flag tensors to not be written in a file - model.SetNotWritableInitializedTensor(nameA); - } else if (model.IsShapeTensor(fNA)) - shapeDataA = model.GetShapeTensorValues(fNA); - if (model.IsInitializedTensor(fNB)) { - std::string nameB = fNBroadcastedB.empty()? fNB : fNBroadcastedB; - dataB = static_cast(model.GetInitializedTensorData(nameB).get()); - model.SetNotWritableInitializedTensor(nameB); - } else if (model.IsShapeTensor(fNB)) - shapeDataB = model.GetShapeTensorValues(fNB); + model.SetNotWritableInitializedTensor(nameX); + } else if (model.IsShapeTensor(fNX)) + shapeDataA = model.GetShapeTensorValues(fNX); + if (model.IsInitializedTensor(fNY)) { + std::string nameY = fNBroadcastedY.empty()? fNY : fNBroadcastedY; + dataB = static_cast(model.GetInitializedTensorData(nameY).get()); + model.SetNotWritableInitializedTensor(nameY); + } else if (model.IsShapeTensor(fNY)) + shapeDataB = model.GetShapeTensorValues(fNY); std::vector dataY; std::vector shapeDataY; @@ -226,10 +226,10 @@ public: if (fIsOutputConstant) fOutputTensorNames.pop_back(); } if (!fIsOutputConstant) { - model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fShapeY); + model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY); if (model.Verbose()) std::cout << "Where op " << " condition : " << fNC << " " << ConvertShapeToString(fShapeC) << - " X " << fNA << " " << ConvertShapeToString(fShapeA) << " Y " << fNB << " " << ConvertShapeToString(fShapeB) + " X " << fNX << " " << ConvertShapeToString(fShapeX) << " Y " << fNY << " " << ConvertShapeToString(fShapeY) << " ---> " << fNY << " " << ConvertShapeToString(fShapeY) << std::endl; } } @@ -253,18 +253,18 @@ public: size_t length = ConvertShapeToLength(fShapeY); std::string typeName = TensorType::Name(); // Broadcast A if it's uninitialized - if (fShapeA != fShapeY) { - out << SP << "// Broadcasting uninitialized tensor " << fNA << "\n"; + if (fShapeX != fShapeY) { + out << SP << "// Broadcasting uninitialized tensor " << fNX << "\n"; //out << SP << "{\n"; - out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast(tensor_" << fNA << ", " << ConvertShapeToString(fShapeA) << ", " << ConvertShapeToString(fShapeY) - << ", tensor_" << fNBroadcastedA << ");\n"; + out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast(tensor_" << fNX << ", " << ConvertShapeToString(fShapeX) << ", " << ConvertShapeToString(fShapeY) + << ", tensor_" << fNBroadcastedX << ");\n"; } // Broadcast B if it's uninitialized - if (fShapeB != fShapeY) { - out << SP << "// Broadcasting uninitialized tensor " << fNB << "\n"; + if (fShapeY != fShapeY) { + out << SP << "// Broadcasting uninitialized tensor " << fNY << "\n"; //out << SP << "{\n"; - out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast(tensor_" << fNB << ", " << ConvertShapeToString(fShapeB) << ", " << ConvertShapeToString(fShapeY) - << ", tensor_" << fNBroadcastedB << ");\n"; + out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast(tensor_" << fNY << ", " << ConvertShapeToString(fShapeY) << ", " << ConvertShapeToString(fShapeY) + << ", tensor_" << fNBroadcastedY << ");\n"; } // Broadcast C if it's uninitialized if (fShapeC != fShapeY) { @@ -278,13 +278,13 @@ public: out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast(tmp_tensor_" << fNC << ".data(), " << ConvertShapeToString(fShapeC) << ", " << ConvertShapeToString(fShapeY) << ", tensor_" << fNBroadcastedC << ");\n"; } - std::string nameA = fNBroadcastedA.empty()? fNA : fNBroadcastedA; - std::string nameB = fNBroadcastedB.empty()? fNB : fNBroadcastedB; + std::string nameX = fNBroadcastedX.empty()? fNX : fNBroadcastedX; + std::string nameY = fNBroadcastedY.empty()? fNY : fNBroadcastedY; std::string nameC = fNBroadcastedC.empty()? fNC : fNBroadcastedC; out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n"; // get output tensor applying condition out << SP << SP << "tensor_" << fNY << "[id] = " << "tensor_" << nameC << "[id] ? tensor_" - << nameA << "[id] : tensor_" + nameB + "[id];\n"; + << nameX << "[id] : tensor_" + nameY + "[id];\n"; out << SP << "}\n"; return out.str(); } diff --git a/tmva/sofie/src/SOFIE_common.cxx b/tmva/sofie/src/SOFIE_common.cxx index cad95a159f58a..aca8afabc4281 100644 --- a/tmva/sofie/src/SOFIE_common.cxx +++ b/tmva/sofie/src/SOFIE_common.cxx @@ -99,6 +99,8 @@ std::string ConvertTypeToString(ETensorType type){ } } +// invert function might now work correctly for booleans +// prefer avoid using it if possible ETensorType ConvertStringToType(std::string type){ if(type == "float32" || type == "float" || type == "Float"){ return ETensorType::FLOAT; @@ -106,10 +108,13 @@ ETensorType ConvertStringToType(std::string type){ else if(type == "int64" || type == "int64_t"){ return ETensorType::INT64; } + else if(type == "int32" || type == "int32_t"){ + return ETensorType::INT32; + } else if (type == "double" || type == "float64"){ return ETensorType::DOUBLE; } - else if (type == "bool" ){ + else if (type == "bool" || type == "uint8_t" ){ return ETensorType::BOOL; } else{ diff --git a/tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx b/tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx index 1efb901d8e791..ed42f7381693b 100644 --- a/tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx +++ b/tmva/sofie_parsers/inc/TMVA/RModelParser_ONNX.hxx @@ -29,6 +29,8 @@ class RModelParser_ONNX { public: struct OperatorsMapImpl; + enum EFusedOp { kMatMulAdd, kConvAdd, kConvTransAdd, kGemmRelu, kBatchnormRelu}; + private: bool fVerbose = false; @@ -36,8 +38,9 @@ private: std::unique_ptr fOperatorsMapImpl; // Type of the tensors std::unordered_map fTensorTypeMap; - // flag list of fused operators - std::vector fFusedOperators; + + // List of fused operators storing as key the second operator and a value a pair of fusion type and parent operator + std::map> fFusedOperators; public: diff --git a/tmva/sofie_parsers/src/ParseBasicNary.cxx b/tmva/sofie_parsers/src/ParseBasicNary.cxx index 0da3d493e095d..13e5d1ce5e5b4 100644 --- a/tmva/sofie_parsers/src/ParseBasicNary.cxx +++ b/tmva/sofie_parsers/src/ParseBasicNary.cxx @@ -43,19 +43,22 @@ std::unique_ptr ParseBasicNary(RModelParser_ONNX& parser, const onnx: return op; } - +// Max ParserFuncSignature ParseMax = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { return ParseBasicNary(parser, nodeproto); }; +// Min ParserFuncSignature ParseMin= [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { return ParseBasicNary(parser, nodeproto); }; +//Mean ParserFuncSignature ParseMean = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { return ParseBasicNary(parser, nodeproto); }; +// Sum ParserFuncSignature ParseSum = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { return ParseBasicNary(parser, nodeproto); }; diff --git a/tmva/sofie_parsers/src/ParseCast.cxx b/tmva/sofie_parsers/src/ParseCast.cxx index 0e07fd7d164fc..6970fc2fcbdce 100644 --- a/tmva/sofie_parsers/src/ParseCast.cxx +++ b/tmva/sofie_parsers/src/ParseCast.cxx @@ -13,21 +13,23 @@ ParserFuncSignature ParseCast = [](RModelParser_ONNX &parser, const onnx::NodePr " but its type is not yet registered"); } - std::unique_ptr op; - std::string attr_type; + ETensorType output_type = ETensorType::UNDEFINED; for (int_t i = 0; i < nodeproto.attribute_size(); i++) { std::string attribute_name = nodeproto.attribute(i).name(); - if (attribute_name == "to") - attr_type = ConvertTypeToString(static_cast(nodeproto.attribute(i).i())); + if (attribute_name == "to") { + output_type = static_cast(nodeproto.attribute(i).i()); + } } + if (output_type == ETensorType::UNDEFINED) + throw std::runtime_error("TMVA::SOFIE ONNX Parser Cast op has invalid output type"); std::string output_name = nodeproto.output(0); - op.reset(new ROperator_Cast(attr_type, nodeproto.input(0), output_name)); + auto op = std::make_unique(output_type, nodeproto.input(0), output_name); if (!parser.IsRegisteredTensorType(output_name)) { - ETensorType output_type = ConvertStringToType(attr_type); parser.RegisterTensorType(output_name, output_type); + std::cout << "Cast -> " << output_name << " type " << (int) output_type << " " << ConvertTypeToString(output_type) << std::endl; } return op; diff --git a/tmva/sofie_parsers/src/ParseComparision.cxx b/tmva/sofie_parsers/src/ParseComparision.cxx index 06128e6fd00d6..bd1598bb5142b 100644 --- a/tmva/sofie_parsers/src/ParseComparision.cxx +++ b/tmva/sofie_parsers/src/ParseComparision.cxx @@ -65,17 +65,17 @@ ParserFuncSignature ParseLess = [](RModelParser_ONNX &parser, const onnx::NodePr return ParseComparision(parser, nodeproto); }; -// Parse Mul +// Parse LessEq ParserFuncSignature ParseLessEq = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { return ParseComparision(parser, nodeproto); }; -// Parse Div +// Parse Greater ParserFuncSignature ParseGreater = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { return ParseComparision(parser, nodeproto); }; -// Parse Pow +// Parse GreaterEq ParserFuncSignature ParseGreaterEq = [](RModelParser_ONNX &parser, const onnx::NodeProto &nodeproto) { return ParseComparision(parser, nodeproto); }; diff --git a/tmva/sofie_parsers/src/ParseConstant.cxx b/tmva/sofie_parsers/src/ParseConstant.cxx index a4ba71d038389..2b5a2ab37031a 100644 --- a/tmva/sofie_parsers/src/ParseConstant.cxx +++ b/tmva/sofie_parsers/src/ParseConstant.cxx @@ -176,6 +176,7 @@ ParserFuncSignature ParseConstant = [](RModelParser_ONNX &parser, const onnx::No if (!parser.IsRegisteredTensorType(output_name)) { parser.RegisterTensorType(output_name, output_type); + std::cout << "Constant -> " << output_name << " " << (int) output_type << std::endl; } if (parser.Verbose()) diff --git a/tmva/sofie_parsers/src/ParseExpand.cxx b/tmva/sofie_parsers/src/ParseExpand.cxx index e55e791aaee0c..8eae2de4e16e2 100644 --- a/tmva/sofie_parsers/src/ParseExpand.cxx +++ b/tmva/sofie_parsers/src/ParseExpand.cxx @@ -13,6 +13,7 @@ ParserFuncSignature ParseExpand = [](RModelParser_ONNX &parser, const onnx::Node const std::string input_name = nodeproto.input(0); if (parser.IsRegisteredTensorType(input_name)) { input_type = parser.GetTensorType(input_name); + std::cout << "input type is " << static_cast(input_type) << " " << ConvertTypeToString(input_type) << std::endl; } else { throw std::runtime_error( "TMVA::SOFIE ONNX Parser Expand op has input tensor " + input_name + @@ -39,6 +40,9 @@ ParserFuncSignature ParseExpand = [](RModelParser_ONNX &parser, const onnx::Node case ETensorType::INT64: op.reset(new ROperator_Expand(input_name, shape_name, output_name)); break; + case ETensorType::BOOL: + op.reset(new ROperator_Expand(input_name, shape_name, output_name)); + break; default: throw std::runtime_error("TMVA::SOFIE - Unsupported - Expand Operator does " "not support input type " + diff --git a/tmva/sofie_parsers/src/ParseNot.cxx b/tmva/sofie_parsers/src/ParseNot.cxx index c1bc7b027a5ad..cd03b00d0c6be 100644 --- a/tmva/sofie_parsers/src/ParseNot.cxx +++ b/tmva/sofie_parsers/src/ParseNot.cxx @@ -17,8 +17,9 @@ ParserFuncSignature ParseNot = [](RModelParser_ONNX &parser, const onnx::NodePro if (parser.IsRegisteredTensorType(input_name)) { input_type = parser.GetTensorType(input_name); + std::cout << "NOT op input type is " << static_cast(input_type) << " " << ConvertTypeToString(input_type) << std::endl; if (input_type !=ETensorType::BOOL && input_type !=ETensorType::UINT8 ) - std::runtime_error("TMVA::SOFIE ONNX Parser Not op has invalid input type " + ConvertTypeToString(input_type)); + throw std::runtime_error("TMVA::SOFIE ONNX Parser Not op has invalid input type " + ConvertTypeToString(input_type)); } else { throw std::runtime_error("TMVA::SOFIE ONNX Parser Not op has input tensor " + input_name + diff --git a/tmva/sofie_parsers/src/ParseReduce.cxx b/tmva/sofie_parsers/src/ParseReduce.cxx index 6c18a4371c342..1753979cfa7bb 100644 --- a/tmva/sofie_parsers/src/ParseReduce.cxx +++ b/tmva/sofie_parsers/src/ParseReduce.cxx @@ -57,14 +57,15 @@ std::unique_ptr ParseReduce(RModelParser_ONNX &parser, const onnx::No std::vector({nodeproto.attribute(i).ints().begin(), nodeproto.attribute(i).ints().end()}); } } - switch (input_type) { - case ETensorType::FLOAT: - op.reset(new ROperator_Reduce(attr_keepdims, attr_axes, input_name, axes_name, output_name)); - break; - default: - throw std::runtime_error("TMVA::SOFIE - Unsupported - Reduce Operator does not yet support input type " + - std::to_string(static_cast(input_type))); - } + op.reset(new ROperator_Reduce(attr_keepdims, attr_axes, input_name, axes_name, output_name)); + // switch (input_type) { + // case ETensorType::FLOAT: + // op.reset(new ROperator_Reduce(attr_keepdims, attr_axes, input_name, axes_name, output_name)); + // break; + // default: + // throw std::runtime_error("TMVA::SOFIE - Unsupported - Reduce Operator does not yet support input type " + + // std::to_string(static_cast(input_type))); + // } if (!parser.IsRegisteredTensorType(output_name)) { parser.RegisterTensorType(output_name, input_type); diff --git a/tmva/sofie_parsers/src/ParseReshape.cxx b/tmva/sofie_parsers/src/ParseReshape.cxx index ddb316ca837a4..5db12d9aac847 100644 --- a/tmva/sofie_parsers/src/ParseReshape.cxx +++ b/tmva/sofie_parsers/src/ParseReshape.cxx @@ -26,6 +26,7 @@ ParserFuncSignature ParseReshape = [](RModelParser_ONNX &parser, const onnx::Nod ? nodeproto.input(1) : ""; if (parser.IsRegisteredTensorType(input_name)) { input_type = parser.GetTensorType(input_name); + std::cout << "Reshape/Un/Squueze op input type is " << static_cast(input_type) << " " << ConvertTypeToString(input_type) << std::endl; } else { throw std::runtime_error("TMVA::SOFIE ONNX Parser Reshape op has input tensor" + input_name + " but its type is not yet registered"); @@ -57,6 +58,7 @@ ParserFuncSignature ParseReshape = [](RModelParser_ONNX &parser, const onnx::Nod if (!parser.IsRegisteredTensorType(output_name)) { parser.RegisterTensorType(output_name, input_type); + std::cout << "Reshape/Un/Squueze register output " << output_name << " with type is " << static_cast(input_type) << " " << ConvertTypeToString(input_type) << std::endl; } return op; diff --git a/tmva/sofie_parsers/src/ParseTranspose.cxx b/tmva/sofie_parsers/src/ParseTranspose.cxx index 198f57ba90d46..8178f37dac0b3 100644 --- a/tmva/sofie_parsers/src/ParseTranspose.cxx +++ b/tmva/sofie_parsers/src/ParseTranspose.cxx @@ -40,6 +40,20 @@ ParserFuncSignature ParseTranspose = [](RModelParser_ONNX &parser, const onnx::N op.reset(new ROperator_Transpose(nodeproto.input(0), nodeproto.output(0))); } break; + case ETensorType::BOOL: + if (!attr_perm.empty()) { + op.reset(new ROperator_Transpose(attr_perm, nodeproto.input(0), nodeproto.output(0))); + } else { + op.reset(new ROperator_Transpose(nodeproto.input(0), nodeproto.output(0))); + } + break; + case ETensorType::UINT8: + if (!attr_perm.empty()) { + op.reset(new ROperator_Transpose(attr_perm, nodeproto.input(0), nodeproto.output(0))); + } else { + op.reset(new ROperator_Transpose(nodeproto.input(0), nodeproto.output(0))); + } + break; default: throw std::runtime_error("TMVA::SOFIE - Unsupported - Operator Transpose does not yet support input type " + std::to_string(static_cast(input_type))); diff --git a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx index 6090b8a0799c6..038afc8df1b74 100644 --- a/tmva/sofie_parsers/src/RModelParser_ONNX.cxx +++ b/tmva/sofie_parsers/src/RModelParser_ONNX.cxx @@ -307,48 +307,60 @@ RModelParser_ONNX::ParseOperator(const size_t i, const onnx::GraphProto &graphpr if (fVerbose) std::cout << "Parsing operator " << op_type << std::endl; - // skip already fused operators - if (fFusedOperators[idx]) return nullptr; + // perform the fusion of operators + if (fFusedOperators.count(idx) == 1) { + int idx1 = fFusedOperators[idx].second; + if (fVerbose) { + std::cout << "\tFusing operators " << graphproto.node(idx1).name() + << " with " << graphproto.node(idx1).name() << std::endl; + } + if (fFusedOperators[idx].first == EFusedOp::kMatMulAdd) { + return ParseFuseMatMulAdd(*this, graphproto.node(idx1), graphproto.node(idx)); + } else if (fFusedOperators[idx].first == EFusedOp::kConvAdd) { + return ParseFuseConvAdd(*this, graphproto.node(idx1), graphproto.node(idx)); + } else if (fFusedOperators[idx].first == EFusedOp::kConvTransAdd) { + return ParseFuseConvTransposeAdd(*this, graphproto.node(idx1), graphproto.node(idx)); + } else if (fFusedOperators[idx].first == EFusedOp::kGemmRelu) { + return ParseFuseGemmRelu(*this, graphproto.node(idx1), graphproto.node(idx)); + } else if (fFusedOperators[idx].first == EFusedOp::kBatchnormRelu) { + return ParseFuseBatchnormRelu(*this, graphproto.node(idx1), graphproto.node(idx)); + } + } - // try to fuse with following operator in case it is not last one + // try to fuse with following operator in case it is not last one and having only a single child if (children.size() == 1) { int idx2 = children.front(); if (op_type == "MatMul") { // Fuse MatMul and Add if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Add") { - fFusedOperators[idx2] = true; - return ParseFuseMatMulAdd(*this, graphproto.node(idx), graphproto.node(idx2)); - } - else { - return ParseMatMul(*this, graphproto.node(idx)); + fFusedOperators[idx2] = {EFusedOp::kMatMulAdd, idx}; + return nullptr; } } else if (nodeproto.op_type() == "Conv" || nodeproto.op_type() == "ConvTranspose") { // Fuse Conv or ConvTranspose without bias and Add if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Add") { if (nodeproto.op_type() == "Conv") { - fFusedOperators[idx2] = true; - return ParseFuseConvAdd(*this, graphproto.node(idx), graphproto.node(idx2)); + fFusedOperators[idx2] = { EFusedOp::kConvAdd, idx}; + return nullptr; } else { - fFusedOperators[idx2] = true; - return ParseFuseConvTransposeAdd(*this, graphproto.node(idx), graphproto.node(idx2)); + fFusedOperators[idx2] = { EFusedOp::kConvTransAdd, idx}; + return nullptr; } } } else if (nodeproto.op_type() == "Gemm") { // Fuse Gemm with activation operators if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Relu") { - fFusedOperators[idx2] = true; - return ParseFuseGemmRelu(*this, graphproto.node(idx), graphproto.node(idx2)); + fFusedOperators[idx2] = {EFusedOp::kGemmRelu, idx}; + return nullptr; } } else if (nodeproto.op_type() == "BatchNormalization") { if (idx2 < graphproto.node_size() && graphproto.node(idx2).op_type() == "Relu") { - fFusedOperators[idx2] = true; - return ParseFuseBatchnormRelu(*this, graphproto.node(idx), graphproto.node(idx2)); + fFusedOperators[idx2] = {EFusedOp::kBatchnormRelu, idx}; + return nullptr; } } } - - auto it = fOperatorsMapImpl->fOperatorsMap.find(op_type); if (it == fOperatorsMapImpl->fOperatorsMap.end()) { std::cout << "operator " << op_type << " is not supported" << std::endl; @@ -771,7 +783,6 @@ void RModelParser_ONNX::ParseONNXGraph(RModel & rmodel, const onnx::GraphProto & // we have to record order of node execution separately to // account for fused operators size_t node_order_exec = 0; - fFusedOperators = std::vector(graph.node_size(), false); for (int i = 0; i < graph.node_size(); i++) { std::string op_type = graph.node(nodesOrder[i]).op_type(); From 6982ef20ec6482c9d414127e201b0097ea1bebb7 Mon Sep 17 00:00:00 2001 From: moneta Date: Wed, 18 Mar 2026 15:23:15 +0100 Subject: [PATCH 3/4] [tmva][sofie] Add support for dynamic tensors in Where operator --- tmva/sofie/inc/TMVA/ROperator_Where.hxx | 612 +++++++++++++++--------- 1 file changed, 380 insertions(+), 232 deletions(-) diff --git a/tmva/sofie/inc/TMVA/ROperator_Where.hxx b/tmva/sofie/inc/TMVA/ROperator_Where.hxx index 59cb311f1a203..494834cacc69e 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Where.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Where.hxx @@ -7,293 +7,441 @@ #include -namespace TMVA{ -namespace Experimental{ -namespace SOFIE{ +namespace TMVA { +namespace Experimental { +namespace SOFIE { - - -template -class ROperator_Where final : public ROperator{ +template +class ROperator_Where final : public ROperator { private: bool fIsInputBoolTensor = false; - - std::string fNX; - std::string fNY; - std::string fNC; + // Tensor names: C = condition, X = true branch, Y = false branch, Z = output + std::string fNC; // condition (bool) + std::string fNX; // true-branch values + std::string fNY; // false-branch values + std::string fNZ; // output + std::string fNBroadcastedC; std::string fNBroadcastedX; std::string fNBroadcastedY; - std::string fNBroadcastedC; - std::string fNY; - - std::vector fShapeX; - std::vector fShapeY; + // Static shapes (used when all inputs are non-dynamic) std::vector fShapeC; + std::vector fShapeX; std::vector fShapeY; + std::vector fShapeZ; + + // Dynamic shapes (Dim-aware, used when any input is dynamic) + std::vector fDimShapeC; + std::vector fDimShapeX; + std::vector fDimShapeY; + std::vector fDimShapeZ; + // Broadcast flag: mirrors convention of BasicBinary + // bit 0: broadcast Y->X (Y needs expanding) + // bit 1: broadcast X->Y (X needs expanding) + // bit 2: broadcast C->Z (C needs expanding) + // bit 4: shapes may differ at runtime (dynamic) + int fBroadcastFlag = 0; public: - ROperator_Where(){} - ROperator_Where(const std::string & nameX, const std::string & nameY, const std::string & nameC, const std::string & nameY): - fNX(UTILITY::Clean_name(nameX)), fNY(UTILITY::Clean_name(nameY)), fNC(UTILITY::Clean_name(nameC)), fNY(UTILITY::Clean_name(nameY)){ - fInputTensorNames = { fNX, fNY, fNC }; - fOutputTensorNames = { fNY }; - } + ROperator_Where() {} + ROperator_Where(const std::string &nameC, + const std::string &nameX, + const std::string &nameY, + const std::string &nameZ) + : fNC(UTILITY::Clean_name(nameC)), + fNX(UTILITY::Clean_name(nameX)), + fNY(UTILITY::Clean_name(nameY)), + fNZ(UTILITY::Clean_name(nameZ)) + { + fInputTensorNames = { fNC, fNX, fNY }; + fOutputTensorNames = { fNZ }; + } // type of output given input - std::vector TypeInference(std::vector input) override { - return input; + std::vector TypeInference(std::vector input) override + { + // output type follows X (and Y), not C (which is bool) + return { input[1] }; } // shape of output tensors given input tensors - std::vector> ShapeInference(std::vector> input) override { - // assume now inputs have same shape (no broadcasting) - auto ret = std::vector>(1, input[0]); // return vector size 1 with first input - return ret; + std::vector> ShapeInference(std::vector> input) override + { + // conservative: assume same shape (broadcasting resolved in Initialize) + return { input[1] }; } - void Initialize(RModel& model) override { - // input must be a graph input, or already initialized intermediate tensor - if (!model.CheckIfTensorAlreadyExist(fNX)){ - throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNX + "is not found in model"); - } - if (!model.CheckIfTensorAlreadyExist(fNY)) { - throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNY + "is not found in model"); - } - if (!model.CheckIfTensorAlreadyExist(fNC)) { - throw std::runtime_error(std::string("TMVA SOFIE Where Op Input Tensor ") + fNC + "is not found in model"); - } - // check if fNC input tensor is boolean + void Initialize(RModel &model) override + { + // ---------------------------------------------------------------- // + // Check all inputs exist + // ---------------------------------------------------------------- // + if (!model.CheckIfTensorAlreadyExist(fNC)) + throw std::runtime_error(std::string("TMVA SOFIE Where Op: condition tensor ") + fNC + " not found in model"); + if (!model.CheckIfTensorAlreadyExist(fNX)) + throw std::runtime_error(std::string("TMVA SOFIE Where Op: X tensor ") + fNX + " not found in model"); + if (!model.CheckIfTensorAlreadyExist(fNY)) + throw std::runtime_error(std::string("TMVA SOFIE Where Op: Y tensor ") + fNY + " not found in model"); + + // condition tensor is bool (uint8) - mark if it is a live input tensor if (model.IsReadyInputTensor(fNC)) fIsInputBoolTensor = true; - // check broadcast for A, B and C - fShapeX = model.GetTensorShape(fNX); - fShapeY = model.GetTensorShape(fNY); - fShapeC = model.GetTensorShape(fNC); - bool broadcast = !UTILITY::AreSameShape(fShapeX, fShapeY) || !UTILITY::AreSameShape(fShapeX, fShapeC); - if (broadcast) { - // find shape to broadcast between A,B,C looking for max length - size_t lengthA = ConvertShapeToLength(fShapeX); - size_t lengthB = ConvertShapeToLength(fShapeY); - size_t lengthC = ConvertShapeToLength(fShapeC); - bool broadcastA = false, broadcastB = false, broadcastC = false; - if (lengthA >= lengthB && lengthA >= lengthC) { - fShapeY = fShapeX; - //broadcast B and C if different than A - broadcastB = (lengthB != lengthA); - broadcastC = (lengthC != lengthA); - } - else if (lengthB >= lengthA && lengthB >= lengthC) { - fShapeY = fShapeY; - //broadcast A and C if different than B - broadcastA = (lengthA != lengthB); - broadcastC = (lengthC != lengthB); - } - else if (lengthC >= lengthA && lengthC >= lengthB) { - fShapeY = fShapeC; - //broadcast A and B if different than C - broadcastA = (lengthA != lengthC); - broadcastB = (lengthB != lengthC); - } - // Broadcast A to Y - if (broadcastA) { - fNBroadcastedX = "BC_" + fNX + "_to_" + fNY; - if (model.IsInitializedTensor(fNX)) { - auto data = model.GetInitializedTensorData(fNX); - std::shared_ptr broadcastedData( - UTILITY::UnidirectionalBroadcast(static_cast(data.get()), fShapeX, fShapeY), - std::default_delete()); - // Update the data and the shape of A - model.AddConstantTensor(fNBroadcastedX, model.GetTensorType(fNX), fShapeY, broadcastedData); - fShapeX = fShapeY; - } else { - // Add an intermediate tensor for broadcasting A - model.AddIntermediateTensor(fNBroadcastedX, model.GetTensorType(fNX), fShapeY); - } - } - // Broadcast B to Y - if (broadcastB) { - fNBroadcastedY = "BC_" + fNY + "_to_" + fNY; - if (model.IsInitializedTensor(fNY)) { - auto data = model.GetInitializedTensorData(fNY); - std::shared_ptr broadcastedData( - UTILITY::UnidirectionalBroadcast(static_cast(data.get()), fShapeY, fShapeY), - std::default_delete()); - // do not update tensor B but add broadcasted one (since it can be input to some other operators) - model.AddConstantTensor(fNBroadcastedY, model.GetTensorType(fNY), fShapeY, broadcastedData); - fShapeY = fShapeY; - } else { - // Add an intermediate tensor for broadcasting B - model.AddIntermediateTensor(fNBroadcastedY, model.GetTensorType(fNY), fShapeY); - } - } - // Broadcast C to Y - if (broadcastC) { - fNBroadcastedC = "BC_" + fNC + "_to_" + fNY; - if (model.IsInitializedTensor(fNC)) { - auto data = model.GetInitializedTensorData(fNC); - std::shared_ptr broadcastedData( - UTILITY::UnidirectionalBroadcast(static_cast(data.get()), fShapeC, fShapeY), - std::default_delete()); - // do not update tensor C but add broadcasted one (since it can be input to some other operators) - model.AddConstantTensor(fNBroadcastedC, model.GetTensorType(fNC), fShapeY, broadcastedData); - fShapeC = fShapeY; - } else { - // Add an intermediate tensor for broadcasting B - model.AddIntermediateTensor(fNBroadcastedC, model.GetTensorType(fNC), fShapeY); - } - } + // ---------------------------------------------------------------- // + // Collect shapes – dynamic or static + // ---------------------------------------------------------------- // + int dynamicInputs = 0; // bitmask: bit0=C, bit1=X, bit2=Y + + if (model.IsDynamicTensor(fNC)) { + fDimShapeC = model.GetDynamicTensorShape(fNC); + dynamicInputs |= 1; + } else { + fShapeC = model.GetTensorShape(fNC); + fDimShapeC = ConvertShapeToDim(fShapeC); + } + if (model.IsDynamicTensor(fNX)) { + fDimShapeX = model.GetDynamicTensorShape(fNX); + dynamicInputs |= 2; + } else { + fShapeX = model.GetTensorShape(fNX); + fDimShapeX = ConvertShapeToDim(fShapeX); + } + if (model.IsDynamicTensor(fNY)) { + fDimShapeY = model.GetDynamicTensorShape(fNY); + dynamicInputs |= 4; } else { - fShapeY = fShapeX; + fShapeY = model.GetTensorShape(fNY); + fDimShapeY = ConvertShapeToDim(fShapeY); } - // check case of constant output (if all inputs are defined) - if (model.IsInitializedTensor(fNC)) { - - std::string nameC = fNBroadcastedC.empty()? fNC : fNBroadcastedC; - auto dataC = static_cast(model.GetInitializedTensorData(nameC).get()); - model.SetNotWritableInitializedTensor(nameC); - T * dataA = nullptr; - T * dataB = nullptr; - std::vector shapeDataA; - std::vector shapeDataB; - if (model.IsInitializedTensor(fNX)) { - std::string nameX = fNBroadcastedX.empty()? fNX : fNBroadcastedX; - dataA = static_cast(model.GetInitializedTensorData(nameX).get()); - // flag tensors to not be written in a file + + if (model.Verbose()) { + if (dynamicInputs & 1) + std::cout << "Where : condition " << fNC << " is dynamic " << ConvertDimShapeToString(fDimShapeC) << "\n"; + if (dynamicInputs & 2) + std::cout << "Where : X " << fNX << " is dynamic " << ConvertDimShapeToString(fDimShapeX) << "\n"; + if (dynamicInputs & 4) + std::cout << "Where : Y " << fNY << " is dynamic " << ConvertDimShapeToString(fDimShapeY) << "\n"; + } + + // ---------------------------------------------------------------- // + // Static path: all shapes known at code-gen time + // ---------------------------------------------------------------- // + if (dynamicInputs == 0) { + + // Multidirectional broadcast over all three tensors + auto retXY = UTILITY::MultidirectionalBroadcastShape(fShapeX, fShapeY); + fBroadcastFlag = retXY.first; + fShapeZ = retXY.second; + // also factor in C + auto retCZ = UTILITY::MultidirectionalBroadcastShape(fShapeC, fShapeZ); + fBroadcastFlag |= retCZ.first; + fShapeZ = retCZ.second; + + bool allConstant = model.IsConstantTensor(fNC) && + model.IsConstantTensor(fNX) && + model.IsConstantTensor(fNY); + + if (allConstant) { + // ---------------------------------------------------------- + // Constant folding: evaluate Where at model initialisation + // ---------------------------------------------------------- + auto broadcastIfNeeded = [&](const std::string &name, + const std::vector &shape, + std::string &bcName, + const std::string &prefix) { + if (shape != fShapeZ) { + bcName = prefix + name + "to" + fNZ; + auto data = model.GetInitializedTensorData(name); + std::shared_ptr bcData( + UTILITY::UnidirectionalBroadcast(static_cast(data.get()), shape, fShapeZ), + std::default_delete()); + model.AddConstantTensor(bcName, model.GetTensorType(name), fShapeZ, bcData); + } + }; + + broadcastIfNeeded(fNX, fShapeX, fNBroadcastedX, "BC_"); + broadcastIfNeeded(fNY, fShapeY, fNBroadcastedY, "BC_"); + broadcastIfNeeded(fNC, fShapeC, fNBroadcastedC, "BC_"); + + const std::string &nameC = fNBroadcastedC.empty() ? fNC : fNBroadcastedC; + const std::string &nameX = fNBroadcastedX.empty() ? fNX : fNBroadcastedX; + const std::string &nameY = fNBroadcastedY.empty() ? fNY : fNBroadcastedY; + + auto dataC = static_cast(model.GetInitializedTensorData(nameC).get()); + auto dataX = static_cast (model.GetInitializedTensorData(nameX).get()); + auto dataY = static_cast (model.GetInitializedTensorData(nameY).get()); + + size_t len = ConvertShapeToLength(fShapeZ); + std::vector dataZ(len); + for (size_t i = 0; i < len; ++i) + dataZ[i] = dataC[i] ? dataX[i] : dataY[i]; + + model.AddConstantTensor(fNZ, fShapeZ, dataZ.data()); + model.SetNotWritableInitializedTensor(nameC); model.SetNotWritableInitializedTensor(nameX); - } else if (model.IsShapeTensor(fNX)) - shapeDataA = model.GetShapeTensorValues(fNX); - if (model.IsInitializedTensor(fNY)) { - std::string nameY = fNBroadcastedY.empty()? fNY : fNBroadcastedY; - dataB = static_cast(model.GetInitializedTensorData(nameY).get()); model.SetNotWritableInitializedTensor(nameY); - } else if (model.IsShapeTensor(fNY)) - shapeDataB = model.GetShapeTensorValues(fNY); + fIsOutputConstant = true; + fOutputTensorNames.pop_back(); - std::vector dataY; - std::vector shapeDataY; + if (model.Verbose()) + std::cout << "Where --> " << fNZ << " " << ConvertShapeToString(fShapeZ) + << " : " << ConvertValuesToString(dataZ) << " (constant)\n"; + } else { + // ---------------------------------------------------------- + // Non-constant static: register broadcasted intermediates + // ---------------------------------------------------------- + auto registerBC = [&](const std::string &name, + const std::vector &shape, + std::string &bcName, + const std::string &prefix) { + if (shape != fShapeZ) { + bcName = prefix + name + "to" + fNZ; + if (model.IsInitializedTensor(name)) { + auto data = model.GetInitializedTensorData(name); + std::shared_ptr bcData( + UTILITY::UnidirectionalBroadcast(static_cast(data.get()), shape, fShapeZ), + std::default_delete()); + model.AddConstantTensor(bcName, model.GetTensorType(name), fShapeZ, bcData); + } else { + model.AddIntermediateTensor(bcName, model.GetTensorType(name), fShapeZ); + } + } + }; - bool isOutputConstantTensor = true; - if (dataA && dataB) { - dataY.resize(ConvertShapeToLength(fShapeY)); - for (size_t i = 0; i < dataY.size(); i++) - dataY[i] = (dataC[i]) ? dataA[i] : dataB[i]; - } - else if (dataA && shapeDataB.size()>0 ) { - shapeDataY.resize(ConvertShapeToLength(fShapeY)); - for (size_t i = 0; i < shapeDataY.size(); i++) { - shapeDataY[i] = (dataC[i]) ? Dim{size_t(dataA[i])} : shapeDataB[i]; - isOutputConstantTensor &= !shapeDataY[i].isParam; - } - } - else if (dataB && shapeDataA.size()>0 ) { - shapeDataY.resize(ConvertShapeToLength(fShapeY)); - for (size_t i = 0; i < shapeDataY.size(); i++) { - shapeDataY[i] = (dataC[i]) ? shapeDataB[i] : Dim{size_t(dataB[i])}; - isOutputConstantTensor &= !shapeDataY[i].isParam; - } + registerBC(fNX, fShapeX, fNBroadcastedX, "BC_"); + registerBC(fNY, fShapeY, fNBroadcastedY, "BC_"); + registerBC(fNC, fShapeC, fNBroadcastedC, "BC_"); + + fDimShapeZ = ConvertShapeToDim(fShapeZ); + model.AddIntermediateTensor(fNZ, model.GetTensorType(fNX), fShapeZ); + + if (model.Verbose()) + std::cout << "Where : C=" << fNC << " " << ConvertShapeToString(fShapeC) + << " X=" << fNX << " " << ConvertShapeToString(fShapeX) + << " Y=" << fNY << " " << ConvertShapeToString(fShapeY) + << " --> Z=" << fNZ << " " << ConvertShapeToString(fShapeZ) << "\n"; } - else if (shapeDataB.size() > 0 && shapeDataA.size()>0 ) { - shapeDataY.resize(ConvertShapeToLength(fShapeY)); - for (size_t i = 0; i < shapeDataY.size(); i++) { - shapeDataY[i] = (dataC[i]) ? shapeDataA[i] : shapeDataB[i]; - isOutputConstantTensor &= !shapeDataY[i].isParam; + + } else { + // ---------------------------------------------------------------- // + // Dynamic path: at least one input has a parametric shape + // ---------------------------------------------------------------- // + auto retXY = UTILITY::MultidirectionalBroadcastShape(fDimShapeX, fDimShapeY); + fBroadcastFlag = retXY.first; + fDimShapeZ = retXY.second; + auto retCZ = UTILITY::MultidirectionalBroadcastShape(fDimShapeC, fDimShapeZ); + fBroadcastFlag |= retCZ.first; + fDimShapeZ = retCZ.second; + + // Resolve std::max params to actual input dim params (same logic as BasicBinary) + if (fBroadcastFlag & 4) { + auto IsInputDimParam = [&](const std::string &p) { + for (auto &input : model.GetInputTensorNames()) + for (auto &s : model.GetDimTensorShape(input)) + if (s.isParam && s.param == p) return true; + return false; + }; + for (size_t i = 0; i < fDimShapeZ.size(); i++) { + auto &s = fDimShapeZ[i]; + if (s.isParam && s.param.find("std::max") != std::string::npos) { + // prefer X dim over Y dim + if (i < fDimShapeX.size() && IsInputDimParam(fDimShapeX[i].param)) { + s = (fDimShapeX[i].dim != 1) ? fDimShapeX[i] : fDimShapeY[i]; + } else if (i < fDimShapeY.size() && IsInputDimParam(fDimShapeY[i].param)) { + s = (fDimShapeY[i].dim != 1) ? fDimShapeY[i] : fDimShapeX[i]; + } + } } } - fIsOutputConstant = true; // this contains both case constant tensor output ans shape tensor output - if (isOutputConstantTensor && dataY.empty()) { - dataY.resize(shapeDataY.size()); - for (size_t i = 0; i < shapeDataY.size(); i++) - dataY[i] = static_cast(shapeDataY[i].dim); - } - if (dataY.size() > 0) - model.AddConstantTensor(fNY, fShapeY, dataY.data()); - else if (shapeDataY.size() > 0 ) - model.AddShapeTensor(fNY, shapeDataY, fShapeY.size() == 0); - else { - fIsOutputConstant = false; - } - if (fIsOutputConstant && model.Verbose()) - std::cout << "Where op ---> " << fNY << " " << ConvertShapeToString(fShapeY) << " : " - << ((dataY.size() > 0) ? ConvertValuesToString(dataY) : ConvertDimShapeToString(shapeDataY) ) - << ((dataY.size() > 0) ? " (constant)" : " (shape)") << std::endl; - // output is a constant tensor - if (fIsOutputConstant) fOutputTensorNames.pop_back(); - } - if (!fIsOutputConstant) { - model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY); - if (model.Verbose()) - std::cout << "Where op " << " condition : " << fNC << " " << ConvertShapeToString(fShapeC) << - " X " << fNX << " " << ConvertShapeToString(fShapeX) << " Y " << fNY << " " << ConvertShapeToString(fShapeY) - << " ---> " << fNY << " " << ConvertShapeToString(fShapeY) << std::endl; + model.AddIntermediateTensor(fNZ, model.GetTensorType(fNX), fDimShapeZ); + + if (model.Verbose()) + std::cout << "Where (dynamic) : C=" << ConvertDimShapeToString(fDimShapeC) + << " X=" << ConvertDimShapeToString(fDimShapeX) + << " Y=" << ConvertDimShapeToString(fDimShapeY) + << " --> Z=" << ConvertDimShapeToString(fDimShapeZ) << "\n"; } } - std::string GenerateInitCode() override { + std::string GenerateInitCode() override + { std::stringstream out; return out.str(); } - std::string Generate(std::string opName) override { - + std::string Generate(std::string opName) override + { if (fIsOutputConstant) return ""; opName = "op_" + opName; - if (fShapeY.empty()) { + if (fDimShapeZ.empty()) { throw std::runtime_error("TMVA SOFIE Where Op called to Generate without being initialized first"); } + std::stringstream out; - out << SP << "\n//-------- Where " << opName << " --> " << ConvertShapeToString(fShapeY) << "\n"; - size_t length = ConvertShapeToLength(fShapeY); - std::string typeName = TensorType::Name(); - // Broadcast A if it's uninitialized - if (fShapeX != fShapeY) { - out << SP << "// Broadcasting uninitialized tensor " << fNX << "\n"; - //out << SP << "{\n"; - out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast(tensor_" << fNX << ", " << ConvertShapeToString(fShapeX) << ", " << ConvertShapeToString(fShapeY) - << ", tensor_" << fNBroadcastedX << ");\n"; + out << SP << "\n//------ WHERE " << opName << " --> " << ConvertDimShapeToString(fDimShapeZ) << "\n"; + + // ---------------------------------------------------------------- // + // Runtime broadcast validation (dynamic shapes, flag bit 4) + // ---------------------------------------------------------------- // + if (fBroadcastFlag & 4) { + auto lengthX = ConvertDimShapeToLength(fDimShapeX); + auto lengthY = ConvertDimShapeToLength(fDimShapeY); + auto lengthC = ConvertDimShapeToLength(fDimShapeC); + out << SP << "if (" << lengthX << " != " << lengthY << " || " + << lengthX << " != " << lengthC << ") {\n"; + for (size_t i = 0; i < fDimShapeZ.size(); i++) { + // validate X vs Z + if (i < fDimShapeX.size() && fDimShapeX[i].isParam) { + out << SP << SP << "if (" << fDimShapeX[i] << " != 1 && " + << fDimShapeX[i] << " != " << fDimShapeZ[i] << ")\n"; + out << SP << SP << SP + << "throw std::runtime_error(\"SOFIE Where: cannot broadcast X dim " << i << " in " << opName << "\");\n"; + } + // validate Y vs Z + if (i < fDimShapeY.size() && fDimShapeY[i].isParam) { + out << SP << SP << "if (" << fDimShapeY[i] << " != 1 && " + << fDimShapeY[i] << " != " << fDimShapeZ[i] << ")\n"; + out << SP << SP << SP + << "throw std::runtime_error(\"SOFIE Where: cannot broadcast Y dim " << i << " in " << opName << "\");\n"; + } + // validate C vs Z + if (i < fDimShapeC.size() && fDimShapeC[i].isParam) { + out << SP << SP << "if (" << fDimShapeC[i] << " != 1 && " + << fDimShapeC[i] << " != " << fDimShapeZ[i] << ")\n"; + out << SP << SP << SP + << "throw std::runtime_error(\"SOFIE Where: cannot broadcast C dim " << i << " in " << opName << "\");\n"; + } + } + out << SP << "}\n"; + } + + // ---------------------------------------------------------------- // + // Runtime broadcasting for non-constant, non-initialised tensors + // ---------------------------------------------------------------- // + // Broadcast X if needed + if (!fNBroadcastedX.empty()) { + out << SP << "// Broadcast X tensor " << fNX << "\n"; + out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" + << TensorType::Name() << ">(tensor_" << fNX << ", " + << ConvertDimShapeToString(fDimShapeX) << ", " + << ConvertDimShapeToString(fDimShapeZ) << ", tensor_" << fNBroadcastedX << ");\n"; } - // Broadcast B if it's uninitialized - if (fShapeY != fShapeY) { - out << SP << "// Broadcasting uninitialized tensor " << fNY << "\n"; - //out << SP << "{\n"; - out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast(tensor_" << fNY << ", " << ConvertShapeToString(fShapeY) << ", " << ConvertShapeToString(fShapeY) - << ", tensor_" << fNBroadcastedY << ");\n"; + // Broadcast Y if needed + if (!fNBroadcastedY.empty()) { + out << SP << "// Broadcast Y tensor " << fNY << "\n"; + out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast<" + << TensorType::Name() << ">(tensor_" << fNY << ", " + << ConvertDimShapeToString(fDimShapeY) << ", " + << ConvertDimShapeToString(fDimShapeZ) << ", tensor_" << fNBroadcastedY << ");\n"; } - // Broadcast C if it's uninitialized - if (fShapeC != fShapeY) { - // special case if C is an input tensor + // Broadcast C (condition) if needed + if (!fNBroadcastedC.empty()) { if (fIsInputBoolTensor) { + // live bool input: need a temporary std::vector for the broadcast utility size_t inputLength = ConvertShapeToLength(fShapeC); - out << SP << "std::vector tmp_tensor_" << fNC << "(tensor_" << fNC << ", tensor_" << fNC << " + " << inputLength << ");\n"; + out << SP << "std::vector tmp_tensor_" << fNC + << "(tensor_" << fNC << ", tensor_" << fNC << " + " << inputLength << ");\n"; + out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast" + << "(tmp_tensor_" << fNC << ".data(), " + << ConvertDimShapeToString(fDimShapeC) << ", " + << ConvertDimShapeToString(fDimShapeZ) << ", tensor_" << fNBroadcastedC << ");\n"; + } else { + out << SP << "// Broadcast condition tensor " << fNC << "\n"; + out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast" + << "(tensor_" << fNC << ", " + << ConvertDimShapeToString(fDimShapeC) << ", " + << ConvertDimShapeToString(fDimShapeZ) << ", tensor_" << fNBroadcastedC << ");\n"; } - out << SP << "// Broadcasting uninitialized tensor " << fNC << "\n"; - //out << SP << "{\n"; - out << SP << "TMVA::Experimental::SOFIE::UTILITY::UnidirectionalBroadcast(tmp_tensor_" << fNC << ".data(), " << ConvertShapeToString(fShapeC) << ", " << ConvertShapeToString(fShapeY) - << ", tensor_" << fNBroadcastedC << ");\n"; } - std::string nameX = fNBroadcastedX.empty()? fNX : fNBroadcastedX; - std::string nameY = fNBroadcastedY.empty()? fNY : fNBroadcastedY; - std::string nameC = fNBroadcastedC.empty()? fNC : fNBroadcastedC; - out << SP << "for (size_t id = 0; id < " << length << " ; id++){\n"; - // get output tensor applying condition - out << SP << SP << "tensor_" << fNY << "[id] = " << "tensor_" << nameC << "[id] ? tensor_" - << nameX << "[id] : tensor_" + nameY + "[id];\n"; - out << SP << "}\n"; + + // Final (possibly broadcasted) tensor names + const std::string nameX = fNBroadcastedX.empty() ? fNX : fNBroadcastedX; + const std::string nameY = fNBroadcastedY.empty() ? fNY : fNBroadcastedY; + const std::string nameC = fNBroadcastedC.empty() ? fNC : fNBroadcastedC; + + // ---------------------------------------------------------------- // + // Generate loop(s) with per-dimension stride-based index arithmetic + // (same pattern as BasicBinary) + // ---------------------------------------------------------------- // + auto stridesX = UTILITY::ComputeStrideFromShape(fDimShapeX); + auto stridesY = UTILITY::ComputeStrideFromShape(fDimShapeY); + auto stridesC = UTILITY::ComputeStrideFromShape(fDimShapeC); + auto stridesZ = UTILITY::ComputeStrideFromShape(fDimShapeZ); + + auto buildIdxExpr = [&](const std::vector &dimShape, + const std::vector &strides, + size_t rankZ) -> std::string { + if (dimShape.empty() || + std::all_of(dimShape.begin(), dimShape.end(), + [](Dim d) { return d.dim == 1 || d.GetVal() == "1"; })) + return "0"; + std::string expr; + size_t offset = rankZ - dimShape.size(); + for (size_t i = 0; i < dimShape.size(); ++i) { + if (dimShape[i].dim == 1 || dimShape[i].GetVal() == "1") continue; + expr += "idx_" + std::to_string(i + offset); + if (strides[i].GetVal() != "1") + expr += " * " + strides[i].GetVal(); + expr += " + "; + } + if (expr.size() >= 3) + for (int j = 0; j < 3; j++) expr.pop_back(); // remove trailing " + " + return expr.empty() ? "0" : expr; + }; + + std::string idxX = buildIdxExpr(fDimShapeX, stridesX, fDimShapeZ.size()); + std::string idxY = buildIdxExpr(fDimShapeY, stridesY, fDimShapeZ.size()); + std::string idxC = buildIdxExpr(fDimShapeC, stridesC, fDimShapeZ.size()); + + // Emit nested loops over output shape + int nloop = 0; + std::string idxZ; + if (fDimShapeZ.empty() || + std::all_of(fDimShapeZ.begin(), fDimShapeZ.end(), + [](Dim d) { return d.dim == 1 || d.GetVal() == "1"; })) { + idxZ = "0"; + } else { + for (size_t i = 0; i < fDimShapeZ.size(); ++i) { + if (fDimShapeZ[i].dim != 1 && fDimShapeZ[i].GetVal() != "1") { + nloop++; + for (int j = 0; j < nloop; j++) out << SP; + out << "for (size_t idx_" << i << " = 0; idx_" << i + << " < " << fDimShapeZ[i] << "; ++idx_" << i << ") {\n"; + idxZ += "idx_" + std::to_string(i); + if (stridesZ[i].GetVal() != "1") + idxZ += " * " + stridesZ[i].GetVal(); + idxZ += " + "; + } + } + if (idxZ.size() >= 3) + for (int j = 0; j < 3; j++) idxZ.pop_back(); + } + + // Inner assignment + for (int j = 0; j < nloop + 1; j++) out << SP; + out << "tensor_" << fNZ << "[" << idxZ << "] = " + << "tensor_" << nameC << "[" << idxC << "] ? " + << "tensor_" << nameX << "[" << idxX << "] : " + << "tensor_" << nameY << "[" << idxY << "];\n"; + + // Close loops + for (int i = nloop; i > 0; i--) { + for (int j = 0; j < i; j++) out << SP; + out << "}\n"; + } + return out.str(); } - }; -}//SOFIE -}//Experimental -}//TMVA - +} // namespace SOFIE +} // namespace Experimental +} // namespace TMVA -#endif //TMVA_SOFIE_ROperator_Where +#endif // TMVA_SOFIE_ROperator_Where From 9260cd5d85b398b89e1d93324dfd49f8144cba59 Mon Sep 17 00:00:00 2001 From: moneta Date: Thu, 19 Mar 2026 22:56:35 +0100 Subject: [PATCH 4/4] [tmva][sofie] Apply some fixes needed for MLPF model Fix also a bug when doing Gemm and applying the bias in case of stacked matrix multiplications. The bias was not correctly broadcasted in this case --- tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx | 56 ++++++++++++++++++- tmva/sofie/inc/TMVA/ROperator_Gemm.hxx | 18 ++++-- tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx | 3 +- tmva/sofie/inc/TMVA/ROperator_Slice.hxx | 2 +- tmva/sofie_parsers/src/ParseConstant.cxx | 2 +- tmva/sofie_parsers/src/ParseWhere.cxx | 4 +- 6 files changed, 74 insertions(+), 11 deletions(-) diff --git a/tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx b/tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx index 50bae3e04b6ec..16be75c5aa960 100644 --- a/tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_BasicBinary.hxx @@ -140,6 +140,7 @@ public: auto ret = UTILITY::MultidirectionalBroadcastShape(fShapeA, fShapeB); fBroadcastFlag = ret.first; fShapeY = ret.second; + auto lengthY = ConvertShapeToLength(fShapeY); if (model.IsConstantTensor(fNA) && model.IsConstantTensor(fNB)) { bool broadcast = fBroadcastFlag > 0; if (broadcast) { @@ -193,7 +194,7 @@ public: const std::string &nameB = fNBroadcastedB.empty() ? fNB : fNBroadcastedB; auto dataA = static_cast(model.GetInitializedTensorData(nameA).get()); auto dataB = static_cast(model.GetInitializedTensorData(nameB).get()); - std::vector dataY(ConvertShapeToLength(fShapeY)); + std::vector dataY(lengthY); for (size_t i = 0; i < dataY.size(); i++) { dataY[i] = BinaryOperatorTrait::Func(dataA[i], dataB[i]); } @@ -207,6 +208,59 @@ public: << " , " << fNB << " " << ConvertShapeToString(fShapeB) << " ---> " << fNY << " " << ConvertShapeToString(fShapeY) << " : " << ConvertValuesToString(dataY) << std::endl; } + } else if (((model.IsShapeTensor(fNA) && model.IsShapeTensor(fNB)) || + (model.IsShapeTensor(fNA) && model.IsConstantTensor(fNB)) || + (model.IsShapeTensor(fNB) && model.IsConstantTensor(fNA))) + && (fShapeA.size() <=1 && fShapeB.size() <=1 && model.GetTensorType(fNA) == ETensorType::INT64)) { + // case of shape tensors ( tensors are of rank 0 or 1 ) + std::vector dimValA; + std::vector dimValB; + if (model.IsShapeTensor(fNA)) + dimValA = model.GetShapeTensorValues(fNA); + if (model.IsShapeTensor(fNB)) + dimValB = model.GetShapeTensorValues(fNB); + // adjust for broadcasting - repet values until it reaches shapes of Y + if (!fShapeY.empty() && fShapeY[0] > 1) { + if (dimValA.size() == 1) dimValA = std::vector( fShapeY[0], dimValA[0]); + if (dimValB.size() == 1) dimValB = std::vector( fShapeY[0], dimValB[0]); + } + + auto convertDataToDim = [&](const std::string & name, const std::vector & shape, std::vector & dimValues) { + auto data = static_cast(model.GetInitializedTensorData(name).get()); + dimValues.resize(lengthY); + for (size_t i = 0; i < lengthY; i++) { + if (!shape.empty() && lengthY == shape[0]) + dimValues[i] = Dim{ static_cast(data[i])}; + else // case dataA is a scalar + dimValues[i] = Dim{ static_cast(data[0])}; + } + }; + if (model.IsConstantTensor(fNA)) { + convertDataToDim(fNA,fShapeA,dimValA); + } else if (model.IsConstantTensor(fNB)) { + convertDataToDim(fNB,fShapeB,dimValB); + } + + //perform binary operations on shape tensors + std::vector dimValY(lengthY); + for (size_t i = 0; i < lengthY; i++) { + if (!dimValA[i].isParam && !dimValB[i].isParam) { + size_t d = BinaryOperatorTrait::Func(dimValA[i].dim, dimValB[i].dim); + dimValY[i] = Dim{d}; + } else { + auto res = BinaryOperatorTrait::Op(dimValA[i].GetVal(), dimValB[i].GetVal()); + dimValY[i] = Dim{res, static_cast(-1)}; + } + } + model.AddShapeTensor(fNY,dimValY, fShapeY.empty()); // cannot be a scalar + if (model.Verbose()) { + std::cout << BinaryOperatorTrait::Name() << " : " << fNA << " " << ConvertShapeToString(fShapeA) + << " , " << fNB << " " << ConvertShapeToString(fShapeB) << " ---> " << fNY << " " + << ConvertShapeToString(fShapeY) << " : " << ConvertDimShapeToString(dimValY) << " (shape)" << std::endl; + } + // no code needs to be generated (flag this as a constant output tensor) + fIsOutputConstant = true; + } else { // case of defined and non-constant tensors model.AddIntermediateTensor(fNY, model.GetTensorType(fNA), fShapeY); diff --git a/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx b/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx index c0cbe18f11475..ae56750e5ebd3 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Gemm.hxx @@ -222,7 +222,7 @@ namespace SOFIE{ if (fIsDynamic && shapeY.empty()) broadcast_needed = true; else - // consider broadcasting also if same length + // consider broadcasting also if hey have different length broadcast_needed = (fShapeC != shapeY); @@ -285,7 +285,12 @@ namespace SOFIE{ int64_t dimA = fShapeA.size(); int64_t dimB = fShapeB.size(); int64_t dimY = fShapeY.size(); - if (dimA != dimB || dimA != dimY) { + int64_t dimC = fShapeC.size(); + if (dimA != dimB || dimA != dimY || (fBroadcastBias && dimC != dimY)) { + std::cout << " shape A " << ConvertDimShapeToString(fShapeA) + << " shape B " << ConvertDimShapeToString(fShapeB) + << " shape C " << ConvertShapeToString(fShapeC) + << " shape Y " << ConvertDimShapeToString(fShapeY) << std::endl; throw std::runtime_error("TMVA SOFIE Gemm(MatMul) has invalid shape for inputs or output"); } auto m = (fAttrTransA ? fShapeA[dimA-1].GetVal() : fShapeA[dimA-2].GetVal()); @@ -357,6 +362,9 @@ namespace SOFIE{ } // do the bias broadcasting if (fBroadcastBias) { + // also shapeC has prepended 1 to be same rank of Y + std::vector sC = {fShapeC[dimC-2], fShapeC[dimC-1]}; + fAttrBeta = 1.; out << SP << "for (size_t j = 0; j < " << sY[0] << "; j++) { \n"; out << SP << SP << "size_t y_index = "; @@ -369,11 +377,11 @@ namespace SOFIE{ out << SP << SP << "for (size_t k = 0; k < " << sY[1] << "; k++) { \n"; std::string bias_index; - if (fShapeC[0] == 1 && fShapeC[1] == sY[1].dim) + if (sC[0] == 1 && sC[1] == sY[1].dim) bias_index = "k"; - else if (fShapeC[1] == 1 && fShapeC[0] == sY[0].dim) + else if (sC[1] == 1 && sC[0] == sY[0].dim) bias_index = "j"; - else if (fShapeC[0] == 1 && fShapeC[1] == 1) // scalar case + else if (sC[0] == 1 && sC[1] == 1) // scalar case bias_index = "0"; else { throw std::runtime_error("TMVA SOFIE Gemm Op - invalid shape for bias tensor " + ConvertShapeToString(fShapeC)); diff --git a/tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx b/tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx index 570b3f7a294aa..32272ea03c2fc 100644 --- a/tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_ScatterND.hxx @@ -132,7 +132,8 @@ public: // Step 2: Emit strides as a static constexpr array out << SP << "// Step 2: data strides (row-major)\n"; - out << SP << "static constexpr int64_t " << opName << "_data_strides[" << r << "] = {"; + //to do: use static constexpr for defined strides + out << SP << "size_t " << opName << "_data_strides[" << r << "] = {"; for (size_t i = 0; i < r; ++i) out << stridesX[i] << (i + 1 < r ? ", " : ""); out << "};\n\n"; diff --git a/tmva/sofie/inc/TMVA/ROperator_Slice.hxx b/tmva/sofie/inc/TMVA/ROperator_Slice.hxx index 674f8a4776520..7b7c9563db352 100644 --- a/tmva/sofie/inc/TMVA/ROperator_Slice.hxx +++ b/tmva/sofie/inc/TMVA/ROperator_Slice.hxx @@ -337,7 +337,7 @@ public: size_t ndim = fShapeInput.size(); fIdentitySlice = fShapeOutput.size() == ndim; // check also if input data is not input to the model. In that case we copy the data since we cannot just copy from the input pointer - fIdentitySlice &= !model.IsReadyInputTensor(fNData); + fIdentitySlice &= (!model.IsReadyInputTensor(fNData) && !model.IsDimInputTensor(fNData)); for (size_t idim = 0; idim < ndim; idim++) { if (!fIdentitySlice) break; fIdentitySlice &= (fStart[idim].GetVal() == "0"); diff --git a/tmva/sofie_parsers/src/ParseConstant.cxx b/tmva/sofie_parsers/src/ParseConstant.cxx index 2b5a2ab37031a..c48b0507769ca 100644 --- a/tmva/sofie_parsers/src/ParseConstant.cxx +++ b/tmva/sofie_parsers/src/ParseConstant.cxx @@ -37,7 +37,7 @@ ParserFuncSignature ParseConstant = [](RModelParser_ONNX &parser, const onnx::No std::string output_name = nodeproto.output(0); ETensorType output_type = ETensorType::FLOAT; std::vector shape; // output shape (use in case of constant operator) - // it should be only one attribute (Constant or 1 or 0 COnstant of Shape) + // it should be only one attribute (Constant or 1 or 0 Constant of Shape) if (nodeproto.attribute_size() > 1) throw std::runtime_error("TMVA::SOFIE ONNX Parser Constant or ConstantOfShape and attribute size is larger than 1"); if (nodeproto.attribute_size() > 0) { diff --git a/tmva/sofie_parsers/src/ParseWhere.cxx b/tmva/sofie_parsers/src/ParseWhere.cxx index c072ad14cb956..6ebcf161e5012 100644 --- a/tmva/sofie_parsers/src/ParseWhere.cxx +++ b/tmva/sofie_parsers/src/ParseWhere.cxx @@ -32,10 +32,10 @@ ParserFuncSignature ParseWhere = [](RModelParser_ONNX &parser, const onnx::NodeP switch (input_type) { case ETensorType::FLOAT: - op.reset(new ROperator_Where(nodeproto.input(1), nodeproto.input(2), nodeproto.input(0), output_name)); + op.reset(new ROperator_Where(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2), output_name)); break; case ETensorType::INT64: - op.reset(new ROperator_Where(nodeproto.input(1), nodeproto.input(2), nodeproto.input(0), output_name)); + op.reset(new ROperator_Where(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2), output_name)); break; default: throw std::runtime_error("TMVA::SOFIE - Unsupported - Where Operator does not yet support input type " +