Add multi_gru op and tests (#28591)
* Add multi_gru op and tests * removed redundant disable_dygraph()musl/fix_failed_unittests_in_musl
parent
fe2cf39f77
commit
04bcc13fac
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,203 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fused/multi_gru_op.h"
|
||||
// #include "paddle/fluid/operators/fused/fusion_gru_op.h"
|
||||
#include <cstring> // for memcpy
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/operators/jit/kernels.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/operators/math/fc.h"
|
||||
#include "paddle/fluid/operators/math/sequence2batch.h"
|
||||
#ifdef PADDLE_WITH_MKLDNN
|
||||
#include "paddle/fluid/platform/mkldnn_helper.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
void MultiGRUOp::InferShape(framework::InferShapeContext* ctx) const {
|
||||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "multi_gru");
|
||||
OP_INOUT_CHECK(ctx->HasInputs("WeightX"), "Input", "WeightX", "multi_gru");
|
||||
OP_INOUT_CHECK(ctx->HasInputs("WeightH"), "Input", "WeightH", "multi_gru");
|
||||
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "multi_gru");
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1)
|
||||
? framework::flatten_to_2d(x_dims, 1)
|
||||
: x_dims;
|
||||
PADDLE_ENFORCE_EQ(
|
||||
x_mat_dims.size(), 2,
|
||||
platform::errors::InvalidArgument("The size of input X dims should be 2, "
|
||||
"or 3 with second dimension equal to "
|
||||
"1, but now Input X dim is:[%s] ",
|
||||
x_dims));
|
||||
|
||||
auto layers = ctx->Attrs().Get<int>("layers");
|
||||
auto wx_dims = ctx->GetInputsDim("WeightX");
|
||||
for (int i : {0, 1}) {
|
||||
PADDLE_ENFORCE_EQ(
|
||||
wx_dims[i][0], x_mat_dims[1],
|
||||
platform::errors::InvalidArgument(
|
||||
"The first dimension of flattened WeightX #%d"
|
||||
"should equal to last dimension of flattened input X, but "
|
||||
"received fattened WeightX dimension is:%d, flattened X dimension "
|
||||
"is:%d",
|
||||
i, wx_dims[i][0], x_mat_dims[1]));
|
||||
}
|
||||
|
||||
auto wh_dims = ctx->GetInputsDim("WeightH");
|
||||
for (int i = 0; i < 2 * layers; ++i) {
|
||||
PADDLE_ENFORCE_EQ(wx_dims[i].size(), 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"The rank of WeightX #%d should be 2, but received "
|
||||
"WeightX dim size is:%d, WeightX dim is:[%s] ",
|
||||
i, wx_dims[i].size(), wx_dims[i]));
|
||||
PADDLE_ENFORCE_EQ(wh_dims[i].size(), 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"The rank of WeightH #%d should be 2, but received "
|
||||
"WeightH dim size is:%d, WeightH dim is:[%s] ",
|
||||
i, wh_dims[i].size(), wh_dims[i]));
|
||||
int frame_size = wh_dims[i][0];
|
||||
PADDLE_ENFORCE_EQ(
|
||||
wh_dims[i][1], 3 * frame_size,
|
||||
platform::errors::InvalidArgument(
|
||||
"The second dimension of WeightH #%d "
|
||||
"should equal to 3 * frame_size, but received WeightH's "
|
||||
"second dimension is: %d, frame size is:%d",
|
||||
i, wh_dims[1], frame_size));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
wx_dims[i][1], 3 * frame_size,
|
||||
platform::errors::InvalidArgument(
|
||||
"The second dimension of WeightX #%d "
|
||||
"should equal to 3 * frame_size, but received WeightX's "
|
||||
"second dimension is: %d, frame size is:%d",
|
||||
i, wx_dims[i][1], frame_size));
|
||||
}
|
||||
|
||||
if (ctx->HasInputs("Bias")) {
|
||||
auto b_dims = ctx->GetInputsDim("Bias");
|
||||
for (int i = 0; i < 2 * layers; ++i) {
|
||||
int frame_size = wh_dims[i][0];
|
||||
PADDLE_ENFORCE_EQ(b_dims[i].size(), 2,
|
||||
platform::errors::InvalidArgument(
|
||||
"The rank of Bias #%d should be 2, but received "
|
||||
"Bias rank is:%d, Bias dim is:[%s]",
|
||||
i, b_dims[i].size(), b_dims[i]));
|
||||
PADDLE_ENFORCE_EQ(b_dims[i][0], 1,
|
||||
platform::errors::InvalidArgument(
|
||||
"The first dimension of Bias #%d should be 1, but "
|
||||
"received Bias first dim is:%d, Bias dim is:[%s]",
|
||||
i, b_dims[i][0], b_dims[i]));
|
||||
PADDLE_ENFORCE_EQ(
|
||||
b_dims[i][1], frame_size * 3,
|
||||
platform::errors::InvalidArgument(
|
||||
"The shape of Bias #%d must be [1, frame_size * 3], but "
|
||||
"received bias dim is:[%s], frame size is:%d",
|
||||
i, b_dims[i], frame_size));
|
||||
}
|
||||
}
|
||||
|
||||
int last_frame_size = wh_dims.back()[0];
|
||||
framework::DDim out_dims({x_mat_dims[0], 2 * last_frame_size});
|
||||
ctx->SetOutputDim("Hidden", out_dims);
|
||||
ctx->ShareLoD("X", "Hidden");
|
||||
}
|
||||
|
||||
framework::OpKernelType MultiGRUOp::GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const {
|
||||
framework::LibraryType library = framework::LibraryType::kMKLDNN;
|
||||
framework::DataLayout layout = framework::DataLayout::kMKLDNN;
|
||||
|
||||
return framework::OpKernelType(
|
||||
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
|
||||
library);
|
||||
}
|
||||
|
||||
void MultiGRUOpMaker::Make() {
|
||||
AddInput("X",
|
||||
"(LoDTensor) the input is an LodTensor, which support "
|
||||
"variable-time length input sequence. The underlying tensor in "
|
||||
"this LoDTensor is a matrix with shape (T X M), where T is the "
|
||||
"total time steps in this mini-batch, M is the dim size of x.");
|
||||
AddInput("WeightX",
|
||||
"(MultiTensor) The FC weight with shape (M x 3D),"
|
||||
"where M is the dim size of x, D is the hidden size. ")
|
||||
.AsDuplicable();
|
||||
AddInput("WeightH",
|
||||
"(MultiTensor) (D x 3D) Same as GRUOp, where D is the hidden size. "
|
||||
"This weight is not exactly D x 3D as: {W_update, W_reset, W_state}"
|
||||
"Acutally they are D x 2D and D x D two part weights."
|
||||
"{W_update, W_reset; W_state}"
|
||||
"{D x (D + D); D x D}")
|
||||
.AsDuplicable();
|
||||
AddInput("Bias",
|
||||
"(MultiTensor, optional) (1 x 3D)."
|
||||
"Almost same as GRUOp."
|
||||
"Note: if have FC bias it should be added on this bias.")
|
||||
.AsDuplicable()
|
||||
.AsDispensable();
|
||||
AddInput(
|
||||
"Scale_weights",
|
||||
"(MultiTensor, optional) Scale_weights to be used for int8 weights data."
|
||||
"Only used with MKL-DNN INT8.")
|
||||
.AsDuplicable()
|
||||
.AsDispensable();
|
||||
AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp");
|
||||
AddAttr<std::string>("activation",
|
||||
"(string, default tanh) "
|
||||
"The activation type used for output candidate {h}_t.")
|
||||
.SetDefault("tanh");
|
||||
AddAttr<std::string>(
|
||||
"gate_activation",
|
||||
"(string, default sigmoid) "
|
||||
"The activation type used in update gate and reset gate.")
|
||||
.SetDefault("sigmoid");
|
||||
AddAttr<int>("layers",
|
||||
"(int, default: 1) "
|
||||
"Number of stacked GRU layers.")
|
||||
.SetDefault(1);
|
||||
AddAttr<bool>("origin_mode",
|
||||
"bool"
|
||||
"use origin mode in article https://arxiv.org/abs/1412.3555")
|
||||
.SetDefault(false);
|
||||
AddAttr<std::string>(
|
||||
"mkldnn_data_type",
|
||||
"(string, default \"float32\"). Data type of mkldnn kernel")
|
||||
.SetDefault("float32")
|
||||
.InEnum({"float32", "int8", "bfloat16"});
|
||||
AddAttr<float>("Scale_data",
|
||||
"Scales to be used for int8 input/output data."
|
||||
"Only used with MKL-DNN INT8.")
|
||||
.SetDefault({1.f});
|
||||
AddAttr<float>("Shift_data",
|
||||
"Shifts to be used for int8 input/output data."
|
||||
"Only used with MKL-DNN INT8.")
|
||||
.SetDefault({0.f});
|
||||
AddAttr<bool>("force_fp32_output",
|
||||
"(bool, default: false) Force INT8 kernel output FP32, only "
|
||||
"used in MKL-DNN INT8")
|
||||
.SetDefault(false);
|
||||
AddComment(R"DOC(
|
||||
The Fusion complete GRU Operator.
|
||||
This operator fuse the fully-connected operator into GRU,
|
||||
more details can refer to GRU op.
|
||||
)DOC");
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(multi_gru, ops::MultiGRUOp, ops::MultiGRUOpMaker);
|
@ -0,0 +1,43 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::LoDTensor;
|
||||
using framework::Tensor;
|
||||
using framework::ExecutionContext;
|
||||
|
||||
class MultiGRUOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override;
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const ExecutionContext& ctx) const override;
|
||||
};
|
||||
|
||||
class MultiGRUOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override;
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,248 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from paddle.fluid.tests.unittests.op_test import OpTest
|
||||
from paddle.fluid.tests.unittests.test_fusion_gru_op import fusion_gru, ACTIVATION
|
||||
from paddle.fluid.dygraph.base import disable_dygraph
|
||||
|
||||
|
||||
def multi_gru(
|
||||
x, # T x M
|
||||
lod, # 1 x N
|
||||
h0, # N x D
|
||||
wx, # M x 3D
|
||||
wh, # D x 3D
|
||||
bias, # 1 x 3D
|
||||
origin_mode,
|
||||
layers):
|
||||
act_state = ACTIVATION['tanh']
|
||||
act_gate = ACTIVATION['sigmoid']
|
||||
input = x
|
||||
for i in range(0, layers * 2, 2):
|
||||
_, _, _, gru1_out = fusion_gru(input, lod, h0[i], wx[i], wh[i], bias[i],
|
||||
False, origin_mode, act_state, act_gate)
|
||||
_, _, _, gru2_out = fusion_gru(input, lod, h0[i + 1], wx[i + 1],
|
||||
wh[i + 1], bias[i + 1], True,
|
||||
origin_mode, act_state, act_gate)
|
||||
input = np.concatenate((gru1_out, gru2_out), axis=1)
|
||||
return input
|
||||
|
||||
|
||||
class TestMultiGruMkldnnOp(OpTest):
|
||||
def set_confs(self):
|
||||
pass
|
||||
|
||||
def set_dtype(self):
|
||||
pass
|
||||
|
||||
def set_force_fp32_output(self):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "multi_gru"
|
||||
self.lod = [[2, 4, 3]]
|
||||
self.ICs = [3]
|
||||
self.OCs = [5]
|
||||
self.with_bias = True
|
||||
self.layers = 1
|
||||
self.origin_mode = False
|
||||
self._cpu_only = True
|
||||
self.error_margin = 1e-5
|
||||
self.set_confs()
|
||||
self.dtype = "float32"
|
||||
self.set_dtype()
|
||||
self.force_fp32_output = False
|
||||
self.set_force_fp32_output()
|
||||
|
||||
is_int8 = self.dtype == 'int8'
|
||||
scale_data = 63
|
||||
shift_data = 64
|
||||
|
||||
T = sum(self.lod[0])
|
||||
N = len(self.lod[0])
|
||||
|
||||
self.inputs = {}
|
||||
if is_int8:
|
||||
x_f32 = np.random.rand(T, self.ICs[0]).astype('float32') * 2 - 1
|
||||
x_u8 = np.rint(x_f32 * scale_data + shift_data).astype(np.uint8)
|
||||
self.inputs['X'] = (x_u8, self.lod)
|
||||
|
||||
else:
|
||||
x_f32 = np.random.rand(T, self.ICs[0]).astype('float32')
|
||||
self.inputs['X'] = (x_f32, self.lod)
|
||||
|
||||
wx = []
|
||||
wh = []
|
||||
bias = []
|
||||
h0 = []
|
||||
|
||||
for layer in range(self.layers):
|
||||
IC = self.ICs[layer]
|
||||
OC = self.OCs[layer]
|
||||
for j in range(2):
|
||||
wx.append(np.random.rand(IC, 3 * OC).astype('float32'))
|
||||
wh.append(np.random.rand(OC, 3 * OC).astype('float32'))
|
||||
bias.append(
|
||||
np.random.rand(1, 3 * OC).astype('float32')
|
||||
if self.with_bias else np.zeros(
|
||||
(1, 3 * OC), dtype='float32'))
|
||||
h0.append(np.zeros((N, OC), dtype='float32'))
|
||||
|
||||
self.inputs['WeightX'] = [('wx' + str(i), wx[i])
|
||||
for i in range(self.layers * 2)]
|
||||
self.inputs['WeightH'] = [('wh' + str(i), wh[i])
|
||||
for i in range(self.layers * 2)]
|
||||
if self.with_bias:
|
||||
self.inputs['Bias'] = [('b' + str(i), bias[i])
|
||||
for i in range(self.layers * 2)]
|
||||
|
||||
if is_int8:
|
||||
s8_max = 127.0
|
||||
scale_weights = []
|
||||
for layer in range(self.layers):
|
||||
OC = self.OCs[layer]
|
||||
for j in range(2):
|
||||
scale_ur = s8_max / np.max(np.abs(
|
||||
np.concatenate(
|
||||
[
|
||||
wx[2 * layer + j][:, :2 * OC], wh[2 * layer + j]
|
||||
.flatten()[:2 * OC * OC].reshape(OC, 2 * OC)
|
||||
],
|
||||
axis=0)),
|
||||
axis=0)
|
||||
scale_o = s8_max / np.max(np.abs(
|
||||
np.concatenate(
|
||||
[
|
||||
wx[2 * layer + j][:, 2 * OC:], wh[2 * layer + j]
|
||||
.flatten()[2 * OC * OC:].reshape(OC, OC)
|
||||
],
|
||||
axis=0)),
|
||||
axis=0)
|
||||
|
||||
scale_weights.append(
|
||||
np.concatenate([scale_ur, scale_o]).astype('float32'))
|
||||
self.inputs['Scale_weights'] = [('w_scale' + str(i),
|
||||
scale_weights[i])
|
||||
for i in range(self.layers * 2)]
|
||||
self.error_margin = 1e-1 if self.force_fp32_output else 1
|
||||
|
||||
hidden_f32 = multi_gru(x_f32, self.lod, h0, wx, wh, bias,
|
||||
self.origin_mode, self.layers)
|
||||
|
||||
if self.dtype == 'float32' or self.force_fp32_output:
|
||||
self.outputs = {'Hidden': (hidden_f32, self.lod)}
|
||||
else:
|
||||
hidden_u8 = np.rint(hidden_f32 * scale_data + shift_data).astype(
|
||||
np.uint8)
|
||||
self.outputs = {'Hidden': (hidden_u8, self.lod)}
|
||||
|
||||
self.attrs = {
|
||||
'activation': 'tanh',
|
||||
'gate_activation': 'sigmoid',
|
||||
'layers': self.layers,
|
||||
'origin_mode': self.origin_mode,
|
||||
'use_mkldnn': True,
|
||||
}
|
||||
|
||||
if is_int8:
|
||||
self.attrs['force_fp32_output'] = self.force_fp32_output
|
||||
self.attrs['Scale_data'] = scale_data
|
||||
self.attrs['Shift_data'] = shift_data
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output(check_dygraph=False, atol=self.error_margin)
|
||||
|
||||
|
||||
class TestMultiGruMkldnnOpNoBias(TestMultiGruMkldnnOp):
|
||||
def set_confs(self):
|
||||
self.with_bias = False
|
||||
|
||||
|
||||
class TestMultiGruMkldnnOpLayers2(TestMultiGruMkldnnOp):
|
||||
def set_confs(self):
|
||||
self.layers = 2
|
||||
self.ICs = [2, 6]
|
||||
self.OCs = [3, 8]
|
||||
|
||||
|
||||
class TestMultiGruMkldnnOpLayers3(TestMultiGruMkldnnOp):
|
||||
def set_confs(self):
|
||||
self.layers = 3
|
||||
self.ICs = [2, 6, 12]
|
||||
self.OCs = [3, 6, 14]
|
||||
|
||||
|
||||
class TestMultiGruMkldnnOpOriginMode(TestMultiGruMkldnnOp):
|
||||
def set_confs(self):
|
||||
self.origin_mode = True
|
||||
|
||||
|
||||
class TestMultiGruMkldnnInt8Op(TestMultiGruMkldnnOp):
|
||||
def set_dtype(self):
|
||||
self.dtype = 'int8'
|
||||
|
||||
|
||||
class TestMultiGruMkldnnInt8OpForceFP32Output(TestMultiGruMkldnnInt8Op):
|
||||
def set_force_fp32_output(self):
|
||||
self.force_fp32_output = True
|
||||
|
||||
|
||||
class TestMultiGruMkldnnInt8OpNoBias(TestMultiGruMkldnnOpNoBias):
|
||||
def set_dtype(self):
|
||||
self.dtype = 'int8'
|
||||
|
||||
|
||||
class TestMultiGruMkldnnInt8OpNoBiasForceFP32Output(
|
||||
TestMultiGruMkldnnInt8OpNoBias):
|
||||
def set_force_fp32_output(self):
|
||||
self.force_fp32_output = True
|
||||
|
||||
|
||||
class TestMultiGruMkldnnInt8OpLayers2(TestMultiGruMkldnnOpLayers2):
|
||||
def set_dtype(self):
|
||||
self.dtype = 'int8'
|
||||
|
||||
|
||||
class TestMultiGruMkldnnInt8OpLayers2ForceFP32Output(
|
||||
TestMultiGruMkldnnInt8OpLayers2):
|
||||
def set_force_fp32_output(self):
|
||||
self.force_fp32_output = True
|
||||
|
||||
|
||||
class TestMultiGruMkldnnInt8OpLayers3(TestMultiGruMkldnnOpLayers3):
|
||||
def set_dtype(self):
|
||||
self.dtype = 'int8'
|
||||
|
||||
|
||||
class TestMultiGruMkldnnInt8OpLayers3ForceFP32Output(
|
||||
TestMultiGruMkldnnInt8OpLayers3):
|
||||
def set_force_fp32_output(self):
|
||||
self.force_fp32_output = True
|
||||
|
||||
|
||||
class TestMultiGruMkldnnInt8OpOriginMode(TestMultiGruMkldnnOpOriginMode):
|
||||
def set_dtype(self):
|
||||
self.dtype = 'int8'
|
||||
|
||||
|
||||
class TestMultiGruMkldnnInt8OpOriginModeForceFP32Output(
|
||||
TestMultiGruMkldnnInt8OpOriginMode):
|
||||
def set_force_fp32_output(self):
|
||||
self.force_fp32_output = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue