Skip to content

Commit 31e8657

Browse files
committed
Load code from marshalled form without Env.parse
1 parent 3bdfd98 commit 31e8657

File tree

6 files changed

+76
-72
lines changed

6 files changed

+76
-72
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/PythonLanguage.java

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import static com.oracle.graal.python.nodes.StringLiterals.J_PY_EXTENSION;
3131
import static com.oracle.graal.python.nodes.StringLiterals.T_PY_EXTENSION;
3232
import static com.oracle.graal.python.nodes.truffle.TruffleStringMigrationHelpers.isJavaString;
33+
import static com.oracle.graal.python.util.PythonUtils.ARRAY_ACCESSOR;
3334
import static com.oracle.graal.python.util.PythonUtils.TS_ENCODING;
3435
import static com.oracle.graal.python.util.PythonUtils.tsLiteral;
3536

@@ -60,7 +61,7 @@
6061
import com.oracle.graal.python.annotations.PythonOS;
6162
import com.oracle.graal.python.builtins.Python3Core;
6263
import com.oracle.graal.python.builtins.PythonBuiltinClassType;
63-
import com.oracle.graal.python.builtins.modules.MarshalModuleBuiltins;
64+
import com.oracle.graal.python.builtins.modules.ImpModuleBuiltins;
6465
import com.oracle.graal.python.builtins.modules.SignalModuleBuiltins;
6566
import com.oracle.graal.python.builtins.objects.PNone;
6667
import com.oracle.graal.python.builtins.objects.PNotImplemented;
@@ -160,7 +161,6 @@
160161
"text/x-python-\2\u0100-eval", "text/x-python-\2\u0100-compile", "text/x-python-\0\u0040-eval", "text/x-python-\0\u0040-compile", "text/x-python-\1\u0040-eval",
161162
"text/x-python-\1\u0040-compile", "text/x-python-\2\u0040-eval", "text/x-python-\2\u0040-compile", "text/x-python-\0\u0140-eval", "text/x-python-\0\u0140-compile",
162163
"text/x-python-\1\u0140-eval", "text/x-python-\1\u0140-compile", "text/x-python-\2\u0140-eval", "text/x-python-\2\u0140-compile"}, //
163-
byteMimeTypes = {PythonLanguage.MIME_TYPE_BYTECODE}, //
164164
defaultMimeType = PythonLanguage.MIME_TYPE, //
165165
dependentLanguages = {"nfi", "llvm"}, //
166166
interactive = true, internal = false, //
@@ -312,8 +312,6 @@ private static boolean mimeTypesComplete(ArrayList<String> mimeJavaStrings) {
312312
assert mimeTypesComplete(mimeJavaStrings) : "Expected all of {" + String.join(", ", mimeJavaStrings) + "} in the PythonLanguage characterMimeTypes";
313313
}
314314

315-
public static final String MIME_TYPE_BYTECODE = "application/x-python-bytecode";
316-
317315
public static final TruffleString[] T_DEFAULT_PYTHON_EXTENSIONS = new TruffleString[]{T_PY_EXTENSION, tsLiteral(".pyc")};
318316

