Skip to content

Commit b0a33ff

Browse files
committed
Improve unmarshalling performance
1 parent 2314319 commit b0a33ff

File tree

1 file changed

+37
-36
lines changed

1 file changed

+37
-36
lines changed

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/modules/MarshalModuleBuiltins.java

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
import static com.oracle.graal.python.util.PythonUtils.EMPTY_TRUFFLESTRING_ARRAY;
4141
import static com.oracle.graal.python.util.PythonUtils.TS_ENCODING;
4242

43-
import java.io.ByteArrayInputStream;
4443
import java.io.ByteArrayOutputStream;
4544
import java.io.DataInput;
4645
import java.io.DataInputStream;
@@ -51,6 +50,7 @@
5150
import java.io.InputStream;
5251
import java.math.BigInteger;
5352
import java.nio.ByteBuffer;
53+
import java.nio.ByteOrder;
5454
import java.nio.charset.StandardCharsets;
5555
import java.util.ArrayList;
5656
import java.util.HashMap;
@@ -147,7 +147,6 @@
147147
import com.oracle.truffle.api.dsl.Specialization;
148148
import com.oracle.truffle.api.frame.VirtualFrame;
149149
import com.oracle.truffle.api.library.CachedLibrary;
150-
import com.oracle.truffle.api.memory.ByteArraySupport;
151150
import com.oracle.truffle.api.nodes.Node;
152151
import com.oracle.truffle.api.source.Source;
153152
import com.oracle.truffle.api.strings.InternalByteArray;
@@ -349,10 +348,6 @@ static final class Marshal {
349348
private static final int MARSHAL_SHIFT = 15;
350349
private static final BigInteger MARSHAL_BASE = BigInteger.valueOf(1 << MARSHAL_SHIFT);
351350

352-
private static final int BYTES_PER_LONG = Long.SIZE / Byte.SIZE;
353-
private static final int BYTES_PER_INT = Integer.SIZE / Byte.SIZE;
354-
private static final int BYTES_PER_SHORT = Short.SIZE / Byte.SIZE;
355-
356351
/**
357352
* This class exists to throw errors out of the (un)marshalling code, without having to
358353
* construct Python exceptions (yet). Since the (un)marshalling code does not have nodes or
@@ -463,9 +458,6 @@ public int read(byte[] b, int off, int len) {
463458
final int version;
464459
final PInt pyTrue;
465460
final PInt pyFalse;
466-
// CPython's marshal code is little endian
467-
final ByteArraySupport baSupport = ByteArraySupport.littleEndian();
468-
byte[] buffer = new byte[Long.BYTES];
469461
int depth = 0;
470462
/*
471463
* A DSL node needs access to its Source during deserialization, but we do not wish to
@@ -499,7 +491,7 @@ public int read(byte[] b, int off, int len) {
499491
}
500492

501493
Marshal(PythonContext context, byte[] in, int length) {
502-
this(context, new DataInputStream(new ByteArrayInputStream(in, 0, length)), null);
494+
this(context, SerializationUtils.createDataInput(ByteBuffer.wrap(in, 0, length)), null);
503495
}
504496

505497
Marshal(PythonContext context, Object in) {
@@ -581,10 +573,7 @@ private byte[] readNBytes(int sz) {
581573
if (sz == 0) {
582574
return PythonUtils.EMPTY_BYTE_ARRAY;
583575
} else {
584-
if (buffer.length < sz) {
585-
buffer = new byte[sz];
586-
}
587-
return readNBytes(sz, buffer);
576+
return readNBytes(sz, new byte[sz]);
588577
}
589578
}
590579

@@ -620,11 +609,25 @@ private void writeShort(short v) {
620609
}
621610

622611
private int readInt() {
623-
return baSupport.getInt(readNBytes(BYTES_PER_INT), 0);
612+
try {
613+
int val = in.readInt();
614+
return ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN ? Integer.reverseBytes(val) : val;
615+
} catch (EOFException e) {
616+
throw new MarshalError(PythonBuiltinClassType.EOFError, ErrorMessages.BAD_MARSHAL_DATA_EOF);
617+
} catch (IOException e) {
618+
throw CompilerDirectives.shouldNotReachHere();
619+
}
624620
}
625621

626622
private short readShort() {
627-
return baSupport.getShort(readNBytes(BYTES_PER_SHORT), 0);
623+
try {
624+
short val = in.readShort();
625+
return ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN ? Short.reverseBytes(val) : val;
626+
} catch (EOFException e) {
627+
throw new MarshalError(PythonBuiltinClassType.EOFError, ErrorMessages.BAD_MARSHAL_DATA_EOF);
628+
} catch (IOException e) {
629+
throw CompilerDirectives.shouldNotReachHere();
630+
}
628631
}
629632

630633
private void writeLong(long v) {
@@ -634,7 +637,14 @@ private void writeLong(long v) {
634637
}
635638

636639
private long readLong() {
637-
return baSupport.getLong(readNBytes(BYTES_PER_LONG), 0);
640+
try {
641+
long val = in.readLong();
642+
return ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN ? Long.reverseBytes(val) : val;
643+
} catch (EOFException e) {
644+
throw new MarshalError(PythonBuiltinClassType.EOFError, ErrorMessages.BAD_MARSHAL_DATA_EOF);
645+
} catch (IOException e) {
646+
throw CompilerDirectives.shouldNotReachHere();
647+
}
638648
}
639649

640650
private void writeBigInteger(BigInteger v) {
@@ -662,6 +672,7 @@ private void writeBigInteger(BigInteger v) {
662672

663673
private BigInteger readBigInteger() {
664674
boolean negative;
675+
// size is in shorts
665676
int sz = readInt();
666677
if (sz < 0) {
667678
negative = true;
@@ -670,21 +681,11 @@ private BigInteger readBigInteger() {
670681
negative = false;
671682
}
672683

673-
// size is in shorts
674-
sz *= 2;
675-
676-
byte[] data = readNBytes(sz);
677-
678-
int i = 0;
679-
int digit = baSupport.getShort(data, i);
680-
i += 2;
684+
int digit = readShort();
681685
BigInteger result = BigInteger.valueOf(digit);
682-
683-
while (i < sz) {
684-
int power = i / 2;
685-
digit = baSupport.getShort(data, i);
686-
i += 2;
687-
result = result.add(BigInteger.valueOf(digit).multiply(MARSHAL_BASE.pow(power)));
686+
for (int i = 1; i < sz; i++) {
687+
digit = readShort();
688+
result = result.add(BigInteger.valueOf(digit).multiply(MARSHAL_BASE.pow(i)));
688689
}
689690
if (negative) {
690691
return result.negate();
@@ -1169,7 +1170,7 @@ private TruffleString readString(boolean intern) {
11691170
if (sz == 0) {
11701171
return StringLiterals.T_EMPTY_STRING;
11711172
}
1172-
var utf8String = TruffleString.fromByteArrayUncached(readNBytes(sz), 0, sz, Encoding.UTF_8, true);
1173+
var utf8String = TruffleString.fromByteArrayUncached(readNBytes(sz), 0, sz, Encoding.UTF_8, false);
11731174
var value = utf8String.switchEncodingUncached(TS_ENCODING, TranscodingErrorHandler.DEFAULT_KEEP_SURROGATES_IN_UTF8);
11741175
if (intern) {
11751176
return PythonUtils.internString(value);
@@ -1189,12 +1190,12 @@ private void writeShortString(String v) throws IOException {
11891190
private TruffleString readShortString() {
11901191
int sz = readByteSize();
11911192
byte[] bytes = readNBytes(sz);
1192-
return TruffleString.fromByteArrayUncached(bytes, 0, sz, Encoding.ISO_8859_1, true).switchEncodingUncached(TS_ENCODING);
1193+
return TruffleString.fromByteArrayUncached(bytes, 0, sz, Encoding.ISO_8859_1, false).switchEncodingUncached(TS_ENCODING);
11931194
}
11941195

1195-
private Object readAscii(long sz, boolean intern) {
1196-
byte[] bytes = readNBytes((int) sz);
1197-
TruffleString value = TruffleString.fromByteArrayUncached(bytes, 0, (int) sz, Encoding.US_ASCII, true).switchEncodingUncached(TS_ENCODING);
1196+
private Object readAscii(int sz, boolean intern) {
1197+
byte[] bytes = readNBytes(sz);
1198+
TruffleString value = TruffleString.fromByteArrayUncached(bytes, 0, sz, Encoding.US_ASCII, false).switchEncodingUncached(TS_ENCODING);
11981199
if (intern) {
11991200
return PythonUtils.internString(value);
12001201
} else {

0 commit comments

Comments
 (0)