Skip to content

Commit 3f901ab

Browse files
committed
Move ContinuationRoot-specific logic from ExecutionContext to ResumeGeneratorNode
1 parent 973d006 commit 3f901ab

File tree

3 files changed

+60
-60
lines changed

3 files changed

+60
-60
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/generator/CommonGeneratorBuiltins.java

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
import com.oracle.graal.python.nodes.ErrorMessages;
6868
import com.oracle.graal.python.nodes.PGuards;
6969
import com.oracle.graal.python.nodes.PRaiseNode;
70+
import com.oracle.graal.python.nodes.PRootNode;
7071
import com.oracle.graal.python.nodes.bytecode.FrameInfo;
7172
import com.oracle.graal.python.nodes.bytecode.GeneratorReturnException;
7273
import com.oracle.graal.python.nodes.bytecode.GeneratorYieldResult;
@@ -77,10 +78,13 @@
7778
import com.oracle.graal.python.nodes.function.builtins.PythonQuaternaryBuiltinNode;
7879
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryBuiltinNode;
7980
import com.oracle.graal.python.nodes.object.BuiltinClassProfiles.IsBuiltinObjectProfile;
81+
import com.oracle.graal.python.runtime.ExecutionContext;
8082
import com.oracle.graal.python.runtime.PythonOptions;
8183
import com.oracle.graal.python.runtime.exception.PException;
8284
import com.oracle.graal.python.runtime.object.PFactory;
85+
import com.oracle.truffle.api.RootCallTarget;
8386
import com.oracle.truffle.api.bytecode.ContinuationResult;
87+
import com.oracle.truffle.api.bytecode.ContinuationRootNode;
8488
import com.oracle.truffle.api.dsl.Bind;
8589
import com.oracle.truffle.api.dsl.Cached;
8690
import com.oracle.truffle.api.dsl.Cached.Exclusive;
@@ -94,6 +98,7 @@
9498
import com.oracle.truffle.api.frame.MaterializedFrame;
9599
import com.oracle.truffle.api.frame.VirtualFrame;
96100
import com.oracle.truffle.api.nodes.DirectCallNode;
101+
import com.oracle.truffle.api.nodes.IndirectCallNode;
97102
import com.oracle.truffle.api.nodes.Node;
98103
import com.oracle.truffle.api.profiles.InlinedBranchProfile;
99104
import com.oracle.truffle.api.profiles.InlinedConditionProfile;
@@ -153,16 +158,31 @@ static Object cached(VirtualFrame frame, Node inliningTarget, PGenerator self, O
153158
@Specialization(guards = {"isBytecodeDSLInterpreter()", "sameCallTarget(self.getCurrentCallTarget(), callNode)"}, limit = "getCallSiteInlineCacheMaxDepth()")
154159
static Object cachedBytecodeDSL(VirtualFrame frame, Node inliningTarget, PGenerator self, Object sendValue,
155160
@Cached(parameters = "self.getCurrentCallTarget()") DirectCallNode callNode,
156-
@Exclusive @Cached CallDispatchers.SimpleDirectInvokeNode invoke,
161+
@Exclusive @Cached ExecutionContext.CallContext callContext,
157162
@Exclusive @Cached InlinedBranchProfile returnProfile,
158163
@Exclusive @Cached IsBuiltinObjectProfile errorProfile,
159164
@Exclusive @Cached PRaiseNode raiseNode) {
160165
self.setRunning(true);
161166
Object generatorResult;
162167
try {
163168
self.prepareResume();
164-
Object[] arguments = new Object[]{self.getGeneratorFrame(), sendValue};
165-
generatorResult = invoke.execute(frame, inliningTarget, callNode, arguments);
169+
RootCallTarget callTarget = (RootCallTarget) callNode.getCurrentCallTarget();
170+
PRootNode rootNode = PGenerator.unwrapContinuationRoot((ContinuationRootNode) callTarget.getRootNode());
171+
/*
172+
* When resuming a generator/coroutine, the call target is a ContinuationRoot with a
173+
* different calling convention from regular PRootNodes. The first argument is a
174+
* materialized frame, which will be used for the execution itself. We will, e.g.,
175+
* lookup the exception state in that frame's arguments.
176+
*
177+
* So for Bytecode DSL generators, we update the arguments array of that
178+
* materialized frame instead of the arguments array that will be used for the
179+
* actual Truffle call to the ContinuationRoot, which is not accessible to us in the
180+
* generator root.
181+
*/
182+
MaterializedFrame generatorFrame = self.getGeneratorFrame();
183+
callContext.executePrepareCall(frame, generatorFrame.getArguments(), rootNode.needsCallerFrame(), rootNode.needsExceptionState());
184+
Object[] arguments = new Object[]{generatorFrame, sendValue};
185+
generatorResult = callNode.call(arguments);
166186
} catch (PException e) {
167187
throw handleException(self, inliningTarget, errorProfile, raiseNode, e);
168188
} finally {
@@ -203,16 +223,22 @@ static Object generic(VirtualFrame frame, Node inliningTarget, PGenerator self,
203223
@Specialization(replaces = "cachedBytecodeDSL", guards = "isBytecodeDSLInterpreter()")
204224
@Megamorphic
205225
static Object genericBytecodeDSL(VirtualFrame frame, Node inliningTarget, PGenerator self, Object sendValue,
206-
@Exclusive @Cached CallDispatchers.SimpleIndirectInvokeNode invoke,
226+
@Exclusive @Cached ExecutionContext.CallContext callContext,
227+
@Exclusive @Cached IndirectCallNode callNode,
207228
@Exclusive @Cached InlinedBranchProfile returnProfile,
208229
@Exclusive @Cached IsBuiltinObjectProfile errorProfile,
209230
@Exclusive @Cached PRaiseNode raiseNode) {
210231
self.setRunning(true);
211232
Object generatorResult;
212233
try {
213234
self.prepareResume();
214-
Object[] arguments = new Object[]{self.getGeneratorFrame(), sendValue};
215-
generatorResult = invoke.execute(frame, inliningTarget, self.getCurrentCallTarget(), arguments);
235+
RootCallTarget callTarget = self.getCurrentCallTarget();
236+
// See the cached specialization for notes about the arguments handling
237+
PRootNode rootNode = PGenerator.unwrapContinuationRoot((ContinuationRootNode) callTarget.getRootNode());
238+
MaterializedFrame generatorFrame = self.getGeneratorFrame();
239+
callContext.executePrepareCall(frame, generatorFrame.getArguments(), rootNode.needsCallerFrame(), rootNode.needsExceptionState());
240+
Object[] arguments = new Object[]{generatorFrame, sendValue};
241+
generatorResult = callNode.call(callTarget, arguments);
216242
} catch (PException e) {
217243
throw handleException(self, inliningTarget, errorProfile, raiseNode, e);
218244
} finally {

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/generator/PGenerator.java

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import com.oracle.graal.python.runtime.object.PFactory;
4545
import com.oracle.truffle.api.CompilerDirectives;
4646
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
47+
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
4748
import com.oracle.truffle.api.RootCallTarget;
4849
import com.oracle.truffle.api.TruffleStackTraceElement;
4950
import com.oracle.truffle.api.bytecode.ContinuationResult;
@@ -237,13 +238,29 @@ public static Frame getDSLGeneratorFrame(Object[] continuationCallArguments) {
237238

238239
public static RootNode unwrapContinuationRoot(RootNode rootNode) {
239240
if (PythonOptions.ENABLE_BYTECODE_DSL_INTERPRETER &&
240-
rootNode instanceof ContinuationRootNode continuationRoot &&
241-
continuationRoot.getSourceRootNode() instanceof PBytecodeDSLRootNode result) {
242-
return result;
241+
rootNode instanceof ContinuationRootNode continuationRoot) {
242+
return unwrapContinuationRoot(continuationRoot);
243243
}
244244
return rootNode;
245245
}
246246

247+
public static PBytecodeDSLRootNode unwrapContinuationRoot(ContinuationRootNode continuationRoot) {
248+
if (CompilerDirectives.isPartialEvaluationConstant(continuationRoot)) {
249+
return (PBytecodeDSLRootNode) continuationRoot.getSourceRootNode();
250+
} else {
251+
/*
252+
* TODO We know that the continuation root node is always the same type, but we can't
253+
* cast to it because it's not public. So we end up with a virtual call.
254+
*/
255+
return unwrapContinuationRootBoundary(continuationRoot);
256+
}
257+
}
258+
259+
@TruffleBoundary
260+
private static PBytecodeDSLRootNode unwrapContinuationRootBoundary(ContinuationRootNode continuationRoot) {
261+
return (PBytecodeDSLRootNode) continuationRoot.getSourceRootNode();
262+
}
263+
247264
public static boolean isGeneratorFrame(Frame frame) {
248265
Object frameInfo = frame.getFrameDescriptor().getInfo();
249266
// just to avoid interface dispatch we must cast the info object

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/runtime/ExecutionContext.java

Lines changed: 8 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
import com.oracle.graal.python.builtins.objects.frame.PFrame;
4646
import com.oracle.graal.python.builtins.objects.frame.PFrame.Reference;
4747
import com.oracle.graal.python.builtins.objects.function.PArguments;
48-
import com.oracle.graal.python.builtins.objects.generator.PGenerator;
4948
import com.oracle.graal.python.nodes.PRootNode;
5049
import com.oracle.graal.python.nodes.bytecode_dsl.PBytecodeDSLRootNode;
5150
import com.oracle.graal.python.nodes.exception.TopLevelExceptionHandler;
@@ -65,7 +64,6 @@
6564
import com.oracle.truffle.api.HostCompilerDirectives.InliningCutoff;
6665
import com.oracle.truffle.api.RootCallTarget;
6766
import com.oracle.truffle.api.bytecode.BytecodeNode;
68-
import com.oracle.truffle.api.bytecode.ContinuationRootNode;
6967
import com.oracle.truffle.api.dsl.Bind;
7068
import com.oracle.truffle.api.dsl.Cached;
7169
import com.oracle.truffle.api.dsl.GenerateCached;
@@ -81,7 +79,6 @@
8179
import com.oracle.truffle.api.frame.VirtualFrame;
8280
import com.oracle.truffle.api.nodes.EncapsulatingNodeReference;
8381
import com.oracle.truffle.api.nodes.Node;
84-
import com.oracle.truffle.api.nodes.RootNode;
8582
import com.oracle.truffle.api.profiles.InlinedConditionProfile;
8683
import com.oracle.truffle.api.profiles.InlinedCountingConditionProfile;
8784

@@ -163,65 +160,25 @@ public abstract class ExecutionContext {
163160
@GenerateUncached
164161
public abstract static class CallContext extends Node {
165162

166-
/*
167-
* Bytecode DSL note: When resuming a generator/coroutine, the call target is a
168-
* ContinuationRoot with a different calling convention from regular PRootNodes. The first
169-
* argument is a materialized frame, which will be used for the execution itself. We will,
170-
* e.g., lookup the exception state in that frame's arguments.
171-
*
172-
* So for Bytecode DSL generators, we update the arguments array of that materialized frame
173-
* instead of the arguments array that will be used for the actual Truffle call to the
174-
* ContinuationRoot, which is not accessible to us in the generator root.
175-
*/
176-
177163
/**
178164
* Prepare an indirect call from a Python frame to a Python function.
179165
*/
180166
public void prepareIndirectCall(VirtualFrame frame, Object[] callArguments, RootCallTarget callTarget) {
181-
PRootNode pRootNode;
182-
RootNode rootNode = callTarget.getRootNode();
183-
if (rootNode instanceof ContinuationRootNode continuationRoot) {
184-
pRootNode = (PRootNode) continuationRoot.getSourceRootNode();
185-
} else {
186-
pRootNode = (PRootNode) rootNode;
187-
}
188-
executePrepareCall(frame, getActualCallArguments(callArguments), pRootNode.needsCallerFrame(), pRootNode.needsExceptionState());
189-
}
190-
191-
private static Object[] getActualCallArguments(Object[] callArguments) {
192-
// See Bytecode DSL note at the top
193-
if (callArguments.length == 2 && callArguments[0] instanceof MaterializedFrame materialized) {
194-
return materialized.getArguments();
195-
}
196-
return callArguments;
167+
PRootNode rootNode = (PRootNode) callTarget.getRootNode();
168+
executePrepareCall(frame, callArguments, rootNode.needsCallerFrame(), rootNode.needsExceptionState());
197169
}
198170

199171
/**
200172
* Prepare a call from a Python frame to a Python function.
201173
*/
202174
public void prepareCall(VirtualFrame frame, Object[] callArguments, RootCallTarget callTarget) {
203-
RootNode rootNode = callTarget.getRootNode();
204-
205-
PRootNode calleeRootNode;
206-
Object[] actualCallArguments;
207-
boolean needsExceptionState;
208-
if (rootNode instanceof ContinuationRootNode continuationRoot) {
209-
// See Bytecode DSL note at the top
210-
calleeRootNode = (PRootNode) continuationRoot.getSourceRootNode();
211-
assert callArguments.length == 2;
212-
actualCallArguments = ((MaterializedFrame) callArguments[0]).getArguments();
213-
needsExceptionState = calleeRootNode.needsExceptionState();
214-
} else {
215-
// n.b.: The class cast should always be correct, since this context
216-
// must only be used when calling from Python to Python
217-
calleeRootNode = (PRootNode) rootNode;
218-
actualCallArguments = callArguments;
219-
needsExceptionState = calleeRootNode.needsExceptionState();
220-
}
221-
executePrepareCall(frame, actualCallArguments, calleeRootNode.needsCallerFrame(), needsExceptionState);
175+
// n.b.: The class cast should always be correct, since this context
176+
// must only be used when calling from Python to Python
177+
PRootNode calleeRootNode = (PRootNode) callTarget.getRootNode();
178+
executePrepareCall(frame, callArguments, calleeRootNode.needsCallerFrame(), calleeRootNode.needsExceptionState());
222179
}
223180

224-
protected abstract void executePrepareCall(VirtualFrame frame, Object[] callArguments, boolean needsCallerFrame, boolean needsExceptionState);
181+
public abstract void executePrepareCall(VirtualFrame frame, Object[] callArguments, boolean needsCallerFrame, boolean needsExceptionState);
225182

226183
@Specialization
227184
protected static void prepareCall(VirtualFrame frame, Object[] callArguments, boolean needsCallerFrame, boolean needsExceptionState,
@@ -791,7 +748,7 @@ public static Object enter(PythonThreadState threadState, Object[] pArguments, R
791748
}
792749

793750
private static boolean needsExceptionState(RootCallTarget callTarget) {
794-
PRootNode calleeRootNode = (PRootNode) PGenerator.unwrapContinuationRoot(callTarget.getRootNode());
751+
PRootNode calleeRootNode = (PRootNode) callTarget.getRootNode();
795752
return calleeRootNode.needsExceptionState();
796753
}
797754

0 commit comments

Comments
 (0)