11# Copyright (c) Microsoft Corporation.
22# Licensed under the MIT License.
33
4- import functools
54import logging
6- import timeit
5+ import time
76
87from absl .testing import absltest , parameterized
98import torch
109from torch import Tensor
1110
1211from fastseq .logging import get_logger
13- from fastseq .optimizer .jit .einsum_rewriter import rewrite_einsum
12+ from fastseq .optimizer .jit .einsum_rewriter import rewrite_einsum , einsum_rewrite_pattern_0
1413from fastseq .utils .test_utils import TestCaseBase
1514
1615logger = get_logger (__name__ , logging .INFO )
@@ -39,11 +38,15 @@ def run_einsum(eqn: str, t0: Tensor, t1: Tensor):
3938
4039 t0 = torch .randn (shape0 , dtype = torch .float32 ).cuda ()
4140 t1 = torch .randn (shape1 , dtype = torch .float32 ).cuda ()
42- repeat_times = 1000
41+ repeat_times = 1024
4342
4443 r0 = run_einsum (eqn , t0 , t1 )
45- time0 = timeit .Timer (functools .partial (run_einsum , eqn , t0 , t1 ))
46- s0 = time0 .timeit (repeat_times )
44+ torch .cuda .synchronize ()
45+ start0 = time .time ()
46+ for _ in range (repeat_times ):
47+ run_einsum (eqn , t0 , t1 )
48+ torch .cuda .synchronize ()
49+ end0 = time .time ()
4750
4851 script_run_einsum = torch .jit .script (run_einsum )
4952 logger .debug (f"Original graph: \n { script_run_einsum .graph .str ()} " )
@@ -52,13 +55,28 @@ def run_einsum(eqn: str, t0: Tensor, t1: Tensor):
5255 self .assertTrue ('bmm' in script_run_einsum .graph .str ())
5356
5457 r1 = script_run_einsum (eqn , t0 , t1 )
55- time1 = timeit .Timer (
56- functools .partial (script_run_einsum , eqn , t0 , t1 ))
57- s1 = time1 .timeit (repeat_times )
58+ torch .cuda .synchronize ()
59+ start1 = time .time ()
60+ for _ in range (repeat_times ):
61+ script_run_einsum (eqn , t0 , t1 )
62+ torch .cuda .synchronize ()
63+ end1 = time .time ()
64+
65+ r2 = einsum_rewrite_pattern_0 (eqn , [t0 , t1 ])
66+ torch .cuda .synchronize ()
67+ start2 = time .time ()
68+ for _ in range (repeat_times ):
69+ einsum_rewrite_pattern_0 (eqn , [t0 , t1 ])
70+ torch .cuda .synchronize ()
71+ end2 = time .time ()
5872
5973 self .assertTrue (torch .equal (r0 , r1 ))
60- logger .info (f"einsum took: { s0 } ; optimized einsum torchscript took: "
61- f"{ s1 } ;" )
74+ self .assertTrue (torch .equal (r0 , r2 ))
75+ self .assertEqual (
76+ r0 .is_contiguous (), r1 .is_contiguous (), r2 .is_contiguous ())
77+ logger .info (f"einsum took: { end0 - start0 } ;"
78+ f"optimized einsum torchscript took: { end1 - start1 } ;"
79+ f"optimized einsum python took: { end2 - start2 } ;" )
6280
6381
6482if __name__ == "__main__" :
0 commit comments