|
|
|
@ -63,7 +63,7 @@ void FusionSeqExpandConcatFCOp::InferShape(
|
|
|
|
|
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const {
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
|
|
|
|
|
framework::ToDataType(ctx.MultiInput<LoDTensor>("X")[0]->type()),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -113,7 +113,7 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
auto* w = ctx.Input<Tensor>("FCWeight");
|
|
|
|
|
auto* b = ctx.Input<Tensor>("FCBias");
|
|
|
|
|
auto* out = ctx.Output<LoDTensor>("Out");
|
|
|
|
|
auto* fc_out = ctx.Output<Tensor>("FCOUT");
|
|
|
|
|
auto* fc_out = ctx.Output<Tensor>("FCOut");
|
|
|
|
|
|
|
|
|
|
auto* ref_in = ins[0];
|
|
|
|
|
auto ref_lod = ref_in->lod();
|
|
|
|
@ -164,7 +164,7 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
math::FCCompute<DeviceContext, T>(blas, total_T, D, M0, ref_in_data, w_data,
|
|
|
|
|
out_data, b ? b->data<T>() : NULL);
|
|
|
|
|
w_data = w_data + M0 * D;
|
|
|
|
|
// first one use write on
|
|
|
|
|
// first write on
|
|
|
|
|
blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data);
|
|
|
|
|
w_data = w_data + M1 * D;
|
|
|
|
|
for (size_t i = 2; i < ins.size(); ++i) {
|
|
|
|
@ -175,16 +175,15 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
K, w_data, D, static_cast<T>(1), fc_out_data, D);
|
|
|
|
|
w_data = w_data + K * D;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
T* cur_out_data = out_data;
|
|
|
|
|
for (int i = 0; i < N; ++i) {
|
|
|
|
|
int seq_len = ref_lod[0][i + 1] - ref_lod[0][i];
|
|
|
|
|
T* src = fc_out_data + i * D;
|
|
|
|
|
for (int step = 0; step < seq_len; ++step) {
|
|
|
|
|
blas.VADD(D, out_data, src, out_data);
|
|
|
|
|
out_data = out_data + D;
|
|
|
|
|
blas.VADD(D, cur_out_data, src, cur_out_data);
|
|
|
|
|
cur_out_data = cur_out_data + D;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fc_act(total_T * D, out_data, out_data);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|