Skip to content

Commit d6645d3

Browse files
committed
added bitwise_not converter
1 parent 1778011 commit d6645d3

File tree

4 files changed

+89
-16
lines changed

4 files changed

+89
-16
lines changed

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ cc_library(
5454
"NodeConverterRegistry.cpp",
5555
"impl/activation.cpp",
5656
"impl/batch_norm.cpp",
57+
"impl/bitwise.cpp",
5758
"impl/cast.cpp",
5859
"impl/concat.cpp",
5960
"impl/constant.cpp",
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#include "core/conversion/converters/converters.h"
2+
#include "core/util/prelude.h"
3+
4+
#include <torch/torch.h>
5+
6+
namespace torch_tensorrt {
7+
namespace core {
8+
namespace conversion {
9+
namespace converters {
10+
namespace impl {
11+
12+
13+
auto bitwisenot TORCHTRT_UNUSED =
14+
RegisterNodeConversionPatterns()
15+
.pattern({"aten::bitwise_not(Tensor self) -> Tensor",
16+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
17+
auto in = args[0].ITensorOrFreeze(ctx);
18+
nvinfer1::ILayer* out;
19+
20+
if(in->getType() == nvinfer1::DataType::kINT32) {
21+
// Integer case
22+
auto one = torch::tensor({1}, util::TRTDataTypeToScalarType(in->getType()));
23+
auto one_const = tensor_to_const(ctx, one);
24+
auto neg = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kNEG);
25+
TORCHTRT_CHECK(neg, "Unable to create neg unary layer from node: " << *n);
26+
out = add_elementwise(
27+
ctx, nvinfer1::ElementWiseOperation::kSUB, neg->getOutput(0),
28+
one_const, util::node_info(n));
29+
TORCHTRT_CHECK(out, "Unable to create sub layer from node: " << *n);
30+
} else if(in->getType() == nvinfer1::DataType::kBOOL) {
31+
// Boolean case
32+
out = ctx->net->addUnary(*in, nvinfer1::UnaryOperation::kNOT);
33+
TORCHTRT_CHECK(out, "Unable to create logical not layer from node: " << *n);
34+
} else {
35+
LOG_ERROR("Input tensor must be 32 bit integer or boolean");
36+
return false;
37+
}
38+
39+
out->setName(util::node_info(n).c_str());
40+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0],
41+
out->getOutput(0));
42+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
43+
44+
return true;
45+
}});
46+
47+
48+
} // namespace impl
49+
} // namespace converters
50+
} // namespace conversion
51+
} // namespace core
52+
} // namespace torch_tensorrt

tests/core/conversion/converters/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ converter_test(
1515
name = "test_batch_norm",
1616
)
1717

18+
converter_test(
19+
name = "test_bitwise",
20+
)
21+
1822
converter_test(
1923
name = "test_instance_norm",
2024
)
@@ -136,6 +140,7 @@ test_suite(
136140
tests = [
137141
":test_activation",
138142
":test_batch_norm",
143+
":test_bitwise",
139144
":test_instance_norm",
140145
":test_cast",
141146
":test_clone",
Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,39 @@
1+
#include <string>
12
#include "core/compiler.h"
23
#include "gtest/gtest.h"
34
#include "tests/util/util.h"
45
#include "torch/csrc/jit/ir/irparser.h"
56

6-
TEST(Converters, ATenBitwiseNotConvertsCorrectly) {
7-
const auto graph = R"IR(
8-
graph(%0 : Tensor):
9-
%3 : Tensor = aten::bitwise_not(%0)
10-
return (%3))IR";
7+
std::string gen_test_graph() {
8+
return R"IR(
9+
graph(%0: Tensor):
10+
%3 : Tensor = aten::bitwise_not(%0)
11+
return (%3))IR";
12+
}
1113

12-
auto g = std::make_shared<torch::jit::Graph>();
13-
torch::jit::parseIR(graph, g.get());
14+
#define test_bitwise_not(dtype) \
15+
TEST(Converters, ATenBitwiseNot##dtype##ConvertsCorrectly) { \
16+
const auto graph = gen_test_graph(); \
17+
\
18+
auto g = std::make_shared<torch::jit::Graph>(); \
19+
torch::jit::parseIR(graph, g.get()); \
20+
\
21+
at::Tensor in; \
22+
if (strcmp(#dtype, "Integer") == 0) \
23+
in = at::randint(-128, 128, {10}, {at::kCUDA}).toType(at::kInt); \
24+
if (strcmp(#dtype, "Boolean") == 0) \
25+
in = at::randint(0, 1, {10}, {at::kCUDA}).toType(at::kBool); \
26+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); \
27+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); \
28+
\
29+
in = at::clone(in); \
30+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); \
31+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); \
32+
\
33+
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0]));\
34+
}
1435

15-
auto in = at::randint(-128, 128, {10}, {at::kCUDA});
16-
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
17-
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});
36+
test_bitwise_not(Integer);
37+
test_bitwise_not(Boolean);
1838

19-
in = at::clone(in);
20-
params = torch_tensorrt::core::ir::get_static_params(g->inputs, {});
21-
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});
22-
23-
ASSERT_TRUE(torch_tensorrt::tests::util::exactlyEqual(jit_results[0], trt_results[0]);
24-
}
39+
#undef test_bitwise_not

0 commit comments

Comments
 (0)