|
|
|
@ -26,13 +26,37 @@ class BatchNormOpConverter : public OpConverter {
|
|
|
|
|
VLOG(3) << "convert a fluid batch norm op to tensorrt batch_norm";
|
|
|
|
|
|
|
|
|
|
framework::OpDesc op_desc(op, nullptr);
|
|
|
|
|
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
|
|
|
|
|
PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1); // Bias is a weight
|
|
|
|
|
PADDLE_ENFORCE_EQ(op_desc.Input("Mean").size(), 1); // Mean is a weight
|
|
|
|
|
PADDLE_ENFORCE_EQ(op_desc.Input("Scale").size(), 1); // Scale is a weight
|
|
|
|
|
PADDLE_ENFORCE_EQ(op_desc.Input("Variance").size(),
|
|
|
|
|
1); // Variance is a weight
|
|
|
|
|
PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 1);
|
|
|
|
|
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Invalid input X's size of batch_norm TRT converter. "
|
|
|
|
|
"Expected 1, received %d.",
|
|
|
|
|
op_desc.Input("X").size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(op_desc.Input("Bias").size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Invalid input Bias's size of batch_norm TRT "
|
|
|
|
|
"converter. Expected 1, received %d.",
|
|
|
|
|
op_desc.Input("Bias").size())); // Bias is a weight
|
|
|
|
|
PADDLE_ENFORCE_EQ(op_desc.Input("Mean").size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Invalid input Mean's size of batch_norm TRT "
|
|
|
|
|
"converter. Expected 1, received %d.",
|
|
|
|
|
op_desc.Input("Mean").size())); // Mean is a weight
|
|
|
|
|
PADDLE_ENFORCE_EQ(op_desc.Input("Scale").size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Invalid input Scale's size of batch_norm TRT "
|
|
|
|
|
"converter. Expected 1, received %d.",
|
|
|
|
|
op_desc.Input("Scale").size())); // Scale is a weight
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
op_desc.Input("Variance").size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Invalid input Variance's size of batch_norm TRT converter. "
|
|
|
|
|
"Expected 1, received %d.",
|
|
|
|
|
op_desc.Input("Variance").size())); // Variance is a weight
|
|
|
|
|
PADDLE_ENFORCE_EQ(op_desc.Output("Y").size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Invalid output Y's size of batch_norm TRT "
|
|
|
|
|
"converter. Expected 1, received %d.",
|
|
|
|
|
op_desc.Output("Y").size()));
|
|
|
|
|
|
|
|
|
|
auto* X = engine_->GetITensor(op_desc.Input("X").front());
|
|
|
|
|
// Declare weights
|
|
|
|
@ -42,10 +66,22 @@ class BatchNormOpConverter : public OpConverter {
|
|
|
|
|
auto* Variance_v = scope.FindVar(op_desc.Input("Variance").front());
|
|
|
|
|
const float eps = boost::get<float>(op_desc.GetAttr("epsilon"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(Bias_v);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(Mean_v);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(Scale_v);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(Variance_v);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
Bias_v,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Variable of Bias of batch_norm TRT converter is not found."));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
Mean_v,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Variable of Mean of batch_norm TRT converter is not found."));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
Scale_v,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Variable of Scale of batch_norm TRT converter is not found."));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
Variance_v,
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Variable of Variance of batch_norm TRT converter is not found."));
|
|
|
|
|
|
|
|
|
|
// get tensor
|
|
|
|
|
auto* Bias_t = Bias_v->GetMutable<framework::LoDTensor>();
|
|
|
|
|