|
|
|
@ -50,7 +50,7 @@ template <typename T>
|
|
|
|
|
class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
|
|
|
|
|
bool is_float_type = std::is_same<T, float>::value;
|
|
|
|
|
const bool is_float_type = std::is_same<T, float>::value;
|
|
|
|
|
PADDLE_ENFORCE(is_float_type, "MKLDNN LRN must use float data.");
|
|
|
|
|
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
|
|
|
|
|
"MKLDNN LRN must use CPUPlace.");
|
|
|
|
@ -132,8 +132,8 @@ template <typename T>
|
|
|
|
|
class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(std::is_same<T, float>::value,
|
|
|
|
|
"MKLDNN LRN must use float data.");
|
|
|
|
|
const bool is_float_type = std::is_same<T, float>::value;
|
|
|
|
|
PADDLE_ENFORCE(is_float_type, "MKLDNN LRN must use float data.");
|
|
|
|
|
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
|
|
|
|
|
"MKLDNN LRN must use CPUPlace.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|