|
|
|
@ -99,8 +99,9 @@ class ElementwiseWeightOpConverter : public OpConverter {
|
|
|
|
|
regist_eltwise_weight(scale_mode);
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"TensorRT Dynamic shape unsupported weight shape for Elementwise "
|
|
|
|
|
"op!"));
|
|
|
|
|
"The size of input bias's dims is %d, but TensorRT dynamic shape "
|
|
|
|
|
"only support size = 1 for Elementwise op!",
|
|
|
|
|
Y_t->dims().size()));
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -132,12 +133,24 @@ class ElementwiseWeightOpConverter : public OpConverter {
|
|
|
|
|
if (scale_mode == nvinfer1::ScaleMode::kCHANNEL) {
|
|
|
|
|
for (size_t i = 1; i < no_batch_dims.size(); i++) {
|
|
|
|
|
if (dims_y[i] != 1)
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"TensorRT unsupported weight shape for Elementwise op!");
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"The bias's %d dim is %d, but TensorRT dynamic shape only "
|
|
|
|
|
"support it equals to 1 for Elementwise op!",
|
|
|
|
|
i, dims_y[i]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("TensorRT unsupported weight Shape for Elementwise op!");
|
|
|
|
|
if (dims_y.size() >= 1) {
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"The size of bias's dims is %d and bias's size is %d. TensorRT "
|
|
|
|
|
"doesn't support this shape for Elementwise op!",
|
|
|
|
|
dims_y.size(), dims_y[0]));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(platform::errors::InvalidArgument(
|
|
|
|
|
"The size of bias's dims is %d. TensorRT doesn't support "
|
|
|
|
|
"this shape for Elementwise op!",
|
|
|
|
|
dims_y.size()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
regist_eltwise_weight(scale_mode);
|
|
|
|
|
}
|
|
|
|
@ -152,7 +165,11 @@ class ElementwiseTensorOpConverter : public OpConverter {
|
|
|
|
|
void operator()(const framework::proto::OpDesc& op,
|
|
|
|
|
const framework::Scope& scope, bool test_mode) override {
|
|
|
|
|
auto op_pair = ops.find(op_type_);
|
|
|
|
|
PADDLE_ENFORCE(op_pair != ops.end(), "Wrong elementwise op type!");
|
|
|
|
|
PADDLE_ENFORCE_NE(op_pair, ops.end(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Elementwise op's type(%s) is not supported. Please "
|
|
|
|
|
"check if the op_type is correct.",
|
|
|
|
|
op_type_));
|
|
|
|
|
|
|
|
|
|
// Here the two nullptr looks strange, that's because the
|
|
|
|
|
// framework::OpDesc's constructor is strange.
|
|
|
|
|