|
|
|
@ -90,12 +90,6 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic<T>::getOutputDimensions(
|
|
|
|
|
"so the index should be zero,"
|
|
|
|
|
"but it's (%d)",
|
|
|
|
|
output_index));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
nb_inputs, 3,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The Input of the EmbEltwiseLayernorm should be 3, but we found "
|
|
|
|
|
"it has (%d) inputs",
|
|
|
|
|
nb_inputs));
|
|
|
|
|
nvinfer1::DimsExprs ret;
|
|
|
|
|
ret.nbDims = 5;
|
|
|
|
|
ret.d[0] = inputs[0].d[0];
|
|
|
|
@ -113,13 +107,18 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
in_out, platform::errors::InvalidArgument(
|
|
|
|
|
"The input of swish plugin shoule not be nullptr."));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(nb_outputs, 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The EmbEltwiseLayerNorm's output should be one"
|
|
|
|
|
"but it's (%d) outputs.",
|
|
|
|
|
nb_outputs));
|
|
|
|
|
PADDLE_ENFORCE_LT(
|
|
|
|
|
pos, nb_inputs + nb_outputs,
|
|
|
|
|
platform::errors::InvalidArgument("The pos(%d) should be less than the "
|
|
|
|
|
"num(%d) of the input and the output.",
|
|
|
|
|
pos, nb_inputs + nb_outputs));
|
|
|
|
|
(in_out && pos < (nb_inputs + nb_outputs));
|
|
|
|
|
|
|
|
|
|
int all_nums = nb_inputs + nb_outputs;
|
|
|
|
|
|
|
|
|
|
const nvinfer1::PluginTensorDesc &desc = in_out[pos];
|
|
|
|
|
if (desc.format != nvinfer1::TensorFormat::kLINEAR) {
|
|
|
|
@ -131,18 +130,19 @@ bool EmbEltwiseLayernormPluginDynamic<T>::supportsFormatCombination(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const nvinfer1::PluginTensorDesc &prev = in_out[pos - 1];
|
|
|
|
|
if (pos == 1 || pos == 2) {
|
|
|
|
|
if (pos < all_nums - 1) {
|
|
|
|
|
return desc.type == nvinfer1::DataType::kINT32 &&
|
|
|
|
|
desc.dims.d[0] == prev.dims.d[0] && desc.dims.d[1] == prev.dims.d[1];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (pos == 3) {
|
|
|
|
|
if (pos == all_nums - 1) {
|
|
|
|
|
if (sizeof(T) == sizeof(float)) {
|
|
|
|
|
return desc.type == nvinfer1::DataType::kFLOAT;
|
|
|
|
|
} else {
|
|
|
|
|
return desc.type == nvinfer1::DataType::kHALF;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|