Skip to content

Commit 09b3155

Browse files
author
Vadim Platonov
committed
[Rust] Correctly determine struct size
1 parent f2c9b07 commit 09b3155

File tree

2 files changed

+62
-73
lines changed

2 files changed

+62
-73
lines changed

sbe-tool/src/main/java/uk/co/real_logic/sbe/generation/rust/RustGenerator.java

Lines changed: 46 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import java.io.IOException;
2626
import java.io.Writer;
2727
import java.util.*;
28-
import java.util.stream.Collectors;
2928
import java.util.stream.IntStream;
3029

3130
import static java.lang.String.format;
@@ -69,16 +68,15 @@ public void generate() throws IOException
6968
final MessageComponents components = MessageComponents.collectMessageComponents(tokens);
7069
final String messageTypeName = formatTypeName(components.messageToken.name());
7170

72-
final Optional<FieldsRepresentationSummary> fieldsRepresentation =
73-
generateFieldsRepresentation(messageTypeName, components, outputManager);
71+
final RustStruct fieldStruct = generateMessageFieldStruct(messageTypeName, components, outputManager);
7472
generateMessageHeaderDefault(ir, outputManager, components.messageToken);
7573

7674
// Avoid the work of recomputing the group tree twice per message
7775
final List<GroupTreeNode> groupTree = buildGroupTrees(messageTypeName, components.groups);
7876
generateGroupFieldRepresentations(outputManager, groupTree);
7977

80-
generateMessageDecoder(outputManager, components, groupTree, fieldsRepresentation, headerSize);
81-
generateMessageEncoder(outputManager, components, groupTree, fieldsRepresentation, headerSize);
78+
generateMessageDecoder(outputManager, components, groupTree, fieldStruct, headerSize);
79+
generateMessageEncoder(outputManager, components, groupTree, fieldStruct, headerSize);
8280
}
8381
}
8482

@@ -117,68 +115,26 @@ private void generateGroupFieldRepresentations(
117115
}
118116
}
119117

