fix trt weight bug ()

added splitter "__" between weight name and suffix number to avoid conflicts.
revert-21172-masked_select_api
Pei Yang 5 years ago committed by GitHub
parent 29b63f0aa1
commit 2e2f92a5b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -211,7 +211,8 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name,
const std::vector<float> &scale) {
static int name_suffix_counter = 0;
std::string name_suffix = std::to_string(name_suffix_counter);
std::string name_with_suffix = name + name_suffix;
std::string splitter = "__";
std::string name_with_suffix = name + splitter + name_suffix;
auto w_dims = weight_tensor->dims();
platform::CPUPlace cpu_place;
PADDLE_ENFORCE_EQ(

@ -159,7 +159,8 @@ class TensorRTEngine {
std::unique_ptr<framework::Tensor> w_tensor) {
static int suffix_counter = 0;
std::string suffix = std::to_string(suffix_counter);
weight_map[w_name + suffix] = std::move(w_tensor);
std::string splitter = "__";
weight_map[w_name + splitter + suffix] = std::move(w_tensor);
suffix_counter += 1;
}

Loading…
Cancel
Save