11#include " core/conversion/conversion.h"
2+ #include < ATen/core/operator_name.h>
23#include < torch/torch.h>
34#include < sstream>
5+ #include " c10/util/intrusive_ptr.h"
46#include " core/conversion/conversionctx/ConversionCtx.h"
7+ #include " core/conversion/converters/converter_util.h"
58#include " core/conversion/converters/converters.h"
69#include " core/conversion/evaluators/evaluators.h"
10+ #include " core/conversion/tensorcontainer/TensorContainer.h"
711#include " core/conversion/var/Var.h"
812#include " core/util/prelude.h"
9-
10- #include " c10/util/intrusive_ptr.h"
11- #include " core/conversion/converters/converter_util.h"
12- #include " core/conversion/tensorcontainer/TensorContainer.h"
1313#include " core/util/trt_util.h"
1414
1515namespace torch_tensorrt {
@@ -105,7 +105,8 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
105105 // Node input has not been converted yet or is a prim op
106106 TORCHTRT_THROW_ERROR (
107107 " Unable to retrieve all node inputs for node: "
108- << util::node_info (n) << " (ctx.AddLayer)\n Specifically failed to retrieve value for input: " << *input_node);
108+ << util::node_info (n) << " (ctx.AddLayer)\n Specifically failed to retrieve value for input: %"
109+ << input->debugName ());
109110 }
110111 }
111112
@@ -426,10 +427,18 @@ void ConvertBlockToNetDef(
426427 << " and node outputs size: " << n->outputs ().size () << " must match." );
427428 for (size_t i = 0 ; i < eval_list->elements ().size (); i++) {
428429 auto eval_output = eval_list.get ()->elements ()[i];
429- LOG_DEBUG (
430- ctx->logger ,
431- " Found the evaluated value(s) to be " << eval_output << " for node: " << util::node_info (n));
432- ctx->AssociateValueAndIValue (n->output (i), eval_output);
430+ if (eval_output.isCustomClass ()) {
431+ auto container = eval_output.toCustomClass <TensorContainer>();
432+ auto tensor = container->tensor ();
433+ LOG_DEBUG (
434+ ctx->logger , " Found the evaluated value(s) to be an ITensor of shape: " << tensor->getDimensions ());
435+ ctx->AssociateValueAndTensor (n->output (i), tensor);
436+ } else {
437+ LOG_DEBUG (
438+ ctx->logger ,
439+ " Found the evaluated value(s) to be " << eval_output << " for node: " << util::node_info (n));
440+ ctx->AssociateValueAndIValue (n->output (i), eval_output);
441+ }
433442 }
434443 } else {
435444 TORCHTRT_THROW_ERROR (" Unsupported return type for evaluated node" );
@@ -487,15 +496,23 @@ std::string ConvertBlockToEngine(
487496std::unordered_map<c10::OperatorName, std::string> GetUnsupportedOpsInBlock (const torch::jit::Block* b) {
488497 std::unordered_map<c10::OperatorName, std::string> unsupported_ops;
489498 for (const auto n : b->nodes ()) {
490- if (n->kind () != torch::jit::prim::Loop && n->kind () != torch::jit::prim::If && !OpSupported (n)) {
491- auto schema = n->maybeSchema ();
492- TORCHTRT_CHECK (
493- schema,
494- " Unable to get schema for Node " << util::node_info (n) << " (conversion.VerifyCoverterSupportForBlock)" );
495- std::stringstream ss;
496- ss << *schema;
497- unsupported_ops[schema->operator_name ()] = ss.str ();
499+ auto schema = n->maybeSchema ();
500+ // Some ops like torch::jit::prim::Loop, torch::jit::prim::If, torch::jit::prim::DictConstruct don't have a schema
501+ // but they are supported. torch::jit::prim::DictConstruct is supported via fallback only
502+ if (!OpSupported (n)) {
503+ if (schema) {
504+ std::stringstream ss;
505+ ss << *schema;
506+ unsupported_ops[schema->operator_name ()] = ss.str ();
507+ } else {
508+ std::stringstream ss;
509+ ss << util::node_info (n);
510+ // operator.overload is a filler name just to call the constructor.
511+ c10::OperatorName op (ss.str (), " operator.overload" );
512+ unsupported_ops[op] = ss.str ();
513+ }
498514 }
515+
499516 for (const auto sub_b : n->blocks ()) {
500517 auto sub_b_unsupported_ops = GetUnsupportedOpsInBlock (sub_b);
501518 unsupported_ops.insert (sub_b_unsupported_ops.begin (), sub_b_unsupported_ops.end ());
@@ -530,22 +547,25 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
530547
531548bool VerifyConverterSupportForBlock (const torch::jit::Block* b, bool suppress_errors) {
532549 auto unsupported_ops = GetUnsupportedOpsInBlock (b);
533-
534550 if (unsupported_ops.size () != 0 ) {
535551 std::stringstream unsupported_msg;
536552 unsupported_msg
537- << " Method requested cannot be compiled by Torch-TensorRT.TorchScript.\n Unsupported operators listed below:"
553+ << " Method requested cannot be compiled end to end by Torch-TensorRT.TorchScript.\n Unsupported operators listed below:"
538554 << std::endl;
539555 for (auto s : unsupported_ops) {
540556 unsupported_msg << " - " << s.second << std::endl;
541557 }
542- unsupported_msg << " You can either implement converters for these ops in your application or request implementation"
543- << std::endl;
544- unsupported_msg << " https://www.github.com/nvidia/Torch-TensorRT/issues" << std::endl;
545- unsupported_msg << std::endl << " In Module:" << std::endl;
546558
547559 if (!suppress_errors) {
560+ unsupported_msg
561+ << " You can either implement converters for these ops in your application or request implementation"
562+ << std::endl;
563+ unsupported_msg << " https://www.github.com/nvidia/Torch-TensorRT/issues" << std::endl;
564+ unsupported_msg << std::endl << " In Module:" << std::endl;
565+
548566 LOG_ERROR (unsupported_msg.str ());
567+ } else {
568+ LOG_INFO (unsupported_msg.str ());
549569 }
550570
551571 std::unordered_map<std::string, std::unordered_set<std::string>> unsupported_node_locations;
@@ -571,8 +591,13 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_er
571591 for (const auto & str : type.second ) {
572592 traceback << str;
573593 }
594+
574595 auto tb_str = traceback.str ();
575- LOG_ERROR (tb_str);
596+ if (!suppress_errors) {
597+ LOG_ERROR (tb_str);
598+ } else {
599+ LOG_DEBUG (tb_str);
600+ }
576601 }
577602
578603 return false ;
0 commit comments