Skip to content

Commit a3624d9

Browse files
committed
Materialize frames when needed in IndirectCallContext
1 parent e6f0d17 commit a3624d9

File tree

3 files changed

+53
-34
lines changed

3 files changed

+53
-34
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/nodes/frame/ReadFrameNode.java

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ PFrame read(VirtualFrame frame, PFrame.Reference startFrameInfo, FrameInstance.F
161161
}
162162
int i = 0;
163163
PFrame.Reference curFrameInfo = startFrameInfo;
164-
while (true) {
164+
while (curFrameInfo != null) {
165165
if (curFrameInfo == PFrame.Reference.EMPTY) {
166166
// We reached the top of the stack
167167
return null;
@@ -208,7 +208,7 @@ PFrame read(VirtualFrame frame, PFrame.Reference startFrameInfo, FrameInstance.F
208208
}
209209

210210
private static PFrame.Reference getBackref(PFrame.Reference reference) {
211-
if (reference.getPyFrame() != null) {
211+
if (reference.getPyFrame() != null && reference.getPyFrame().getBackref() != null) {
212212
return reference.getPyFrame().getBackref();
213213
}
214214
return reference.getCallerInfo();
@@ -346,6 +346,7 @@ public static StackWalkResult getFrame(Node requestingNode, PFrame.Reference sta
346346
return Truffle.getRuntime().iterateFrames(new FrameInstanceVisitor<>() {
347347
int i = startFrame != null ? -1 : 0;
348348
boolean first = true;
349+
RootNode prevRootNode;
349350

350351
public StackWalkResult visitFrame(FrameInstance frameInstance) {
351352
RootNode rootNode = ReadFrameNode.getRootNode(frameInstance);
@@ -369,6 +370,7 @@ public StackWalkResult visitFrame(FrameInstance frameInstance) {
369370
// through thread state. We will eventually arrive at the Python frame that did
370371
// BoundaryCallContext.enter, find the IndirectCallData via the callNode and
371372
// tell it to pass the PFrame.Reference in thread state next time
373+
prevRootNode = rootNode;
372374
return null;
373375
}
374376
IndirectCallData.setCallerFlagsOnIndirectCallData(callNode, callerFlags);
@@ -386,17 +388,19 @@ public StackWalkResult visitFrame(FrameInstance frameInstance) {
386388
if (i == level) {
387389
Frame frame = ReadFrameNode.getFrame(frameInstance, frameAccess);
388390
assert PArguments.isPythonFrame(frame);
389-
pRootNode.updateCallerFlags(callerFlags);
391+
if (prevRootNode instanceof PRootNode prevPRootNode && prevPRootNode.setsUpCalleeContext()) {
392+
// Update the flags in the callee
393+
prevPRootNode.updateCallerFlags(callerFlags);
394+
}
390395
return new StackWalkResult(pRootNode, callNode, frame);
391396
}
392397
i += 1;
393398
}
394399
}
395400
// For any Python root node we traverse we need the PFrame.Reference to be passed in
396-
// call arguments next time. If we are at the frame
397-
// that we need, we still need caller frame info if our frame is escaped, see
398-
// CalleeContext#exitEscaped
401+
// call arguments next time.
399402
pRootNode.updateCallerFlags(CallerFlags.NEEDS_FRAME_REFERENCE);
403+
prevRootNode = pRootNode;
400404
return null; // if 'null' continue iterating
401405
}
402406
});

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/runtime/ExecutionContext.java

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import com.oracle.graal.python.nodes.frame.ReadFrameNode;
5454
import com.oracle.graal.python.nodes.util.ExceptionStateNodes.GetCaughtExceptionNode;
5555
import com.oracle.graal.python.runtime.IndirectCallData.BoundaryCallData;
56+
import com.oracle.graal.python.runtime.IndirectCallData.IndirectCallDataBase;
5657
import com.oracle.graal.python.runtime.IndirectCallData.InteropCallData;
5758
import com.oracle.graal.python.runtime.PythonContext.PythonThreadState;
5859
import com.oracle.graal.python.runtime.exception.PException;
@@ -513,7 +514,7 @@ public static Object enter(VirtualFrame frame, BoundaryCallData boundaryCallData
513514

514515
private static Object enterWithPythonFrame(VirtualFrame frame, BoundaryCallData boundaryCallData, PythonThreadState pythonThreadState) {
515516
assert frame != null;
516-
return IndirectCallContext.enterWithPythonFrame(frame, boundaryCallData, pythonThreadState, boundaryCallData.getCallerFlags(), EMPTY_SAVED_STATE);
517+
return IndirectCallContext.enterWithPythonFrame(frame, boundaryCallData, boundaryCallData, pythonThreadState, boundaryCallData.getCallerFlags(), EMPTY_SAVED_STATE);
517518
}
518519

519520
public static void exit(VirtualFrame frame, PythonLanguage language, PythonContext context, Object savedState) {
@@ -596,7 +597,7 @@ public static Object enter(VirtualFrame frame, Node node, InteropCallData callDa
596597

597598
private static Object enterWithPythonFrame(VirtualFrame frame, InteropCallData callData, PythonThreadState pythonThreadState) {
598599
assert frame != null;
599-
return IndirectCallContext.enterWithPythonFrame(frame, null, pythonThreadState, callData.getCallerFlags(), null);
600+
return IndirectCallContext.enterWithPythonFrame(frame, callData, null, pythonThreadState, callData.getCallerFlags(), null);
600601
}
601602

602603
public static void exit(VirtualFrame frame, PythonLanguage language, PythonContext context, Object savedState) {
@@ -635,8 +636,8 @@ public static void exit(VirtualFrame frame, PythonThreadState pythonThreadState,
635636

636637
// Common code shared by BoundaryCallContext and InteropCallContext
637638
public abstract static class IndirectCallContext {
638-
private static Object enterWithPythonFrame(VirtualFrame frame, Node encapsulatingNodeToPush, PythonThreadState pythonThreadState,
639-
int callerFlags, Object defaultReturn) {
639+
private static Object enterWithPythonFrame(VirtualFrame frame, IndirectCallDataBase callData, Node encapsulatingNodeToPush,
640+
PythonThreadState pythonThreadState, int callerFlags, Object defaultReturn) {
640641
CompilerAsserts.partialEvaluationConstant(encapsulatingNodeToPush == null);
641642
CompilerAsserts.partialEvaluationConstant(defaultReturn);
642643
if (callerFlags == 0) {
@@ -654,18 +655,24 @@ private static Object enterWithPythonFrame(VirtualFrame frame, Node encapsulatin
654655
}
655656
}
656657

657-
return enterSlowPath(frame, encapsulatingNodeToPush, pythonThreadState, callerFlags, defaultReturn);
658+
return enterSlowPath(frame, callData, encapsulatingNodeToPush, pythonThreadState, callerFlags, defaultReturn);
658659
}
659660

660-
private static Object enterSlowPath(VirtualFrame frame, Node encapsulatingNodeToPush, PythonThreadState pythonThreadState,
661-
int callerFlags, Object defaultReturn) {
661+
private static Object enterSlowPath(VirtualFrame frame, IndirectCallDataBase callData, Node encapsulatingNodeToPush,
662+
PythonThreadState pythonThreadState, int callerFlags, Object defaultReturn) {
662663
PFrame.Reference info = null;
663664
if (CallerFlags.needsFrameReference(callerFlags)) {
664665
PFrame.Reference prev = pythonThreadState.popTopFrameInfo();
665666
assert prev == null : "trying to call from Python to a foreign function, but we didn't clear the topframeref. " +
666667
"This indicates that a call into Python code happened without a proper enter through IndirectCalleeContext";
667668
info = PArguments.getCurrentFrameInfo(frame);
668669
pythonThreadState.setTopFrameInfo(info);
670+
if (CallerFlags.needsPFrame(callerFlags)) {
671+
callData.getMaterializeFrameNode().executeOnStack(false, CallerFlags.needsLocals(callerFlags), frame);
672+
} else if (info.getPyFrame() != null) {
673+
// Avoid passing stale locals
674+
info.getPyFrame().setLocals(null);
675+
}
669676
}
670677
AbstractTruffleException curExc = pythonThreadState.getCaughtException();
671678
AbstractTruffleException exceptionState = PArguments.getException(frame);

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/runtime/IndirectCallData.java

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
package com.oracle.graal.python.runtime;
4242

4343
import com.oracle.graal.python.PythonLanguage;
44+
import com.oracle.graal.python.nodes.frame.MaterializeFrameNode;
4445
import com.oracle.truffle.api.Assumption;
4546
import com.oracle.truffle.api.CompilerDirectives;
4647
import com.oracle.truffle.api.CompilerDirectives.CompilationFinal;
@@ -154,6 +155,32 @@ private static Assumption createAssumption() {
154155
}
155156
}
156157

158+
public abstract static class IndirectCallDataBase extends Node {
159+
@CompilationFinal protected CallerFlagsAssumptionSet callerFlagsAssumptionSet = new CallerFlagsAssumptionSet();
160+
@Child private MaterializeFrameNode materializeFrameNode;
161+
162+
public int getCallerFlags() {
163+
return callerFlagsAssumptionSet.getCallerFlags();
164+
}
165+
166+
public void updateCallerFlags(int callerFlags) {
167+
callerFlagsAssumptionSet.updateCallerFlags(callerFlags);
168+
}
169+
170+
public abstract boolean isUncached();
171+
172+
public MaterializeFrameNode getMaterializeFrameNode() {
173+
if (isUncached()) {
174+
return MaterializeFrameNode.getUncached();
175+
}
176+
if (materializeFrameNode == null) {
177+
CompilerDirectives.transferToInterpreterAndInvalidate();
178+
materializeFrameNode = insert(MaterializeFrameNode.create());
179+
}
180+
return materializeFrameNode;
181+
}
182+
}
183+
157184
/**
158185
* Truffle interop overrides the {@link com.oracle.truffle.api.nodes.EncapsulatingNodeReference}
159186
* with its own node when doing transition to uncached, or it does not transition to uncached at
@@ -162,8 +189,7 @@ private static Assumption createAssumption() {
162189
* <p>
163190
* This scheme is used also for our interop buffer Truffle libraries.
164191
*/
165-
public static final class InteropCallData {
166-
@CompilationFinal private CallerFlagsAssumptionSet callerFlagsAssumptionSet = new CallerFlagsAssumptionSet();
192+
public static final class InteropCallData extends IndirectCallDataBase {
167193

168194
private static final InteropCallData UNCACHED = new InteropCallData();
169195

@@ -175,14 +201,6 @@ public boolean isUncached() {
175201
return this == UNCACHED;
176202
}
177203

178-
public int getCallerFlags() {
179-
return callerFlagsAssumptionSet.getCallerFlags();
180-
}
181-
182-
public void updateCallerFlags(int callerFlags) {
183-
callerFlagsAssumptionSet.updateCallerFlags(callerFlags);
184-
}
185-
186204
@NeverDefault
187205
public static InteropCallData createFor(Node node) {
188206
return PythonLanguage.createInteropCallData(node);
@@ -202,9 +220,7 @@ public static InteropCallData getUncached() {
202220
* situation we still need to maintain a mapping of this node and its parent like for
203221
* {@link InteropCallData}.
204222
*/
205-
public static final class BoundaryCallData extends Node {
206-
207-
@CompilationFinal private CallerFlagsAssumptionSet callerFlagsAssumptionSet = new CallerFlagsAssumptionSet();
223+
public static final class BoundaryCallData extends IndirectCallDataBase {
208224

209225
private static final BoundaryCallData UNCACHED = new BoundaryCallData();
210226

@@ -216,14 +232,6 @@ public boolean isUncached() {
216232
return this == UNCACHED;
217233
}
218234

219-
public int getCallerFlags() {
220-
return callerFlagsAssumptionSet.getCallerFlags();
221-
}
222-
223-
public void updateCallerFlags(int callerFlags) {
224-
callerFlagsAssumptionSet.updateCallerFlags(callerFlags);
225-
}
226-
227235
@NeverDefault
228236
public static BoundaryCallData createFor(Node node) {
229237
return PythonLanguage.createBoundaryCallData(node);

0 commit comments

Comments
 (0)