Delete Ref & VectorRef and add GetDataSafely (#22997)

* delete invalid check inferface Ref & VectorRef, test=develop

* fix vector ref delete error, test=develop

* try the new check inferface, test=develop

* change all related code with new check macro, test=develop

* remove static assert, test=develop

* polish detail, test=develop

* skip coverage problem, test=develop

* add new check macro, test=develop
revert-23830-2.0-beta
Chen Weihang 5 years ago committed by GitHub
parent 4c675a450f
commit 16315d3d9e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -31,8 +31,8 @@ class CudnnActivationKernel
ExtractActivationTensor(context, X, Out); ExtractActivationTensor(context, X, Out);
ActivationDescriptor act_desc; ActivationDescriptor act_desc;
TensorDescriptor x_desc, out_desc; TensorDescriptor x_desc, out_desc;
x_desc.set(detail::Ref(X)); x_desc.set(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivation"));
out_desc.set(detail::Ref(Out)); out_desc.set(GET_DATA_SAFELY(Out, "Output", "Out", "CudnnActivation");
} }
}; };

@ -37,7 +37,7 @@ struct CudnnActivationFunctor {
act_desc.set(mode_, coef_); act_desc.set(mode_, coef_);
TensorDescriptor x_desc, out_desc; TensorDescriptor x_desc, out_desc;
x_desc.set(x); x_desc.set(x);
out_desc.set(detail::Ref(out)); out_desc.set(GET_DATA_SAFELY(out, "Output", "Out", "CudnnActivation"));
PADDLE_ENFORCE(platform::dynload::cudnnActivationForward( PADDLE_ENFORCE(platform::dynload::cudnnActivationForward(
ctx_.cudnn_handle(), act_desc.desc(), ctx_.cudnn_handle(), act_desc.desc(),
platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(), platform::CudnnDataType<T>::kOne(), x_desc.desc(), x.data<T>(),
@ -63,7 +63,7 @@ struct CudnnActivationGradFunctor {
x_desc.set(x); x_desc.set(x);
out_desc.set(out); out_desc.set(out);
dout_desc.set(dout); dout_desc.set(dout);
dx_desc.set(detail::Ref(dx)); dx_desc.set(GET_DATA_SAFELY(dx, "Output", "X@GRAD", "CudnnActivationGrad"));
PADDLE_ENFORCE(platform::dynload::cudnnActivationBackward( PADDLE_ENFORCE(platform::dynload::cudnnActivationBackward(
ctx_.cudnn_handle(), act_desc.desc(), ctx_.cudnn_handle(), act_desc.desc(),
platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(), platform::CudnnDataType<T>::kOne(), out_desc.desc(), out.data<T>(),
@ -141,7 +141,7 @@ class CudnnActivationKernel
Out->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<CUDADeviceContext>(); auto& dev_ctx = context.template device_context<CUDADeviceContext>();
Functor functor(dev_ctx); Functor functor(dev_ctx);
functor(detail::Ref(X), Out); functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivation"), Out);
} }
}; };
@ -161,7 +161,10 @@ class CudnnActivationGradKernel
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<CUDADeviceContext>(); auto& dev_ctx = context.template device_context<CUDADeviceContext>();
Functor functor(dev_ctx); Functor functor(dev_ctx);
functor(detail::Ref(X), detail::Ref(Out), detail::Ref(dOut), dX); functor(GET_DATA_SAFELY(X, "Input", "X", "CudnnActivationGrad"),
GET_DATA_SAFELY(Out, "Input", "Out", "CudnnActivationGrad"),
GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "CudnnActivationGrad"),
dX);
} }
}; };

@ -26,7 +26,6 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
@ -156,8 +155,10 @@ class ActivationKernel
ExtractActivationTensor(context, &X, &Out); ExtractActivationTensor(context, &X, &Out);
Out->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X)); auto x = framework::EigenVector<T>::Flatten(
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); GET_DATA_SAFELY(X, "Input", "X", "Activation"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "Activation"));
auto* place = auto* place =
context.template device_context<DeviceContext>().eigen_device(); context.template device_context<DeviceContext>().eigen_device();
Functor functor; Functor functor;
@ -182,10 +183,14 @@ class ActivationGradKernel
ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut, ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
&dX); &dX);
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut)); auto dout = framework::EigenVector<T>::Flatten(
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad"));
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX)); auto out = framework::EigenVector<T>::Flatten(
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X)); GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad"));
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad"));
auto* place = auto* place =
context.template device_context<DeviceContext>().eigen_device(); context.template device_context<DeviceContext>().eigen_device();
Functor functor; Functor functor;
@ -1285,10 +1290,13 @@ struct ReluGradGradFunctor : public BaseActivationFunctor<T> {
framework::Tensor* ddOut, framework::Tensor* dOut, framework::Tensor* ddOut, framework::Tensor* dOut,
framework::Tensor* dX) const { framework::Tensor* dX) const {
auto* d = dev.eigen_device(); auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX)); auto ddx = framework::EigenVector<T>::Flatten(
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); GET_DATA_SAFELY(ddX, "Input", "DDX", "ReluGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "ReluGradGrad"));
if (ddOut) { if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut)); auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ReluGradGrad"));
ddout.device(*d) = ddx * (out > static_cast<T>(0)).template cast<T>(); ddout.device(*d) = ddx * (out > static_cast<T>(0)).template cast<T>();
} }
} }
@ -1308,9 +1316,12 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
framework::Tensor* dX) const { framework::Tensor* dX) const {
if (ddOut) { if (ddOut) {
auto* d = dev.eigen_device(); auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX)); auto ddx = framework::EigenVector<T>::Flatten(
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); GET_DATA_SAFELY(ddX, "Input", "DDX", "LeakyReluGradGrad"));
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut)); auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "LeakyReluGradGrad"));
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DOut", "LeakyReluGradGrad"));
ddout.device(*d) = ddx * ddout.device(*d) = ddx *
((out > static_cast<T>(0)).template cast<T>() + ((out > static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) * static_cast<T>(alpha) *
@ -1332,18 +1343,23 @@ struct ELUGradGradFunctor : public BaseActivationFunctor<T> {
const framework::Tensor* ddX, framework::Tensor* ddOut, const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const { const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device(); auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX)); auto ddx = framework::EigenVector<T>::Flatten(
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X)); GET_DATA_SAFELY(ddX, "Input", "DDX", "ELUGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "ELUGradGrad"));
if (dX) { if (dX) {
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX)); auto dx = framework::EigenVector<T>::Flatten(
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut)); GET_DATA_SAFELY(dX, "Output", "DX", "ELUGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad"));
dx.device(*d) = ddx * dout * static_cast<T>(alpha) * x.exp() * dx.device(*d) = ddx * dout * static_cast<T>(alpha) * x.exp() *
(x < static_cast<T>(0)).template cast<T>(); (x < static_cast<T>(0)).template cast<T>();
} }
if (ddOut) { if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut)); auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ELUGradGrad"));
ddout.device(*d) = ddx * ddout.device(*d) = ddx *
((x > static_cast<T>(0)).template cast<T>() + ((x > static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) * x.exp() * static_cast<T>(alpha) * x.exp() *
@ -1361,17 +1377,22 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
const framework::Tensor* ddX, framework::Tensor* ddOut, const framework::Tensor* ddX, framework::Tensor* ddOut,
framework::Tensor* dOut, const framework::Tensor* dX) const { framework::Tensor* dOut, const framework::Tensor* dX) const {
auto* d = dev.eigen_device(); auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX)); auto ddx = framework::EigenVector<T>::Flatten(
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad"));
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx // sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
// calculate dy first, so ddy can inplace ddx // calculate dy first, so ddy can inplace ddx
if (dOut) { if (dOut) {
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX)); auto dx = framework::EigenVector<T>::Flatten(
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut)); GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad"));
dout.device(*d) = dx * ddx * static_cast<T>(-1) / out; dout.device(*d) = dx * ddx * static_cast<T>(-1) / out;
} }
if (ddOut) { if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut)); auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(0.5) / out; ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
} }
} }
@ -1385,17 +1406,22 @@ struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
const framework::Tensor* ddX, framework::Tensor* ddOut, const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const { const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device(); auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX)); auto ddx = framework::EigenVector<T>::Flatten(
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X)); GET_DATA_SAFELY(ddX, "Input", "DDX", "SquareGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "SquareGradGrad"));
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx // square GradGrad: ddy=2x*ddx, dx=2dy*ddx
// calculate dx first, so ddy can inplace ddx // calculate dx first, so ddy can inplace ddx
if (dX) { if (dX) {
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX)); auto dx = framework::EigenVector<T>::Flatten(
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut)); GET_DATA_SAFELY(dX, "Output", "DX", "SquareGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "SquareGradGrad"));
dx.device(*d) = ddx * static_cast<T>(2) * dout; dx.device(*d) = ddx * static_cast<T>(2) * dout;
} }
if (ddOut) { if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut)); auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad"));
ddout.device(*d) = ddx * static_cast<T>(2) * x; ddout.device(*d) = ddx * static_cast<T>(2) * x;
} }
} }
@ -1557,8 +1583,10 @@ class PowKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
ExtractActivationTensor(context, &X, &Out); ExtractActivationTensor(context, &X, &Out);
Out->mutable_data<T>(context.GetPlace()); Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X)); auto x = framework::EigenVector<T>::Flatten(
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); GET_DATA_SAFELY(X, "Input", "X", "Pow"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "Pow"));
auto* place = auto* place =
context.template device_context<DeviceContext>().eigen_device(); context.template device_context<DeviceContext>().eigen_device();
Functor functor; Functor functor;
@ -1602,10 +1630,14 @@ class PowGradKernel
ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut, ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
&dX); &dX);
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut)); auto dout = framework::EigenVector<T>::Flatten(
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out)); GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "PowGrad"));
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX)); auto out = framework::EigenVector<T>::Flatten(
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X)); GET_DATA_SAFELY(Out, "Input", "Out", "PowGrad"));
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "X@GRAD", "PowGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "PowGrad"));
auto* place = auto* place =
context.template device_context<DeviceContext>().eigen_device(); context.template device_context<DeviceContext>().eigen_device();
Functor functor; Functor functor;

