Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_croplayer
commit
7c09999d57
@ -0,0 +1,84 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/tensor.h"
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
// EigenDim converts paddle::platform::DDim into Eigen::DSizes.
|
||||
template <int D>
|
||||
struct EigenDim {
|
||||
using Type = Eigen::DSizes<Eigen::DenseIndex, D>;
|
||||
|
||||
static Type From(const DDim& dims) {
|
||||
PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)");
|
||||
Type ret;
|
||||
for (int d = 0; d < arity(dims); d++) {
|
||||
ret[d] = dims[d];
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
// Interpret paddle::platform::Tensor as EigenTensor and EigenConstTensor.
|
||||
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
struct EigenTensor {
|
||||
// TODO(qijun) Now, default type in unaligned, and we will make a benchmark on
|
||||
// the speed of aligned and unaligned version in future.
|
||||
using Type = Eigen::TensorMap<Eigen::Tensor<T, D, MajorType, IndexType>>;
|
||||
|
||||
using ConstType =
|
||||
Eigen::TensorMap<Eigen::Tensor<const T, D, MajorType, IndexType>>;
|
||||
|
||||
static Type From(Tensor& tensor, DDim dims) {
|
||||
return Type(tensor.data<T>(), EigenDim<D>::From(dims));
|
||||
}
|
||||
|
||||
static Type From(Tensor& tensor) { return From(tensor, tensor.dims_); }
|
||||
|
||||
static ConstType From(const Tensor& tensor, DDim dims) {
|
||||
return ConstType(tensor.data<T>(), EigenDim<D>::From(dims));
|
||||
}
|
||||
|
||||
static ConstType From(const Tensor& tensor) {
|
||||
return From(tensor, tensor.dims_);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> {
|
||||
// Flatten is to reshape a Tensor into a one dimension EigenVector
|
||||
static typename EigenTensor<T, 1>::Type Flatten(Tensor& tensor) {
|
||||
return EigenTensor<T, 1>::From(
|
||||
tensor, make_ddim({static_cast<int>(product(tensor.dims_))}));
|
||||
}
|
||||
|
||||
static typename EigenTensor<T, 1>::ConstType Flatten(const Tensor& tensor) {
|
||||
return EigenTensor<T, 1>::From(
|
||||
tensor, make_ddim({static_cast<int>(product(tensor.dims_))}));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = EigenTensor<T, 2, MajorType, IndexType>;
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,101 @@
|
||||
/*
|
||||
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
TEST(EigenDim, From) {
|
||||
EigenDim<3>::Type ed = EigenDim<3>::From(make_ddim({1, 2, 3}));
|
||||
ASSERT_EQ(1, ed[0]);
|
||||
ASSERT_EQ(2, ed[1]);
|
||||
ASSERT_EQ(3, ed[2]);
|
||||
}
|
||||
|
||||
TEST(Eigen, Tensor) {
|
||||
Tensor t;
|
||||
float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
|
||||
for (int i = 0; i < 1 * 2 * 3; i++) {
|
||||
p[i] = static_cast<float>(i);
|
||||
}
|
||||
|
||||
EigenTensor<float, 3>::Type et = EigenTensor<float, 3>::From(t);
|
||||
|
||||
ASSERT_EQ(1, et.dimension(0));
|
||||
ASSERT_EQ(2, et.dimension(1));
|
||||
ASSERT_EQ(3, et.dimension(2));
|
||||
|
||||
for (int i = 0; i < 1; i++) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
for (int k = 0; k < 3; k++) {
|
||||
ASSERT_NEAR((i * 2 + j) * 3 + k, et(i, j, k), 1e-6f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Eigen, VectorFrom) {
|
||||
Tensor t;
|
||||
float* p = t.mutable_data<float>(make_ddim({6}), platform::CPUPlace());
|
||||
for (int i = 0; i < 6; i++) {
|
||||
p[i] = static_cast<float>(i);
|
||||
}
|
||||
|
||||
EigenVector<float>::Type ev = EigenVector<float>::From(t);
|
||||
|
||||
ASSERT_EQ(6, ev.dimension(0));
|
||||
|
||||
for (int i = 0; i < 6; i++) {
|
||||
ASSERT_NEAR(i, ev(i), 1e-6f);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Eigen, VectorFlatten) {
|
||||
Tensor t;
|
||||
float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
|
||||
for (int i = 0; i < 1 * 2 * 3; i++) {
|
||||
p[i] = static_cast<float>(i);
|
||||
}
|
||||
|
||||
EigenVector<float>::Type ev = EigenVector<float>::Flatten(t);
|
||||
|
||||
ASSERT_EQ(1 * 2 * 3, ev.dimension(0));
|
||||
|
||||
for (int i = 0; i < 1 * 2 * 3; i++) {
|
||||
ASSERT_NEAR(i, ev(i), 1e-6f);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Eigen, Matrix) {
|
||||
Tensor t;
|
||||
float* p = t.mutable_data<float>(make_ddim({2, 3}), platform::CPUPlace());
|
||||
for (int i = 0; i < 2 * 3; i++) {
|
||||
p[i] = static_cast<float>(i);
|
||||
}
|
||||
|
||||
EigenMatrix<float>::Type em = EigenMatrix<float>::From(t);
|
||||
|
||||
ASSERT_EQ(2, em.dimension(0));
|
||||
ASSERT_EQ(3, em.dimension(1));
|
||||
|
||||
for (int i = 0; i < 2; i++) {
|
||||
for (int j = 0; j < 3; j++) {
|
||||
ASSERT_NEAR(i * 3 + j, em(i, j), 1e-6f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -1,15 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/enforce.h"
|
@ -1,75 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <glog/logging.h>
|
||||
#include <paddle/string/printf.h>
|
||||
#include <exception>
|
||||
#include <sstream>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
/**
|
||||
* @brief Enforce exception. Inherits std::exception
|
||||
*
|
||||
* All enforce condition not met, will throw an EnforceNotMet exception.
|
||||
*/
|
||||
class EnforceNotMet : public std::exception {
|
||||
public:
|
||||
EnforceNotMet(const std::string& msg, const char* file, int fileline) {
|
||||
std::ostringstream sout;
|
||||
sout << msg << " at [" << file << ":" << fileline << "];";
|
||||
all_msg_ = sout.str();
|
||||
}
|
||||
|
||||
const char* what() const noexcept override { return all_msg_.c_str(); }
|
||||
|
||||
private:
|
||||
std::string all_msg_;
|
||||
};
|
||||
|
||||
// From https://stackoverflow.com/questions/30130930/
|
||||
// __buildin_expect is in C++ 11 standard. Since the condition which enforced
|
||||
// should be true in most situation, it will make the compiler generate faster
|
||||
// code by adding `UNLIKELY` macro.
|
||||
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
|
||||
|
||||
/**
|
||||
* @brief Throw a EnforceNotMet exception, automatically filled __FILE__ &
|
||||
* __LINE__
|
||||
*
|
||||
* This macro take __VA_ARGS__, user can pass any type if that type can
|
||||
* serialize to std::ostream
|
||||
*/
|
||||
#define PADDLE_THROW(...) \
|
||||
do { \
|
||||
throw ::paddle::framework::EnforceNotMet( \
|
||||
::paddle::string::Sprintf(__VA_ARGS__), __FILE__, __LINE__); \
|
||||
} while (0)
|
||||
|
||||
/**
|
||||
* @brief Enforce a condition, otherwise throw an EnforceNotMet
|
||||
*/
|
||||
#ifdef NDEBUG
|
||||
#define PADDLE_ENFORCE(condition, ...) \
|
||||
do { \
|
||||
if (UNLIKELY(!(condition))) { \
|
||||
PADDLE_THROW(__VA_ARGS__); \
|
||||
} \
|
||||
} while (0)
|
||||
#else
|
||||
#define PADDLE_ENFORCE(condition, ...) \
|
||||
CHECK(condition) << ::paddle::string::Sprintf(__VA_ARGS__);
|
||||
#endif
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -1,67 +0,0 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
// Helper to define Tensor types given that the scalar is of type T.
|
||||
template <typename T, int NDIMS = 1, typename IndexType = Eigen::DenseIndex>
|
||||
struct TTypes {
|
||||
// Rank-<NDIMS> tensor of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, NDIMS, Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned>
|
||||
Tensor;
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::Tensor<const T, NDIMS, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
||||
ConstTensor;
|
||||
|
||||
// Scalar tensor (implemented as a rank-0 tensor) of scalar type T.
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::TensorFixedSize<T, Eigen::Sizes<>, Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned>
|
||||
Scalar;
|
||||
typedef Eigen::TensorMap<Eigen::TensorFixedSize<const T, Eigen::Sizes<>,
|
||||
Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned>
|
||||
ConstScalar;
|
||||
|
||||
// Rank-1 tensor (vector) of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned>
|
||||
Flat;
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
||||
ConstFlat;
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned>
|
||||
Vec;
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::Tensor<const T, 1, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
||||
ConstVec;
|
||||
|
||||
// Rank-2 tensor (matrix) of scalar type T.
|
||||
typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor, IndexType>,
|
||||
Eigen::Aligned>
|
||||
Matrix;
|
||||
typedef Eigen::TensorMap<
|
||||
Eigen::Tensor<const T, 2, Eigen::RowMajor, IndexType>, Eigen::Aligned>
|
||||
ConstMatrix;
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,76 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/net.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class FullyConnectedOp : public framework::PlainNet {
|
||||
public:
|
||||
void Init() override {
|
||||
AddOp(framework::OpRegistry::CreateOp("mul",
|
||||
{
|
||||
Input("X"), Input("W"),
|
||||
},
|
||||
{Output("before_act")},
|
||||
{}));
|
||||
auto b = Input("b");
|
||||
if (b != framework::OperatorBase::EMPTY_VAR_NAME()) {
|
||||
AddOp(framework::OpRegistry::CreateOp("rowwise_add",
|
||||
{Output("before_act"), Input("b")},
|
||||
{Output("before_act")},
|
||||
{}));
|
||||
}
|
||||
|
||||
auto activation = GetAttr<std::string>("activation");
|
||||
AddOp(framework::OpRegistry::CreateOp(
|
||||
activation, {Output("before_act")}, {Output("Y")}, {}));
|
||||
CompleteAddOp(false);
|
||||
}
|
||||
};
|
||||
|
||||
class FullyConnectedOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
FullyConnectedOpMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", "the input of fc operator");
|
||||
AddInput("W", "the weight of fc operator");
|
||||
AddInput("b", "the bias of fc operator");
|
||||
|
||||
AddOutput("Y", "the output of fc operator");
|
||||
AddOutput(
|
||||
"before_act", "the before activation output of fc operator", true);
|
||||
AddAttr<std::string>("activation", "The activation key for fc layer")
|
||||
.SetDefault("sigmoid")
|
||||
.InEnum({"sigmoid", "softmax"});
|
||||
|
||||
//! TODO(yuyang18): Complete comment;
|
||||
AddComment("FullyConnected Operator");
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
USE_OP(mul);
|
||||
USE_OP(rowwise_add);
|
||||
USE_OP(sigmoid);
|
||||
USE_OP(softmax);
|
||||
|
||||
REGISTER_OP(fc,
|
||||
paddle::operators::FullyConnectedOp,
|
||||
paddle::operators::FullyConnectedOpMaker);
|
@ -0,0 +1,61 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/sgd_op.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/framework/tensor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class SGDOp : public framework::OperatorWithKernel {
|
||||
protected:
|
||||
void InferShape(
|
||||
const std::vector<const framework::Tensor *> &inputs,
|
||||
const std::vector<framework::Tensor *> &outputs) const override {
|
||||
PADDLE_ENFORCE(inputs.size() == 2, "Input size of SGDOp must be two");
|
||||
PADDLE_ENFORCE(outputs.size() == 1, "Output size of SGDOp must be one");
|
||||
PADDLE_ENFORCE(inputs[0] != nullptr, "inputs[0] mast be set");
|
||||
PADDLE_ENFORCE(inputs[1] != nullptr, "inputs[1] mast be set");
|
||||
PADDLE_ENFORCE(outputs[0] != nullptr, "outputs[0] mast be set");
|
||||
PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(),
|
||||
"Two input of SGD Op's dimension must be same.");
|
||||
outputs[0]->set_dims(inputs[0]->dims());
|
||||
}
|
||||
};
|
||||
|
||||
class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
SGDOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
||||
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("param", "input parameter");
|
||||
AddInput("grad", "input gradient");
|
||||
AddOutput("param_out", "output parameter");
|
||||
AddAttr<float>("learning_rate", "learning rate of sgd");
|
||||
AddComment(R"DOC(
|
||||
|
||||
Simplest sgd algorithm.
|
||||
|
||||
param_out = param - learning_rate * grad;
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_OP(sgd, paddle::operators::SGDOp, paddle::operators::SGDOpMaker);
|
||||
typedef paddle::operators::SGDOpKernel<::paddle::platform::CPUPlace, float>
|
||||
SGDOpKernel_CPU_float;
|
||||
REGISTER_OP_CPU_KERNEL(sgd, SGDOpKernel_CPU_float);
|
@ -0,0 +1,5 @@
|
||||
#include "paddle/operators/sgd_op.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
|
||||
typedef paddle::operators::SGDOpKernel<::paddle::platform::GPUPlace, float> SGDOpKernel_GPU_float;
|
||||
REGISTER_OP_GPU_KERNEL(sgd, SGDOpKernel_GPU_float);
|
@ -0,0 +1,42 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "glog/logging.h"
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename Place, typename T>
|
||||
class SGDOpKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::KernelContext& ctx) const override {
|
||||
auto param = ctx.Input("param")->Get<framework::Tensor>();
|
||||
auto grad = ctx.Input("grad")->Get<framework::Tensor>();
|
||||
auto* param_out = ctx.Output(0)->GetMutable<framework::Tensor>();
|
||||
float lr = ctx.op_.GetAttr<float>("learning_rate");
|
||||
|
||||
param_out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
framework::EigenVector<T>::Flatten(*param_out)
|
||||
.device(*(ctx.GetEigenDevice<Place>())) =
|
||||
framework::EigenVector<T>::Flatten(param) -
|
||||
lr * framework::EigenVector<T>::Flatten(grad);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,22 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <paddle/framework/op_registry.h>
|
||||
USE_OP(sgd);
|
||||
TEST(SGDOp, GetOpProto) {
|
||||
auto& protos = paddle::framework::OpRegistry::protos();
|
||||
auto it = protos.find("sgd");
|
||||
ASSERT_NE(it, protos.end());
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue