use operator context and infer context (#3024)

* use operator context

* optimize code

* update net infershape

* update InferShape

* disable override InferShape(scope) in OperatorBase

* change InferShapeImpl to InferShape

* add template to OperatorContext Input/Output

* merge Input InputVar, Output OutputVar

* change Inputs to MultiInput

* fix conflict

* fix MultiInput bugs and add unit test

* rename KernelContext to ExecutionContext

* clean code

* change InferShape to protected

* fix template bug

* refine code

* use InputVar instead of Input<Variable>

* typo

* optimize code
cblas_new
Qiao Longfei 8 years ago committed by GitHub
parent 0b68077221
commit 61ebacbcd3

@ -16,8 +16,7 @@ static int run_cnt = 0;
class TestOp : public OperatorBase { class TestOp : public OperatorBase {
public: public:
void InferShape( void InferShape(const std::shared_ptr<Scope>& scope) const override {
const std::shared_ptr<framework::Scope>& scope) const override {
++infer_shape_cnt; ++infer_shape_cnt;
} }
void Run(const std::shared_ptr<framework::Scope>& scope, void Run(const std::shared_ptr<framework::Scope>& scope,

@ -20,7 +20,7 @@ namespace paddle {
namespace framework { namespace framework {
template <> template <>
Eigen::DefaultDevice* KernelContext::GetEigenDevice< Eigen::DefaultDevice* ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const { platform::CPUPlace, Eigen::DefaultDevice>() const {
return device_context_.get_eigen_device<Eigen::DefaultDevice>(); return device_context_.get_eigen_device<Eigen::DefaultDevice>();
} }
@ -28,7 +28,7 @@ Eigen::DefaultDevice* KernelContext::GetEigenDevice<
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice* Eigen::GpuDevice*
KernelContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return device_context_.get_eigen_device<Eigen::GpuDevice>(); return device_context_.get_eigen_device<Eigen::GpuDevice>();
} }
#endif #endif

File diff suppressed because it is too large Load Diff

@ -24,7 +24,8 @@ static int op_run_num = 0;
class OpWithoutKernelTest : public OperatorBase { class OpWithoutKernelTest : public OperatorBase {
public: public:
void Init() override { x = 1; } void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope, void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
op_run_num++; op_run_num++;
@ -73,6 +74,7 @@ TEST(OperatorBase, all) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
scope->CreateVariable("OUT1"); scope->CreateVariable("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0); ASSERT_EQ(paddle::framework::op_run_num, 0);
op->InferShape(scope);
op->Run(scope, device_context); op->Run(scope, device_context);
ASSERT_EQ(paddle::framework::op_run_num, 1); ASSERT_EQ(paddle::framework::op_run_num, 1);
} }
@ -97,14 +99,13 @@ static int cpu_kernel_run_num = 0;
class OpWithKernelTest : public OperatorWithKernel { class OpWithKernelTest : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor*>& inputs, void InferShape(const framework::InferShapeContext& ctx) const override {}
const std::vector<Tensor*>& outputs) const override {}
}; };
template <typename T1, typename T2> template <typename T1, typename T2>
class CPUKernelTest : public OpKernel { class CPUKernelTest : public OpKernel {
public: public:
void Compute(const KernelContext& ctx) const { void Compute(const ExecutionContext& ctx) const {
std::cout << "this is cpu kernel" << std::endl; std::cout << "this is cpu kernel" << std::endl;
std::cout << ctx.op_.DebugString() << std::endl; std::cout << ctx.op_.DebugString() << std::endl;
cpu_kernel_run_num++; cpu_kernel_run_num++;
@ -117,7 +118,8 @@ class CPUKernelTest : public OpKernel {
class OperatorMultiInputsTest : public OperatorBase { class OperatorMultiInputsTest : public OperatorBase {
public: public:
void Init() override { x = 1; } void Init() override { x = 1; }
void InferShape(const std::shared_ptr<Scope>& scope) const override {} void InferShape(
const std::shared_ptr<framework::Scope>& scope) const override {}
void Run(const std::shared_ptr<Scope>& scope, void Run(const std::shared_ptr<Scope>& scope,
const platform::DeviceContext& dev_ctx) const override { const platform::DeviceContext& dev_ctx) const override {
ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr);
@ -149,13 +151,31 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
class CPUKernalMultiInputsTest : public OpKernel { class CPUKernalMultiInputsTest : public OpKernel {
public: public:
void Compute(const KernelContext& ctx) const { void Compute(const ExecutionContext& ctx) const {
auto xs = ctx.op_.Inputs("xs"); auto xs = ctx.op_.Inputs("xs");
ASSERT_EQ(xs.size(), 3UL); ASSERT_EQ(xs.size(), 3UL);
ASSERT_EQ(xs[0], "x0"); ASSERT_EQ(xs[0], "x0");
ASSERT_EQ(xs[1], "x1"); ASSERT_EQ(xs[1], "x1");
ASSERT_EQ(xs[2], "x2"); ASSERT_EQ(xs[2], "x2");
auto inVar0 = ctx.MultiInputVar("xs");
ASSERT_EQ(inVar0.size(), 3);
auto intVar1 = ctx.InputVar("k");
ASSERT_NE(intVar1, nullptr);
auto outVar0 = ctx.MultiOutputVar("ys");
ASSERT_EQ(outVar0.size(), 2);
auto inTensor0 = ctx.MultiInput<Tensor>("xs");
ASSERT_EQ(inTensor0.size(), 3);
auto intTensor1 = ctx.Input<Tensor>("k");
ASSERT_NE(intTensor1, nullptr);
auto outTensor0 = ctx.MultiOutput<Tensor>("ys");
ASSERT_EQ(outTensor0.size(), 2);
auto k = ctx.op_.Input("k"); auto k = ctx.op_.Input("k");
ASSERT_EQ(k, "k0"); ASSERT_EQ(k, "k0");
@ -233,6 +253,12 @@ TEST(OpKernel, multi_inputs) {
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
scope->CreateVariable("x0")->GetMutable<Tensor>();
scope->CreateVariable("x1")->GetMutable<Tensor>();
scope->CreateVariable("x2")->GetMutable<Tensor>();
scope->CreateVariable("k0")->GetMutable<Tensor>();
scope->CreateVariable("y0")->GetMutable<Tensor>();
scope->CreateVariable("y1")->GetMutable<Tensor>();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);

@ -19,16 +19,16 @@ namespace operators {
class AddOp : public OperatorWithKernel { class AddOp : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {
const std::vector<Tensor *> &outputs) const override { PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of AddOp must be two");
PADDLE_ENFORCE(inputs.size() == 2, "Input size of AddOp must be two"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of AddOp must be one"); PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr,
PADDLE_ENFORCE( "Inputs of AddOp must all be set");
inputs[0] != nullptr && inputs[1] != nullptr && outputs[0] != nullptr, PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
"Inputs/Outputs of AddOp must all be set"); "Outputs of AddOp must all be set");
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(), PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims() == ctx.Input<Tensor>(1)->dims(),
"Two input of Add Op's dimension must be same."); "Two input of Add Op's dimension must be same.");
outputs[0]->Resize(inputs[0]->dims()); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
} }
}; };
@ -49,8 +49,7 @@ The equation is: Out = X + Y
class AddOpGrad : public OperatorWithKernel { class AddOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {}
const std::vector<Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "AddOpGrad"; LOG(INFO) << "AddOpGrad";
return ""; return "";

@ -21,16 +21,17 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class AddKernel : public OpKernel { class AddKernel : public OpKernel {
public: public:
void Compute(const KernelContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto input0 = context.Input(0)->Get<Tensor>(); auto input0 = context.Input<Tensor>(0);
auto input1 = context.Input(1)->Get<Tensor>(); auto input1 = context.Input<Tensor>(1);
auto output = context.Output(0)->GetMutable<Tensor>(); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device( EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) = *(context.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(input0) + EigenVector<T>::Flatten(input1); framework::EigenVector<T>::Flatten(*input0) +
framework::EigenVector<T>::Flatten(*input1);
} }
}; };

@ -19,20 +19,20 @@ namespace operators {
class OnehotCrossEntropyOp : public OperatorWithKernel { class OnehotCrossEntropyOp : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {
const std::vector<Tensor *> &outputs) const override { PADDLE_ENFORCE(ctx.InputSize() == 2,
PADDLE_ENFORCE(inputs.size() == 2,
"Input size of OnehotCrossEntropyOp must be two"); "Input size of OnehotCrossEntropyOp must be two");
PADDLE_ENFORCE(outputs.size() == 1, PADDLE_ENFORCE(ctx.OutputSize() == 1,
"Output size of OnehotCrossEntropyOp must be one"); "Output size of OnehotCrossEntropyOp must be one");
PADDLE_ENFORCE(inputs[0] != nullptr && inputs[1] != nullptr, PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr,
"Inputs of OnehotCrossEntropyOp must all be set"); "Inputs of OnehotCrossEntropyOp must all be set");
PADDLE_ENFORCE(outputs[0] != nullptr, PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
"Outputs of OnehotCrossEntropyOp must all be set"); "Outputs of OnehotCrossEntropyOp must all be set");
PADDLE_ENFORCE(inputs[0]->dims().size() == 2, "X's dimension must be 2."); PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims().size() == 2,
PADDLE_ENFORCE(outputs[0]->dims().size() == 1, "X's dimension must be 2.");
PADDLE_ENFORCE(ctx.Output<Tensor>(0)->dims().size() == 1,
"label's dimension must be 1."); "label's dimension must be 1.");
outputs[0]->Resize({inputs[0]->dims()[0]}); ctx.Output<Tensor>(0)->Resize({ctx.Input<Tensor>(0)->dims()[0]});
} }
}; };

@ -23,18 +23,18 @@ class OnehotCrossEntropyOpKernel : public OpKernel {
public: public:
constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); } constexpr T LOG_THRESHOLD() const { return static_cast<T>(1e-20); }
void Compute(const KernelContext& context) const override { void Compute(const ExecutionContext& ctx) const override {
auto X = context.Input(0)->Get<Tensor>(); auto X = ctx.Input<Tensor>(0);
const T* X_data = X.data<T>(); const T* X_data = X->data<T>();
const int* label_data = context.Input(1)->Get<Tensor>().data<int>(); const int* label_data = ctx.Input<Tensor>(1)->data<int>();
auto* Y = context.Output(0)->GetMutable<Tensor>(); auto Y = ctx.Output<Tensor>(0);
Y->mutable_data<T>(context.GetPlace()); Y->mutable_data<T>(ctx.GetPlace());
T* Y_data = Y->data<T>(); T* Y_data = Y->data<T>();
int batch_size = X.dims()[0]; int batch_size = X->dims()[0];
int class_num = X.dims()[1]; int class_num = X->dims()[1];
// Y[i] = -log(X[i][j]) // Y[i] = -log(X[i][j])
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {

@ -19,18 +19,17 @@ namespace operators {
class MulOp : public OperatorWithKernel { class MulOp : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {
const std::vector<Tensor *> &outputs) const override { PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs");
PADDLE_ENFORCE(inputs.size() == 2, "The mul op must take two inputs"); auto dim0 = ctx.Input<Tensor>(0)->dims();
auto dim0 = inputs[0]->dims(); auto dim1 = ctx.Input<Tensor>(1)->dims();
auto dim1 = inputs[1]->dims();
PADDLE_ENFORCE(dim0.size() == 2 && dim1.size() == 2, PADDLE_ENFORCE(dim0.size() == 2 && dim1.size() == 2,
"The input of mul op must be matrix"); "The input of mul op must be matrix");
PADDLE_ENFORCE( PADDLE_ENFORCE(
dim0[1] == dim1[0], dim0[1] == dim1[0],
"First matrix's width must be equal with second matrix's height."); "First matrix's width must be equal with second matrix's height.");
PADDLE_ENFORCE(outputs.size() == 1, "The mul op must take one output"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "The mul op must take one output");
outputs[0]->Resize({dim0[0], dim1[1]}); ctx.Output<Tensor>(0)->Resize({dim0[0], dim1[1]});
} }
}; };
@ -51,8 +50,7 @@ The equation is: Out = X * Y
class MulOpGrad : public OperatorWithKernel { class MulOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {}
const std::vector<Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "MulGrad"; LOG(INFO) << "MulGrad";
return ""; return "";

@ -22,19 +22,17 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class MulKernel : public OpKernel { class MulKernel : public OpKernel {
public: public:
void Compute(const KernelContext& context) const override { void Compute(const ExecutionContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = { Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}}; {Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto input0 = context.Input(0)->Get<Tensor>(); auto output = context.Output<Tensor>(0);
auto input1 = context.Input(1)->Get<Tensor>();
auto* output = context.Output(0)->GetMutable<Tensor>();
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) = EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
EigenMatrix<T>::From(input0).contract(EigenMatrix<T>::From(input1), EigenMatrix<T>::From(*context.Input<Tensor>("X"))
dim_pair); .contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")),
dim_pair);
} }
}; };
} // namespace operators } // namespace operators

@ -18,17 +18,17 @@ namespace operators {
class RowWiseAddOp : public OperatorWithKernel { class RowWiseAddOp : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {
const std::vector<Tensor *> &outputs) const override { PADDLE_ENFORCE(ctx.InputSize() == 2UL,
PADDLE_ENFORCE(inputs.size() == 2UL, "Two inputs is needed by rowwise add"); "Two inputs is needed by rowwise add");
auto dim0 = inputs[0]->dims(); auto dim0 = ctx.Input<Tensor>(0)->dims();
auto dim1 = inputs[1]->dims(); auto dim1 = ctx.Input<Tensor>(1)->dims();
PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix"); PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix");
PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector"); PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector");
PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same"); PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same");
PADDLE_ENFORCE(outputs.size() == 1, "The output size must be 1"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "The output size must be 1");
outputs[0]->Resize(inputs[0]->dims()); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
} }
}; };

@ -21,14 +21,12 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class RowWiseAddKernel : public OpKernel { class RowWiseAddKernel : public OpKernel {
public: public:
void Compute(const KernelContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto in0 = context.Input(0)->Get<Tensor>(); auto out = context.Output<Tensor>(0);
auto in1 = context.Input(1)->Get<Tensor>();
auto* out = context.Output(0)->GetMutable<Tensor>();
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto input = EigenMatrix<T>::From(in0); auto input = EigenMatrix<T>::From(*context.Input<Tensor>(0));
auto bias = EigenVector<T>::From(in1); auto bias = EigenVector<T>::From(*context.Input<Tensor>(1));
auto output = EigenMatrix<T>::From(*out); auto output = EigenMatrix<T>::From(*out);
const int bias_size = bias.dimension(0); const int bias_size = bias.dimension(0);

@ -19,16 +19,15 @@ namespace operators {
class SGDOp : public OperatorWithKernel { class SGDOp : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {
const std::vector<Tensor *> &outputs) const override { PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of SGDOp must be two");
PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of SGDOp must be one");
PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one"); PADDLE_ENFORCE(ctx.InputVar(0) != nullptr, "inputs[0] mast be set");
PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set"); PADDLE_ENFORCE(ctx.InputVar(1) != nullptr, "inputs[1] mast be set");
PADDLE_ENFORCE(inputs[1] != nullptr, "inputs[1] mast be set"); PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, "outputs[0] mast be set");
PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set"); PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims() == ctx.Input<Tensor>(1)->dims(),
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
"Two input of SGD Op's dimension must be same."); "Two input of SGD Op's dimension must be same.");
outputs[0]->Resize(inputs[0]->dims()); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
} }
}; };

@ -21,16 +21,16 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SGDOpKernel : public OpKernel { class SGDOpKernel : public OpKernel {
public: public:
void Compute(const KernelContext& ctx) const override { void Compute(const ExecutionContext& ctx) const override {
auto param = ctx.Input("param")->Get<Tensor>(); auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input("grad")->Get<Tensor>(); auto grad = ctx.Input<Tensor>("grad");
auto* param_out = ctx.Output(0)->GetMutable<Tensor>(); auto param_out = ctx.Output<Tensor>(0);
float lr = ctx.op_.GetAttr<float>("learning_rate"); float lr = ctx.op_.GetAttr<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) = EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(param) - lr * EigenVector<T>::Flatten(grad); EigenVector<T>::Flatten(*param) - lr * EigenVector<T>::Flatten(*grad);
} }
}; };

@ -18,11 +18,10 @@ namespace operators {
class SigmoidOp : public OperatorWithKernel { class SigmoidOp : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {
const std::vector<Tensor *> &outputs) const override { PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input");
PADDLE_ENFORCE(inputs.size() == 1, "Sigmoid Op only have one input"); PADDLE_ENFORCE(ctx.OutputSize() == 1, "Sigmoid Op only have one output");
PADDLE_ENFORCE(outputs.size() == 1, "Sigmoid Op only have one output"); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
outputs[0]->Resize(inputs[0]->dims());
} }
}; };
@ -38,8 +37,7 @@ public:
class SigmoidOpGrad : public OperatorWithKernel { class SigmoidOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {}
const std::vector<Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "SigmoidGrad"; LOG(INFO) << "SigmoidGrad";
return ""; return "";

@ -22,15 +22,14 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SigmoidKernel : public OpKernel { class SigmoidKernel : public OpKernel {
public: public:
void Compute(const KernelContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto input = context.Input(0)->Get<Tensor>(); auto input = context.Input<Tensor>(0);
auto* output = context.Output(0)->GetMutable<Tensor>(); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device( EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) = *(context.GetEigenDevice<Place>())) =
1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(input)).exp()); 1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp());
} }
}; };
} // namespace operators } // namespace operators

@ -18,14 +18,13 @@ namespace operators {
class SoftmaxOp : public OperatorWithKernel { class SoftmaxOp : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {
const std::vector<Tensor *> &outputs) const override { PADDLE_ENFORCE(ctx.InputSize() == 1, "Only one input is need for softmax");
PADDLE_ENFORCE(inputs.size() == 1, "Only one input is need for softmax"); PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims().size() == 2,
PADDLE_ENFORCE(inputs[0]->dims().size() == 2,
"The input of softmax op must be matrix"); "The input of softmax op must be matrix");
PADDLE_ENFORCE(outputs.size() == 1, "Only one output is need for softmax"); PADDLE_ENFORCE(ctx.OutputSize() == 1,
"Only one output is need for softmax");
outputs[0]->Resize(inputs[0]->dims()); ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
} }
}; };
@ -41,8 +40,7 @@ public:
class SoftmaxOpGrad : public OperatorWithKernel { class SoftmaxOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const std::vector<const Tensor *> &inputs, void InferShape(const InferShapeContext &ctx) const override {}
const std::vector<Tensor *> &outputs) const override {}
std::string DebugString() const override { std::string DebugString() const override {
LOG(INFO) << "SoftmaxOpGrad"; LOG(INFO) << "SoftmaxOpGrad";
return ""; return "";

@ -22,12 +22,12 @@ namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SoftmaxKernel : public OpKernel { class SoftmaxKernel : public OpKernel {
public: public:
void Compute(const KernelContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto input = context.Input(0)->Get<Tensor>(); auto input = context.Input<Tensor>(0);
auto* output = context.Output(0)->GetMutable<Tensor>(); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
auto logits = EigenMatrix<T>::From(input); auto logits = EigenMatrix<T>::From(*input);
auto softmax = EigenMatrix<T>::From(*output); auto softmax = EigenMatrix<T>::From(*output);
const int kBatchDim = 0; const int kBatchDim = 0;

@ -22,7 +22,9 @@ namespace paddle {
namespace operators { namespace operators {
using OpKernel = framework::OpKernel; using OpKernel = framework::OpKernel;
using KernelContext = framework::KernelContext; using InferShapeContext = framework::InferShapeContext;
using ExecutionContext = framework::ExecutionContext;
using Variable = framework::Variable;
template <typename T, template <typename T,
int MajorType = Eigen::RowMajor, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>

Loading…
Cancel
Save