Skip to content

Commit a536990

Browse files
authored
Dispatch optimizer (#1149)
1 parent e733a49 commit a536990

4 files changed

Lines changed: 747 additions & 3 deletions

File tree

Lines changed: 330 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,330 @@
1+
package de.peeeq.wurstscript.intermediatelang.optimizer;
2+
3+
import de.peeeq.wurstscript.WurstOperator;
4+
import de.peeeq.wurstscript.jassIm.*;
5+
import de.peeeq.wurstscript.translation.imoptimizer.UselessFunctionCallsRemover;
6+
import de.peeeq.wurstscript.translation.imoptimizer.OptimizerPass;
7+
import de.peeeq.wurstscript.translation.imtranslation.ImTranslator;
8+
9+
import java.util.ArrayList;
10+
import java.util.HashSet;
11+
import java.util.IdentityHashMap;
12+
import java.util.Objects;
13+
import java.util.Set;
14+
15+
/**
16+
* Collapses consecutive identical dispatch safety checks into one.
17+
*
18+
* This targets the repetitive pattern emitted by checked dispatch after
19+
* inlining, where several method calls on the same receiver produce adjacent
20+
* copies of the same guard.
21+
*/
22+
public class DispatchCheckDeduplicator implements OptimizerPass {
23+
24+
private int rewrites;
25+
private final IdentityHashMap<ImFunction, IdentityHashMap<ImVar, Boolean>> mayWriteTypeIdMemo = new IdentityHashMap<>();
26+
private SideEffectAnalyzer sideEffectAnalyzer;
27+
28+
@Override
29+
public int optimize(ImTranslator trans) {
30+
rewrites = 0;
31+
mayWriteTypeIdMemo.clear();
32+
ImProg prog = trans.getImProg();
33+
sideEffectAnalyzer = new SideEffectAnalyzer(prog);
34+
for (ImFunction f : prog.getFunctions()) {
35+
optimizeStmts(f.getBody());
36+
}
37+
prog.flatten(trans);
38+
return rewrites;
39+
}
40+
41+
@Override
42+
public String getName() {
43+
return "Dispatch Check Dedup";
44+
}
45+
46+
private void optimizeStmts(ImStmts stmts) {
47+
for (ImStmt s : new ArrayList<>(stmts)) {
48+
if (s instanceof ImIf) {
49+
ImIf imIf = (ImIf) s;
50+
optimizeStmts(imIf.getThenBlock());
51+
optimizeStmts(imIf.getElseBlock());
52+
} else if (s instanceof ImLoop) {
53+
optimizeStmts(((ImLoop) s).getBody());
54+
} else if (s instanceof ImVarargLoop) {
55+
optimizeStmts(((ImVarargLoop) s).getBody());
56+
}
57+
}
58+
59+
int i = 0;
60+
while (i < stmts.size()) {
61+
GuardPattern first = extractDispatchGuard(stmts.get(i));
62+
if (first == null) {
63+
i++;
64+
continue;
65+
}
66+
67+
int j = i + 1;
68+
while (j < stmts.size()) {
69+
ImStmt s = stmts.get(j);
70+
GuardPattern next = extractDispatchGuard(s);
71+
if (next != null) {
72+
if (first.sameGuardAs(next)) {
73+
stmts.remove(j);
74+
rewrites++;
75+
continue;
76+
}
77+
// Different guard: keep scanning only if statement cannot invalidate this guard.
78+
if (invalidatesGuard(s, first)) {
79+
break;
80+
}
81+
j++;
82+
continue;
83+
}
84+
if (invalidatesGuard(s, first)) {
85+
break;
86+
}
87+
j++;
88+
}
89+
90+
i++;
91+
}
92+
}
93+
94+
private boolean invalidatesGuard(ImStmt s, GuardPattern guard) {
95+
if (mayWriteTypeIdFromElement(s, guard.failedCond.typeIdVar)) {
96+
return true;
97+
}
98+
if (s instanceof ImSet) {
99+
ImSet set = (ImSet) s;
100+
ImLExpr left = set.getLeft();
101+
if (left instanceof ImVarAccess) {
102+
ImVar v = ((ImVarAccess) left).getVar();
103+
return v == guard.failedCond.receiverVar || v == guard.failedCond.typeIdVar;
104+
}
105+
if (left instanceof ImVarArrayAccess) {
106+
ImVar v = ((ImVarArrayAccess) left).getVar();
107+
return v == guard.failedCond.typeIdVar;
108+
}
109+
if (left instanceof ImMemberAccess) {
110+
ImVar v = ((ImMemberAccess) left).getVar();
111+
return v == guard.failedCond.typeIdVar;
112+
}
113+
return false;
114+
}
115+
if (s instanceof ImFunctionCall) {
116+
ImFunction f = ((ImFunctionCall) s).getFunc();
117+
return mayWriteTypeId(f, guard.failedCond.typeIdVar);
118+
}
119+
if (s instanceof ImMethodCall) {
120+
ImMethod m = ((ImMethodCall) s).getMethod();
121+
return mayWriteTypeId(m.getImplementation(), guard.failedCond.typeIdVar);
122+
}
123+
if (s instanceof ImDealloc || s instanceof ImAlloc) {
124+
return true;
125+
}
126+
if (s instanceof ImIf || s instanceof ImLoop || s instanceof ImVarargLoop
127+
|| s instanceof ImReturn || s instanceof ImExitwhen) {
128+
return true;
129+
}
130+
return false;
131+
}
132+
133+
private boolean mayWriteTypeIdFromElement(Element elem, ImVar typeIdVar) {
134+
if (elem == null) {
135+
return false;
136+
}
137+
if (sideEffectAnalyzer.directlySetVariables(elem).contains(typeIdVar)) {
138+
return true;
139+
}
140+
for (ImFunction called : sideEffectAnalyzer.calledFunctions(elem)) {
141+
if (mayWriteTypeId(called, typeIdVar)) {
142+
return true;
143+
}
144+
}
145+
return false;
146+
}
147+
148+
private boolean mayWriteTypeId(ImFunction f, ImVar typeIdVar) {
149+
if (f == null) {
150+
return true;
151+
}
152+
if (f.isNative()) {
153+
return !UselessFunctionCallsRemover.isFunctionWithoutSideEffect(f.getName());
154+
}
155+
if (f.isExtern()) {
156+
return true;
157+
}
158+
IdentityHashMap<ImVar, Boolean> byTypeId = mayWriteTypeIdMemo.computeIfAbsent(f, k -> new IdentityHashMap<>());
159+
Boolean memo = byTypeId.get(typeIdVar);
160+
if (memo != null) {
161+
return memo;
162+
}
163+
boolean result = mayWriteTypeIdUsingAnalysis(f, typeIdVar);
164+
byTypeId.put(typeIdVar, result);
165+
return result;
166+
}
167+
168+
private boolean mayWriteTypeIdUsingAnalysis(ImFunction f, ImVar typeIdVar) {
169+
Set<ImFunction> reachable = new HashSet<>();
170+
reachable.add(f);
171+
reachable.addAll(sideEffectAnalyzer.calledFunctions(f.getBody()));
172+
for (ImFunction g : reachable) {
173+
if (g == null) {
174+
return true;
175+
}
176+
if (g.isExtern()) {
177+
return true;
178+
}
179+
if (g.isNative()) {
180+
if (!UselessFunctionCallsRemover.isFunctionWithoutSideEffect(g.getName())) {
181+
return true;
182+
}
183+
continue;
184+
}
185+
if (sideEffectAnalyzer.directlySetVariables(g.getBody()).contains(typeIdVar)) {
186+
return true;
187+
}
188+
}
189+
return false;
190+
}
191+
192+
private GuardPattern extractDispatchGuard(ImStmt stmt) {
193+
if (!(stmt instanceof ImIf)) {
194+
return null;
195+
}
196+
ImIf outer = (ImIf) stmt;
197+
if (!outer.getElseBlock().isEmpty() || outer.getThenBlock().size() != 1) {
198+
return null;
199+
}
200+
GuardCond failed = parseTypeIdZeroCond(outer.getCondition());
201+
if (failed == null) {
202+
return null;
203+
}
204+
205+
ImStmt innerStmt = outer.getThenBlock().get(0);
206+
if (!(innerStmt instanceof ImIf)) {
207+
return null;
208+
}
209+
ImIf inner = (ImIf) innerStmt;
210+
if (inner.getThenBlock().size() != 1 || inner.getElseBlock().size() != 1) {
211+
return null;
212+
}
213+
if (!isReceiverZeroCond(inner.getCondition(), failed.receiverVar)) {
214+
return null;
215+
}
216+
217+
ErrorCall nullErr = parseSingleErrorCall(inner.getThenBlock().get(0));
218+
ErrorCall invalidErr = parseSingleErrorCall(inner.getElseBlock().get(0));
219+
if (nullErr == null || invalidErr == null) {
220+
return null;
221+
}
222+
223+
return new GuardPattern(failed, nullErr, invalidErr);
224+
}
225+
226+
private static GuardCond parseTypeIdZeroCond(ImExpr expr) {
227+
if (!(expr instanceof ImOperatorCall)) {
228+
return null;
229+
}
230+
ImOperatorCall op = (ImOperatorCall) expr;
231+
if (op.getOp() != WurstOperator.EQ || op.getArguments().size() != 2) {
232+
return null;
233+
}
234+
ImExpr a = op.getArguments().get(0);
235+
ImExpr b = op.getArguments().get(1);
236+
GuardCond c = parseTypeIdEqZero(a, b);
237+
if (c != null) {
238+
return c;
239+
}
240+
return parseTypeIdEqZero(b, a);
241+
}
242+
243+
private static GuardCond parseTypeIdEqZero(ImExpr left, ImExpr right) {
244+
if (!(right instanceof ImIntVal) || ((ImIntVal) right).getValI() != 0) {
245+
return null;
246+
}
247+
if (!(left instanceof ImVarArrayAccess)) {
248+
return null;
249+
}
250+
ImVarArrayAccess aa = (ImVarArrayAccess) left;
251+
if (aa.getIndexes().size() != 1 || !(aa.getIndexes().get(0) instanceof ImVarAccess)) {
252+
return null;
253+
}
254+
ImVar receiver = ((ImVarAccess) aa.getIndexes().get(0)).getVar();
255+
return new GuardCond(aa.getVar(), receiver);
256+
}
257+
258+
private static boolean isReceiverZeroCond(ImExpr expr, ImVar receiver) {
259+
if (!(expr instanceof ImOperatorCall)) {
260+
return false;
261+
}
262+
ImOperatorCall op = (ImOperatorCall) expr;
263+
if (op.getOp() != WurstOperator.EQ || op.getArguments().size() != 2) {
264+
return false;
265+
}
266+
return isReceiverEqZero(op.getArguments().get(0), op.getArguments().get(1), receiver)
267+
|| isReceiverEqZero(op.getArguments().get(1), op.getArguments().get(0), receiver);
268+
}
269+
270+
private static boolean isReceiverEqZero(ImExpr left, ImExpr right, ImVar receiver) {
271+
return left instanceof ImVarAccess
272+
&& ((ImVarAccess) left).getVar() == receiver
273+
&& right instanceof ImIntVal
274+
&& ((ImIntVal) right).getValI() == 0;
275+
}
276+
277+
private static ErrorCall parseSingleErrorCall(ImStmt stmt) {
278+
if (!(stmt instanceof ImFunctionCall)) {
279+
return null;
280+
}
281+
ImFunctionCall fc = (ImFunctionCall) stmt;
282+
if (fc.getArguments().size() != 1 || !(fc.getArguments().get(0) instanceof ImStringVal)) {
283+
return null;
284+
}
285+
return new ErrorCall(fc.getFunc(), ((ImStringVal) fc.getArguments().get(0)).getValS());
286+
}
287+
288+
private static final class GuardPattern {
289+
private final GuardCond failedCond;
290+
private final ErrorCall nullError;
291+
private final ErrorCall invalidError;
292+
293+
private GuardPattern(GuardCond failedCond, ErrorCall nullError, ErrorCall invalidError) {
294+
this.failedCond = failedCond;
295+
this.nullError = nullError;
296+
this.invalidError = invalidError;
297+
}
298+
299+
private boolean sameGuardAs(GuardPattern other) {
300+
return failedCond.typeIdVar == other.failedCond.typeIdVar
301+
&& failedCond.receiverVar == other.failedCond.receiverVar
302+
&& nullError.sameAs(other.nullError)
303+
&& invalidError.sameAs(other.invalidError);
304+
}
305+
}
306+
307+
private static final class GuardCond {
308+
private final ImVar typeIdVar;
309+
private final ImVar receiverVar;
310+
311+
private GuardCond(ImVar typeIdVar, ImVar receiverVar) {
312+
this.typeIdVar = typeIdVar;
313+
this.receiverVar = receiverVar;
314+
}
315+
}
316+
317+
private static final class ErrorCall {
318+
private final ImFunction func;
319+
private final String message;
320+
321+
private ErrorCall(ImFunction func, String message) {
322+
this.func = func;
323+
this.message = message;
324+
}
325+
326+
private boolean sameAs(ErrorCall other) {
327+
return func == other.func && Objects.equals(message, other.message);
328+
}
329+
}
330+
}

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imoptimizer/ImOptimizer.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import de.peeeq.wurstscript.WLogger;
77
import de.peeeq.wurstscript.intermediatelang.optimizer.BranchMerger;
88
import de.peeeq.wurstscript.intermediatelang.optimizer.ConstantAndCopyPropagation;
9+
import de.peeeq.wurstscript.intermediatelang.optimizer.DispatchCheckDeduplicator;
910
import de.peeeq.wurstscript.intermediatelang.optimizer.LocalMerger;
1011
import de.peeeq.wurstscript.intermediatelang.optimizer.SideEffectAnalyzer;
1112
import de.peeeq.wurstscript.intermediatelang.optimizer.SimpleRewrites;
@@ -32,6 +33,7 @@ public class ImOptimizer {
3233
localPasses.add(new ConstantAndCopyPropagation());
3334
localPasses.add(new UselessFunctionCallsRemover());
3435
localPasses.add(new GlobalsInliner());
36+
localPasses.add(new DispatchCheckDeduplicator());
3537
localPasses.add(new SimpleRewrites());
3638
}
3739

0 commit comments

Comments
 (0)