Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -3335,15 +3335,7 @@
"Matcher": "ChangePrefixMatcher"
},
"torch.aminmax": {
"Matcher": "AMinMaxMatcher",
"min_input_args": 1,
"args_list": [
"input",
"*",
"dim",
"keepdim",
"out"
]
"Matcher": "ChangePrefixMatcher"
},
"torch.amp.autocast": {
"Matcher": "ChangePrefixMatcher"
Expand Down
109 changes: 109 additions & 0 deletions tests/test_aminmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,112 @@ def test_case_6():
"""
)
obj.run(pytorch_code, ["out"])


def test_case_7():
"""1D tensor input"""
pytorch_code = textwrap.dedent(
"""
import torch
t = torch.tensor([5.0, 1.0, 3.0, 9.0, 2.0])
result = torch.aminmax(t)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_8():
"""3D tensor with dim=0"""
pytorch_code = textwrap.dedent(
"""
import torch
t = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])
result = torch.aminmax(t, dim=0)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_9():
"""3D tensor with dim=1 and keepdim=True"""
pytorch_code = textwrap.dedent(
"""
import torch
t = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])
result = torch.aminmax(t, dim=1, keepdim=True)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_10():
"""float64 dtype"""
pytorch_code = textwrap.dedent(
"""
import torch
t = torch.tensor([[1.5, 2.3, 3.7], [4.1, 5.9, 6.2]], dtype=torch.float64)
result = torch.aminmax(t, dim=1)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_11():
"""explicit keepdim=False"""
pytorch_code = textwrap.dedent(
"""
import torch
t = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)
result = torch.aminmax(t, dim=0, keepdim=False)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_12():
"""input as keyword argument only"""
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.aminmax(input=torch.tensor([3.0, 1.0, 4.0, 1.0, 5.0]))
"""
)
obj.run(pytorch_code, ["result"])


def test_case_13():
"""kwargs unpacking"""
pytorch_code = textwrap.dedent(
"""
import torch
t = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
kwargs = {"dim": 0, "keepdim": True}
result = torch.aminmax(t, **kwargs)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_14():
"""expression as dim argument"""
pytorch_code = textwrap.dedent(
"""
import torch
t = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
result = torch.aminmax(t, dim=2 - 1)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_15():
"""dim with out parameter and keepdim"""
pytorch_code = textwrap.dedent(
"""
import torch
t = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
out = tuple([torch.tensor([0.0, 0.0]), torch.tensor([0.0, 0.0])])
result = torch.aminmax(t, dim=1, keepdim=False, out=out)
"""
)
obj.run(pytorch_code, ["out", "result"])