120-
private static final class FieldsRepresentationSummary
121-
{
122-
final String typeName;
123-
final int numBytes;
124-
125-
private FieldsRepresentationSummary(final String typeName, final int numBytes)
126-
{
127-
this.typeName = typeName;
128-
this.numBytes = numBytes;
129-
}
130-
}
131-
132-
private static Optional<FieldsRepresentationSummary> generateFieldsRepresentation(
118+
private static RustStruct generateMessageFieldStruct(
133119
final String messageTypeName,
134120
final MessageComponents components,
135121
final OutputManager outputManager) throws IOException
136122
{
137123
final List<NamedToken> namedFieldTokens = NamedToken.gatherNamedNonConstantFieldTokens(components.fields);
138124

139125
final String representationStruct = messageTypeName + "Fields";
140-
try (Writer writer = outputManager.createOutput(messageTypeName + " Fixed-size Fields"))
126+
final RustStruct struct = RustStruct.fromTokens(representationStruct, namedFieldTokens,
127+
EnumSet.of(RustStruct.Modifier.PACKED));
128+
129+
try (Writer writer = outputManager.createOutput(
130+
messageTypeName + " Fixed-size Fields (" + struct.sizeBytes() + " bytes)"))
141131
{
142-
final RustStruct struct = RustStruct.fromTokens(representationStruct, namedFieldTokens,
143-
EnumSet.of(RustStruct.Modifier.PACKED));
144132
struct.appendDefinitionTo(writer);
145133
writer.append("\n");
146-
147134
generateConstantAccessorImpl(writer, representationStruct, components.fields);
148135
}
149136

150-
// Compute the total static size in bytes of the fields representation
151-
int numBytes = 0;
152-
for (int i = 0, size = components.fields.size(); i < size;)
153-
{
154-
final Token fieldToken = components.fields.get(i);
155-
if (fieldToken.signal() == Signal.BEGIN_FIELD)
156-
{
157-
final int fieldEnd = i + fieldToken.componentTokenCount();
158-
if (!fieldToken.isConstantEncoding())
159-
{
160-
for (int j = i; j < fieldEnd; j++)
161-
{
162-
final Token t = components.fields.get(j);
163-
if (t.isConstantEncoding())
164-
{
165-
continue;
166-
}
167-
if (t.signal() == ENCODING || t.signal() == BEGIN_ENUM || t.signal() == BEGIN_SET)
168-
{
169-
numBytes += t.encodedLength();
170-
}
171-
}
172-
}
173-
i += fieldToken.componentTokenCount();
174-
}
175-
else
176-
{
177-
throw new IllegalStateException("field tokens must include bounding BEGIN_FIELD and END_FIELD tokens");
178-
}
179-
}
180-
181-
return Optional.of(new FieldsRepresentationSummary(representationStruct, numBytes));
137+
return struct;
182138
}
183139

184140
private static void generateBitSets(final Ir ir, final OutputManager outputManager) throws IOException
@@ -246,7 +202,7 @@ private static void generateMessageEncoder(
246202
final OutputManager outputManager,
247203
final MessageComponents components,
248204
final List<GroupTreeNode> groupTree,
249-
final Optional<FieldsRepresentationSummary> fieldsRepresentation,
205+
final RustStruct fieldStruct,
250206
final int headerSize)
251207
throws IOException
252208
{
@@ -256,7 +212,7 @@ private static void generateMessageEncoder(
256212
String topType = codecType.generateDoneCoderType(outputManager, messageTypeName);
257213
topType = generateTopVarDataCoders(messageTypeName, components.varData, outputManager, topType, codecType);
258214
topType = generateGroupsCoders(groupTree, outputManager, topType, codecType);
259-
topType = generateFixedFieldCoder(messageTypeName, outputManager, topType, fieldsRepresentation, codecType);
215+
topType = generateFixedFieldCoder(messageTypeName, outputManager, topType, fieldStruct, codecType);
260216
topType = codecType.generateMessageHeaderCoder(messageTypeName, outputManager, topType, headerSize);
261217
generateEntryPoint(messageTypeName, outputManager, topType, codecType);
262218
}
@@ -265,7 +221,7 @@ private static void generateMessageDecoder(
265221
final OutputManager outputManager,
266222
final MessageComponents components,
267223
final List<GroupTreeNode> groupTree,
268-
final Optional<FieldsRepresentationSummary> fieldsRepresentation,
224+
final RustStruct fieldStruct,
269225
final int headerSize)
270226
throws IOException
271227
{
@@ -275,7 +231,7 @@ private static void generateMessageDecoder(
275231
String topType = codecType.generateDoneCoderType(outputManager, messageTypeName);
276232
topType = generateTopVarDataCoders(messageTypeName, components.varData, outputManager, topType, codecType);
277233
topType = generateGroupsCoders(groupTree, outputManager, topType, codecType);
278-
topType = generateFixedFieldCoder(messageTypeName, outputManager, topType, fieldsRepresentation, codecType);
234+
topType = generateFixedFieldCoder(messageTypeName, outputManager, topType, fieldStruct, codecType);
279235
topType = codecType.generateMessageHeaderCoder(messageTypeName, outputManager, topType, headerSize);
280236
generateEntryPoint(messageTypeName, outputManager, topType, codecType);
281237
}
@@ -308,24 +264,17 @@ private static String generateFixedFieldCoder(
308264
final String messageTypeName,
309265
final OutputManager outputManager,
310266
final String topType,
311-
final Optional<FieldsRepresentationSummary> fieldsRepresentationOptional,
267+
final RustStruct fieldStruct,
312268
final RustCodecType codecType) throws IOException
313269
{
314-
if (!fieldsRepresentationOptional.isPresent())
315-
{
316-
return topType;
317-
}
318-
319-
final FieldsRepresentationSummary fieldsRepresentation = fieldsRepresentationOptional.get();
320270
try (Writer writer = outputManager.createOutput(messageTypeName + " Fixed fields " + codecType.name()))
321271
{
322-
final String representationStruct = fieldsRepresentation.typeName;
323-
final String decoderName = representationStruct + codecType.name();
272+
final String decoderName = fieldStruct.name + codecType.name();
324273
codecType.appendScratchWrappingStruct(writer, decoderName);
325274
appendImplWithLifetimeHeader(writer, decoderName);
326275
codecType.appendWrapMethod(writer, decoderName);
327276
codecType.appendDirectCodeMethods(writer, formatMethodName(messageTypeName) + "_fields",
328-
representationStruct, topType, fieldsRepresentation.numBytes);
277+
fieldStruct.name, topType, fieldStruct.sizeBytes());
329278
writer.append("}\n");
330279
// TODO - Move read position further if in-message blockLength exceeds fixed fields representation size
331280
// will require piping some data from the previously-read message header
@@ -1406,6 +1355,7 @@ private interface RustTypeDescriptor
14061355
{
14071356
String name();
14081357
String literalValue(String valueRep);
1358+
int sizeBytes();
14091359

14101360
default String defaultValue()
14111361
{
@@ -1435,15 +1385,22 @@ public String literalValue(String valueRep)
14351385
{
14361386
return getRustStaticArrayString(valueRep + componentType.name(), length);
14371387
}
1388+
1389+
@Override
1390+
public int sizeBytes() {
1391+
return componentType.sizeBytes() * length;
1392+
}
14381393
}
14391394

14401395
private static final class RustPrimitiveType implements RustTypeDescriptor
14411396
{
14421397
private final String name;
1398+
private final int sizeBytes;
14431399

1444-
private RustPrimitiveType(String name)
1400+
private RustPrimitiveType(String name, int sizeBytes)
14451401
{
14461402
this.name = name;
1403+
this.sizeBytes = sizeBytes;
14471404
}
14481405

14491406
@Override
@@ -1457,15 +1414,22 @@ public String literalValue(String valueRep)
14571414
{
14581415
return valueRep + name;
14591416
}
1417+
1418+
@Override
1419+
public int sizeBytes() {
1420+
return sizeBytes;
1421+
}
14601422
}
14611423

14621424
private static final class AnyRustType implements RustTypeDescriptor
14631425
{
14641426
private final String name;
1427+
private final int sizeBytes;
14651428

1466-
private AnyRustType(String name)
1429+
private AnyRustType(String name, int sizeBytes)
14671430
{
14681431
this.name = name;
1432+
this.sizeBytes = sizeBytes;
14691433
}
14701434

14711435
@Override
@@ -1480,17 +1444,22 @@ public String literalValue(String valueRep)
14801444
final String msg = String.format("Cannot produce a literal value %s of type %s!", valueRep, name);
14811445
throw new UnsupportedOperationException(msg);
14821446
}
1447+
1448+
@Override
1449+
public int sizeBytes() {
1450+
return sizeBytes;
1451+
}
14831452
}
14841453

14851454
private static final class RustTypes
14861455
{
1487-
static final RustTypeDescriptor u8 = new RustPrimitiveType("u8");
1456+
static final RustTypeDescriptor u8 = new RustPrimitiveType("u8", 1);
14881457

14891458
static RustTypeDescriptor ofPrimitiveToken(Token token)
14901459
{
14911460
final PrimitiveType primitiveType = token.encoding().primitiveType();
14921461
final String rustPrimitiveType = RustUtil.rustTypeName(primitiveType);
1493-
final RustPrimitiveType type = new RustPrimitiveType(rustPrimitiveType);
1462+
final RustPrimitiveType type = new RustPrimitiveType(rustPrimitiveType, primitiveType.size());
14941463
if (token.arrayLength() > 1) {
14951464
return new RustArrayType(type, token.arrayLength());
14961465
}
@@ -1499,7 +1468,7 @@ static RustTypeDescriptor ofPrimitiveToken(Token token)
14991468

15001469
static RustTypeDescriptor ofGeneratedToken(Token token)
15011470
{
1502-
return new AnyRustType(formatTypeName(token.applicableTypeName()));
1471+
return new AnyRustType(formatTypeName(token.applicableTypeName()), token.encodedLength());
15031472
}
15041473

15051474
static RustTypeDescriptor arrayOf(RustTypeDescriptor type, int len)
@@ -1526,6 +1495,10 @@ private RustStruct(String name, List<RustStructField> fields, EnumSet<Modifier>
15261495
this.modifiers = modifiers;
15271496
}
15281497

1498+
public int sizeBytes() {
1499+
return fields.stream().mapToInt(v -> v.type.sizeBytes()).sum();
1500+
}
1501+
15291502
static RustStruct fromHeader(HeaderStructure header)
15301503
{
15311504
final List<Token> tokens = header.tokens();

sbe-tool/src/test/java/uk/co/real_logic/sbe/generation/rust/RustGeneratorTest.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,22 @@ public void checkValidRustFromAllExampleSchema() throws IOException, Interrupted
290290
}
291291
}
292292

293+
@Test
294+
public void messageWithOffsets()
295+
{
296+
final String rust = fullGenerateForResource(outputManager, "composite-offsets-schema");
297+
final String expectedHeader =
298+
"pub struct MessageHeader {\n" +
299+
" pub block_length:u16,\n" +
300+
" template_id_padding:[u8;2],\n" +
301+
" pub template_id:u16,\n" +
302+
" schema_id_padding:[u8;2],\n" +
303+
" pub schema_id:u16,\n" +
304+
" pub version:u16,\n" +
305+
"}";
306+
assertContains(rust, expectedHeader);
307+
}
308+
293309
@Test
294310
public void constantEnumFields() throws IOException, InterruptedException
295311
{

0 commit comments

Comments
 (0)