Merge pull request #5872 from qingqing01/op_debug

Fix lstm_op and gru_op in debug mode.
release/0.11.0
qingqing01 8 years ago committed by GitHub
commit 52007ea662
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -297,7 +297,25 @@ void set_constant_with_place<platform::GPUPlace>(
template struct RowwiseAdd<platform::GPUPlace, float>; template struct RowwiseAdd<platform::GPUPlace, float>;
template struct RowwiseAdd<platform::GPUPlace, double>; template struct RowwiseAdd<platform::GPUPlace, double>;
template struct ColwiseSum<platform::GPUPlace, float>; template struct ColwiseSum<platform::GPUPlace, float>;
template struct ColwiseSum<platform::GPUPlace, double>; // template struct ColwiseSum<platform::GPUPlace, double>;
// The ColwiseSum<platform::GPUPlace, double> failed in debug mode,
// and only failed for this case. So reimplemented it.
template <>
void ColwiseSum<platform::GPUPlace, double>::operator()(
const platform::DeviceContext& context, const framework::Tensor& input,
framework::Tensor* vector) {
auto in_dims = input.dims();
auto size = input.numel() / in_dims[0];
PADDLE_ENFORCE_EQ(vector->numel(), size);
framework::Tensor one;
one.mutable_data<double>({in_dims[0]}, context.GetPlace());
SetConstant<platform::GPUPlace, double> set;
set(context, &one, static_cast<double>(1.0));
gemv<platform::GPUPlace, double>(context, true, static_cast<int>(in_dims[0]),
static_cast<int>(in_dims[1]), 1.0,
input.data<double>(), one.data<double>(),
0.0, vector->data<double>());
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators

Loading…
Cancel
Save