errmsg refine of trt plugin (#27309)

revert-27520-disable_pr
Pei Yang 5 years ago committed by GitHub
parent 905e2346ac
commit fda54c0212
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,8 +25,10 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name,
const char* plugin_type;
DeserializeValue(&serial_data, &serial_length, &plugin_type);
PADDLE_ENFORCE(Has(plugin_type),
"trt plugin type %s does not exists, check it.", plugin_type);
PADDLE_ENFORCE_EQ(
Has(plugin_type), true,
platform::errors::NotFound(
"trt plugin type %s does not exists, check it.", plugin_type));
auto plugin = plugin_registry_[plugin_type](serial_data, serial_length);
owned_plugins_.emplace_back(plugin);

@ -103,7 +103,12 @@ struct Serializer<std::vector<T>,
DeserializeValue(buffer, buffer_size, &size);
value->resize(size);
size_t nbyte = value->size() * sizeof(T);
PADDLE_ENFORCE_GE(*buffer_size, nbyte);
PADDLE_ENFORCE_GE(
*buffer_size, nbyte,
platform::errors::InvalidArgument("Expect buffer size >= value size in "
"trt plugin deserialization, but got "
"buffer size = %d, value size = %d.",
*buffer_size, nbyte));
std::memcpy(value->data(), *buffer, nbyte);
reinterpret_cast<char const*&>(*buffer) += nbyte;
*buffer_size -= nbyte;

Loading…
Cancel
Save