|
|
|
@ -49,14 +49,14 @@ class MultiheadMatMulOpConverter : public OpConverter {
|
|
|
|
|
memcpy(weight_data_tmp.data(), weight_data,
|
|
|
|
|
weight_t->numel() * sizeof(float));
|
|
|
|
|
|
|
|
|
|
// (hidden, 3, all_head_size)
|
|
|
|
|
// (hidden_in, 3, hidden_out)
|
|
|
|
|
auto weight_dims = weight_t->dims();
|
|
|
|
|
|
|
|
|
|
int hidden = weight_dims[0]; // channels_in
|
|
|
|
|
int hidden_in = weight_dims[0]; // channels_in
|
|
|
|
|
int three = weight_dims[1]; // channels_out
|
|
|
|
|
int all_head_size = weight_dims[2]; // channels_out
|
|
|
|
|
int m = hidden;
|
|
|
|
|
int n = three * all_head_size;
|
|
|
|
|
int hidden_out = weight_dims[2]; // channels_out
|
|
|
|
|
int m = hidden_in;
|
|
|
|
|
int n = three * hidden_out;
|
|
|
|
|
auto tranpose_weight = [](const float* src, float* dst, int m, int n) {
|
|
|
|
|
for (int i = 0; i < m; i++) {
|
|
|
|
|
for (int j = 0; j < n; j++) {
|
|
|
|
@ -72,21 +72,23 @@ class MultiheadMatMulOpConverter : public OpConverter {
|
|
|
|
|
|
|
|
|
|
if (engine_->with_dynamic_shape()) {
|
|
|
|
|
if (engine_->use_oss()) {
|
|
|
|
|
int head_size = hidden / head_number;
|
|
|
|
|
// [3, Nout, Hout, Nin, Hin] -> [Nout, 3, Hout, Nin, Hin]
|
|
|
|
|
auto transpose_weight_v2 = [](const float* src, float* dst, int N,
|
|
|
|
|
int H) {
|
|
|
|
|
const int HNH = H * N * H;
|
|
|
|
|
for (int i = 0; i < 3; ++i) {
|
|
|
|
|
for (int n = 0; n < N; ++n) {
|
|
|
|
|
for (int hnh = 0; hnh < HNH; ++hnh) {
|
|
|
|
|
dst[n * 3 * HNH + i * HNH + hnh] =
|
|
|
|
|
src[i * N * HNH + n * HNH + hnh];
|
|
|
|
|
int head_size = hidden_out / head_number;
|
|
|
|
|
// [3, head_number, head_size, hidden_in] -> [head_number, 3, head_size,
|
|
|
|
|
// hidden_in]
|
|
|
|
|
auto transpose_weight_v2 = [](const float* src, float* dst, int three,
|
|
|
|
|
int head_number, int head_size,
|
|
|
|
|
int hidden_in) {
|
|
|
|
|
const int HH = head_size * hidden_in;
|
|
|
|
|
for (int i = 0; i < three; ++i) {
|
|
|
|
|
for (int n = 0; n < head_number; ++n) {
|
|
|
|
|
for (int hh = 0; hh < HH; ++hh) {
|
|
|
|
|
dst[n * three * HH + i * HH + hh] =
|
|
|
|
|
src[i * head_number * HH + n * HH + hh];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
// [3, N, H] -> [N, 3, H]
|
|
|
|
|
// [3, head_number, head_size] -> [head_number, 3, head_size]
|
|
|
|
|
auto transpose_bias_v2 = [](const float* src, float* dst, int N,
|
|
|
|
|
int H) {
|
|
|
|
|
for (int i = 0; i < 3; ++i) {
|
|
|
|
@ -99,8 +101,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
|
|
|
|
|
};
|
|
|
|
|
memcpy(weight_data_tmp.data(), weight_data,
|
|
|
|
|
weight_t->numel() * sizeof(float));
|
|
|
|
|
transpose_weight_v2(weight_data_tmp.data(), weight_data, head_number,
|
|
|
|
|
head_size);
|
|
|
|
|
transpose_weight_v2(weight_data_tmp.data(), weight_data, three,
|
|
|
|
|
head_number, head_size, hidden_in);
|
|
|
|
|
nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT,
|
|
|
|
|
static_cast<void*>(weight_data),
|
|
|
|
|
static_cast<int32_t>(weight_t->numel())};
|
|
|
|
@ -130,7 +132,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
|
|
|
|
|
int var_seqlen = 1;
|
|
|
|
|
const std::vector<nvinfer1::PluginField> fields{
|
|
|
|
|
{"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1},
|
|
|
|
|
{"hidden_size", &hidden, nvinfer1::PluginFieldType::kINT32, 1},
|
|
|
|
|
{"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, 1},
|
|
|
|
|
{"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1},
|
|
|
|
|
{"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1},
|
|
|
|
|
{"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, 1},
|
|
|
|
@ -186,7 +188,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
|
|
|
|
|
n, weight.get(), bias.get());
|
|
|
|
|
auto* fc_out = fc_layer->getOutput(0);
|
|
|
|
|
// add qkv to context
|
|
|
|
|
int head_size = all_head_size / head_number;
|
|
|
|
|
int head_size = hidden_out / head_number;
|
|
|
|
|
float scale = BOOST_GET_CONST(float, op_desc.GetAttr("alpha"));
|
|
|
|
|
|
|
|
|
|
std::vector<nvinfer1::ITensor*> plugin_inputs;
|
|
|
|
@ -195,7 +197,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
|
|
|
|
|
bool with_fp16 =
|
|
|
|
|
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
|
|
|
|
|
plugin::DynamicPluginTensorRT* plugin =
|
|
|
|
|
new plugin::QkvToContextPluginDynamic(hidden, head_number,
|
|
|
|
|
new plugin::QkvToContextPluginDynamic(hidden_in, head_number,
|
|
|
|
|
head_size, scale, with_fp16);
|
|
|
|
|
layer = engine_->AddPluginV2(plugin_inputs.data(), 2, plugin);
|
|
|
|
|
}
|
|
|
|
|