commit
						3e812383bc
					
				@ -0,0 +1,45 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/operators/mean_op.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
class MeanOp : public OperatorWithKernel {
 | 
				
			||||
protected:
 | 
				
			||||
  void InferShape(const InferShapeContext &ctx) const override {
 | 
				
			||||
    PADDLE_ENFORCE(ctx.InputSize() == 1, "Input size of AddOp must be one");
 | 
				
			||||
    PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one");
 | 
				
			||||
    PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.OutputVar(0) != nullptr,
 | 
				
			||||
                   "Input/Output of MeanOp must be initialized.");
 | 
				
			||||
    ctx.Output<Tensor>(0)->Resize(framework::make_ddim({1}));
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
class MeanOpMaker : public OpProtoAndCheckerMaker {
 | 
				
			||||
public:
 | 
				
			||||
  MeanOpMaker(OpProto *proto, OpAttrChecker *op_checker)
 | 
				
			||||
      : OpProtoAndCheckerMaker(proto, op_checker) {
 | 
				
			||||
    AddInput("X", "The input of mean op");
 | 
				
			||||
    AddOutput("Out", "The output of mean op");
 | 
				
			||||
    AddComment("Mean Operator");
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker);
 | 
				
			||||
REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel<ops::CPUPlace, float>);
 | 
				
			||||
@ -0,0 +1,5 @@
 | 
				
			||||
#define EIGEN_USE_GPU
 | 
				
			||||
 | 
				
			||||
#include "paddle/operators/mean_op.h"
 | 
				
			||||
 | 
				
			||||
REGISTER_OP_GPU_KERNEL(mean, ops::MeanKernel<ops::GPUPlace, float>);
 | 
				
			||||
@ -0,0 +1,36 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
#include "paddle/operators/type_alias.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
template <typename Place, typename T>
 | 
				
			||||
class MeanKernel : public OpKernel {
 | 
				
			||||
public:
 | 
				
			||||
  void Compute(const ExecutionContext& context) const override {
 | 
				
			||||
    auto input = context.Input<Tensor>(0);
 | 
				
			||||
    auto output = context.Output<Tensor>(0);
 | 
				
			||||
 | 
				
			||||
    output->mutable_data<T>(context.GetPlace());
 | 
				
			||||
 | 
				
			||||
    EigenScalar<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
 | 
				
			||||
        EigenVector<T>::Flatten(*input).mean();
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,25 @@
 | 
				
			||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include <gtest/gtest.h>
 | 
				
			||||
 | 
				
			||||
#include <paddle/framework/op_registry.h>
 | 
				
			||||
 | 
				
			||||
USE_OP(mean);
 | 
				
			||||
 | 
				
			||||
TEST(MeanOp, GetOpProto) {
 | 
				
			||||
  auto& protos = paddle::framework::OpRegistry::protos();
 | 
				
			||||
  auto it = protos.find("mean");
 | 
				
			||||
  ASSERT_NE(it, protos.end());
 | 
				
			||||
}
 | 
				
			||||
@ -1,2 +1,9 @@
 | 
				
			||||
cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python
 | 
				
			||||
        add_op fc_op sgd_op cross_entropy_op recurrent_network_op)
 | 
				
			||||
cc_library(paddle_pybind SHARED
 | 
				
			||||
    SRCS pybind.cc
 | 
				
			||||
    DEPS pybind python
 | 
				
			||||
	fc_op
 | 
				
			||||
	sgd_op
 | 
				
			||||
	add_op
 | 
				
			||||
	mean_op
 | 
				
			||||
	cross_entropy_op
 | 
				
			||||
	recurrent_network_op)
 | 
				
			||||
 | 
				
			||||
@ -0,0 +1,16 @@
 | 
				
			||||
import unittest
 | 
				
			||||
from op_test_util import OpTestMeta
 | 
				
			||||
import numpy as np
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestMeanOp(unittest.TestCase):
 | 
				
			||||
    __metaclass__ = OpTestMeta
 | 
				
			||||
 | 
				
			||||
    def setUp(self):
 | 
				
			||||
        self.type = "mean"
 | 
				
			||||
        self.X = np.random.random((32, 784)).astype("float32")
 | 
				
			||||
        self.Out = np.mean(self.X)
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
if __name__ == '__main__':
 | 
				
			||||
    unittest.main()
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue