Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 2 additions & 26 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -2271,18 +2271,7 @@
]
},
"torch.Tensor.select_scatter": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.select_scatter",
"min_input_args": 3,
"args_list": [
"src",
"dim",
"index"
],
"kwargs_change": {
"src": "values",
"dim": "axis"
}
"Matcher": "ChangePrefixMatcher"
},
"torch.Tensor.set_": {
"Matcher": "TensorSetMatcher",
Expand Down Expand Up @@ -11127,20 +11116,7 @@
]
},
"torch.select_scatter": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.select_scatter",
"min_input_args": 4,
"args_list": [
"input",
"src",
"dim",
"index"
],
"kwargs_change": {
"input": "x",
"src": "values",
"dim": "axis"
}
"Matcher": "ChangePrefixMatcher"
},
"torch.selu": {
"Matcher": "GenericMatcher",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_Tensor_select_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_case_3():
import torch
input = torch.zeros(2, 2)
src = torch.ones(2)
result = input.select_scatter(src, dim=1, index=1)
result = input.select_scatter(src=src, dim=1, index=1)
"""
)
obj.run(pytorch_code, ["result"])
Expand All @@ -61,7 +61,7 @@ def test_case_4():
import torch
input = torch.zeros(2, 2)
src = torch.ones(2)
result = input.select_scatter(src=src, dim=1, index=1)
result = input.select_scatter(src, 1, 1)
"""
)
obj.run(pytorch_code, ["result"])
Expand Down
6 changes: 3 additions & 3 deletions tests/test_select_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_case_3():
import torch
x = torch.zeros((2,3,4)).type(torch.float32)
values = torch.ones((2,4)).type(torch.float32)
result = torch.select_scatter(input=x, dim=1, src=values, index=1)
result = torch.select_scatter(input=x, src=values, dim=1, index=1)
"""
)
obj.run(pytorch_code, ["result"])
Expand All @@ -61,7 +61,7 @@ def test_case_4():
import torch
x = torch.zeros((2,3,4)).type(torch.float32)
values = torch.ones((2,4)).type(torch.float32)
result = torch.select_scatter(input=x, src=values, dim=1, index=1)
result = torch.select_scatter(x, values, 1, 1)
"""
)
obj.run(pytorch_code, ["result"])
Expand All @@ -73,7 +73,7 @@ def test_case_5():
import torch
x = torch.zeros((2,3,4)).type(torch.float32)
values = torch.ones((2,4)).type(torch.float32)
result = torch.select_scatter(x, values, 1, 1)
result = torch.select_scatter(index=1, dim=1, src=values, input=x)
"""
)
obj.run(pytorch_code, ["result"])
153 changes: 149 additions & 4 deletions tools/validate_unittest/validate_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#

import argparse
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改这个的目的是?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当时主要是为了解决 ChangePrefixMatcher 后 validate_unittest 对 select_scatter 的参数校验误判问题

import importlib
import inspect
import json
import os
import re
Expand Down Expand Up @@ -145,6 +147,149 @@
}


# ChangePrefixMatcher removes explicit args_list metadata, so builtin APIs that
# are not introspectable still need a validator-side fallback.
change_prefix_validator_metadata_fallback = {
"torch.select_scatter": {
"args_list": ["input", "src", "dim", "index"],
"min_input_args": 4,
},
"torch.Tensor.select_scatter": {
"args_list": ["src", "dim", "index"],
"min_input_args": 3,
},
}


def isclassname(name, module_parts):
if name and name[0].isupper():
return True
elif (
name == "profile" and len(module_parts) >= 1 and module_parts[-1] == "profiler"
):
return True
return False


def resolve_torch_callable(function_string):
try:
parts = function_string.split(".")
function_name = parts.pop()
classname = None
if len(parts) >= 1 and isclassname(parts[-1], parts[:-1]):
classname = parts.pop()

if len(parts) > 0:
module_name = ".".join(parts)
module = importlib.import_module(module_name)
else:
module = globals()

if classname is not None:
module = getattr(module, classname)

if not hasattr(module, function_name):
return None

return getattr(module, function_name)
except Exception:
return None


def build_args_list_from_signature(api_target, signature):
args_list = []
need_positional_separator = False
need_keyword_separator = True

for param in signature.parameters.values():
if param.name == "self" and api_target.startswith("torch.Tensor."):
continue

if param.kind == inspect.Parameter.POSITIONAL_ONLY:
args_list.append(param.name)
need_positional_separator = True
elif param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD:
if need_positional_separator:
args_list.append("/")
need_positional_separator = False
args_list.append(param.name)
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
if need_positional_separator:
args_list.append("/")
need_positional_separator = False
args_list.append("*" + param.name)
need_keyword_separator = False
elif param.kind == inspect.Parameter.KEYWORD_ONLY:
if need_positional_separator:
args_list.append("/")
need_positional_separator = False
if need_keyword_separator:
args_list.append("*")
need_keyword_separator = False
args_list.append(param.name)
elif param.kind == inspect.Parameter.VAR_KEYWORD:
if need_positional_separator:
args_list.append("/")
need_positional_separator = False
args_list.append("**" + param.name)

if need_positional_separator:
args_list.append("/")

return args_list


def infer_min_input_args(api_target, signature):
min_input_args = 0
for param in signature.parameters.values():
if param.name == "self" and api_target.startswith("torch.Tensor."):
continue
if (
param.kind
in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
and param.default == inspect.Parameter.empty
):
min_input_args += 1
return min_input_args


def resolve_validator_metadata(api_target, mapping_data):
min_input_args = mapping_data.get("min_input_args")
args_list_full = mapping_data.get("args_list")

if (
mapping_data.get("Matcher") != "ChangePrefixMatcher"
or "args_list" in mapping_data
):
if min_input_args is None:
min_input_args = -1
if args_list_full is None:
args_list_full = []
return min_input_args, args_list_full

pytorch_callable = resolve_torch_callable(api_target)
if pytorch_callable is not None:
try:
signature = inspect.signature(pytorch_callable)
if args_list_full is None:
args_list_full = build_args_list_from_signature(api_target, signature)
if min_input_args is None:
min_input_args = infer_min_input_args(api_target, signature)
except (TypeError, ValueError):
pass

fallback_metadata = change_prefix_validator_metadata_fallback.get(api_target, {})
if args_list_full is None:
args_list_full = fallback_metadata.get("args_list", [])
if min_input_args is None:
min_input_args = fallback_metadata.get("min_input_args", -1)

return min_input_args, args_list_full


def get_test_cases(discovery_paths=["tests"]):
# Collect the test cases
monkeypatch = pytest.MonkeyPatch()
Expand Down Expand Up @@ -364,11 +509,11 @@ def check_call_variety(test_data, api_mapping, *, api_alias={}, verbose=True):

position_args_checkable = True

min_input_args = mapping_data.get("min_input_args", -1)
min_input_args, args_list_full = resolve_validator_metadata(
api_target, mapping_data
)
aux_detailed_data_api["min_input_args"] = min_input_args

args_list_full = mapping_data.get("args_list", [])

var_arg_name = None
var_kwarg_name = None

Expand Down Expand Up @@ -444,7 +589,7 @@ def check_call_variety(test_data, api_mapping, *, api_alias={}, verbose=True):
all_args = False
all_kwargs = False
not_subsequence = False
all_default = False if "min_input_args" in mapping_data else None
all_default = False if min_input_args >= 0 else None

for case_name, code in unittest_data["code"].items():
try:
Expand Down