@@ -505,9 +505,8 @@ class ModelInstanceState : public BackendModelInstance {
505505 const std::string& control_kind, bool required, bool * have_control);
506506 TRITONSERVER_Error* ValidateInputs (const size_t expected_input_cnt);
507507 void AddInputToMap (
508- NamingConvention naming_convention,
509- const std::vector<std::string> allowed_inputs,
510- const std::string &io_name,
508+ NamingConvention naming_convention,
509+ const std::vector<std::string> allowed_inputs, const std::string& io_name,
511510 const uint32_t index);
512511 TRITONSERVER_Error* ValidateOutputs ();
513512 void Execute (
@@ -771,7 +770,12 @@ ModelInstanceState::ValidateTypedSequenceControl(
771770 return nullptr ; // success
772771}
773772
774- void ModelInstanceState::AddInputToMap (NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
773+ void
774+ ModelInstanceState::AddInputToMap (
775+ NamingConvention naming_convention,
776+ const std::vector<std::string> allowed_inputs, const std::string& io_name,
777+ const uint32_t index)
778+ {
775779 std::string deliminator = " __" ;
776780
777781 if (is_dict_input_) {
@@ -924,11 +928,13 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
924928 }
925929
926930 triton::common::TritonJson::Value batch_inputs;
927- RETURN_IF_ERROR (model_state_->ModelConfig ().MemberAsArray (" batch_input" , &batch_inputs));
931+ RETURN_IF_ERROR (
932+ model_state_->ModelConfig ().MemberAsArray (" batch_input" , &batch_inputs));
928933 size_t i = 0 ;
929934 for (const auto & batch_input : StateForModel ()->BatchInputs ()) {
930935 for (const auto & input_name : batch_input.TargetNames ()) {
931- AddInputToMap (naming_convention, allowed_inputs, input_name, i + ios.ArraySize ());
936+ AddInputToMap (
937+ naming_convention, allowed_inputs, input_name, i + ios.ArraySize ());
932938 i++;
933939 }
934940 }
@@ -1754,6 +1760,16 @@ ModelInstanceState::SetInputTensors(
17541760 RETURN_IF_ERROR (TRITONBACKEND_RequestInputCount (requests[0 ], &input_count));
17551761
17561762 input_tensors->resize (input_count + batch_input_count_);
1763+
1764+ // The inputs must be in contiguous CPU/GPU memory.
1765+ std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
1766+ if (device_.is_cpu ()) {
1767+ alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0 },
1768+ {TRITONSERVER_MEMORY_CPU, 0 }};
1769+ } else {
1770+ alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index ()}};
1771+ }
1772+
17571773 for (uint32_t input_idx = 0 ; input_idx < input_count; input_idx++) {
17581774 TRITONBACKEND_Input* input;
17591775 RETURN_IF_ERROR (
@@ -1797,15 +1813,6 @@ ModelInstanceState::SetInputTensors(
17971813 }
17981814 }
17991815
1800- // The input must be in contiguous CPU/GPU memory.
1801- std::vector<std::pair<TRITONSERVER_MemoryType, int64_t >> alloc_perference;
1802- if (device_.is_cpu ()) {
1803- alloc_perference = {{TRITONSERVER_MEMORY_CPU_PINNED, 0 },
1804- {TRITONSERVER_MEMORY_CPU, 0 }};
1805- } else {
1806- alloc_perference = {{TRITONSERVER_MEMORY_GPU, device_.index ()}};
1807- }
1808-
18091816 const char * input_buffer;
18101817 size_t batchn_byte_size;
18111818 TRITONSERVER_MemoryType memory_type;
@@ -1868,15 +1875,14 @@ ModelInstanceState::SetInputTensors(
18681875 TRITONSERVER_MemoryType dst_memory_type;
18691876 int64_t dst_memory_type_id;
18701877
1871- // Batch inputs are always created on CPU
18721878 RESPOND_ALL_AND_SET_NULL_IF_ERROR (
18731879 (*responses), responses->size (),
18741880 collector->ProcessBatchInput (
1875- batch_input, nullptr , 0 , {{TRITONSERVER_MEMORY_CPU, 0 }},
1876- &dst_buffer, &dst_buffer_byte_size, &dst_memory_type,
1877- &dst_memory_type_id));
1881+ batch_input, nullptr , 0 , alloc_perference, &dst_buffer,
1882+ &dst_buffer_byte_size, &dst_memory_type, &dst_memory_type_id));
18781883
1879- const auto torch_dtype = ConvertDataTypeToTorchType (batch_input.DataType ());
1884+ const auto torch_dtype =
1885+ ConvertDataTypeToTorchType (batch_input.DataType ());
18801886 torch::TensorOptions options{torch_dtype.second };
18811887 auto updated_options = options.device (torch::kCPU );
18821888
0 commit comments