2525// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626
2727#include < stdint.h>
28+
29+ #include < cstdint>
2830#include < exception>
31+
2932#include " libtorch_utils.h"
3033#include " triton/backend/backend_common.h"
3134#include " triton/backend/backend_input_collector.h"
@@ -502,6 +505,10 @@ class ModelInstanceState : public BackendModelInstance {
502505 triton::common::TritonJson::Value& sequence_batching,
503506 const std::string& control_kind, bool required, bool * have_control);
504507 TRITONSERVER_Error* ValidateInputs (const size_t expected_input_cnt);
508+ void AddInputToMap (
509+ NamingConvention naming_convention,
510+ const std::vector<std::string> allowed_inputs, const std::string& io_name,
511+ const uint32_t index);
505512 TRITONSERVER_Error* ValidateOutputs ();
506513 void Execute (
507514 std::vector<TRITONBACKEND_Response*>* responses,
@@ -538,6 +545,7 @@ class ModelInstanceState : public BackendModelInstance {
538545 // Map from configuration name for an input to the index of
539546 // that input in the model.
540547 std::unordered_map<std::string, int > input_index_map_;
548+ uint32_t batch_input_count_ = 0 ;
541549
542550 // Map from configuration name for an output to the index of
543551 // that output in the model.
@@ -607,6 +615,12 @@ ModelInstanceState::ModelInstanceState(
607615 if (model_state->ModelConfig ().Find (" input" , &inputs)) {
608616 expected_input_cnt = inputs.ArraySize ();
609617 }
618+
619+ triton::common::TritonJson::Value config_batch_inputs;
620+ if (model_state->ModelConfig ().Find (" batch_input" , &config_batch_inputs)) {
621+ batch_input_count_ = config_batch_inputs.ArraySize ();
622+ expected_input_cnt += batch_input_count_;
623+ }
610624 }
611625
612626 // If this is a sequence model then make sure that the required
@@ -757,6 +771,43 @@ ModelInstanceState::ValidateTypedSequenceControl(
757771 return nullptr ; // success
758772}
759773
774+ void
775+ ModelInstanceState::AddInputToMap (
776+ NamingConvention naming_convention,
777+ const std::vector<std::string> allowed_inputs, const std::string& io_name,
778+ const uint32_t index)
779+ {
780+ std::string deliminator = " __" ;
781+
782+ if (is_dict_input_) {
783+ // If dictionary, index is irrelevant but we use the map to store the
784+ // input names since they are the keys for the dictionary
785+ input_index_map_[io_name] = index;
786+ } else {
787+ switch (naming_convention) {
788+ case NamingConvention::FORWARD_ARGUMENT: {
789+ auto itr =
790+ std::find (allowed_inputs.begin (), allowed_inputs.end (), io_name);
791+ if (itr != allowed_inputs.end ()) {
792+ input_index_map_[io_name] =
793+ std::distance (allowed_inputs.begin (), itr);
794+ }
795+ return ;
796+ }
797+ case NamingConvention::NAMED_INDEX: {
798+ int start_pos = io_name.find (deliminator);
799+ int ip_index = std::atoi (io_name.substr (start_pos + 2 ).c_str ());
800+ input_index_map_[io_name] = ip_index;
801+ return ;
802+ }
803+ case NamingConvention::STRICT_CONFIG_ORDERING: {
804+ input_index_map_[io_name] = index;
805+ return ;
806+ }
807+ }
808+ }
809+ }
810+
760811TRITONSERVER_Error*
761812ModelInstanceState::ValidateInputs (const size_t expected_input_cnt)
762813{
@@ -822,8 +873,6 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
822873
823874 triton::common::TritonJson::Value ios;
824875 RETURN_IF_ERROR (model_state_->ModelConfig ().MemberAsArray (" input" , &ios));
825- std::string deliminator = " __" ;
826- int ip_index = 0 ;
827876
828877 if (ios.ArraySize () == 0 ) {
829878 return TRITONSERVER_ErrorNew (
@@ -842,34 +891,7 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
842891 // Validate name
843892 std::string io_name;
844893 RETURN_IF_ERROR (io.MemberAsString (" name" , &io_name));
845- if (is_dict_input_) {
846- // If dictionary, index is irrelevant but we use the map to store the
847- // input names since they are the keys for the dictionary
848- input_index_map_[io_name] = i;
849- } else {
850- switch (naming_convention) {
851- case NamingConvention::FORWARD_ARGUMENT: {
852- auto itr =
853- std::find (allowed_inputs.begin (), allowed_inputs.end (), io_name);
854- if (itr != allowed_inputs.end ()) {
855- input_index_map_[io_name] =
856- std::distance (allowed_inputs.begin (), itr);
857- }
858- break ;
859- }
860- case NamingConvention::NAMED_INDEX: {
861- int start_pos = io_name.find (deliminator);
862- ip_index = std::atoi (io_name.substr (start_pos + 2 ).c_str ());
863- input_index_map_[io_name] = ip_index;
864- break ;
865- }
866- case NamingConvention::STRICT_CONFIG_ORDERING: {
867- input_index_map_[io_name] = i;
868- break ;
869- }
870- }
871- }
872-
894+ AddInputToMap (naming_convention, allowed_inputs, io_name, i);
873895 // Validate data type
874896 std::string io_dtype;
875897 RETURN_IF_ERROR (io.MemberAsString (" data_type" , &io_dtype));
@@ -906,6 +928,18 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
906928 }
907929 }
908930
931+ triton::common::TritonJson::Value batch_inputs;
932+ RETURN_IF_ERROR (
933+ model_state_->ModelConfig ().MemberAsArray (" batch_input" , &batch_inputs));
934+ size_t i = 0 ;
935+ for (const auto & batch_input : StateForModel ()->BatchInputs ()) {
936+ for (const auto & input_name : batch_input.TargetNames ()) {
937+ AddInputToMap (
938+ naming_convention, allowed_inputs, input_name, i + ios.ArraySize ());
939+ i++;
940+ }
941+ }
942+
909943 return nullptr ; // success
910944}
911945
@@ -1725,7 +1759,8 @@ ModelInstanceState::SetInputTensors(
17251759 // request as the representative for the input tensors.
17261760 uint32_t input_count;
17271761 RETURN_IF_ERROR (TRITONBACKEND_RequestInputCount (requests[0 ], &input_count));
1728- input_tensors->resize (input_count);
1762+
1763+ input_tensors->resize (input_count + batch_input_count_);
17291764 for (uint32_t input_idx = 0 ; input_idx < input_count; input_idx++) {
17301765 TRITONBACKEND_Input* input;
17311766 RETURN_IF_ERROR (
@@ -1828,6 +1863,36 @@ ModelInstanceState::SetInputTensors(
18281863 }
18291864 }
18301865
1866+ for (const auto & batch_input : StateForModel ()->BatchInputs ()) {
1867+ std::vector<int64_t > shape;
1868+ collector->BatchInputShape (batch_input, &shape);
1869+
1870+ for (const auto & input_name : batch_input.TargetNames ()) {
1871+ input_names->emplace_back (input_name.c_str ());
1872+
1873+ const char * dst_buffer;
1874+ size_t dst_buffer_byte_size;
1875+ TRITONSERVER_MemoryType dst_memory_type;
1876+ int64_t dst_memory_type_id;
1877+
1878+ // Batch inputs are always created on CPU
1879+ RESPOND_ALL_AND_SET_NULL_IF_ERROR (
1880+ (*responses), responses->size (),
1881+ collector->ProcessBatchInput (
1882+ batch_input, nullptr , 0 , {{TRITONSERVER_MEMORY_CPU, 0 }},
1883+ &dst_buffer, &dst_buffer_byte_size, &dst_memory_type,
1884+ &dst_memory_type_id));
1885+
1886+ const auto torch_dtype =
1887+ ConvertDataTypeToTorchType (batch_input.DataType ());
1888+
1889+ torch::Tensor input_tensor = torch::from_blob (
1890+ const_cast <char *>(dst_buffer), shape,
1891+ updated_options.dtype (torch_dtype.second ));
1892+ (*input_tensors)[input_index_map_[input_name]] = input_tensor;
1893+ }
1894+ }
1895+
18311896 // Finalize...
18321897 *cuda_copy |= collector->Finalize ();
18331898
0 commit comments