Skip to content

Commit c1955b7

Browse files
committed
Add synchronize and check for contiguous
1 parent 95994f8 commit c1955b7

File tree

1 file changed

+29
-11
lines changed

1 file changed

+29
-11
lines changed

tests/optimizer/jit/test_einsum_rewriter.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
33

4-
import functools
54
import logging
6-
import timeit
5+
import time
76

87
from absl.testing import absltest, parameterized
98
import torch
109
from torch import Tensor
1110

1211
from fastseq.logging import get_logger
13-
from fastseq.optimizer.jit.einsum_rewriter import rewrite_einsum
12+
from fastseq.optimizer.jit.einsum_rewriter import rewrite_einsum, einsum_rewrite_pattern_0
1413
from fastseq.utils.test_utils import TestCaseBase
1514

1615
logger = get_logger(__name__, logging.INFO)
@@ -39,11 +38,15 @@ def run_einsum(eqn: str, t0: Tensor, t1: Tensor):
3938

4039
t0 = torch.randn(shape0, dtype=torch.float32).cuda()
4140
t1 = torch.randn(shape1, dtype=torch.float32).cuda()
42-
repeat_times = 1000
41+
repeat_times = 1024
4342

4443
r0 = run_einsum(eqn, t0, t1)
45-
time0 = timeit.Timer(functools.partial(run_einsum, eqn, t0, t1))
46-
s0 = time0.timeit(repeat_times)
44+
torch.cuda.synchronize()
45+
start0 = time.time()
46+
for _ in range(repeat_times):
47+
run_einsum(eqn, t0, t1)
48+
torch.cuda.synchronize()
49+
end0 = time.time()
4750

4851
script_run_einsum = torch.jit.script(run_einsum)
4952
logger.debug(f"Original graph: \n{script_run_einsum.graph.str()}")
@@ -52,13 +55,28 @@ def run_einsum(eqn: str, t0: Tensor, t1: Tensor):
5255
self.assertTrue('bmm' in script_run_einsum.graph.str())
5356

5457
r1 = script_run_einsum(eqn, t0, t1)
55-
time1 = timeit.Timer(
56-
functools.partial(script_run_einsum, eqn, t0, t1))
57-
s1 = time1.timeit(repeat_times)
58+
torch.cuda.synchronize()
59+
start1 = time.time()
60+
for _ in range(repeat_times):
61+
script_run_einsum(eqn, t0, t1)
62+
torch.cuda.synchronize()
63+
end1 = time.time()
64+
65+
r2 = einsum_rewrite_pattern_0(eqn, [t0, t1])
66+
torch.cuda.synchronize()
67+
start2 = time.time()
68+
for _ in range(repeat_times):
69+
einsum_rewrite_pattern_0(eqn, [t0, t1])
70+
torch.cuda.synchronize()
71+
end2 = time.time()
5872

5973
self.assertTrue(torch.equal(r0, r1))
60-
logger.info(f"einsum took: {s0}; optimized einsum torchscript took: "
61-
f"{s1};")
74+
self.assertTrue(torch.equal(r0, r2))
75+
self.assertEqual(
76+
r0.is_contiguous(), r1.is_contiguous(), r2.is_contiguous())
77+
logger.info(f"einsum took: {end0 - start0};"
78+
f"optimized einsum torchscript took: {end1 - start1};"
79+
f"optimized einsum python took: {end2 - start2};")
6280

6381

6482
if __name__ == "__main__":

0 commit comments

Comments
 (0)