@ -15,7 +15,6 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {

@ -14,7 +14,6 @@ limitations under the License. */
#include <thread> // NOLINT #include <thread> // NOLINT
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
@ -56,9 +55,8 @@ class GetPlacesOp : public framework::OperatorBase {
is_gpu ? "GPU" : "CPU"); is_gpu ? "GPU" : "CPU");
auto out_var_name = Output("Out"); auto out_var_name = Output("Out");
auto &places = auto &places = *(GET_DATA_SAFELY(scope.FindVar(out_var_name), "Output",
*(detail::Ref(scope.FindVar(out_var_name), "Out", "GetPlaces")
"Output variable %s cannot be found", out_var_name)
.GetMutable<platform::PlaceList>()); .GetMutable<platform::PlaceList>());
places.reserve(device_count); places.reserve(device_count);
if (is_gpu) { if (is_gpu) {

@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/array_operator.h" #include "paddle/fluid/operators/array_operator.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
namespace paddle { namespace paddle {

@ -19,7 +19,6 @@
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h" #include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -198,23 +197,18 @@ class WhileGradOp : public framework::OperatorBase {
continue; continue;
} }
auto &og_outside = auto &og_outside = *scope.FindVar(outside_og_name);
detail::Ref(scope.FindVar(outside_og_name), auto &og_inside = *cur_scope.Var(inside_og_name);
"Cannot find Outside Gradient %s", outside_og_name);
auto &og_inside =
detail::Ref(cur_scope.Var(inside_og_name),
"Cannot find inside gradient %s", inside_og_name);
if (og_outside.IsType<framework::LoDTensor>()) { if (og_outside.IsType<framework::LoDTensor>()) {
auto &outside_tensor = og_outside.Get<framework::LoDTensor>(); auto &outside_tensor = og_outside.Get<framework::LoDTensor>();
auto &inside_tensor = auto &inside_tensor = *og_inside.GetMutable<framework::LoDTensor>();
detail::Ref(og_inside.GetMutable<framework::LoDTensor>());
inside_tensor.set_lod(outside_tensor.lod()); inside_tensor.set_lod(outside_tensor.lod());
inside_tensor.ShareDataWith(outside_tensor); inside_tensor.ShareDataWith(outside_tensor);
} else if (og_outside.IsType<framework::LoDTensorArray>()) { } else if (og_outside.IsType<framework::LoDTensorArray>()) {
auto outside_array = auto outside_array =
og_outside.GetMutable<framework::LoDTensorArray>(); og_outside.GetMutable<framework::LoDTensorArray>();
auto &inside_array = auto &inside_array =
detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>()); *og_inside.GetMutable<framework::LoDTensorArray>();
inside_array.clear(); inside_array.clear();
inside_array.resize(outside_array->size()); inside_array.resize(outside_array->size());
VLOG(8) << outside_og_name << " size = " << outside_array->size(); VLOG(8) << outside_og_name << " size = " << outside_array->size();

@ -20,7 +20,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/depthwise_conv.h" #include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
@ -674,9 +673,8 @@ class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
Tensor* ddY = ctx.Output<Tensor>("DDOutput"); Tensor* ddY = ctx.Output<Tensor>("DDOutput");
Tensor* dW = ctx.Output<Tensor>("DFilter"); Tensor* dW = ctx.Output<Tensor>("DFilter");
Tensor* dX = ctx.Output<Tensor>("DInput"); Tensor* dX = ctx.Output<Tensor>("DInput");
Tensor W = detail::Ref(ctx.Input<Tensor>("Filter"), Tensor W = GET_DATA_SAFELY(ctx.Input<Tensor>("Filter"), "Input", "Filter",
"Cannot find input Filter(%s) in scope)", "GemmConvDoubleGrad");
ctx.InputNames("Filter")[0]);
if (!ddY && !dW && !dX) return; if (!ddY && !dW && !dX) return;
const int groups = ctx.Attr<int>("groups"); const int groups = ctx.Attr<int>("groups");

@ -18,7 +18,6 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -29,13 +28,11 @@ class CumKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
using T = typename Functor::ELEMENT_TYPE; using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto& X = detail::Ref(context.Input<framework::Tensor>("X"), auto& X = GET_DATA_SAFELY(context.Input<framework::Tensor>("X"), "Input",
"Cannot get input tensor X, variable name = %s", "X", "Cum");
context.InputName("X"));
auto& Out = detail::Ref(context.Output<framework::Tensor>("Out"), auto& Out = GET_DATA_SAFELY(context.Output<framework::Tensor>("Out"),
"Cannot get output tensor Out, variable name = %s", "Output", "Out", "Cum");
context.OutputName("Out"));
int axis = context.Attr<int>("axis"); int axis = context.Attr<int>("axis");
bool exclusive = context.Attr<bool>("exclusive"); bool exclusive = context.Attr<bool>("exclusive");
bool reverse = context.Attr<bool>("reverse"); bool reverse = context.Attr<bool>("reverse");
@ -46,7 +43,7 @@ class CumKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
axis, x_dims.size(), axis, x_dims.size(),
"axis should be less than the dimensiotn of the input tensor"); "axis should be less than the dimensiotn of the input tensor");
Out.mutable_data<T>(context.GetPlace()); Out.template mutable_data<T>(context.GetPlace());
int pre = 1; int pre = 1;
int post = 1; int post = 1;

@ -1,45 +0,0 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace detail {
/**
* Get Reference From Pointer with check. The error message is printf format,
* and passed by `args`
*/
template <typename T, typename... ARGS>
inline T& Ref(T* ptr, ARGS&&... args) {
PADDLE_ENFORCE_NOT_NULL(ptr, ::paddle::string::Sprintf(args...));
return *ptr;
}
template <typename T, typename... ARGS>
inline std::vector<std::reference_wrapper<T>> VectorRef(
const std::vector<T*>& vec, ARGS&&... args) {
std::vector<std::reference_wrapper<T>> result;
result.reserve(vec.size());
for (auto* ptr : vec) {
result.emplace_back(Ref(ptr, args...));
}
return result;
}
} // namespace detail
} // namespace operators
} // namespace paddle

@ -20,7 +20,6 @@ limitations under the License.*/
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"

@ -20,7 +20,6 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"

@ -17,7 +17,6 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.h" #include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
@ -293,12 +292,10 @@ class GenerateProposalsKernel : public framework::OpKernel<T> {
auto *scores = context.Input<Tensor>("Scores"); auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas"); auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_info = context.Input<Tensor>("ImInfo"); auto *im_info = context.Input<Tensor>("ImInfo");
auto anchors = detail::Ref(context.Input<Tensor>("Anchors"), auto anchors = GET_DATA_SAFELY(context.Input<Tensor>("Anchors"), "Input",
"Cannot find input Anchors(%s) in scope", "Anchors", "GenerateProposals");
context.InputNames("Anchors")[0]); auto variances = GET_DATA_SAFELY(context.Input<Tensor>("Variances"),
auto variances = detail::Ref(context.Input<Tensor>("Variances"), "Input", "Variances", "GenerateProposals");
"Cannot find input Variances(%s) in scope",
context.InputNames("Variances")[0]);
auto *rpn_rois = context.Output<LoDTensor>("RpnRois"); auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs"); auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");

@ -20,7 +20,6 @@ limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/gather.cu.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
@ -367,12 +366,10 @@ class CUDAGenerateProposalsKernel : public framework::OpKernel<T> {
auto *scores = context.Input<Tensor>("Scores"); auto *scores = context.Input<Tensor>("Scores");
auto *bbox_deltas = context.Input<Tensor>("BboxDeltas"); auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
auto *im_info = context.Input<Tensor>("ImInfo"); auto *im_info = context.Input<Tensor>("ImInfo");
auto anchors = detail::Ref(context.Input<Tensor>("Anchors"), auto anchors = GET_DATA_SAFELY(context.Input<Tensor>("Anchors"), "Input",
"Cannot find input Anchors(%s) in scope", "Anchors", "GenerateProposals");
context.InputNames("Anchors")[0]); auto variances = GET_DATA_SAFELY(context.Input<Tensor>("Variances"),
auto variances = detail::Ref(context.Input<Tensor>("Variances"), "Input", "Variances", "GenerateProposals");
"Cannot find input Variances(%s) in scope",
context.InputNames("Variances")[0]);
auto *rpn_rois = context.Output<LoDTensor>("RpnRois"); auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs"); auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");

