[tmva][sofie] Add new ScatterND operator#21621
[tmva][sofie] Add new ScatterND operator#21621lmoneta wants to merge 4 commits intoroot-project:masterfrom
Conversation
Add an implementation of ScatterND operator which is needed to parse the MLPF model from CMS Include also 3 tests to probe the different type of scattering wich can be performed
guitargeek
left a comment
There was a problem hiding this comment.
Great! I have also tested it with my AD refactoring PR, and it works well.
I have a few small suggestions though, and also one more general question: why not format new source file with clang-format? That would make the code more consistent in the long run, and also not confuse new contributors with the CI checks that are red when the code is not formatted.
| if (!parser.IsRegisteredTensorType(nodeproto.input(0))){ | ||
| throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has input tensor " + nodeproto.input(0) | ||
| + " but its type is not yet registered"); | ||
| } | ||
| if (!parser.IsRegisteredTensorType(nodeproto.input(1))){ | ||
| throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has input tensor " + nodeproto.input(1) | ||
| + " but its type is not yet registered"); | ||
| } | ||
| if (!parser.IsRegisteredTensorType(nodeproto.input(2))){ | ||
| throw std::runtime_error("TMVA::SOFIE ONNX Parser ScatterND op has input tensor " + nodeproto.input(2) | ||
| + " but its type is not yet registered"); | ||
| } |
There was a problem hiding this comment.
I guess that can also be done in a loop from zero to nodeproto.input_size().
| std::unique_ptr<ROperator> op; | ||
| std::string output_name = nodeproto.output(0); | ||
|
|
||
| op.reset(new ROperator_ScatterND(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2), | ||
| output_name, reduction)); |
There was a problem hiding this comment.
| std::unique_ptr<ROperator> op; | |
| std::string output_name = nodeproto.output(0); | |
| op.reset(new ROperator_ScatterND(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2), | |
| output_name, reduction)); | |
| auto op = std::make_unique<ROperator_ScatterND>(nodeproto.input(0), nodeproto.input(1), nodeproto.input(2), | |
| nodeproto.output(0), reduction)); |
We can use C++17 now, and it's better to have no naked new or delete, because these can be a red flag.
| namespace TMVA { | ||
| namespace Experimental { | ||
| namespace SOFIE { |
There was a problem hiding this comment.
| namespace TMVA { | |
| namespace Experimental { | |
| namespace SOFIE { | |
| namespace TMVA::Experimental::SOFIE { |
Personal preference maybe, but I think this is more readable.
| namespace TMVA{ | ||
| namespace Experimental{ | ||
| namespace SOFIE{ | ||
|
|
There was a problem hiding this comment.
| namespace TMVA{ | |
| namespace Experimental{ | |
| namespace SOFIE{ | |
| namespace TMVA::Experimental::SOFIE { | |
Same here maybe.
Test Results 22 files 22 suites 3d 5h 8m 41s ⏱️ For more details on these failures, see this check. Results for commit 9260cd5. ♻️ This comment has been updated with latest results. |
- Fix in operator Reduce to return a scalar and not a tensor of shape [1] - Fix handling of output boolean type in Cast. Do not convert type in a string, because a boolean is converted to a uint8_t which can be a native uint8_t or a bool. Avoid then calling function ConvertStrigToType if possible - Fix fusion of operators. Perform fusion not at first op encountered but at the last onem in order to parse before all operators which can provide an input to last fused one. This was the case in MLPF where there was a MatMul + Constant + Add, where COnstant is an input to Add. - remove check in Generate on empty shapes because scalars tensors have empty shapes
Fix also a bug when doing Gemm and applying the bias in case of stacked matrix multiplications. The bias was not correctly broadcasted in this case
58a9ff4 to
9260cd5
Compare
Add implementation of ScatterND operator which is needed to parse MLPF model from CMS
Include also 3 tests to probe the different type of scattering wich can be performed by the operator