From d3f219aa9911015bd8c4a1316b85620a07eb9f49 Mon Sep 17 00:00:00 2001
From: Yu Yang <yuyang18@baidu.com>
Date: Mon, 21 Aug 2017 18:09:17 +0800
Subject: [PATCH] Change IdentityOp to ScaleOp

---
 paddle/framework/CMakeLists.txt               |   2 +-
 paddle/framework/pybind.cc                    |   3 +-
 paddle/framework/tensor.h                     |   5 +-
 paddle/operators/CMakeLists.txt               |   2 +-
 paddle/operators/identity_op.cc               |  71 ------------
 paddle/operators/net_op.cc                    |   9 +-
 paddle/operators/scale_op.cc                  | 102 ++++++++++++++++++
 .../operators/{identity_op.cu => scale_op.cu} |   5 +-
 .../operators/{identity_op.h => scale_op.h}   |  16 ++-
 .../paddle/v2/framework/tests/CMakeLists.txt  |   2 +-
 .../v2/framework/tests/gradient_checker.py    |   7 +-
 ...ty_op.py => test_scale_and_identity_op.py} |  19 ++++
 12 files changed, 158 insertions(+), 85 deletions(-)
 delete mode 100644 paddle/operators/identity_op.cc
 create mode 100644 paddle/operators/scale_op.cc
 rename paddle/operators/{identity_op.cu => scale_op.cu} (81%)
 rename paddle/operators/{identity_op.h => scale_op.h} (66%)
 rename python/paddle/v2/framework/tests/{test_identity_op.py => test_scale_and_identity_op.py} (51%)

diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt
index f249512f47..5df14ae78d 100644
--- a/paddle/framework/CMakeLists.txt
+++ b/paddle/framework/CMakeLists.txt
@@ -56,5 +56,5 @@ cc_library(paddle_pybind SHARED
     uniform_random_op
     gaussian_random_op
     fill_zeros_like_op
-    identity_op)
+    scale_op)
 endif(WITH_PYTHON)
diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc
index ddb244623f..3aaf0de150 100644
--- a/paddle/framework/pybind.cc
+++ b/paddle/framework/pybind.cc
@@ -42,7 +42,8 @@ USE_OP(fill_zeros_like);
 USE_OP_ITSELF(recurrent_op);
 USE_OP(gaussian_random);
 USE_OP(uniform_random);
