7171import com .oracle .graal .python .builtins .objects .buffer .PythonBufferAccessLibrary ;
7272import com .oracle .graal .python .builtins .objects .buffer .PythonBufferAcquireLibrary ;
7373import com .oracle .graal .python .builtins .objects .bytes .PByteArray ;
74- import com .oracle .graal .python .builtins .objects .code .CodeNodes .CreateCodeNode ;
7574import com .oracle .graal .python .builtins .objects .code .PCode ;
7675import com .oracle .graal .python .builtins .objects .common .EconomicMapStorage ;
7776import com .oracle .graal .python .builtins .objects .common .HashingStorage ;
120119import com .oracle .graal .python .nodes .function .PythonBuiltinNode ;
121120import com .oracle .graal .python .nodes .function .builtins .PythonBinaryClinicBuiltinNode ;
122121import com .oracle .graal .python .nodes .function .builtins .PythonTernaryClinicBuiltinNode ;
123- import com .oracle .graal .python .nodes .function .builtins .PythonUnaryClinicBuiltinNode ;
124122import com .oracle .graal .python .nodes .function .builtins .clinic .ArgumentClinicProvider ;
125123import com .oracle .graal .python .runtime .ExecutionContext .BoundaryCallContext ;
126124import com .oracle .graal .python .runtime .IndirectCallData .BoundaryCallData ;
131129import com .oracle .graal .python .runtime .sequence .storage .ByteSequenceStorage ;
132130import com .oracle .graal .python .runtime .sequence .storage .SequenceStorage ;
133131import com .oracle .graal .python .util .PythonUtils ;
132+ import com .oracle .truffle .api .CallTarget ;
134133import com .oracle .truffle .api .CompilerAsserts ;
135134import com .oracle .truffle .api .CompilerDirectives ;
136135import com .oracle .truffle .api .CompilerDirectives .TruffleBoundary ;
136+ import com .oracle .truffle .api .RootCallTarget ;
137137import com .oracle .truffle .api .bytecode .BytecodeConfig ;
138138import com .oracle .truffle .api .bytecode .BytecodeRootNodes ;
139139import com .oracle .truffle .api .bytecode .serialization .BytecodeDeserializer ;
@@ -261,20 +261,28 @@ static Object doit(VirtualFrame frame, Object file,
261261 }
262262 }
263263
264- @ Builtin (name = "loads" , minNumOfPositionalArgs = 1 , numOfPositionalOnlyArgs = 1 , parameterNames = {"bytes" })
264+ // cache_key is a GraalPy-specific keyword
265+ @ Builtin (name = "loads" , minNumOfPositionalArgs = 1 , numOfPositionalOnlyArgs = 1 , parameterNames = {"bytes" }, keywordOnlyNames = {"cache_key" })
265266 @ ArgumentClinic (name = "bytes" , conversion = ClinicConversion .ReadableBuffer )
267+ @ ArgumentClinic (name = "cache_key" , conversion = ClinicConversion .Long , defaultValue = "0" )
266268 @ GenerateNodeFactory
267- abstract static class LoadsNode extends PythonUnaryClinicBuiltinNode {
269+ abstract static class LoadsNode extends PythonBinaryClinicBuiltinNode {
268270
269271 @ Specialization
270- static Object doit (VirtualFrame frame , Object buffer ,
272+ static Object doit (VirtualFrame frame , Object buffer , long cacheKey ,
271273 @ Bind Node inliningTarget ,
272274 @ Bind PythonContext context ,
273275 @ Cached ("createFor($node)" ) InteropCallData callData ,
274276 @ CachedLibrary (limit = "3" ) PythonBufferAccessLibrary bufferLib ,
275277 @ Cached PRaiseNode raiseNode ) {
278+ PythonLanguage language = context .getLanguage (inliningTarget );
276279 try {
277- return Marshal .load (context , bufferLib .getInternalOrCopiedByteArray (buffer ), bufferLib .getBufferLength (buffer ));
280+ byte [] bytes = bufferLib .getInternalOrCopiedByteArray (buffer );
281+ int length = bufferLib .getBufferLength (buffer );
282+ if (!language .isSingleContext () && cacheKey < 0 ) {
283+ cacheKey = language .cacheKeyForBytecode (bytes , length );
284+ }
285+ return Marshal .load (context , bytes , length , cacheKey );
278286 } catch (NumberFormatException e ) {
279287 throw raiseNode .raise (inliningTarget , ValueError , ErrorMessages .BAD_MARSHAL_DATA_S , e .getMessage ());
280288 } catch (Marshal .MarshalError me ) {
@@ -384,8 +392,8 @@ static byte[] dump(PythonContext context, Object value, int version) throws IOEx
384392 }
385393
386394 @ TruffleBoundary
387- static Object load (PythonContext context , byte [] ary , int length ) throws NumberFormatException , MarshalError {
388- Marshal inMarshal = new Marshal (context , ary , length );
395+ static Object load (PythonContext context , byte [] ary , int length , long cacheKey ) throws NumberFormatException , MarshalError {
396+ Marshal inMarshal = new Marshal (context , ary , length , cacheKey );
389397 Object result = inMarshal .readObject ();
390398 if (result == null ) {
391399 throw new MarshalError (PythonBuiltinClassType .TypeError , ErrorMessages .BAD_MARSHAL_DATA_NULL );
@@ -459,6 +467,7 @@ public int read(byte[] b, int off, int len) {
459467 final PInt pyTrue ;
460468 final PInt pyFalse ;
461469 int depth = 0 ;
470+ long cacheKey ;
462471 /*
463472 * A DSL node needs access to its Source during deserialization, but we do not wish to
464473 * actually encode it in the serialized representation. Instead, we supply a Source to the
@@ -490,8 +499,9 @@ public int read(byte[] b, int off, int len) {
490499 this .refList = null ;
491500 }
492501
493- Marshal (PythonContext context , byte [] in , int length ) {
502+ Marshal (PythonContext context , byte [] in , int length , long cacheKey ) {
494503 this (context , SerializationUtils .createDataInput (ByteBuffer .wrap (in , 0 , length )), null );
504+ this .cacheKey = cacheKey ;
495505 }
496506
497507 Marshal (PythonContext context , Object in ) {
@@ -908,14 +918,13 @@ private void writeComplexObject(Object v, int flag) {
908918 writeByte (TYPE_ARRAY | flag );
909919 writeByte (ARRAY_TYPE_OBJECT );
910920 writeObjectArray ((Object []) v );
911- } else if (v instanceof PCode ) {
921+ } else if (v instanceof PCode c ) {
912922 // we always store code objects in our format, CPython will not read our
913923 // marshalled data when that contains code objects
914- PCode c = (PCode ) v ;
915924 writeByte (TYPE_GRAALPYTHON_CODE | flag );
916925 writeString (c .getFilename ());
917926 writeInt (c .getFlags ());
918- writeBytes (c .getCodestring ());
927+ writeCodeUnit (c .getCodeUnit ());
919928 writeInt (c .getFirstLineNo ());
920929 byte [] lnotab = c .getLinetable ();
921930 if (lnotab == null ) {
@@ -1507,21 +1516,23 @@ private void writeBytecodeDSLCodeUnit(BytecodeDSLCodeUnit code) throws IOExcepti
15071516 private PCode readCode () {
15081517 TruffleString fileName = readString (true );
15091518 int flags = readInt ();
1510-
1511- int codeLen = readSize ();
1512- byte [] codeString = new byte [codeLen + Long .BYTES ];
1513- try {
1514- in .readFully (codeString , 0 , codeLen );
1515- } catch (IOException e ) {
1516- throw CompilerDirectives .shouldNotReachHere ();
1517- }
1518- // get a new ID every time we deserialize the same filename in the same context. We use
1519- // slow-path context lookup, since this code is likely dominated by the deserialization
1520- // time
1521- ByteBuffer .wrap (codeString ).putLong (codeLen , context .getDeserializationId (fileName ));
1519+ CodeUnit code = readCodeUnit ();
15221520 int firstLineNo = readInt ();
15231521 byte [] lnoTab = readBytes ();
1524- return CreateCodeNode .createCode (context , flags , codeString , fileName , firstLineNo , lnoTab );
1522+ com .oracle .graal .python .util .Supplier <CallTarget > supplier = () -> {
1523+ String jFilename = fileName .toJavaStringUncached ();
1524+ Source subSource = Source .newBuilder (PythonLanguage .ID , "" , jFilename ).content (Source .CONTENT_NONE ).build ();
1525+ return PythonLanguage .callTargetFromBytecode (context , subSource , code );
1526+ };
1527+ CallTarget callTarget ;
1528+ if (context .getLanguage ().isSingleContext () || cacheKey == 0 ) {
1529+ callTarget = supplier .get ();
1530+ } else {
1531+ // get a new ID every time we deserialize the same filename in the same context
1532+ long fullCacheKey = cacheKey + context .getDeserializationId (fileName );
1533+ callTarget = context .getLanguage ().cacheCode (new PythonLanguage .CodeCacheKey (fileName , fullCacheKey ), supplier );
1534+ }
1535+ return PFactory .createCode (context .getLanguage (), (RootCallTarget ) callTarget , flags , firstLineNo , lnoTab , fileName );
15251536 }
15261537 }
15271538
@@ -1541,7 +1552,7 @@ public static byte[] serializeCodeUnit(Node node, PythonContext context, CodeUni
15411552 @ TruffleBoundary
15421553 public static CodeUnit deserializeCodeUnit (Node node , PythonContext context , byte [] bytes ) {
15431554 try {
1544- Marshal marshal = new Marshal (context , bytes , bytes .length );
1555+ Marshal marshal = new Marshal (context , bytes , bytes .length , 0 );
15451556 return marshal .readCodeUnit ();
15461557 } catch (Marshal .MarshalError me ) {
15471558 throw PRaiseNode .raiseStatic (node , me .type , me .message , me .arguments );
0 commit comments