Skip to content

Commit 828d120

Browse files
committed
tests: Update fp16 test for new API
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
1 parent c275dd0 commit 828d120

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/accuracy/test_fp16_accuracy.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,10 @@ TEST_P(AccuracyTests, FP16AccuracyIsClose) {
2525
}
2626
torch::Tensor jit_accuracy = (jit_correct / jit_total) * 100;
2727

28-
std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
29-
auto compile_spec = torch_tensorrt::ts::CompileSpec({input_shape});
28+
std::vector<int64_t> input_shape = {32, 3, 32, 32};
29+
auto input = torch_tensorrt::Input(input_shape);
30+
input.dtype = torch::kF16;
31+
auto compile_spec = torch_tensorrt::ts::CompileSpec({input});
3032
compile_spec.enabled_precisions.insert(torch::kF16);
3133

3234
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_spec);

0 commit comments

Comments
 (0)