@@ -25,8 +25,10 @@ def test_case_1():
2525 pytorch_code = textwrap .dedent (
2626 """
2727 import torch
28- a = torch.zeros(3, requires_grad=True)
29- target = torch.tensor([0.,1.,0.])
28+ a = torch.zeros(3)
29+ a.requires_grad = True
30+ target = torch.zeros(3)
31+ target[1] = 1.0
3032 result = torch.nn.functional.binary_cross_entropy_with_logits(a, target)
3133 """
3234 )
@@ -37,8 +39,10 @@ def test_case_2():
3739 pytorch_code = textwrap .dedent (
3840 """
3941 import torch
40- a = torch.zeros(3, requires_grad=True)
41- target = torch.tensor([0.,1.,0.])
42+ a = torch.zeros(3)
43+ a.requires_grad = True
44+ target = torch.zeros(3)
45+ target[1] = 1.0
4246 result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, weight=None)
4347 """
4448 )
@@ -49,8 +53,10 @@ def test_case_3():
4953 pytorch_code = textwrap .dedent (
5054 """
5155 import torch
52- a = torch.zeros(3, requires_grad=True)
53- target = torch.tensor([0.,1.,0.])
56+ a = torch.zeros(3)
57+ a.requires_grad = True
58+ target = torch.zeros(3)
59+ target[1] = 1.0
5460 result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, reduce=True)
5561 """
5662 )
@@ -61,8 +67,10 @@ def test_case_4():
6167 pytorch_code = textwrap .dedent (
6268 """
6369 import torch
64- a = torch.zeros(3, requires_grad=True)
65- target = torch.tensor([0.,1.,0.])
70+ a = torch.zeros(3)
71+ a.requires_grad = True
72+ target = torch.zeros(3)
73+ target[1] = 1.0
6674 result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, reduce=False)
6775 """
6876 )
@@ -73,8 +81,10 @@ def test_case_5():
7381 pytorch_code = textwrap .dedent (
7482 """
7583 import torch
76- a = torch.zeros(3, requires_grad=True)
77- target = torch.tensor([0.,1.,0.])
84+ a = torch.zeros(3)
85+ a.requires_grad = True
86+ target = torch.zeros(3)
87+ target[1] = 1.0
7888 result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, reduction='none')
7989 """
8090 )
@@ -85,8 +95,10 @@ def test_case_6():
8595 pytorch_code = textwrap .dedent (
8696 """
8797 import torch
88- a = torch.zeros(3, requires_grad=True)
89- target = torch.tensor([0.,1.,0.])
98+ a = torch.zeros(3)
99+ a.requires_grad = True
100+ target = torch.zeros(3)
101+ target[1] = 1.0
90102 result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, reduction='mean')
91103 """
92104 )
@@ -97,8 +109,10 @@ def test_case_7():
97109 pytorch_code = textwrap .dedent (
98110 """
99111 import torch
100- a = torch.zeros(3, requires_grad=True)
101- target = torch.tensor([0.,1.,0.])
112+ a = torch.zeros(3)
113+ a.requires_grad = True
114+ target = torch.zeros(3)
115+ target[1] = 1.0
102116 result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, reduction='sum')
103117 """
104118 )
@@ -109,8 +123,10 @@ def test_case_8():
109123 pytorch_code = textwrap .dedent (
110124 """
111125 import torch
112- a = torch.zeros(3, requires_grad=True)
113- target = torch.tensor([0.,1.,0.])
126+ a = torch.zeros(3)
127+ a.requires_grad = True
128+ target = torch.zeros(3)
129+ target[1] = 1.0
114130 result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, weight=None, size_average=None, reduce=True, reduction='sum', pos_weight=None)
115131 """
116132 )
@@ -122,8 +138,10 @@ def test_case_9():
122138 pytorch_code = textwrap .dedent (
123139 """
124140 import torch
125- a = torch.zeros(3, requires_grad=True)
126- target = torch.tensor([0.,1.,0.])
141+ a = torch.zeros(3)
142+ a.requires_grad = True
143+ target = torch.zeros(3)
144+ target[1] = 1.0
127145 result = torch.nn.functional.binary_cross_entropy_with_logits(a, target, None, None, True, 'sum', None)
128146 """
129147 )
@@ -135,22 +153,103 @@ def test_case_10():
135153 pytorch_code = textwrap .dedent (
136154 """
137155 import torch
138- a = torch.zeros(3, requires_grad=True)
139- target = torch.tensor([0.,1.,0.])
140- result = torch.nn.functional.binary_cross_entropy_with_logits(input=a, target=target, weight=None, size_average=None, reduce=True, reduction='sum', pos_weight=None)
156+ a = torch.zeros(3)
157+ a.requires_grad = True
158+ target = torch.zeros(3)
159+ target[1] = 1.0
160+ result = torch.nn.functional.binary_cross_entropy_with_logits(input=a, target=target)
141161 """
142162 )
143- obj .run (pytorch_code , ["result" ])
163+ expect_paddle_code = textwrap .dedent (
164+ """
165+ import paddle
166+
167+ a = paddle.zeros(3)
168+ a.stop_gradient = not True
169+ target = paddle.zeros(3)
170+ target[1] = 1.0
171+ result = paddle.nn.functional.binary_cross_entropy_with_logits(input=a, target=target)
172+ """
173+ )
174+ obj .run (pytorch_code , expect_paddle_code = expect_paddle_code )
144175
145176
146- # generated by validate_unittest autofix, based on test_case_8
147177def test_case_11 ():
148178 pytorch_code = textwrap .dedent (
149179 """
150180 import torch
151- a = torch.zeros(3, requires_grad=True)
152- target = torch.tensor([0.,1.,0.])
181+ a = torch.zeros(3)
182+ a.requires_grad = True
183+ target = torch.zeros(3)
184+ target[1] = 1.0
185+ result = torch.nn.functional.binary_cross_entropy_with_logits(target=target, input=a, reduction='sum', weight=None, pos_weight=None)
186+ """
187+ )
188+ expect_paddle_code = textwrap .dedent (
189+ """
190+ import paddle
191+
192+ a = paddle.zeros(3)
193+ a.stop_gradient = not True
194+ target = paddle.zeros(3)
195+ target[1] = 1.0
196+ result = paddle.nn.functional.binary_cross_entropy_with_logits(
197+ target=target, input=a, reduction="sum", weight=None, pos_weight=None
198+ )
199+ """
200+ )
201+ obj .run (pytorch_code , expect_paddle_code = expect_paddle_code )
202+
203+
204+ def test_case_12 ():
205+ pytorch_code = textwrap .dedent (
206+ """
207+ import torch
208+ a = torch.zeros(3)
209+ a.requires_grad = True
210+ target = torch.zeros(3)
211+ target[1] = 1.0
212+ result = torch.nn.functional.binary_cross_entropy_with_logits(input=a, target=target, weight=None, size_average=None, reduce=True, reduction='sum', pos_weight=None)
213+ """
214+ )
215+ expect_paddle_code = textwrap .dedent (
216+ """
217+ import paddle
218+
219+ a = paddle.zeros(3)
220+ a.stop_gradient = not True
221+ target = paddle.zeros(3)
222+ target[1] = 1.0
223+ result = paddle.nn.functional.binary_cross_entropy_with_logits(
224+ input=a, target=target, weight=None, reduction="mean", pos_weight=None
225+ )
226+ """
227+ )
228+ obj .run (pytorch_code , expect_paddle_code = expect_paddle_code )
229+
230+
231+ def test_case_13 ():
232+ pytorch_code = textwrap .dedent (
233+ """
234+ import torch
235+ a = torch.zeros(3)
236+ a.requires_grad = True
237+ target = torch.zeros(3)
238+ target[1] = 1.0
153239 result = torch.nn.functional.binary_cross_entropy_with_logits(pos_weight=None, reduction='sum', reduce=True, size_average=None, weight=None, target=target, input=a)
154240 """
155241 )
156- obj .run (pytorch_code , ["result" ])
242+ expect_paddle_code = textwrap .dedent (
243+ """
244+ import paddle
245+
246+ a = paddle.zeros(3)
247+ a.stop_gradient = not True
248+ target = paddle.zeros(3)
249+ target[1] = 1.0
250+ result = paddle.nn.functional.binary_cross_entropy_with_logits(
251+ pos_weight=None, reduction="mean", weight=None, target=target, input=a
252+ )
253+ """
254+ )
255+ obj .run (pytorch_code , expect_paddle_code = expect_paddle_code )
0 commit comments