319317
public static final TruffleLogger LOGGER = TruffleLogger.getLogger(ID, PythonLanguage.class);
@@ -557,11 +555,6 @@ protected CallTarget parse(ParsingRequest request) {
557555
if (!request.getArgumentNames().isEmpty()) {
558556
throw new IllegalStateException("parse with arguments is only allowed for " + MIME_TYPE + " mime type");
559557
}
560-
if (MIME_TYPE_BYTECODE.equals(source.getMimeType())) {
561-
byte[] bytes = source.getBytes().toByteArray();
562-
CodeUnit code = MarshalModuleBuiltins.deserializeCodeUnit(null, context, bytes);
563-
return callTargetFromBytecode(context, source, code);
564-
}
565558

566559
String mime = source.getMimeType();
567560
String prefix = mime.substring(0, MIME_PREFIX.length());
@@ -586,7 +579,7 @@ protected CallTarget parse(ParsingRequest request) {
586579
return parse(context, source, type, false, optimize, false, null, FutureFeature.fromFlags(flags));
587580
}
588581

589-
public RootCallTarget callTargetFromBytecode(PythonContext context, Source source, CodeUnit code) {
582+
public static RootCallTarget callTargetFromBytecode(PythonContext context, Source source, CodeUnit code) {
590583
boolean internal = shouldMarkSourceInternal(context);
591584
SourceBuilder builder = null;
592585
// The original file path should be passed as the name
@@ -616,7 +609,7 @@ public RootCallTarget callTargetFromBytecode(PythonContext context, Source sourc
616609
// TODO lazily load source in bytecode DSL interpreter too
617610
rootNode = ((BytecodeDSLCodeUnit) code).createRootNode(context, lazySource.getSource());
618611
} else {
619-
rootNode = PBytecodeRootNode.create(this, (BytecodeCodeUnit) code, lazySource, internal);
612+
rootNode = PBytecodeRootNode.create(context.getLanguage(), (BytecodeCodeUnit) code, lazySource, internal);
620613
}
621614

622615
return PythonUtils.getOrCreateCallTarget(rootNode);
@@ -999,10 +992,17 @@ protected void initializeMultipleContexts() {
999992
singleContext = false;
1000993
}
1001994

1002-
private final ConcurrentHashMap<TruffleString, CallTarget> cachedCode = new ConcurrentHashMap<>();
995+
public record CodeCacheKey(TruffleString filename, long codeHash) {
996+
}
997+
998+
private final ConcurrentHashMap<CodeCacheKey, CallTarget> cachedCode = new ConcurrentHashMap<>();
1003999

1004-
@TruffleBoundary
10051000
public CallTarget cacheCode(TruffleString filename, Supplier<CallTarget> createCode) {
1001+
return cacheCode(new CodeCacheKey(filename, 0), createCode);
1002+
}
1003+
1004+
@TruffleBoundary
1005+
public CallTarget cacheCode(CodeCacheKey filename, Supplier<CallTarget> createCode) {
10061006
if (!singleContext) {
10071007
return cachedCode.computeIfAbsent(filename, f -> {
10081008
LOGGER.log(Level.FINEST, () -> "Caching CallTarget for " + filename);
@@ -1013,6 +1013,19 @@ public CallTarget cacheCode(TruffleString filename, Supplier<CallTarget> createC
10131013
}
10141014
}
10151015

1016+
public long cacheKeyForBytecode(byte[] code, int length) {
1017+
if (singleContext) {
1018+
// No caching in single context
1019+
return 0;
1020+
}
1021+
byte[] hashBytes = ImpModuleBuiltins.SourceHashNode.hashSource(0, code, length);
1022+
return ARRAY_ACCESSOR.getLong(hashBytes, 0);
1023+
}
1024+
1025+
public long cacheKeyForBytecode(byte[] code) {
1026+
return cacheKeyForBytecode(code, code.length);
1027+
}
1028+
10161029
private static final Source LINEBREAK_REGEX_SOURCE = Source.newBuilder("regex", "/\r\n|[\n\u000B\u000C\r\u0085\u2028\u2029]/", "re_linebreak") //
10171030
.option("regex.Flavor", "Python") //
10181031
.option("regex.Encoding", "UTF-32") //

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/ImpModuleBuiltins.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ static Object run(TruffleString name, Object dataObj,
484484
Object code = null;
485485

486486
try {
487-
code = MarshalModuleBuiltins.Marshal.load(context, bytes, size);
487+
code = MarshalModuleBuiltins.Marshal.load(context, bytes, size, 0);
488488
} catch (MarshalError | NumberFormatException e) {
489489
raiseFrozenError(inliningTarget, raiseNode, FROZEN_INVALID, name);
490490
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/MarshalModuleBuiltins.java

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171
import com.oracle.graal.python.builtins.objects.buffer.PythonBufferAccessLibrary;
7272
import com.oracle.graal.python.builtins.objects.buffer.PythonBufferAcquireLibrary;
7373
import com.oracle.graal.python.builtins.objects.bytes.PByteArray;
74-
import com.oracle.graal.python.builtins.objects.code.CodeNodes.CreateCodeNode;
7574
import com.oracle.graal.python.builtins.objects.code.PCode;
7675
import com.oracle.graal.python.builtins.objects.common.EconomicMapStorage;
7776
import com.oracle.graal.python.builtins.objects.common.HashingStorage;
@@ -120,7 +119,6 @@
120119
import com.oracle.graal.python.nodes.function.PythonBuiltinNode;
121120
import com.oracle.graal.python.nodes.function.builtins.PythonBinaryClinicBuiltinNode;
122121
import com.oracle.graal.python.nodes.function.builtins.PythonTernaryClinicBuiltinNode;
123-
import com.oracle.graal.python.nodes.function.builtins.PythonUnaryClinicBuiltinNode;
124122
import com.oracle.graal.python.nodes.function.builtins.clinic.ArgumentClinicProvider;
125123
import com.oracle.graal.python.runtime.ExecutionContext.BoundaryCallContext;
126124
import com.oracle.graal.python.runtime.IndirectCallData.BoundaryCallData;
@@ -131,9 +129,11 @@
131129
import com.oracle.graal.python.runtime.sequence.storage.ByteSequenceStorage;
132130
import com.oracle.graal.python.runtime.sequence.storage.SequenceStorage;
133131
import com.oracle.graal.python.util.PythonUtils;
132+
import com.oracle.truffle.api.CallTarget;
134133
import com.oracle.truffle.api.CompilerAsserts;
135134
import com.oracle.truffle.api.CompilerDirectives;
136135
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
136+
import com.oracle.truffle.api.RootCallTarget;
137137
import com.oracle.truffle.api.bytecode.BytecodeConfig;
138138
import com.oracle.truffle.api.bytecode.BytecodeRootNodes;
139139
import com.oracle.truffle.api.bytecode.serialization.BytecodeDeserializer;
@@ -261,20 +261,28 @@ static Object doit(VirtualFrame frame, Object file,
261261
}
262262
}
263263

264-
@Builtin(name = "loads", minNumOfPositionalArgs = 1, numOfPositionalOnlyArgs = 1, parameterNames = {"bytes"})
264+
// cache_key is a GraalPy-specific keyword
265+
@Builtin(name = "loads", minNumOfPositionalArgs = 1, numOfPositionalOnlyArgs = 1, parameterNames = {"bytes"}, keywordOnlyNames = {"cache_key"})
265266
@ArgumentClinic(name = "bytes", conversion = ClinicConversion.ReadableBuffer)
267+
@ArgumentClinic(name = "cache_key", conversion = ClinicConversion.Long, defaultValue = "0")
266268
@GenerateNodeFactory
267-
abstract static class LoadsNode extends PythonUnaryClinicBuiltinNode {
269+
abstract static class LoadsNode extends PythonBinaryClinicBuiltinNode {
268270

269271
@Specialization
270-
static Object doit(VirtualFrame frame, Object buffer,
272+
static Object doit(VirtualFrame frame, Object buffer, long cacheKey,
271273
@Bind Node inliningTarget,
272274
@Bind PythonContext context,
273275
@Cached("createFor($node)") InteropCallData callData,
274276
@CachedLibrary(limit = "3") PythonBufferAccessLibrary bufferLib,
275277
@Cached PRaiseNode raiseNode) {
278+
PythonLanguage language = context.getLanguage(inliningTarget);
276279
try {
277-
return Marshal.load(context, bufferLib.getInternalOrCopiedByteArray(buffer), bufferLib.getBufferLength(buffer));
280+
byte[] bytes = bufferLib.getInternalOrCopiedByteArray(buffer);
281+
int length = bufferLib.getBufferLength(buffer);
282+
if (!language.isSingleContext() && cacheKey < 0) {
283+
cacheKey = language.cacheKeyForBytecode(bytes, length);
284+
}
285+
return Marshal.load(context, bytes, length, cacheKey);
278286
} catch (NumberFormatException e) {
279287
throw raiseNode.raise(inliningTarget, ValueError, ErrorMessages.BAD_MARSHAL_DATA_S, e.getMessage());
280288
} catch (Marshal.MarshalError me) {
@@ -384,8 +392,8 @@ static byte[] dump(PythonContext context, Object value, int version) throws IOEx
384392
}
385393

386394
@TruffleBoundary
387-
static Object load(PythonContext context, byte[] ary, int length) throws NumberFormatException, MarshalError {
388-
Marshal inMarshal = new Marshal(context, ary, length);
395+
static Object load(PythonContext context, byte[] ary, int length, long cacheKey) throws NumberFormatException, MarshalError {
396+
Marshal inMarshal = new Marshal(context, ary, length, cacheKey);
389397
Object result = inMarshal.readObject();
390398
if (result == null) {
391399
throw new MarshalError(PythonBuiltinClassType.TypeError, ErrorMessages.BAD_MARSHAL_DATA_NULL);
@@ -459,6 +467,7 @@ public int read(byte[] b, int off, int len) {
459467
final PInt pyTrue;
460468
final PInt pyFalse;
461469
int depth = 0;
470+
long cacheKey;
462471
/*
463472
* A DSL node needs access to its Source during deserialization, but we do not wish to
464473
* actually encode it in the serialized representation. Instead, we supply a Source to the
@@ -490,8 +499,9 @@ public int read(byte[] b, int off, int len) {
490499
this.refList = null;
491500
}
492501

493-
Marshal(PythonContext context, byte[] in, int length) {
502+
Marshal(PythonContext context, byte[] in, int length, long cacheKey) {
494503
this(context, SerializationUtils.createDataInput(ByteBuffer.wrap(in, 0, length)), null);
504+
this.cacheKey = cacheKey;
495505
}
496506

497507
Marshal(PythonContext context, Object in) {
@@ -908,14 +918,13 @@ private void writeComplexObject(Object v, int flag) {
908918
writeByte(TYPE_ARRAY | flag);
909919
writeByte(ARRAY_TYPE_OBJECT);
910920
writeObjectArray((Object[]) v);
911-
} else if (v instanceof PCode) {
921+
} else if (v instanceof PCode c) {
912922
// we always store code objects in our format, CPython will not read our
913923
// marshalled data when that contains code objects
914-
PCode c = (PCode) v;
915924
writeByte(TYPE_GRAALPYTHON_CODE | flag);
916925
writeString(c.getFilename());
917926
writeInt(c.getFlags());
918-
writeBytes(c.getCodestring());
927+
writeCodeUnit(c.getCodeUnit());
919928
writeInt(c.getFirstLineNo());
920929
byte[] lnotab = c.getLinetable();
921930
if (lnotab == null) {
@@ -1507,21 +1516,23 @@ private void writeBytecodeDSLCodeUnit(BytecodeDSLCodeUnit code) throws IOExcepti
15071516
private PCode readCode() {
15081517
TruffleString fileName = readString(true);
15091518
int flags = readInt();
1510-
1511-
int codeLen = readSize();
1512-
byte[] codeString = new byte[codeLen + Long.BYTES];
1513-
try {
1514-
in.readFully(codeString, 0, codeLen);
1515-
} catch (IOException e) {
1516-
throw CompilerDirectives.shouldNotReachHere();
1517-
}
1518-
// get a new ID every time we deserialize the same filename in the same context. We use
1519-
// slow-path context lookup, since this code is likely dominated by the deserialization
1520-
// time
1521-
ByteBuffer.wrap(codeString).putLong(codeLen, context.getDeserializationId(fileName));
1519+
CodeUnit code = readCodeUnit();
15221520
int firstLineNo = readInt();
15231521
byte[] lnoTab = readBytes();
1524-
return CreateCodeNode.createCode(context, flags, codeString, fileName, firstLineNo, lnoTab);
1522+
com.oracle.graal.python.util.Supplier<CallTarget> supplier = () -> {
1523+
String jFilename = fileName.toJavaStringUncached();
1524+
Source subSource = Source.newBuilder(PythonLanguage.ID, "", jFilename).content(Source.CONTENT_NONE).build();
1525+
return PythonLanguage.callTargetFromBytecode(context, subSource, code);
1526+
};
1527+
CallTarget callTarget;
1528+
if (context.getLanguage().isSingleContext() || cacheKey == 0) {
1529+
callTarget = supplier.get();
1530+
} else {
1531+
// get a new ID every time we deserialize the same filename in the same context
1532+
long fullCacheKey = cacheKey + context.getDeserializationId(fileName);
1533+
callTarget = context.getLanguage().cacheCode(new PythonLanguage.CodeCacheKey(fileName, fullCacheKey), supplier);
1534+
}
1535+
return PFactory.createCode(context.getLanguage(), (RootCallTarget) callTarget, flags, firstLineNo, lnoTab, fileName);
15251536
}
15261537
}
15271538

@@ -1541,7 +1552,7 @@ public static byte[] serializeCodeUnit(Node node, PythonContext context, CodeUni
15411552
@TruffleBoundary
15421553
public static CodeUnit deserializeCodeUnit(Node node, PythonContext context, byte[] bytes) {
15431554
try {
1544-
Marshal marshal = new Marshal(context, bytes, bytes.length);
1555+
Marshal marshal = new Marshal(context, bytes, bytes.length, 0);
15451556
return marshal.readCodeUnit();
15461557
} catch (Marshal.MarshalError me) {
15471558
throw PRaiseNode.raiseStatic(node, me.type, me.message, me.arguments);

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/code/CodeNodes.java

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,6 @@
4242

4343
import java.util.Arrays;
4444

45-
import org.graalvm.polyglot.io.ByteSequence;
46-
4745
import com.oracle.graal.python.PythonLanguage;
4846
import com.oracle.graal.python.builtins.modules.MarshalModuleBuiltins;
4947
import com.oracle.graal.python.builtins.objects.code.CodeNodesFactory.GetCodeRootNodeGen;
@@ -64,8 +62,6 @@
6462
import com.oracle.graal.python.runtime.object.PFactory;
6563
import com.oracle.graal.python.util.LazySource;
6664
import com.oracle.graal.python.util.PythonUtils;
67-
import com.oracle.graal.python.util.Supplier;
68-
import com.oracle.truffle.api.CallTarget;
6965
import com.oracle.truffle.api.CompilerAsserts;
7066
import com.oracle.truffle.api.CompilerDirectives;
7167
import com.oracle.truffle.api.CompilerDirectives.TruffleBoundary;
@@ -79,7 +75,6 @@
7975
import com.oracle.truffle.api.frame.VirtualFrame;
8076
import com.oracle.truffle.api.nodes.Node;
8177
import com.oracle.truffle.api.nodes.RootNode;
82-
import com.oracle.truffle.api.source.Source;
8378
import com.oracle.truffle.api.strings.TruffleString;
8479

8580
public abstract class CodeNodes {
@@ -146,7 +141,7 @@ private static PCode createCode(PythonLanguage language, PythonContext context,
146141
parameterNames,
147142
kwOnlyNames);
148143
} else {
149-
ct = create().deserializeForBytecodeInterpreter(language, context, codedata, cellvars, freevars);
144+
ct = deserializeForBytecodeInterpreter(context, codedata, cellvars, freevars);
150145
signature = ((PRootNode) ct.getRootNode()).getSignature();
151146
}
152147
if (filename != null) {
@@ -155,10 +150,9 @@ private static PCode createCode(PythonLanguage language, PythonContext context,
155150
return PFactory.createCode(language, ct, signature, nlocals, stacksize, flags, constants, names, varnames, freevars, cellvars, filename, name, qualname, firstlineno, linetable);
156151
}
157152

158-
@SuppressWarnings("static-method")
159-
private RootCallTarget deserializeForBytecodeInterpreter(PythonLanguage language, PythonContext context, byte[] data, TruffleString[] cellvars, TruffleString[] freevars) {
153+
private static RootCallTarget deserializeForBytecodeInterpreter(PythonContext context, byte[] data, TruffleString[] cellvars, TruffleString[] freevars) {
160154
CodeUnit codeUnit = MarshalModuleBuiltins.deserializeCodeUnit(null, context, data);
161-
RootNode rootNode = null;
155+
RootNode rootNode;
162156

163157
if (PythonOptions.ENABLE_BYTECODE_DSL_INTERPRETER) {
164158
BytecodeDSLCodeUnit code = (BytecodeDSLCodeUnit) codeUnit;
@@ -182,25 +176,6 @@ private RootCallTarget deserializeForBytecodeInterpreter(PythonLanguage language
182176
return PythonUtils.getOrCreateCallTarget(rootNode);
183177
}
184178

185-
@TruffleBoundary
186-
public static PCode createCode(PythonContext context, int flags, byte[] codedata, TruffleString filename, int firstlineno, byte[] lnotab) {
187-
boolean isNotAModule = (flags & PCode.CO_GRAALPYHON_MODULE) == 0;
188-
String jFilename = filename.toJavaStringUncached();
189-
PythonLanguage language = context.getLanguage();
190-
Supplier<CallTarget> createCode = () -> {
191-
ByteSequence bytes = ByteSequence.create(codedata);
192-
Source source = Source.newBuilder(PythonLanguage.ID, bytes, jFilename).mimeType(PythonLanguage.MIME_TYPE_BYTECODE).cached(!language.isSingleContext()).build();
193-
return context.getEnv().parsePublic(source);
194-
};
195-
196-
if (context.isCoreInitialized() || isNotAModule) {
197-
return PFactory.createCode(language, (RootCallTarget) createCode.get(), flags, firstlineno, lnotab, filename);
198-
} else {
199-
RootCallTarget ct = (RootCallTarget) language.cacheCode(filename, createCode);
200-
return PFactory.createCode(language, ct, flags, firstlineno, lnotab, filename);
201-
}
202-
}
203-
204179
@NeverDefault
205180
public static CreateCodeNode create() {
206181
return new CreateCodeNode();

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/compiler/Compiler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@
228228
* Compiler for bytecode interpreter.
229229
*/
230230
public class Compiler implements SSTreeVisitor<Void> {
231-
public static final int BYTECODE_VERSION = 31;
231+
public static final int BYTECODE_VERSION = 32;
232232

233233
private final ParserCallbacks parserCallbacks;
234234

0 commit comments

Comments
 (0)