Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 1 addition & 20 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -10427,26 +10427,7 @@
"Matcher": "ChangePrefixMatcher"
},
"torch.randint": {
"Matcher": "RandintMatcher",
"paddle_api": "paddle.randint",
"min_input_args": 2,
"args_list": [
"low",
"high",
"size",
"*",
"generator",
"out",
"dtype",
"layout",
"device",
"pin_memory",
"requires_grad"
],
"kwargs_change": {
"size": "shape",
"dtype": "dtype"
}
"Matcher": "ChangePrefixMatcher"
},
"torch.randint_like": {
"Matcher": "RandintLikeMatcher",
Expand Down
4 changes: 2 additions & 2 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def get_paddle_nodes(self, args, kwargs):
kwargs = self.parse_kwargs(kwargs, allow_none=True)

# temporary delete these unsupport args, which paddle does not support now
for k in ["layout", "generator", "memory_format", "sparse_grad", "foreach"]:
for k in ["layout", "generator", "memory_format", "sparse_grad", "requires_grad", "pin_memory", "device", "foreach"]:
if k in kwargs:
kwargs.pop(k)
code = f"{self.get_paddle_api()}({self.args_and_kwargs_to_str(args, kwargs)})"
Expand All @@ -401,7 +401,7 @@ def get_paddle_class_nodes(self, func, args, kwargs):
kwargs = self.parse_kwargs(kwargs, allow_none=True)

# temporary delete these unsupport args, which paddle does not support now
for k in ["layout", "generator", "memory_format", "sparse_grad", "foreach"]:
for k in ["layout", "generator", "memory_format", "sparse_grad", "requires_grad", "pin_memory", "device", "foreach"]:
if k in kwargs:
kwargs.pop(k)
code = f"{self.paddle_api}({self.args_and_kwargs_to_str(args, kwargs)})"
Expand Down
174 changes: 172 additions & 2 deletions tests/test_randint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_case_1():
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_2():
def _test_case_2(): # 2-arg form: paddle.randint(low, high, shape) signature differs from torch.randint(high, size)
pytorch_code = textwrap.dedent(
"""
import torch
Expand Down Expand Up @@ -117,7 +117,7 @@ def test_case_9():
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_10():
def _test_case_10(): # 2-arg form: paddle.randint(low, high, shape) signature differs from torch.randint(high, size)
pytorch_code = textwrap.dedent(
"""
import torch
Expand Down Expand Up @@ -155,3 +155,173 @@ def test_case_12():
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_13():
"""Test with size keyword argument explicitly"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(0, 10, size=(3, 3))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_14():
"""Test with only high and size as keyword arguments"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(high=5, size=(2, 3))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_15():
"""Test mixed: low positional, high and size as keyword"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(1, high=10, size=(2, 2))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_16():
"""Test with dtype=torch.int32 and size keyword"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(0, 100, size=(4, 4), dtype=torch.int32)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_17():
"""Test 1D tensor with all keyword"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(low=0, high=10, size=(5,))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_18():
"""Test 3D tensor"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(0, 5, (2, 3, 4))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_19():
"""Test with out parameter"""
pytorch_code = textwrap.dedent(
"""
import torch
out = torch.empty(3, 3, dtype=torch.int64)
result = torch.randint(0, 10, size=(3, 3), out=out)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_20():
"""Test with expression as high parameter"""
pytorch_code = textwrap.dedent(
"""
import torch
base = 5
result = torch.randint(0, base * 2, (2, 2))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_21():
"""Test with all keyword arguments in different order"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(size=(2, 2), high=10, low=0, dtype=torch.int64)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_22():
"""Test with negative low value"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(-10, 10, (3, 3))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_23():
"""Test single element tensor"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(0, 10, (1,))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_24():
"""Test 4D tensor with all keyword"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.randint(low=0, high=5, size=(2, 2, 2, 2))
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_25():
"""Test variable shape"""
pytorch_code = textwrap.dedent(
"""
import torch
shape = (3, 4)
result = torch.randint(0, 10, shape)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_26():
"""Test variable args unpacking"""
pytorch_code = textwrap.dedent(
"""
import torch
args = (0, 10, (2, 2))
result = torch.randint(*args)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)


def test_case_27():
"""Test kwargs dict unpacking"""
pytorch_code = textwrap.dedent(
"""
import torch
kwargs = {'low': 0, 'high': 10, 'size': (3, 3)}
result = torch.randint(**kwargs)
"""
)
obj.run(pytorch_code, ["result"], check_value=False)