-USE_OP(identity);
+USE_OP(scale);
+USE_OP_ITSELF(identity);
 
 namespace paddle {
 namespace framework {
diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h
index b8c779f4e5..643f875491 100644
--- a/paddle/framework/tensor.h
+++ b/paddle/framework/tensor.h
@@ -105,7 +105,10 @@ class Tensor {
   template <typename T>
   inline Tensor Slice(const int& begin_idx, const int& end_idx) const;
 
-  platform::Place place() const { return holder_->place(); }
+  platform::Place place() const {
+    PADDLE_ENFORCE_NOT_NULL(holder_, "Tensor get place() must contains holder");
+    return holder_->place();
+  }
 
  private:
   template <typename T>
diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt
index 20e562c7d3..0ba598823b 100644
--- a/paddle/operators/CMakeLists.txt
+++ b/paddle/operators/CMakeLists.txt
@@ -68,4 +68,4 @@ op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
     DEPS framework_proto tensor op_registry operator net_op)
 op_library(uniform_random_op
         SRCS uniform_random_op.cc uniform_random_op.cu)
-op_library(identity_op SRCS identity_op.cc identity_op.cu DEPS net_op)
+op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op)
diff --git a/paddle/operators/identity_op.cc b/paddle/operators/identity_op.cc
deleted file mode 100644
index cac44020bc..0000000000
--- a/paddle/operators/identity_op.cc
+++ /dev/null
@@ -1,71 +0,0 @@
-/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
-
-   Licensed under the Apache License, Version 2.0 (the "License");
-   you may not use this file except in compliance with the License.
-   You may obtain a copy of the License at
-
-   http://www.apache.org/licenses/LICENSE-2.0
-
-   Unless required by applicable law or agreed to in writing, software
-   distributed under the License is distributed on an "AS IS" BASIS,
-   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-   See the License for the specific language governing permissions and
-   limitations under the License. */
-
-#include "paddle/operators/identity_op.h"
-#include "paddle/operators/net_op.h"
-
-namespace paddle {
-namespace operators {
-
-class IdentityOp : public framework::OperatorWithKernel {
- public:
-  IdentityOp(const std::string &type, const VarNameMap &inputs,
-             const VarNameMap &outputs, const framework::AttributeMap &attrs)
-      : OperatorWithKernel(type, inputs, outputs, attrs) {}
-
- protected:
-  void InferShape(const framework::InferShapeContext &ctx) const override {
-    auto *in = ctx.Input<framework::Tensor>("X");
-    auto *out = ctx.Output<framework::Tensor>("Out");
-    out->Resize(in->dims());
-  }
-};
-
-class IdentityOpMaker : public framework::OpProtoAndCheckerMaker {
- public:
-  IdentityOpMaker(framework::OpProto *proto,
-                  framework::OpAttrChecker *op_checker)
-      : OpProtoAndCheckerMaker(proto, op_checker) {
-    AddInput("X", "The input tensor of identity operator.").NotInGradient();
-    AddOutput("Out", "The output tensor of identity operator.").NotInGradient();
-    AddComment(R"DOC(Identity operator
-
-The equation is: Out = X
-)DOC");
-  }
-};
-
-// Identity Op's gradient is identity op, too.
-// Grad(Out=identity_op(X)) => Grad(X) = identity_op(Grad(Out))
-class IdentityGradOp : public NetOp {
- public:
-  IdentityGradOp(const std::string &type, const VarNameMap &inputs,
-                 const VarNameMap &outputs,
-                 const framework::AttributeMap &attrs)
-      : NetOp(type, inputs, outputs, attrs) {
-    AddOp(framework::OpRegistry::CreateOp(
-        "identity", {{"X", {Input(framework::GradVarName("Out"))}}},
-        {{"Out", {Output(framework::GradVarName("X"))}}}, {}));
-    CompleteAddOp(false);
-  }
-};
-
-}  // namespace operators
-}  // namespace paddle
-
-namespace ops = paddle::operators;
-
-REGISTER_OP(identity, ops::IdentityOp, ops::IdentityOpMaker, identity_grad,
-            ops::IdentityGradOp);
-REGISTER_OP_CPU_KERNEL(identity, ops::IdentityKernel<float>);
diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc
index a7d7105110..7e3779ed2e 100644
--- a/paddle/operators/net_op.cc
+++ b/paddle/operators/net_op.cc
@@ -68,10 +68,15 @@ std::string NetOp::DebugString() const {
 bool NetOp::IsNetOp() const { return true; }
 
 std::vector<std::string> NetOp::OutputVars(bool has_intermediate) const {
+  std::vector<std::string> all;
+  for (auto& pair : this->outputs_) {
+    for (auto& var_name : pair.second) {
+      all.push_back(var_name);
+    }
+  }
   if (has_intermediate) {
-    return this->outputs_.at(kAll);
+    return all;
   }
-  auto& all = this->outputs_.at(kAll);
   std::vector<std::string> ret_val;
   for (auto& each : all) {
     if (!Contains(intermediate_outputs_, each)) {
diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc
new file mode 100644
index 0000000000..3b18ff078e
--- /dev/null
+++ b/paddle/operators/scale_op.cc
@@ -0,0 +1,102 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+   Licensed under the Apache License, Version 2.0 (the "License");
+   you may not use this file except in compliance with the License.
+   You may obtain a copy of the License at
+
+   http://www.apache.org/licenses/LICENSE-2.0
+
+   Unless required by applicable law or agreed to in writing, software
+   distributed under the License is distributed on an "AS IS" BASIS,
+   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+   See the License for the specific language governing permissions and
+   limitations under the License. */
+
+#include "paddle/operators/scale_op.h"
+#include "paddle/operators/net_op.h"
+
+namespace paddle {
+namespace operators {
+
+class ScaleOp : public framework::OperatorWithKernel {
+ public:
+  ScaleOp(const std::string &type, const VarNameMap &inputs,
+          const VarNameMap &outputs, const framework::AttributeMap &attrs)
+      : OperatorWithKernel(type, inputs, outputs, attrs) {}
+
+ protected:
+  void InferShape(const framework::InferShapeContext &ctx) const override {
+    auto *in = ctx.Input<framework::Tensor>("X");
+    auto *out = ctx.Output<framework::Tensor>("Out");
+    out->Resize(in->dims());
+  }
+};
+
+template <typename AttrType>
+class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+      : OpProtoAndCheckerMaker(proto, op_checker) {
+    AddInput("X", "The input tensor of scale operator.").NotInGradient();
+    AddOutput("Out", "The output tensor of scale operator.").NotInGradient();
+    AddComment(R"DOC(Scale operator
+
+The equation is: Out = scale*X
+)DOC");
+    AddAttr<AttrType>("scale", "scale of scale operator.").SetDefault(1.0);
+  }
+};
+
+// Identity Op's gradient is identity op, too.
+// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out))
+template <typename AttrType>
+class ScaleGradOp : public NetOp {
+ public:
+  ScaleGradOp(const std::string &type, const VarNameMap &inputs,
+              const VarNameMap &outputs, const framework::AttributeMap &attrs)
+      : NetOp(type, inputs, outputs, attrs) {
+    AddOp(framework::OpRegistry::CreateOp(
+        "scale", {{"X", {Input(framework::GradVarName("Out"))}}},
+        {{"Out", {Output(framework::GradVarName("X"))}}},
+        {{"scale", GetAttr<AttrType>("scale")}}));
+    CompleteAddOp(false);
+  }
+};
+
+// identity is a alias of scale op. This is also a example for creating a alias
+// operator.
+template <typename AttrType>
+class IdentityOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+  IdentityOpMaker(framework::OpProto *proto,
+                  framework::OpAttrChecker *op_checker)
+      : OpProtoAndCheckerMaker(proto, op_checker) {
+    AddInput("X", "input tensor of identity op");
+    AddOutput("Out", "output tensor of identity op");
+    AddComment("identity operator. Just a alias of scale op which scale = 1.0");
+  }
+};
+
+template <typename AttrType>
+class IdentityOp : public NetOp {
+ public:
+  IdentityOp(const std::string &type, const VarNameMap &inputs,
+             const VarNameMap &outputs, const framework::AttributeMap &attrs)
+      : NetOp(type, inputs, outputs, attrs) {
+    AddOp(framework::OpRegistry::CreateOp(
+        "scale", {{"X", {Input("X")}}}, {{"Out", {Output("Out")}}},
+        {{"scale", static_cast<AttrType>(1)}}));
+  }
+};
+
+}  // namespace operators
+}  // namespace paddle
+
+namespace ops = paddle::operators;
+
+REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker<float>, scale_grad,
+            ops::ScaleGradOp<float>);
+REGISTER_OP_CPU_KERNEL(scale,
+                       ops::ScaleKernel<paddle::platform::CPUPlace, float>);
+REGISTER_OP_WITHOUT_GRADIENT(identity, ops::IdentityOp<float>,
+                             ops::IdentityOpMaker<float>);
diff --git a/paddle/operators/identity_op.cu b/paddle/operators/scale_op.cu
similarity index 81%
rename from paddle/operators/identity_op.cu
rename to paddle/operators/scale_op.cu
index 3053104bbe..63efbe0da8 100644
--- a/paddle/operators/identity_op.cu
+++ b/paddle/operators/scale_op.cu
@@ -12,6 +12,7 @@
    See the License for the specific language governing permissions and
    limitations under the License. */
 
-#include "paddle/operators/identity_op.h"
+#include "paddle/operators/scale_op.h"
 
-REGISTER_OP_GPU_KERNEL(identity, paddle::operators::IdentityKernel<float>);
+REGISTER_OP_GPU_KERNEL(
+    scale, paddle::operators::ScaleKernel<paddle::platform::GPUPlace, float>);
diff --git a/paddle/operators/identity_op.h b/paddle/operators/scale_op.h
similarity index 66%
rename from paddle/operators/identity_op.h
rename to paddle/operators/scale_op.h
index 14a832257b..aea64f1b04 100644
--- a/paddle/operators/identity_op.h
+++ b/paddle/operators/scale_op.h
@@ -14,17 +14,25 @@
 
 #pragma once
 
+#include "paddle/framework/eigen.h"
 #include "paddle/framework/op_registry.h"
-#include "paddle/memory/memcpy.h"
+
 namespace paddle {
 namespace operators {
-template <typename T>
-class IdentityKernel : public framework::OpKernel {
+template <typename Place, typename T, typename AttrType = T>
+class ScaleKernel : public framework::OpKernel {
  public:
   virtual void Compute(const framework::ExecutionContext& context) const {
     auto* tensor = context.Output<framework::Tensor>("Out");
     auto* in = context.Input<framework::Tensor>("X");
-    tensor->CopyFrom<T>(*in, in->place());
+    tensor->mutable_data<T>(in->place());
+
+    auto scale = static_cast<T>(context.op_.GetAttr<AttrType>("scale"));
+
+    auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
+    auto eigen_in = framework::EigenVector<T>::Flatten(*in);
+    auto& dev = context.GetEigenDevice<Place>();
+    eigen_out.device(dev) = scale * eigen_in;
   }
 };
 
diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt
index cf7baa5556..0e8811bfe7 100644
--- a/python/paddle/v2/framework/tests/CMakeLists.txt
+++ b/python/paddle/v2/framework/tests/CMakeLists.txt
@@ -27,4 +27,4 @@ py_test(test_uniform_random_op SRCS test_uniform_random_op.py)
 py_test(test_recurrent_op SRCS test_recurrent_op.py)
 py_test(test_sgd_op SRCS test_sgd_op.py)
 py_test(test_gradient_checker SRCS test_gradient_checker.py)
-py_test(test_identity_op SRCS test_identity_op.py)
+py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py)
diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py
index 8b8e2f444b..c22c6f8831 100644
--- a/python/paddle/v2/framework/tests/gradient_checker.py
+++ b/python/paddle/v2/framework/tests/gradient_checker.py
@@ -160,8 +160,13 @@ class GradientChecker(unittest.TestCase):
             grad_tensor.set(data, place)
 
         # run backward op
-        for name in backward_op.outputs():
+        backward_outs = backward_op.outputs()
+        backward_names = [
+            item for key in backward_outs for item in backward_outs[key]
+        ]
+        for name in backward_names:
             scope.new_var(name)
+
         backward_op.infer_shape(scope)
         backward_op.run(scope, ctx)
 
diff --git a/python/paddle/v2/framework/tests/test_identity_op.py b/python/paddle/v2/framework/tests/test_scale_and_identity_op.py
similarity index 51%
rename from python/paddle/v2/framework/tests/test_identity_op.py
rename to python/paddle/v2/framework/tests/test_scale_and_identity_op.py
index 181d9c0c21..69b301c376 100644
--- a/python/paddle/v2/framework/tests/test_identity_op.py
+++ b/python/paddle/v2/framework/tests/test_scale_and_identity_op.py
@@ -2,6 +2,7 @@ import unittest
 from op_test_util import OpTestMeta
 from gradient_checker import GradientChecker, create_op
 import numpy as np
+from paddle.v2.framework.op import Operator
 
 
 class IdentityTest(unittest.TestCase):
@@ -20,5 +21,23 @@ class IdentityGradOpTest(GradientChecker):
         self.check_grad(op, inputs, set("X"), "Out")
 
 
+class ScaleTest(unittest.TestCase):
+    __metaclass__ = OpTestMeta
+
+    def setUp(self):
+        self.type = "scale"
+        self.inputs = {'X': np.random.random((32, 784)).astype("float32")}
+        self.attrs = {'scale': -2.3}
+        self.outputs = {'Out': self.inputs['X'] * self.attrs['scale']}
+
+
+class ScaleGradTest(GradientChecker):
+    def test_normal(self):
+        op = Operator("scale", X="X", Out="Out", scale=3.2)
+        self.check_grad(op,
+                        {"X": np.random.random((10, 10)).astype("float32")},
+                        set("X"), "Out")
+
+
 if __name__ == '__main__':
     unittest.main()