-
Notifications
You must be signed in to change notification settings - Fork 128
Support dynamic input values for range op #4591
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b85386b
78895ff
73a5b2e
8b7f747
c6b1daa
7f933bc
2600d73
ed25b9b
0af88e2
aa24135
35f84f7
9d6c8b0
43cb32e
56212be
b38fe2b
bf61ce5
90a3c84
04574e2
ec8a66c
cd8ad21
fba8f69
49c2b02
a56cfa2
57c5e6a
b5d2665
5eef3d7
d8c9e2d
83ecc93
dc445f1
2c55e8e
fb69c63
4d24399
676f408
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -248,6 +248,7 @@ register_migraphx_ops( | |
| nearbyint | ||
| neg | ||
| nonmaxsuppression | ||
| dynamic_range | ||
| nonzero | ||
| onehot | ||
| outline | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| /* | ||
| * The MIT License (MIT) | ||
| * | ||
| * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. | ||
| * | ||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| * of this software and associated documentation files (the "Software"), to deal | ||
| * in the Software without restriction, including without limitation the rights | ||
| * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| * copies of the Software, and to permit persons to whom the Software is | ||
| * furnished to do so, subject to the following conditions: | ||
| * | ||
| * The above copyright notice and this permission notice shall be included in | ||
| * all copies or substantial portions of the Software. | ||
| * | ||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
| * THE SOFTWARE. | ||
| */ | ||
| #ifndef MIGRAPHX_GUARD_OPERATORS_DYNAMIC_RANGE_HPP | ||
| #define MIGRAPHX_GUARD_OPERATORS_DYNAMIC_RANGE_HPP | ||
|
|
||
| #include <migraphx/check_shapes.hpp> | ||
| #include <migraphx/argument.hpp> | ||
| #include <migraphx/config.hpp> | ||
| #include <migraphx/value.hpp> | ||
| #include <migraphx/op/name.hpp> | ||
| #include <migraphx/dyn_output.hpp> | ||
| #include <cmath> | ||
| #include <limits> | ||
|
|
||
| namespace migraphx { | ||
| inline namespace MIGRAPHX_INLINE_NS { | ||
| namespace op { | ||
|
|
||
| struct dynamic_range : op_name<dynamic_range> | ||
| { | ||
| // Intentionally capped at int::max to match the maximum tensor size in MIGraphX | ||
| std::size_t max_output = std::numeric_limits<int>::max(); | ||
|
|
||
| template <class Self, class F> | ||
| static auto reflect(Self& self, F f) | ||
| { | ||
| return pack(f(self.max_output, "max_output")); | ||
| } | ||
|
|
||
| shape compute_shape(std::vector<shape> inputs) const | ||
| { | ||
| check_shapes{inputs, *this}.has(3).same_type(); | ||
| const auto& type = inputs.at(0).type(); | ||
| // The output shape is 1D with unknown size if we don't evaluate. | ||
| return shape{type, {shape::dynamic_dimension{0, max_output}}}; | ||
| } | ||
|
CharlieL7 marked this conversation as resolved.
|
||
| argument compute(const dyn_output&, std::vector<argument> args) const | ||
| { | ||
| size_t num_elements = 0; | ||
| visit_all(args[0], args[1], args[2])([&](auto start, auto limit, auto delta) { | ||
| auto start_val = start.front(); | ||
| auto limit_val = limit.front(); | ||
| auto delta_val = delta.front(); | ||
|
|
||
|
pfultz2 marked this conversation as resolved.
|
||
| if(not(delta_val > 0 or delta_val < 0)) | ||
| MIGRAPHX_THROW("dynamic_range: delta must be non-zero"); | ||
|
|
||
| // number_of_elements = max( ceil( (limit - start) / delta ), 0 ) | ||
| double start_d = start_val; | ||
| double limit_d = limit_val; | ||
| double delta_d = delta_val; | ||
| double num_elements_d = std::ceil((limit_d - start_d) / delta_d); | ||
| if(not std::isfinite(num_elements_d)) | ||
| MIGRAPHX_THROW("dynamic_range: computed element count is not finite"); | ||
|
|
||
| num_elements = static_cast<size_t>(std::max(0.0, num_elements_d)); | ||
| }); | ||
|
|
||
|
kazhang2 marked this conversation as resolved.
|
||
| num_elements = std::min(num_elements, max_output); | ||
| argument result{shape{args[0].get_shape().type(), {num_elements}}}; | ||
|
|
||
| visit_all(args[0], args[2], result)([&](auto start, auto delta, auto output) { | ||
| auto start_val = start.front(); | ||
| auto delta_val = delta.front(); | ||
|
|
||
| for(size_t i = 0; i < num_elements; ++i) | ||
| { | ||
| output[i] = start_val + (static_cast<decltype(start_val)>(i) * delta_val); | ||
| } | ||
| }); | ||
|
|
||
| return result; | ||
| } | ||
| }; | ||
|
|
||
| } // namespace op | ||
| } // namespace MIGRAPHX_INLINE_NS | ||
| } // namespace migraphx | ||
|
|
||
| #endif | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,7 @@ | ||
| /* | ||
| * The MIT License (MIT) | ||
| * | ||
| * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. | ||
| * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. | ||
| * | ||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| * of this software and associated documentation files (the "Software"), to deal | ||
|
|
@@ -41,11 +41,13 @@ struct parse_range : op_parser<parse_range> | |
| std::vector<instruction_ref> args) const | ||
| { | ||
| auto start_arg = args[0]->eval(); | ||
| check_arg_empty(start_arg, "PARSE_RANGE: start arg dynamic shape is not supported"); | ||
| auto limit_arg = args[1]->eval(); | ||
| check_arg_empty(limit_arg, "PARSE_RANGE: limit arg dynamic shape is not supported"); | ||
| auto delta_arg = args[2]->eval(); | ||
| check_arg_empty(delta_arg, "PARSE_RANGE: delta arg dynamic shape is not supported"); | ||
|
|
||
| if(start_arg.empty() or limit_arg.empty() or delta_arg.empty()) | ||
| { | ||
| return info.add_instruction(make_op("dynamic_range"), args); | ||
| } | ||
|
Comment on lines
43
to
+50
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is having no elements a valid Range operator?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it will return empty array Range(5, 5, 1) → max(ceil(0/1), 0) = 0 elements.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean, is it even valid ONNX? When would an empty tensor output be handled or expected?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It looks like zero elements is what the ONNX spec is expecting. Weird to me since I don't know how you would use the empty tensor output but it's what the spec says. |
||
|
|
||
| assert(args[0]->get_shape().elements() == 1 and args[1]->get_shape().elements() == 1 and | ||
| args[2]->get_shape().elements() == 1); | ||
|
|
@@ -57,10 +59,14 @@ struct parse_range : op_parser<parse_range> | |
| auto limit_val = limit.front(); | ||
| auto delta_val = delta.front(); | ||
|
|
||
| size_t num_elements = | ||
| ceil(static_cast<double>(limit_val - start_val) / static_cast<double>(delta_val)); | ||
| if(not(delta_val > 0 or delta_val < 0)) | ||
| MIGRAPHX_THROW("PARSE_RANGE: delta must be non-zero"); | ||
|
|
||
| assert(num_elements > 0); | ||
| double start_d = start_val; | ||
| double limit_d = limit_val; | ||
| double delta_d = delta_val; | ||
| double num_elements_d = ceil((limit_d - start_d) / delta_d); | ||
| size_t num_elements = std::max(0.0, num_elements_d); | ||
|
|
||
| using type = decltype(start_val); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| /* | ||
| * The MIT License (MIT) | ||
| * | ||
| * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. | ||
| * | ||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| * of this software and associated documentation files (the "Software"), to deal | ||
| * in the Software without restriction, including without limitation the rights | ||
| * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| * copies of the Software, and to permit persons to whom the Software is | ||
| * furnished to do so, subject to the following conditions: | ||
| * | ||
| * The above copyright notice and this permission notice shall be included in | ||
| * all copies or substantial portions of the Software. | ||
| * | ||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
| * THE SOFTWARE. | ||
| */ | ||
|
|
||
| #include <onnx_test.hpp> | ||
|
|
||
| TEST_CASE(range_inputs_test) | ||
| { | ||
| migraphx::program p; | ||
| auto* mm = p.get_main_module(); | ||
| auto start = mm->add_parameter("start", migraphx::shape{migraphx::shape::float_type, {1}, {0}}); | ||
| auto limit = mm->add_parameter("limit", migraphx::shape{migraphx::shape::float_type, {1}, {0}}); | ||
| auto delta = mm->add_parameter("delta", migraphx::shape{migraphx::shape::float_type, {1}, {0}}); | ||
|
|
||
| mm->add_instruction(migraphx::make_op("dynamic_range"), start, limit, delta); | ||
|
|
||
| auto prog = optimize_onnx("range_inputs_test.onnx"); | ||
|
|
||
| EXPECT(p == prog); | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,177 @@ | ||
| /* | ||
| * The MIT License (MIT) | ||
| * | ||
| * Copyright (c) 2015-2026 Advanced Micro Devices, Inc. All rights reserved. | ||
| * | ||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| * of this software and associated documentation files (the "Software"), to deal | ||
| * in the Software without restriction, including without limitation the rights | ||
| * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| * copies of the Software, and to permit persons to whom the Software is | ||
| * furnished to do so, subject to the following conditions: | ||
| * | ||
| * The above copyright notice and this permission notice shall be included in | ||
| * all copies or substantial portions of the Software. | ||
| * | ||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
| * THE SOFTWARE. | ||
| */ | ||
|
|
||
| #include <migraphx/instruction.hpp> | ||
| #include <migraphx/literal.hpp> | ||
| #include <migraphx/make_op.hpp> | ||
| #include <migraphx/register_target.hpp> | ||
| #include <migraphx/verify.hpp> | ||
| #include <migraphx/module.hpp> | ||
| #include <migraphx/program.hpp> | ||
|
|
||
| #include <test.hpp> | ||
|
|
||
| TEST_CASE(dynamic_range_float_inc) | ||
| { | ||
| // Start=0, Limit=5, Delta=1 -> [0, 1, 2, 3, 4] | ||
| migraphx::program p; | ||
| auto* mm = p.get_main_module(); | ||
| migraphx::shape s{migraphx::shape::float_type, {1}, {0}}; | ||
| auto start = mm->add_parameter("start", s); | ||
| auto limit = mm->add_parameter("limit", s); | ||
| auto delta = mm->add_parameter("delta", s); | ||
| mm->add_instruction(migraphx::make_op("dynamic_range"), start, limit, delta); | ||
| p.compile(migraphx::make_target("ref")); | ||
|
|
||
| std::vector<float> start_val = {0.0f}; | ||
| std::vector<float> limit_val = {5.0f}; | ||
| std::vector<float> delta_val = {1.0f}; | ||
|
|
||
| migraphx::parameter_map pp; | ||
| pp["start"] = migraphx::argument(s, start_val.data()); | ||
| pp["limit"] = migraphx::argument(s, limit_val.data()); | ||
| pp["delta"] = migraphx::argument(s, delta_val.data()); | ||
|
|
||
| auto result = p.eval(pp).back(); | ||
| std::vector<float> result_vector; | ||
| result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); | ||
|
|
||
| std::vector<float> gold = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; | ||
| EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); | ||
| } | ||
|
kazhang2 marked this conversation as resolved.
|
||
|
|
||
| TEST_CASE(dynamic_range_float_step) | ||
| { | ||
| // Start=0, Limit=5, Delta=2 -> [0, 2, 4] | ||
| migraphx::program p; | ||
| auto* mm = p.get_main_module(); | ||
| migraphx::shape s{migraphx::shape::float_type, {1}, {0}}; | ||
| auto start = mm->add_parameter("start", s); | ||
| auto limit = mm->add_parameter("limit", s); | ||
| auto delta = mm->add_parameter("delta", s); | ||
| mm->add_instruction(migraphx::make_op("dynamic_range"), start, limit, delta); | ||
| p.compile(migraphx::make_target("ref")); | ||
|
|
||
| std::vector<float> start_val = {0.0f}; | ||
| std::vector<float> limit_val = {5.0f}; | ||
| std::vector<float> delta_val = {2.0f}; | ||
|
|
||
| migraphx::parameter_map pp; | ||
| pp["start"] = migraphx::argument(s, start_val.data()); | ||
| pp["limit"] = migraphx::argument(s, limit_val.data()); | ||
| pp["delta"] = migraphx::argument(s, delta_val.data()); | ||
|
|
||
| auto result = p.eval(pp).back(); | ||
| std::vector<float> result_vector; | ||
| result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); | ||
|
|
||
| std::vector<float> gold = {0.0f, 2.0f, 4.0f}; | ||
| EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); | ||
| } | ||
|
|
||
| TEST_CASE(dynamic_range_float_dec) | ||
| { | ||
| // Start=5, Limit=0, Delta=-1 -> [5, 4, 3, 2, 1] | ||
| migraphx::program p; | ||
| auto* mm = p.get_main_module(); | ||
| migraphx::shape s{migraphx::shape::float_type, {1}, {0}}; | ||
| auto start = mm->add_parameter("start", s); | ||
| auto limit = mm->add_parameter("limit", s); | ||
| auto delta = mm->add_parameter("delta", s); | ||
| mm->add_instruction(migraphx::make_op("dynamic_range"), start, limit, delta); | ||
| p.compile(migraphx::make_target("ref")); | ||
|
|
||
| std::vector<float> start_val = {5.0f}; | ||
| std::vector<float> limit_val = {0.0f}; | ||
| std::vector<float> delta_val = {-1.0f}; | ||
|
|
||
| migraphx::parameter_map pp; | ||
| pp["start"] = migraphx::argument(s, start_val.data()); | ||
| pp["limit"] = migraphx::argument(s, limit_val.data()); | ||
| pp["delta"] = migraphx::argument(s, delta_val.data()); | ||
|
|
||
| auto result = p.eval(pp).back(); | ||
| std::vector<float> result_vector; | ||
| result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); | ||
|
|
||
| std::vector<float> gold = {5.0f, 4.0f, 3.0f, 2.0f, 1.0f}; | ||
| EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); | ||
| } | ||
|
|
||
| TEST_CASE(dynamic_range_int) | ||
| { | ||
| // Start=1, Limit=10, Delta=3 -> [1, 4, 7] | ||
| migraphx::program p; | ||
| auto* mm = p.get_main_module(); | ||
| migraphx::shape s{migraphx::shape::int32_type, {1}, {0}}; | ||
| auto start = mm->add_parameter("start", s); | ||
| auto limit = mm->add_parameter("limit", s); | ||
| auto delta = mm->add_parameter("delta", s); | ||
| mm->add_instruction(migraphx::make_op("dynamic_range"), start, limit, delta); | ||
| p.compile(migraphx::make_target("ref")); | ||
|
|
||
| std::vector<int32_t> start_val = {1}; | ||
| std::vector<int32_t> limit_val = {10}; | ||
| std::vector<int32_t> delta_val = {3}; | ||
|
|
||
| migraphx::parameter_map pp; | ||
| pp["start"] = migraphx::argument(s, start_val.data()); | ||
| pp["limit"] = migraphx::argument(s, limit_val.data()); | ||
| pp["delta"] = migraphx::argument(s, delta_val.data()); | ||
|
|
||
| auto result = p.eval(pp).back(); | ||
| std::vector<int32_t> result_vector; | ||
| result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); | ||
|
|
||
| std::vector<int32_t> gold = {1, 4, 7}; | ||
| EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); | ||
| } | ||
|
|
||
| TEST_CASE(dynamic_range_float_start_equals_limit) | ||
| { | ||
| // Start=5, Limit=5, Delta=1 -> [] (empty output) | ||
| migraphx::program p; | ||
| auto* mm = p.get_main_module(); | ||
| migraphx::shape s{migraphx::shape::float_type, {1}, {0}}; | ||
| auto start = mm->add_parameter("start", s); | ||
| auto limit = mm->add_parameter("limit", s); | ||
| auto delta = mm->add_parameter("delta", s); | ||
| mm->add_instruction(migraphx::make_op("dynamic_range"), start, limit, delta); | ||
| p.compile(migraphx::make_target("ref")); | ||
|
|
||
| std::vector<float> start_val = {5.0f}; | ||
| std::vector<float> limit_val = {5.0f}; | ||
| std::vector<float> delta_val = {1.0f}; | ||
|
|
||
| migraphx::parameter_map pp; | ||
| pp["start"] = migraphx::argument(s, start_val.data()); | ||
| pp["limit"] = migraphx::argument(s, limit_val.data()); | ||
| pp["delta"] = migraphx::argument(s, delta_val.data()); | ||
|
|
||
| auto result = p.eval(pp).back(); | ||
| std::vector<float> result_vector; | ||
| result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); | ||
|
|
||
| EXPECT(result_vector.empty()); | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.