Skip to content

Commit 01f98d5

Browse files
committed
[GR-71718] Fix regression in generator benchmarks on DSL interpreter
PullRequest: graalpython/4131
2 parents d6fbfa2 + e26568c commit 01f98d5

File tree

8 files changed

+48
-66
lines changed

8 files changed

+48
-66
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/asyncio/PAsyncGen.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,11 @@ public static PAsyncGen create(PythonLanguage lang, PFunction function, PBytecod
6161
}
6262

6363
private PAsyncGen(PythonLanguage lang, PFunction function, MaterializedFrame generatorFrame, PBytecodeRootNode rootNode, RootCallTarget[] callTargets) {
64-
super(lang, function, generatorFrame, PythonBuiltinClassType.PAsyncGenerator, false, new BytecodeState(rootNode, callTargets));
64+
super(lang, function, generatorFrame, PythonBuiltinClassType.PAsyncGenerator, new BytecodeState(rootNode, callTargets));
6565
}
6666

6767
public PAsyncGen(PythonLanguage language, PFunction function, PBytecodeDSLRootNode rootNode, ContinuationRootNode continuationRootNode, MaterializedFrame continuationFrame) {
68-
super(language, function, continuationFrame, PythonBuiltinClassType.PAsyncGenerator, false,
68+
super(language, function, continuationFrame, PythonBuiltinClassType.PAsyncGenerator,
6969
new BytecodeDSLState(rootNode, continuationFrame.getArguments(), continuationRootNode));
7070
}
7171

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/code/CodeNodes.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ private static PCode createCode(PythonLanguage language, PythonContext context,
141141
parameterNames,
142142
kwOnlyNames);
143143
} else {
144-
ct = deserializeForBytecodeInterpreter(context, codedata, cellvars, freevars);
144+
ct = deserializeForBytecodeInterpreter(context, codedata, cellvars, freevars, flags);
145145
signature = ((PRootNode) ct.getRootNode()).getSignature();
146146
}
147147
if (filename != null) {
@@ -150,17 +150,20 @@ private static PCode createCode(PythonLanguage language, PythonContext context,
150150
return PFactory.createCode(language, ct, signature, nlocals, stacksize, flags, constants, names, varnames, freevars, cellvars, filename, name, qualname, firstlineno, linetable);
151151
}
152152

153-
private static RootCallTarget deserializeForBytecodeInterpreter(PythonContext context, byte[] data, TruffleString[] cellvars, TruffleString[] freevars) {
153+
private static RootCallTarget deserializeForBytecodeInterpreter(PythonContext context, byte[] data, TruffleString[] cellvars, TruffleString[] freevars, int flags) {
154154
CodeUnit codeUnit = MarshalModuleBuiltins.deserializeCodeUnit(null, context, data);
155155
RootNode rootNode;
156156

157157
if (PythonOptions.ENABLE_BYTECODE_DSL_INTERPRETER) {
158158
BytecodeDSLCodeUnit code = (BytecodeDSLCodeUnit) codeUnit;
159+
if (code.flags != flags) {
160+
code = code.withFlags(flags);
161+
}
159162
rootNode = code.createRootNode(context, PythonUtils.createFakeSource());
160163
} else {
161164
BytecodeCodeUnit code = (BytecodeCodeUnit) codeUnit;
162-
if (cellvars != null && !Arrays.equals(code.cellvars, cellvars) || freevars != null && !Arrays.equals(code.freevars, freevars)) {
163-
code = new BytecodeCodeUnit(code.name, code.qualname, code.argCount, code.kwOnlyArgCount, code.positionalOnlyArgCount, code.flags, code.names,
165+
if (cellvars != null && !Arrays.equals(code.cellvars, cellvars) || freevars != null && !Arrays.equals(code.freevars, freevars) || flags != code.flags) {
166+
code = new BytecodeCodeUnit(code.name, code.qualname, code.argCount, code.kwOnlyArgCount, code.positionalOnlyArgCount, flags, code.names,
164167
code.varnames, cellvars != null ? cellvars : code.cellvars, freevars != null ? freevars : code.freevars, code.cell2arg,
165168
code.constants, code.startLine,
166169
code.startColumn, code.endLine, code.endColumn, code.code, code.srcOffsetTable,

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/generator/PGenerator.java

Lines changed: 24 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import com.oracle.graal.python.builtins.objects.function.PFunction;
3434
import com.oracle.graal.python.builtins.objects.object.PythonBuiltinObject;
3535
import com.oracle.graal.python.builtins.objects.object.PythonObject;
36+
import com.oracle.graal.python.compiler.CodeUnit;
3637
import com.oracle.graal.python.nodes.bytecode.BytecodeFrameInfo;
3738
import com.oracle.graal.python.nodes.bytecode.FrameInfo;
3839
import com.oracle.graal.python.nodes.bytecode.GeneratorYieldResult;
@@ -66,8 +67,6 @@ public class PGenerator extends PythonBuiltinObject {
6667
private boolean finished;
6768
// running means it is currently on the stack, not just started
6869
private boolean running;
69-
private final boolean isCoroutine;
70-
private final boolean isAsyncGen;
7170

7271
private PCode code;
7372

@@ -117,21 +116,20 @@ public Object handleResult(PythonLanguage language, GeneratorYieldResult result)
117116
public static class BytecodeDSLState {
118117
private final PBytecodeDSLRootNode rootNode;
119118
private final Object[] arguments;
120-
private BytecodeNode bytecodeNode;
119+
private BytecodeLocation lastLocation;
121120
private ContinuationRootNode continuationRootNode;
122121
private boolean isStarted;
123122

124123
public BytecodeDSLState(PBytecodeDSLRootNode rootNode, Object[] arguments, ContinuationRootNode continuationRootNode) {
125124
this.rootNode = rootNode;
126125
this.arguments = arguments;
127126
this.continuationRootNode = continuationRootNode;
128-
this.bytecodeNode = rootNode.getBytecodeNode();
129127
}
130128

131129
public Object handleResult(PGenerator generator, ContinuationResult result) {
132130
assert result.getContinuationRootNode() == null || result.getContinuationRootNode().getFrameDescriptor() == generator.frame.getFrameDescriptor();
133131
isStarted = true;
134-
bytecodeNode = continuationRootNode.getLocation().getBytecodeNode();
132+
lastLocation = continuationRootNode.getLocation();
135133
continuationRootNode = result.getContinuationRootNode();
136134
return result.getResult();
137135
}
@@ -148,40 +146,26 @@ private BytecodeDSLState getBytecodeDSLState() {
148146
return (BytecodeDSLState) state;
149147
}
150148

151-
// An explicit isIterableCoroutine argument is needed for iterable coroutines (generally created
152-
// via types.coroutine)
153149
public static PGenerator create(PythonLanguage lang, PFunction function, PBytecodeRootNode rootNode, RootCallTarget[] callTargets, Object[] arguments,
154-
PythonBuiltinClassType cls, boolean isIterableCoroutine) {
150+
PythonBuiltinClassType cls) {
155151
// note: also done in PAsyncGen.create
156152
MaterializedFrame generatorFrame = rootNode.createGeneratorFrame(arguments);
157-
return new PGenerator(lang, function, generatorFrame, cls, isIterableCoroutine, new BytecodeState(rootNode, callTargets));
158-
}
159-
160-
public static PGenerator create(PythonLanguage lang, PFunction function, PBytecodeDSLRootNode rootNode, Object[] arguments,
161-
PythonBuiltinClassType cls, boolean isIterableCoroutine, ContinuationRootNode continuationRootNode, MaterializedFrame continuationFrame) {
162-
return new PGenerator(lang, function, continuationFrame, cls, isIterableCoroutine, new BytecodeDSLState(rootNode, arguments, continuationRootNode));
163-
}
164-
165-
public static PGenerator create(PythonLanguage lang, PFunction function, PBytecodeRootNode rootNode, RootCallTarget[] callTargets, Object[] arguments,
166-
PythonBuiltinClassType cls) {
167-
return create(lang, function, rootNode, callTargets, arguments, cls, false);
153+
return new PGenerator(lang, function, generatorFrame, cls, new BytecodeState(rootNode, callTargets));
168154
}
169155

170156
public static PGenerator create(PythonLanguage lang, PFunction function, PBytecodeDSLRootNode rootNode, Object[] arguments,
171157
PythonBuiltinClassType cls, ContinuationRootNode continuationRootNode, MaterializedFrame continuationFrame) {
172-
return create(lang, function, rootNode, arguments, cls, false, continuationRootNode, continuationFrame);
158+
return new PGenerator(lang, function, continuationFrame, cls, new BytecodeDSLState(rootNode, arguments, continuationRootNode));
173159
}
174160

175-
protected PGenerator(PythonLanguage lang, PFunction function, MaterializedFrame frame, PythonBuiltinClassType cls, boolean isIterableCoroutine, Object state) {
161+
protected PGenerator(PythonLanguage lang, PFunction function, MaterializedFrame frame, PythonBuiltinClassType cls, Object state) {
176162
super(cls, cls.getInstanceShape(lang));
177163
this.name = function.getName();
178164
this.qualname = function.getQualname();
179165
this.globals = function.getGlobals();
180166
this.generatorFunction = function;
181167
this.frame = frame;
182168
this.finished = false;
183-
this.isCoroutine = isIterableCoroutine || cls == PythonBuiltinClassType.PCoroutine;
184-
this.isAsyncGen = cls == PythonBuiltinClassType.PAsyncGenerator;
185169
if (PythonOptions.ENABLE_BYTECODE_DSL_INTERPRETER) {
186170
BytecodeDSLState bytecodeDSLState = (BytecodeDSLState) state;
187171
this.state = state;
@@ -311,7 +295,12 @@ public RootCallTarget getCurrentCallTarget() {
311295
*/
312296
public BytecodeNode getBytecodeNode() {
313297
assert PythonOptions.ENABLE_BYTECODE_DSL_INTERPRETER;
314-
return getBytecodeDSLState().bytecodeNode;
298+
BytecodeDSLState state = getBytecodeDSLState();
299+
if (state.lastLocation != null) {
300+
return state.lastLocation.getBytecodeNode();
301+
} else {
302+
return state.rootNode.getBytecodeNode();
303+
}
315304
}
316305

317306
public BytecodeLocation getCurrentLocation() {
@@ -398,12 +387,21 @@ public final void setQualname(TruffleString qualname) {
398387
this.qualname = qualname;
399388
}
400389

390+
private CodeUnit getCodeUnit() {
391+
if (PythonOptions.ENABLE_BYTECODE_DSL_INTERPRETER) {
392+
return getBytecodeDSLState().rootNode.getCodeUnit();
393+
} else {
394+
return getBytecodeState().rootNode.getCodeUnit();
395+
}
396+
}
397+
401398
public final boolean isCoroutine() {
402-
return isCoroutine;
399+
CodeUnit codeUnit = getCodeUnit();
400+
return codeUnit.isCoroutine() || codeUnit.isIterableCoroutine();
403401
}
404402

405403
public final boolean isAsyncGen() {
406-
return isAsyncGen;
404+
return getCodeUnit().isAsyncGenerator();
407405
}
408406

409407
public int getBci() {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/compiler/CodeUnit.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ public boolean isAsyncGenerator() {
148148
return (flags & PCode.CO_ASYNC_GENERATOR) != 0;
149149
}
150150

151+
public boolean isIterableCoroutine() {
152+
return (flags & PCode.CO_ITERABLE_COROUTINE) != 0;
153+
}
154+
151155
public boolean isGeneratorOrCoroutine() {
152156
return (flags & (PCode.CO_GENERATOR | PCode.CO_COROUTINE | PCode.CO_ASYNC_GENERATOR | PCode.CO_ITERABLE_COROUTINE)) != 0;
153157
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/bytecode/PBytecodeGeneratorFunctionRootNode.java

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
import com.oracle.truffle.api.RootCallTarget;
5454
import com.oracle.truffle.api.frame.FrameDescriptor;
5555
import com.oracle.truffle.api.frame.VirtualFrame;
56-
import com.oracle.truffle.api.profiles.ConditionProfile;
5756
import com.oracle.truffle.api.source.SourceSection;
5857
import com.oracle.truffle.api.strings.TruffleString;
5958

@@ -63,8 +62,6 @@ public class PBytecodeGeneratorFunctionRootNode extends PRootNode {
6362

6463
@CompilationFinal(dimensions = 1) private final RootCallTarget[] callTargets;
6564

66-
private final ConditionProfile isIterableCoroutine = ConditionProfile.create();
67-
6865
@TruffleBoundary
6966
public PBytecodeGeneratorFunctionRootNode(PythonLanguage language, FrameDescriptor frameDescriptor, PBytecodeRootNode rootNode, TruffleString originalName) {
7067
super(language, frameDescriptor);
@@ -83,14 +80,7 @@ public Object execute(VirtualFrame frame) {
8380
PFunction generatorFunction = PArguments.getFunctionObject(arguments);
8481
assert generatorFunction != null;
8582
if (rootNode.getCodeUnit().isGenerator()) {
86-
// if CO_ITERABLE_COROUTINE was explicitly set (likely by types.coroutine), we have to
87-
// pass the information to the generator
88-
// .gi_code.co_flags will still be wrong, but at least await will work correctly
89-
if (isIterableCoroutine.profile((generatorFunction.getCode().getFlags() & 0x100) != 0)) {
90-
return PFactory.createIterableCoroutine(language, generatorFunction, rootNode, callTargets, arguments);
91-
} else {
92-
return PFactory.createGenerator(language, generatorFunction, rootNode, callTargets, arguments);
93-
}
83+
return PFactory.createGenerator(language, generatorFunction, rootNode, callTargets, arguments);
9484
} else if (rootNode.getCodeUnit().isCoroutine()) {
9585
return PFactory.createCoroutine(language, generatorFunction, rootNode, callTargets, arguments);
9686
} else if (rootNode.getCodeUnit().isAsyncGenerator()) {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/bytecode_dsl/BytecodeDSLCodeUnit.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ public BytecodeDSLCodeUnit(TruffleString name, TruffleString qualname, int argCo
8686
this.selfIndex = selfIndex;
8787
}
8888

89+
public BytecodeDSLCodeUnit withFlags(int flags) {
90+
return new BytecodeDSLCodeUnit(name, qualname, argCount, kwOnlyArgCount, positionalOnlyArgCount, flags,
91+
names, varnames, cellvars, freevars, cell2arg, constants,
92+
startLine, startColumn, endLine, endColumn, classcellIndex, selfIndex, serialized, nodes);
93+
}
94+
8995
@TruffleBoundary
9096
public PBytecodeDSLRootNode createRootNode(PythonContext context, Source source) {
9197
if (nodes != null) {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/bytecode_dsl/PBytecodeDSLRootNode.java

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,32 +1053,23 @@ public static Object doYield(
10531053
@Bind Node inliningTarget,
10541054
@Bind ContinuationRootNode continuationRootNode,
10551055
@Bind PBytecodeDSLRootNode innerRoot,
1056-
@Bind BytecodeNode bytecodeNode,
1057-
@Cached InlinedConditionProfile isIterableCoroutine) {
1058-
Object result = createGenerator(continuationFrame, inliningTarget, continuationRootNode, innerRoot, isIterableCoroutine);
1056+
@Bind BytecodeNode bytecodeNode) {
1057+
Object result = createGenerator(continuationFrame, inliningTarget, continuationRootNode, innerRoot);
10591058
if (innerRoot.needsTraceAndProfileInstrumentation()) {
10601059
innerRoot.getThreadState().popInstrumentationData(innerRoot);
10611060
}
10621061
return result;
10631062
}
10641063

10651064
private static PythonAbstractObject createGenerator(MaterializedFrame continuationFrame, Node inliningTarget,
1066-
ContinuationRootNode continuationRootNode, PBytecodeDSLRootNode innerRoot,
1067-
InlinedConditionProfile isIterableCoroutine) {
1065+
ContinuationRootNode continuationRootNode, PBytecodeDSLRootNode innerRoot) {
10681066
Object[] arguments = continuationFrame.getArguments();
10691067
PFunction generatorFunction = PArguments.getFunctionObject(arguments);
10701068
assert generatorFunction != null;
10711069
PythonLanguage language = PythonLanguage.get(inliningTarget);
10721070
PArguments.setCurrentFrameInfo(continuationFrame, new PFrame.Reference(innerRoot, PFrame.Reference.EMPTY));
10731071
if (innerRoot.getCodeUnit().isGenerator()) {
1074-
// if CO_ITERABLE_COROUTINE was explicitly set (likely by types.coroutine), we have
1075-
// to pass the information to the generator .gi_code.co_flags will still be wrong,
1076-
// but at least await will work correctly
1077-
if (isIterableCoroutine.profile(inliningTarget, (generatorFunction.getCode().getFlags() & 0x100) != 0)) {
1078-
return PFactory.createIterableCoroutine(language, generatorFunction, innerRoot, arguments, continuationRootNode, continuationFrame);
1079-
} else {
1080-
return PFactory.createGenerator(language, generatorFunction, innerRoot, arguments, continuationRootNode, continuationFrame);
1081-
}
1072+
return PFactory.createGenerator(language, generatorFunction, innerRoot, arguments, continuationRootNode, continuationFrame);
10821073
} else if (innerRoot.getCodeUnit().isCoroutine()) {
10831074
return PFactory.createCoroutine(language, generatorFunction, innerRoot, arguments, continuationRootNode, continuationFrame);
10841075
} else if (innerRoot.getCodeUnit().isAsyncGenerator()) {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/runtime/object/PFactory.java

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -803,16 +803,6 @@ public static PGenerator createGenerator(PythonLanguage language, PFunction func
803803
return PGenerator.create(language, function, rootNode, arguments, PythonBuiltinClassType.PGenerator, continuationRootNode, continuationFrame);
804804
}
805805

806-
public static PGenerator createIterableCoroutine(PythonLanguage language, PFunction function, PBytecodeRootNode rootNode, RootCallTarget[] callTargets,
807-
Object[] arguments) {
808-
return PGenerator.create(language, function, rootNode, callTargets, arguments, PythonBuiltinClassType.PGenerator, true);
809-
}
810-
811-
public static PGenerator createIterableCoroutine(PythonLanguage language, PFunction function, PBytecodeDSLRootNode rootNode,
812-
Object[] arguments, ContinuationRootNode continuationRootNode, MaterializedFrame continuationFrame) {
813-
return PGenerator.create(language, function, rootNode, arguments, PythonBuiltinClassType.PGenerator, true, continuationRootNode, continuationFrame);
814-
}
815-
816806
public static PGenerator createCoroutine(PythonLanguage language, PFunction function, PBytecodeRootNode rootNode, RootCallTarget[] callTargets, Object[] arguments) {
817807
return PGenerator.create(language, function, rootNode, callTargets, arguments, PythonBuiltinClassType.PCoroutine);
818808
}

0 commit comments

Comments
 (0)