|
|
@ -89,7 +89,7 @@ __global__ void LSTMUnitGradientKernel(const int nthreads, const int dim,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T, typename AttrType = T>
|
|
|
|
template <typename T>
|
|
|
|
class LstmUnitOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
class LstmUnitOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
@ -101,7 +101,7 @@ class LstmUnitOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
auto* c_tensor = ctx.Output<framework::Tensor>("C");
|
|
|
|
auto* c_tensor = ctx.Output<framework::Tensor>("C");
|
|
|
|
auto* h_tensor = ctx.Output<framework::Tensor>("H");
|
|
|
|
auto* h_tensor = ctx.Output<framework::Tensor>("H");
|
|
|
|
|
|
|
|
|
|
|
|
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
|
|
|
|
auto forget_bias = static_cast<T>(ctx.Attr<float>("forget_bias"));
|
|
|
|
|
|
|
|
|
|
|
|
int b_size = c_tensor->dims()[0];
|
|
|
|
int b_size = c_tensor->dims()[0];
|
|
|
|
int D = c_tensor->dims()[1];
|
|
|
|
int D = c_tensor->dims()[1];
|
|
|
@ -120,7 +120,7 @@ class LstmUnitOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T, typename AttrType = T>
|
|
|
|
template <typename T>
|
|
|
|
class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
@ -153,7 +153,7 @@ class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
int N = c_tensor->dims()[0];
|
|
|
|
int N = c_tensor->dims()[0];
|
|
|
|
int D = c_tensor->dims()[1];
|
|
|
|
int D = c_tensor->dims()[1];
|
|
|
|
|
|
|
|
|
|
|
|
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
|
|
|
|
auto forget_bias = static_cast<T>(ctx.Attr<float>("forget_bias"));
|
|
|
|
|
|
|
|
|
|
|
|
int block = 512;
|
|
|
|
int block = 512;
|
|
|
|
int n = N * D;
|
|
|
|
int n = N * D;
|
|
|
@ -169,5 +169,7 @@ class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> {
|
|
|
|
} // namespace paddle
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
REGISTER_OP_GPU_KERNEL(lstm_unit, ops::LstmUnitOpCUDAKernel<float>);
|
|
|
|
REGISTER_OP_GPU_KERNEL(lstm_unit, ops::LstmUnitOpCUDAKernel<float>,
|
|
|
|
REGISTER_OP_GPU_KERNEL(lstm_unit_grad, ops::LstmUnitGradOpCUDAKernel<float>);
|
|
|
|
ops::LstmUnitOpCUDAKernel<double>);
|
|
|
|
|
|
|
|
REGISTER_OP_GPU_KERNEL(lstm_unit_grad, ops::LstmUnitGradOpCUDAKernel<float>,
|
|
|
|
|
|
|
|
ops::LstmUnitGradOpCUDAKernel<double>);
|
|
|
|