Skip to content

Commit 7deefdd

Browse files
Aidenwu0209PaddleCI
authored andcommitted
[API Compatibility No.155] Keep input/target alias for bce_with_logits-part
1 parent eb0a495 commit 7deefdd

3 files changed

Lines changed: 227 additions & 31 deletions

File tree

paconvert/api_mapping.json

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8848,7 +8848,7 @@
88488848
"min_input_args": 2
88498849
},
88508850
"torch.nn.functional.binary_cross_entropy_with_logits": {
8851-
"Matcher": "SizeAverageMatcher",
8851+
"Matcher": "BCEWithLogitsAliasMatcher",
88528852
"paddle_api": "paddle.nn.functional.binary_cross_entropy_with_logits",
88538853
"args_list": [
88548854
"input",
@@ -8859,10 +8859,6 @@
88598859
"reduction",
88608860
"pos_weight"
88618861
],
8862-
"kwargs_change": {
8863-
"input": "logit",
8864-
"target": "label"
8865-
},
88668862
"min_input_args": 2
88678863
},
88688864
"torch.nn.functional.celu": {

paconvert/api_matcher.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4218,6 +4218,107 @@ def generate_code(self, kwargs):
42184218
return GenericMatcher.generate_code(self, kwargs)
42194219

42204220

4221+
class BCEWithLogitsAliasMatcher(BaseMatcher):
4222+
def _generate_code_with_args(self, args, kwargs):
4223+
kwargs_change_value = self.api_mapping_dict.get("kwargs_change", {}).values()
4224+
kwargs = self.change_kwargs(
4225+
kwargs,
4226+
[
4227+
"layout",
4228+
"device",
4229+
"memory_format",
4230+
"inplace",
4231+
"generator",
4232+
"non_blocking",
4233+
"async",
4234+
],
4235+
)
4236+
kwargs = self.set_paddle_default_kwargs(kwargs)
4237+
4238+
dtype_v = "None"
4239+
if "dtype" in kwargs and "dtype" not in kwargs_change_value:
4240+
dtype_v = kwargs.pop("dtype")
4241+
4242+
pin_memory_v = "(False)"
4243+
if "pin_memory" in kwargs and "pin_memory" not in kwargs_change_value:
4244+
pin_memory_v = kwargs.pop("pin_memory")
4245+
4246+
stop_gradient_v = "None"
4247+
if "requires_grad" in kwargs and "requires_grad" not in kwargs_change_value:
4248+
stop_gradient_v = "not " + kwargs.pop("requires_grad").strip("()")
4249+
4250+
out_v = "None"
4251+
if "out" in kwargs and "out" not in kwargs_change_value:
4252+
out_v = kwargs.pop("out")
4253+
4254+
code = "{}({})".format(
4255+
self.get_paddle_api(), self.args_and_kwargs_to_str(args, kwargs)
4256+
)
4257+
4258+
if dtype_v != "None":
4259+
code += ".astype({})".format(dtype_v)
4260+
4261+
if pin_memory_v != "(False)":
4262+
code += ".pin_memory()"
4263+
4264+
if stop_gradient_v != "None" and out_v != "None":
4265+
API_TEMPLATE = textwrap.dedent(
4266+
"""
4267+
paddle.assign({}, output={})
4268+
{}.stop_gradient = {}
4269+
{}
4270+
"""
4271+
)
4272+
code = API_TEMPLATE.format(code, out_v, out_v, stop_gradient_v, out_v)
4273+
elif stop_gradient_v != "None" and out_v == "None":
4274+
API_TEMPLATE = textwrap.dedent(
4275+
"""
4276+
{} = {}
4277+
{}.stop_gradient = {}
4278+
{}
4279+
"""
4280+
)
4281+
out = get_unique_name("out")
4282+
code = API_TEMPLATE.format(out, code, out, stop_gradient_v, out)
4283+
elif out_v != "None" and stop_gradient_v == "None":
4284+
API_TEMPLATE = textwrap.dedent(
4285+
"""
4286+
paddle.assign({}, output={})
4287+
"""
4288+
)
4289+
code = API_TEMPLATE.format(code, out_v)
4290+
4291+
return code
4292+
4293+
def get_paddle_nodes(self, args, kwargs):
4294+
new_kwargs = self.parse_args_and_kwargs(args, kwargs)
4295+
if new_kwargs == "misidentify":
4296+
return "misidentify"
4297+
elif new_kwargs is None:
4298+
return None
4299+
4300+
process_reduce_and_size_average(new_kwargs)
4301+
4302+
new_args = self.parse_args(args)
4303+
call_args = new_args[:2]
4304+
if len(new_args) >= 3 and "weight" in new_kwargs:
4305+
call_args.append(new_kwargs.pop("weight"))
4306+
4307+
if len(call_args) >= 1:
4308+
new_kwargs.pop("input", None)
4309+
if len(call_args) >= 2:
4310+
new_kwargs.pop("target", None)
4311+
4312+
new_code = self._generate_code_with_args(call_args, new_kwargs)
4313+
if new_code == "misidentify":
4314+
return "misidentify"
4315+
elif new_code == "unchange":
4316+
return "unchange"
4317+
elif new_code:
4318+
return ast.parse(new_code).body
4319+
return None
4320+
4321+
42214322
class CudaNvtxRangePushMatcher(BaseMatcher):
42224323
def generate_code(self, kwargs):
42234324
code = "{}({})".format(self.get_paddle_api(), kwargs["msg"])

tests/test_nn_functional_binary_cross_entropy_with_logits.py

Lines changed: 125 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ def test_case_1():
2525
pytorch_code = textwrap.dedent(
2626
"""
2727
import torch
28-
a = torch.zeros(3, requires_grad=True)
29-
target = torch.tensor([0.,1.,0.])
28+
a = torch.zeros(3)
29+
a.requires_grad = True
30+
target = torch.zeros(3)
31+
target[1] = 1.0
3032
result = torch.nn.functional.binary_cross_entropy_with_logits(a, target)
3133
"""
3234
)
@@ -37,8 +39,10 @@ def test_case_2():
3739
pytorch_code = textwrap.dedent(
3840
"""
3941
import torch
40-
a = torch.zeros(3, requires_grad=True)
41-
target = torch.tensor([0.,1.,0.])
42+
a = torch.zeros(3)
43+
a.requires_grad = True
44+
target = torch.zeros(3)
45+
target[1] = 1.0
4246
result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, weight=None)
4347
"""
4448
)
@@ -49,8 +53,10 @@ def test_case_3():
4953
pytorch_code = textwrap.dedent(
5054
"""
5155
import torch
52-
a = torch.zeros(3, requires_grad=True)
53-
target = torch.tensor([0.,1.,0.])
56+
a = torch.zeros(3)
57+
a.requires_grad = True
58+
target = torch.zeros(3)
59+
target[1] = 1.0
5460
result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, reduce=True)
5561
"""
5662
)
@@ -61,8 +67,10 @@ def test_case_4():
6167
pytorch_code = textwrap.dedent(
6268
"""
6369
import torch
64-
a = torch.zeros(3, requires_grad=True)
65-
target = torch.tensor([0.,1.,0.])
70+
a = torch.zeros(3)
71+
a.requires_grad = True
72+
target = torch.zeros(3)
73+
target[1] = 1.0
6674
result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, reduce=False)
6775
"""
6876
)
@@ -73,8 +81,10 @@ def test_case_5():
7381
pytorch_code = textwrap.dedent(
7482
"""
7583
import torch
76-
a = torch.zeros(3, requires_grad=True)
77-
target = torch.tensor([0.,1.,0.])
84+
a = torch.zeros(3)
85+
a.requires_grad = True
86+
target = torch.zeros(3)
87+
target[1] = 1.0
7888
result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, reduction='none')
7989
"""
8090
)
@@ -85,8 +95,10 @@ def test_case_6():
8595
pytorch_code = textwrap.dedent(
8696
"""
8797
import torch
88-
a = torch.zeros(3, requires_grad=True)
89-
target = torch.tensor([0.,1.,0.])
98+
a = torch.zeros(3)
99+
a.requires_grad = True
100+
target = torch.zeros(3)
101+
target[1] = 1.0
90102
result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, reduction='mean')
91103
"""
92104
)
@@ -97,8 +109,10 @@ def test_case_7():
97109
pytorch_code = textwrap.dedent(
98110
"""
99111
import torch
100-
a = torch.zeros(3, requires_grad=True)
101-
target = torch.tensor([0.,1.,0.])
112+
a = torch.zeros(3)
113+
a.requires_grad = True
114+
target = torch.zeros(3)
115+
target[1] = 1.0
102116
result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, reduction='sum')
103117
"""
104118
)
@@ -109,8 +123,10 @@ def test_case_8():
109123
pytorch_code = textwrap.dedent(
110124
"""
111125
import torch
112-
a = torch.zeros(3, requires_grad=True)
113-
target = torch.tensor([0.,1.,0.])
126+
a = torch.zeros(3)
127+
a.requires_grad = True
128+
target = torch.zeros(3)
129+
target[1] = 1.0
114130
result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, weight=None, size_average=None, reduce=True, reduction='sum', pos_weight=None)
115131
"""
116132
)
@@ -122,8 +138,10 @@ def test_case_9():
122138
pytorch_code = textwrap.dedent(
123139
"""
124140
import torch
125-
a = torch.zeros(3, requires_grad=True)
126-
target = torch.tensor([0.,1.,0.])
141+
a = torch.zeros(3)
142+
a.requires_grad = True
143+
target = torch.zeros(3)
144+
target[1] = 1.0
127145
result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, None, None, True, 'sum', None)
128146
"""
129147
)
@@ -135,22 +153,103 @@ def test_case_10():
135153
pytorch_code = textwrap.dedent(
136154
"""
137155
import torch
138-
a = torch.zeros(3, requires_grad=True)
139-
target = torch.tensor([0.,1.,0.])
140-
result = torch.nn.functional.binary_cross_entropy_with_logits(input=a, target=target, weight=None, size_average=None, reduce=True, reduction='sum', pos_weight=None)
156+
a = torch.zeros(3)
157+
a.requires_grad = True
158+
target = torch.zeros(3)
159+
target[1] = 1.0
160+
result = torch.nn.functional.binary_cross_entropy_with_logits(input=a, target=target)
141161
"""
142162
)
143-
obj.run(pytorch_code, ["result"])
163+
expect_paddle_code = textwrap.dedent(
164+
"""
165+
import paddle
166+
167+
a = paddle.zeros(3)
168+
a.stop_gradient = not True
169+
target = paddle.zeros(3)
170+
target[1] = 1.0
171+
result = paddle.nn.functional.binary_cross_entropy_with_logits(input=a, target=target)
172+
"""
173+
)
174+
obj.run(pytorch_code, expect_paddle_code=expect_paddle_code)
144175

145176

146-
# generated by validate_unittest autofix, based on test_case_8
147177
def test_case_11():
148178
pytorch_code = textwrap.dedent(
149179
"""
150180
import torch
151-
a = torch.zeros(3, requires_grad=True)
152-
target = torch.tensor([0.,1.,0.])
181+
a = torch.zeros(3)
182+
a.requires_grad = True
183+
target = torch.zeros(3)
184+
target[1] = 1.0
185+
result = torch.nn.functional.binary_cross_entropy_with_logits(target=target, input=a, reduction='sum', weight=None, pos_weight=None)
186+
"""
187+
)
188+
expect_paddle_code = textwrap.dedent(
189+
"""
190+
import paddle
191+
192+
a = paddle.zeros(3)
193+
a.stop_gradient = not True
194+
target = paddle.zeros(3)
195+
target[1] = 1.0
196+
result = paddle.nn.functional.binary_cross_entropy_with_logits(
197+
target=target, input=a, reduction="sum", weight=None, pos_weight=None
198+
)
199+
"""
200+
)
201+
obj.run(pytorch_code, expect_paddle_code=expect_paddle_code)
202+
203+
204+
def test_case_12():
205+
pytorch_code = textwrap.dedent(
206+
"""
207+
import torch
208+
a = torch.zeros(3)
209+
a.requires_grad = True
210+
target = torch.zeros(3)
211+
target[1] = 1.0
212+
result = torch.nn.functional.binary_cross_entropy_with_logits(input=a, target=target, weight=None, size_average=None, reduce=True, reduction='sum', pos_weight=None)
213+
"""
214+
)
215+
expect_paddle_code = textwrap.dedent(
216+
"""
217+
import paddle
218+
219+
a = paddle.zeros(3)
220+
a.stop_gradient = not True
221+
target = paddle.zeros(3)
222+
target[1] = 1.0
223+
result = paddle.nn.functional.binary_cross_entropy_with_logits(
224+
input=a, target=target, weight=None, reduction="mean", pos_weight=None
225+
)
226+
"""
227+
)
228+
obj.run(pytorch_code, expect_paddle_code=expect_paddle_code)
229+
230+
231+
def test_case_13():
232+
pytorch_code = textwrap.dedent(
233+
"""
234+
import torch
235+
a = torch.zeros(3)
236+
a.requires_grad = True
237+
target = torch.zeros(3)
238+
target[1] = 1.0
153239
result = torch.nn.functional.binary_cross_entropy_with_logits(pos_weight=None, reduction='sum', reduce=True, size_average=None, weight=None, target=target, input=a)
154240
"""
155241
)
156-
obj.run(pytorch_code, ["result"])
242+
expect_paddle_code = textwrap.dedent(
243+
"""
244+
import paddle
245+
246+
a = paddle.zeros(3)
247+
a.stop_gradient = not True
248+
target = paddle.zeros(3)
249+
target[1] = 1.0
250+
result = paddle.nn.functional.binary_cross_entropy_with_logits(
251+
pos_weight=None, reduction="mean", weight=None, target=target, input=a
252+
)
253+
"""
254+
)
255+
obj.run(pytorch_code, expect_paddle_code=expect_paddle_code)

0 commit comments

Comments
 (0)