@ -19,7 +19,6 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -44,10 +43,8 @@ template <typename T>
class FillKernel : public framework::OpKernel<T> { class FillKernel : public framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override { void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto &out = auto &out = GET_DATA_SAFELY(ctx.Output<framework::LoDTensor>("Out"),
detail::Ref(ctx.Output<framework::LoDTensor>("Out"), "Output", "Out", "Fill");
"Cannot get output lod tensor Out, variable name = %s",
ctx.OutputName("Out"));
out.Resize(framework::make_ddim(ctx.Attr<std::vector<int>>("shape"))); out.Resize(framework::make_ddim(ctx.Attr<std::vector<int>>("shape")));
auto dtype = auto dtype =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")); static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));

@ -18,7 +18,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/compound_functors.h" #include "paddle/fluid/operators/math/compound_functors.h"
#include "paddle/fluid/operators/math/functors.h" #include "paddle/fluid/operators/math/functors.h"
@ -383,12 +382,10 @@ template <typename DeviceContext, typename T>
class FusedElemwiseActivationKernel : public framework::OpKernel<T> { class FusedElemwiseActivationKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto &in_x = detail::Ref(ctx.Input<framework::Tensor>("X"), auto &in_x = GET_DATA_SAFELY(ctx.Input<framework::Tensor>("X"), "Input",
"Cannot get input tensor %s, variable name = %s", "X", "FusedElemwiseActivation");
"X", ctx.InputName("X")); auto &in_y = GET_DATA_SAFELY(ctx.Input<framework::Tensor>("Y"), "Input",
auto &in_y = detail::Ref(ctx.Input<framework::Tensor>("Y"), "Y", "FusedElemwiseActivation");
"Cannot get input tensor %s, variable name = %s",
"Y", ctx.InputName("Y"));
PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty"); PADDLE_ENFORCE(ctx.HasOutput("Out"), "The output(Out) should not be empty");
auto output = ctx.Output<framework::Tensor>("Out"); auto output = ctx.Output<framework::Tensor>("Out");

@ -14,7 +14,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
namespace paddle { namespace paddle {

@ -19,7 +19,6 @@
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"

@ -14,7 +14,6 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
namespace paddle { namespace paddle {

@ -17,7 +17,6 @@
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h" #include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
@ -142,14 +141,13 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
auto *input = context.Input<framework::Tensor>("Input"); auto *input = context.Input<framework::Tensor>("Input");
auto *w = context.Input<framework::Tensor>("W"); auto *w = context.Input<framework::Tensor>("W");
auto *bias = context.Input<framework::Tensor>("Bias"); auto *bias = context.Input<framework::Tensor>("Bias");
auto &bias_qk = GET_DATA_SAFELY(context.Input<framework::Tensor>("BiasQK"),
auto &bias_qk = detail::Ref(context.Input<framework::Tensor>("BiasQK"), "Input", "BiasQK", "MultiHeadMatMulV2");
"Cannot find QK");
auto *input_d = input->data<T>(); auto *input_d = input->data<T>();
auto *w_d = w->data<T>(); auto *w_d = w->data<T>();
auto *bias_d = bias->data<T>(); auto *bias_d = bias->data<T>();
auto *bias_qk_d = bias_qk.data<T>(); auto *bias_qk_d = bias_qk.template data<T>();
T scale = static_cast<T>(context.Attr<float>("alpha")); T scale = static_cast<T>(context.Attr<float>("alpha"));
int head_number = context.Attr<int>("head_number"); int head_number = context.Attr<int>("head_number");

@ -24,7 +24,6 @@ limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/clip_op.h" #include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_bit_code.h" #include "paddle/fluid/operators/math/matrix_bit_code.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
@ -40,8 +39,9 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
using platform::Transform; using platform::Transform;
using framework::LoDTensor;
static std::vector<int64_t> PathToRows(const framework::LoDTensor& path) { static std::vector<int64_t> PathToRows(const LoDTensor& path) {
std::set<int64_t> rows; std::set<int64_t> rows;
const int64_t* paths = path.data<int64_t>(); const int64_t* paths = path.data<int64_t>();
for (int64_t i = 0; i < path.numel(); ++i) { for (int64_t i = 0; i < path.numel(); ++i) {
@ -57,14 +57,17 @@ template <typename DeviceContext, typename T>
class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> { class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X")); auto& in = GET_DATA_SAFELY(ctx.Input<LoDTensor>("X"), "Input", "X",
auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W")); "HierarchicalSigmoid");
auto* path = ctx.Input<framework::LoDTensor>("PathTable"); auto& w = GET_DATA_SAFELY(ctx.Input<LoDTensor>("W"), "Input", "W",
auto* code = ctx.Input<framework::LoDTensor>("PathCode"); "HierarchicalSigmoid");
auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label")); auto* path = ctx.Input<LoDTensor>("PathTable");
auto* bias = ctx.Input<framework::LoDTensor>("Bias"); auto* code = ctx.Input<LoDTensor>("PathCode");
auto* out = ctx.Output<framework::LoDTensor>("Out"); auto& label = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Label"), "Input",
auto* pre_out = ctx.Output<framework::LoDTensor>("PreOut"); "Label", "HierarchicalSigmoid");
auto* bias = ctx.Input<LoDTensor>("Bias");
auto* out = ctx.Output<LoDTensor>("Out");
auto* pre_out = ctx.Output<LoDTensor>("PreOut");
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes")); size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
// for remote prefetch // for remote prefetch
@ -75,7 +78,7 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
int64_t code_length = int64_t code_length =
path ? path->dims()[1] : math::FindLastSet(num_classes - 1); path ? path->dims()[1] : math::FindLastSet(num_classes - 1);
int64_t batch_size = in.dims()[0]; int64_t batch_size = in.dims()[0];
framework::LoDTensor sum; LoDTensor sum;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto* pre_out_data = pre_out->mutable_data<T>( auto* pre_out_data = pre_out->mutable_data<T>(
framework::make_ddim({batch_size, code_length}), ctx.GetPlace()); framework::make_ddim({batch_size, code_length}), ctx.GetPlace());
@ -89,11 +92,11 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code; std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) { if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(num_classes, bit_code.reset(new math::MatrixBitCodeFunctor<T>(
label.data<int64_t>())); num_classes, label.template data<int64_t>()));
} else { } else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(*path, *code, bit_code.reset(new math::MatrixBitCodeFunctor<T>(
label.data<int64_t>())); *path, *code, label.template data<int64_t>()));
} }
std::vector<int64_t> sum_dims({batch_size, 1UL}); std::vector<int64_t> sum_dims({batch_size, 1UL});
@ -126,20 +129,24 @@ template <typename DeviceContext, typename T>
class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> { class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto& in = detail::Ref(ctx.Input<framework::LoDTensor>("X")); auto& in = GET_DATA_SAFELY(ctx.Input<LoDTensor>("X"), "Input", "X",
auto& w = detail::Ref(ctx.Input<framework::LoDTensor>("W")); "HierarchicalSigmoidGrad");
auto* path = ctx.Input<framework::LoDTensor>("PathTable"); auto& w = GET_DATA_SAFELY(ctx.Input<LoDTensor>("W"), "Input", "W",
auto* code = ctx.Input<framework::LoDTensor>("PathCode"); "HierarchicalSigmoidGrad");
auto* in_grad = auto* path = ctx.Input<LoDTensor>("PathTable");
ctx.Output<framework::LoDTensor>(framework::GradVarName("X")); auto* code = ctx.Input<LoDTensor>("PathCode");
auto* in_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
bool is_sparse = ctx.Attr<bool>("is_sparse"); bool is_sparse = ctx.Attr<bool>("is_sparse");
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero; math::SetConstant<DeviceContext, T> zero;
auto& label = detail::Ref(ctx.Input<framework::LoDTensor>("Label")); auto& label = GET_DATA_SAFELY(ctx.Input<LoDTensor>("Label"), "Input",
auto& pre_out = detail::Ref(ctx.Input<framework::LoDTensor>("PreOut")); "Label", "HierarchicalSigmoidGrad");
auto& out_grad = detail::Ref( auto& pre_out = GET_DATA_SAFELY(ctx.Input<LoDTensor>("PreOut"), "Input",
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))); "PreOut", "HierarchicalSigmoidGrad");
framework::LoDTensor pre_out_grad; auto& out_grad = GET_DATA_SAFELY(
ctx.Input<LoDTensor>(framework::GradVarName("Out")), "Input",
framework::GradVarName("Out"), "HierarchicalSigmoidGrad");
LoDTensor pre_out_grad;
pre_out_grad.mutable_data<T>(pre_out.dims(), ctx.GetPlace()); pre_out_grad.mutable_data<T>(pre_out.dims(), ctx.GetPlace());
in_grad->mutable_data<T>(ctx.GetPlace()); in_grad->mutable_data<T>(ctx.GetPlace());
@ -154,11 +161,11 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code; std::unique_ptr<math::MatrixBitCodeFunctor<T>> bit_code;
if (!is_custom) { if (!is_custom) {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(num_classes, bit_code.reset(new math::MatrixBitCodeFunctor<T>(
label.data<int64_t>())); num_classes, label.template data<int64_t>()));
} else { } else {
bit_code.reset(new math::MatrixBitCodeFunctor<T>(*path, *code, bit_code.reset(new math::MatrixBitCodeFunctor<T>(
label.data<int64_t>())); *path, *code, label.template data<int64_t>()));
} }
// softrelu derivative // softrelu derivative
@ -166,7 +173,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
auto* pre_out_grad_data = pre_out_grad.data<T>(); auto* pre_out_grad_data = pre_out_grad.data<T>();
auto* pre_out_data = pre_out.data<T>(); auto* pre_out_data = pre_out.template data<T>();
auto n = pre_out.numel(); auto n = pre_out.numel();
blas.VEXP(n, pre_out_data, pre_out_grad_data); blas.VEXP(n, pre_out_data, pre_out_grad_data);
blas.VINV(n, pre_out_grad_data, pre_out_grad_data); blas.VINV(n, pre_out_grad_data, pre_out_grad_data);
@ -174,7 +181,7 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i]; pre_out_grad_data[i] = 1.0 - pre_out_grad_data[i];
} }
bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b) bit_code->Sub(&pre_out_grad); // the gradient of clip(w * x + b)
auto* out_grad_data = out_grad.data<T>(); auto* out_grad_data = out_grad.template data<T>();
int64_t dim0 = pre_out_grad.dims()[0]; int64_t dim0 = pre_out_grad.dims()[0];
int64_t dim1 = pre_out_grad.dims()[1]; int64_t dim1 = pre_out_grad.dims()[1];
@ -184,16 +191,14 @@ class HierarchicalSigmoidGradOpKernel : public framework::OpKernel<T> {
} }
// TODO(guosheng): multiply pre_out_grad with subgradient of clipping to // TODO(guosheng): multiply pre_out_grad with subgradient of clipping to
// be consistent with the clipping in forward. // be consistent with the clipping in forward.
auto* bias_grad = auto* bias_grad = ctx.Output<LoDTensor>(framework::GradVarName("Bias"));
ctx.Output<framework::LoDTensor>(framework::GradVarName("Bias"));
if (bias_grad) { if (bias_grad) {
bias_grad->mutable_data<T>(ctx.GetPlace()); bias_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, bias_grad, static_cast<T>(0.0)); zero(dev_ctx, bias_grad, static_cast<T>(0.0));
bit_code->AddGrad(pre_out_grad, bias_grad); bit_code->AddGrad(pre_out_grad, bias_grad);
} }
if (!is_sparse) { if (!is_sparse) {
auto* w_grad = auto* w_grad = ctx.Output<LoDTensor>(framework::GradVarName("W"));
ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
w_grad->mutable_data<T>(ctx.GetPlace()); w_grad->mutable_data<T>(ctx.GetPlace());
zero(dev_ctx, w_grad, static_cast<T>(0.0)); zero(dev_ctx, w_grad, static_cast<T>(0.0));
bit_code->MulGradWeight(pre_out_grad, w_grad, in); bit_code->MulGradWeight(pre_out_grad, w_grad, in);

@ -16,7 +16,6 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
@ -95,13 +94,15 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto &x = detail::Ref(scope.FindVar(Input("X")), "Cannot find input %s", auto &x = GET_DATA_SAFELY(scope.FindVar(Input("X")), "Input", "X",
Input("X")) "LoDTensorToArray")
.Get<framework::LoDTensor>(); .Get<framework::LoDTensor>();
auto &rank_table = detail::Ref(scope.FindVar(Input("RankTable"))) auto &rank_table = GET_DATA_SAFELY(scope.FindVar(Input("RankTable")),
"Input", "RankTable", "LoDTensorToArray")
.Get<framework::LoDRankTable>(); .Get<framework::LoDRankTable>();
auto &out = *detail::Ref(scope.FindVar(Output("Out"))) auto &out = *(GET_DATA_SAFELY(scope.FindVar(Output("Out")), "Output", "Out",
.GetMutable<framework::LoDTensorArray>(); "LoDTensorToArray")
.GetMutable<framework::LoDTensorArray>());
auto &items = rank_table.items(); auto &items = rank_table.items();
auto max_seq_len = items[0].length; auto max_seq_len = items[0].length;
auto rank_level = rank_table.level(); auto rank_level = rank_table.level();

@ -16,7 +16,6 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
@ -58,10 +57,10 @@ template <typename DeviceContext, typename T>
class MatMulKernel : public framework::OpKernel<T> { class MatMulKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto &x = auto &x = GET_DATA_SAFELY(context.Input<framework::Tensor>("X"), "Input",
detail::Ref(context.Input<framework::Tensor>("X"), "Cannot find X"); "X", "MatMul");
auto &y = auto &y = GET_DATA_SAFELY(context.Input<framework::Tensor>("Y"), "Input",
detail::Ref(context.Input<framework::Tensor>("Y"), "Cannot find Y"); "Y", "MatMul");
auto *out = context.Output<framework::Tensor>("Out"); auto *out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());

@ -17,7 +17,6 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"

@ -128,7 +128,6 @@ class AdamOpCUDAKernel : public framework::OpKernel<T> {
framework::ToTypeName(param_var->Type()))); framework::ToTypeName(param_var->Type())));
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
using paddle::operators::detail::Ref;
int64_t min_row_size_to_use_multithread = int64_t min_row_size_to_use_multithread =
ctx.Attr<int64_t>("min_row_size_to_use_multithread"); ctx.Attr<int64_t>("min_row_size_to_use_multithread");

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save