commit
2f12256186
@ -0,0 +1,61 @@
|
||||
# Design Doc: ProgramDesc
|
||||
|
||||
The basic structure of a PaddlePaddle program is some nested blocks, as a C++ or Java program.
|
||||
|
||||
As described in [graph.md](./graph.md), the first five lines of the following PaddlePaddle program
|
||||
|
||||
```python
|
||||
x = layer.data("images")
|
||||
l = layer.data("label")
|
||||
y = layer.fc(x)
|
||||
cost = layer.mse(y, l)
|
||||
optimize(cost)
|
||||
train(cost, reader=mnist.train())
|
||||
```
|
||||
|
||||
generates, or compiles, a PaddelPaddle program, which is represented by the following protobuf message:
|
||||
|
||||
```protobuf
|
||||
message ProgramDesc {
|
||||
repeated BlockDesc blocks = 1;
|
||||
}
|
||||
|
||||
message BlockDesc {
|
||||
required int32 parent = 1;
|
||||
repeated VarDesc vars = 2;
|
||||
repeated OpDesc ops = 3;
|
||||
}
|
||||
|
||||
message OpDesc {
|
||||
AttrDesc attrs = 1;
|
||||
...
|
||||
}
|
||||
|
||||
message AttrDesc {
|
||||
required AttrType type = 1;
|
||||
|
||||
// index into ProgramDesc::blocks when type==BLOCK
|
||||
optional int32 block = 2;
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
When each of the first five lines runs, related Python function, e.g., `layer.fc`, calls C++ InferShape functions. This InferShape function needs to access the properties of VarDesc's accessed by the current OpDesc. These VarDesc's might not be defined in the current block, but in some ancestor blocks. This requires that we can trace the parent of a block.
|
||||
|
||||
A nested block is often an attribute of an operator, most likely, an IfElseOp or a WhileOp. In above solution, all blocks are in `ProgramDesc::blocks`, this implicitly assigns a zero-based ID to each block -- the index of the block in `ProgramDesc::blocks`. So that `AttrDesc::block` could be an integer block ID.
|
||||
|
||||
With this design, the InferShape function should take the following parameters:
|
||||
|
||||
```c++
|
||||
void InferShape(int current_block,
|
||||
int current_operator,
|
||||
ProgramDesc* program // might change VarDesc values.
|
||||
) {
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
where
|
||||
|
||||
- `current_block` indices into `ProgramDesc::blocks`,
|
||||
- `current_operator` indices into `BlockDesc::ops`.
|
@ -0,0 +1,58 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/op_proto_maker.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
void OpProtoAndCheckerMaker::Validate() {
|
||||
validated_ = true;
|
||||
CheckNoDuplicatedInOutAttrs();
|
||||
}
|
||||
|
||||
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput(
|
||||
const std::string& name, const std::string& comment) {
|
||||
auto* input = proto_->add_inputs();
|
||||
input->set_name(name);
|
||||
input->set_comment(comment);
|
||||
return OpProtoAndCheckerMaker::VariableBuilder{input};
|
||||
}
|
||||
|
||||
OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
|
||||
const std::string& name, const std::string& comment) {
|
||||
auto* output = proto_->add_outputs();
|
||||
output->set_name(name);
|
||||
output->set_comment(comment);
|
||||
return OpProtoAndCheckerMaker::VariableBuilder{output};
|
||||
}
|
||||
|
||||
void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
|
||||
std::unordered_set<std::string> names;
|
||||
auto checker = [&](const std::string& name) {
|
||||
PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name);
|
||||
names.insert(name);
|
||||
};
|
||||
for (auto& attr : proto_->attrs()) {
|
||||
checker(attr.name());
|
||||
}
|
||||
for (auto& input : proto_->inputs()) {
|
||||
checker(input.name());
|
||||
}
|
||||
for (auto& output : proto_->outputs()) {
|
||||
checker(output.name());
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,88 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/attribute.h"
|
||||
#include "paddle/framework/framework.pb.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
// this class not only make proto but also init attribute checkers.
|
||||
class OpProtoAndCheckerMaker {
|
||||
public:
|
||||
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
|
||||
: proto_(proto), op_checker_(op_checker) {}
|
||||
|
||||
virtual ~OpProtoAndCheckerMaker() {
|
||||
PADDLE_ENFORCE(validated_, "should call Validate after build");
|
||||
}
|
||||
|
||||
void Validate();
|
||||
|
||||
protected:
|
||||
struct VariableBuilder {
|
||||
OpProto::Var* var_;
|
||||
|
||||
VariableBuilder& AsDuplicable() {
|
||||
var_->set_duplicable(true);
|
||||
return *this;
|
||||
}
|
||||
|
||||
VariableBuilder& AsIntermediate() {
|
||||
var_->set_intermediate(true);
|
||||
return *this;
|
||||
}
|
||||
|
||||
VariableBuilder& NotInGradient() {
|
||||
var_->set_not_in_gradient(true);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
VariableBuilder AddInput(const std::string& name, const std::string& comment);
|
||||
|
||||
VariableBuilder AddOutput(const std::string& name,
|
||||
const std::string& comment);
|
||||
|
||||
template <typename T>
|
||||
TypedAttrChecker<T>& AddAttr(const std::string& name,
|
||||
const std::string& comment,
|
||||
bool generated = false) {
|
||||
auto* attr = proto_->add_attrs();
|
||||
attr->set_name(name);
|
||||
attr->set_comment(comment);
|
||||
attr->set_generated(generated);
|
||||
attr->set_type(AttrTypeID<T>());
|
||||
return op_checker_->AddAttrChecker<T>(name);
|
||||
}
|
||||
|
||||
void AddComment(const std::string& comment) { proto_->set_comment(comment); }
|
||||
|
||||
private:
|
||||
void CheckNoDuplicatedInOutAttrs();
|
||||
|
||||
OpProto* proto_;
|
||||
OpAttrChecker* op_checker_;
|
||||
bool validated_{false};
|
||||
};
|
||||
|
||||
class NOPMaker : public OpProtoAndCheckerMaker {
|
||||
public:
|
||||
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {}
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,51 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/op_proto_maker.h"
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
TestAttrProtoMaker(paddle::framework::OpProto* proto,
|
||||
paddle::framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddAttr<float>("scale", "scale of test op");
|
||||
AddAttr<float>("scale", "scale of test op");
|
||||
}
|
||||
};
|
||||
|
||||
TEST(ProtoMaker, DuplicatedAttr) {
|
||||
paddle::framework::OpProto op_proto;
|
||||
paddle::framework::OpAttrChecker op_checker;
|
||||
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
|
||||
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
|
||||
}
|
||||
|
||||
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
TestInOutProtoMaker(paddle::framework::OpProto* proto,
|
||||
paddle::framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("input", "input of test op");
|
||||
AddInput("input", "input of test op");
|
||||
}
|
||||
};
|
||||
|
||||
TEST(ProtoMaker, DuplicatedInOut) {
|
||||
paddle::framework::OpProto op_proto;
|
||||
paddle::framework::OpAttrChecker op_checker;
|
||||
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
|
||||
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,100 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "paddle/operators/activation_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(sigmoid,
|
||||
ops::ActivationKernel<paddle::platform::GPUPlace, float,
|
||||
ops::SigmoidFunctor<float>>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
sigmoid_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
|
||||
ops::SigmoidGradFunctor<float>>);
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
exp,
|
||||
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::ExpFunctor>);
|
||||
REGISTER_OP_GPU_KERNEL(exp_grad,
|
||||
ops::ActivationGradKernel<paddle::platform::GPUPlace,
|
||||
float, ops::ExpGradFunctor>);
|
||||
REGISTER_OP_GPU_KERNEL(relu,
|
||||
ops::ActivationKernel<paddle::platform::GPUPlace, float,
|
||||
ops::ReluFunctor<float>>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
relu_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
|
||||
ops::ReluGradFunctor<float>>);
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
tanh,
|
||||
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::TanhFunctor>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
tanh_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
|
||||
ops::TanhGradFunctor<float>>);
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
sqrt,
|
||||
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::SqrtFunctor>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
sqrt_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
|
||||
ops::SqrtGradFunctor<float>>);
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
abs,
|
||||
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::AbsFunctor>);
|
||||
REGISTER_OP_GPU_KERNEL(abs_grad,
|
||||
ops::ActivationGradKernel<paddle::platform::GPUPlace,
|
||||
float, ops::AbsGradFunctor>);
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(reciprocal,
|
||||
ops::ActivationKernel<paddle::platform::GPUPlace, float,
|
||||
ops::ReciprocalFunctor<float>>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
reciprocal_grad,
|
||||
ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
|
||||
ops::ReciprocalGradFunctor<float>>);
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
log,
|
||||
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::LogFunctor>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
log_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
|
||||
ops::LogGradFunctor<float>>);
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(square,
|
||||
ops::ActivationKernel<paddle::platform::GPUPlace, float,
|
||||
ops::SquareFunctor>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
square_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
|
||||
ops::SquareGradFunctor<float>>);
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(brelu,
|
||||
ops::BReluKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(brelu_grad,
|
||||
ops::BReluGradKernel<paddle::platform::GPUPlace, float>);
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(soft_relu,
|
||||
ops::SoftReluKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
soft_relu_grad, ops::SoftReluGradKernel<paddle::platform::GPUPlace, float>);
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(pow, ops::PowKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(pow_grad,
|
||||
ops::PowGradKernel<paddle::platform::GPUPlace, float>);
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(stanh,
|
||||
ops::STanhKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(stanh_grad,
|
||||
ops::STanhGradKernel<paddle::platform::GPUPlace, float>);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,86 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/clip_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::LoDTensor;
|
||||
|
||||
class ClipOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
|
||||
"Input(X) of ClipOp should not be null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
|
||||
"Output(Out) of ClipOp should not be null.");
|
||||
auto x_dims = ctx.Input<LoDTensor>("X")->dims();
|
||||
auto max = Attr<float>("max");
|
||||
auto min = Attr<float>("min");
|
||||
PADDLE_ENFORCE_LT(min, max, "max should be greater than min.");
|
||||
ctx.Output<LoDTensor>("Out")->Resize(x_dims);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AttrType>
|
||||
class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
ClipOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X",
|
||||
"(Tensor)The input of clip op."
|
||||
"The input should be a k-D tensor(k > 0 and k < 7)");
|
||||
AddOutput("Out", "(Tensor)The output of clip op with shape as input(X)");
|
||||
AddAttr<AttrType>(
|
||||
"min", "(float)Minimum value, under which element is replaced by min.");
|
||||
AddAttr<AttrType>(
|
||||
"max", "(float)Maximum value, above which element is replaced by max");
|
||||
AddComment(R"DOC(
|
||||
Clip operator limits the given input within an interval. The interval is
|
||||
specified with arguments 'min' and 'max'.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class ClipOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) should not be null");
|
||||
auto x_dims = ctx.Input<LoDTensor>("X")->dims();
|
||||
auto *x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
|
||||
if (x_grad != nullptr) {
|
||||
x_grad->Resize(x_dims);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(clip, ops::ClipOp, ops::ClipOpMaker<float>, clip_grad,
|
||||
ops::ClipOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(clip,
|
||||
ops::ClipKernel<paddle::platform::CPUPlace, float>);
|
||||
REGISTER_OP_CPU_KERNEL(clip_grad,
|
||||
ops::ClipGradKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,21 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/clip_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(clip,
|
||||
ops::ClipKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(clip_grad,
|
||||
ops::ClipGradKernel<paddle::platform::GPUPlace, float>);
|
@ -0,0 +1,97 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/platform/transform.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
using platform::Transform;
|
||||
|
||||
template <typename T>
|
||||
class ClipFunctor {
|
||||
public:
|
||||
explicit ClipFunctor(const T min, const T max) : min_(min), max_(max) {}
|
||||
HOSTDEVICE T operator()(const T& x) const {
|
||||
if (x < min_)
|
||||
return min_;
|
||||
else if (x > max_)
|
||||
return max_;
|
||||
else
|
||||
return x;
|
||||
}
|
||||
|
||||
private:
|
||||
T min_;
|
||||
T max_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ClipGradFunctor {
|
||||
public:
|
||||
explicit ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {}
|
||||
HOSTDEVICE T operator()(const T& x, const T& y) const {
|
||||
return (y > min_ && y < max_) ? x : 0;
|
||||
}
|
||||
|
||||
private:
|
||||
T min_;
|
||||
T max_;
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class ClipKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto max = context.Attr<T>("max");
|
||||
auto min = context.Attr<T>("min");
|
||||
auto* x = context.Input<Tensor>("X");
|
||||
auto* out = context.Output<Tensor>("Out");
|
||||
T* out_data = out->mutable_data<T>(context.GetPlace());
|
||||
const T* x_data = x->data<T>();
|
||||
int64_t numel = x->numel();
|
||||
Transform<Place> trans;
|
||||
trans(context.device_context(), x_data, x_data + numel, out_data,
|
||||
ClipFunctor<T>(min, max));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class ClipGradKernel : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto max = context.Attr<T>("max");
|
||||
auto min = context.Attr<T>("min");
|
||||
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
|
||||
if (d_x != nullptr) {
|
||||
auto* x = context.Input<Tensor>("X");
|
||||
int64_t numel = d_out->numel();
|
||||
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
|
||||
const T* d_out_data = d_out->data<T>();
|
||||
const T* x_data = x->data<T>();
|
||||
Transform<Place> trans;
|
||||
trans(context.device_context(), d_out_data, d_out_data + numel, x_data,
|
||||
d_x_data, ClipGradFunctor<T>(min, max));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,133 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/gemm_conv2d_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
int outputSize(int input_size, int filter_size, int padding, int stride) {
|
||||
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
|
||||
return output_size;
|
||||
}
|
||||
|
||||
class Conv2DOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Input"),
|
||||
"Input(Input) of Conv2DOp should not be null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Filter"),
|
||||
"Input(Filter) of Conv2DOp should not be null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Output"),
|
||||
"Output(Output) of Conv2DOp should not be null.");
|
||||
|
||||
auto in = ctx.Input<Tensor>("Input");
|
||||
auto filter = ctx.Input<Tensor>("Filter");
|
||||
auto out = ctx.Output<framework::LoDTensor>("Output");
|
||||
std::vector<int> strides = Attr<std::vector<int>>("strides");
|
||||
std::vector<int> paddings = Attr<std::vector<int>>("paddings");
|
||||
int groups = Attr<int>("groups");
|
||||
int input_channels = in->dims()[1];
|
||||
int output_channels = filter->dims()[0];
|
||||
|
||||
PADDLE_ENFORCE_EQ(in->dims().size(), 4, "Conv2DOp input should be 4-D.");
|
||||
PADDLE_ENFORCE_EQ(filter->dims().size(), 4,
|
||||
"Conv2DOp filter should be 4-D.");
|
||||
PADDLE_ENFORCE_EQ(input_channels, filter->dims()[1] * groups,
|
||||
"The number of input channels should be equal to filter "
|
||||
"channels * groups.");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
output_channels % groups, 0,
|
||||
"The number of output channels should be divided by groups.");
|
||||
|
||||
auto output_height =
|
||||
outputSize(in->dims()[2], filter->dims()[2], paddings[0], strides[0]);
|
||||
auto output_width =
|
||||
outputSize(in->dims()[3], filter->dims()[3], paddings[1], strides[1]);
|
||||
out->Resize(
|
||||
{in->dims()[0], filter->dims()[0], output_height, output_width});
|
||||
}
|
||||
};
|
||||
|
||||
class Conv2DOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
Conv2DOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput(
|
||||
"Input",
|
||||
"The input tensor of convolution operator. "
|
||||
"The format of input tensor is NCHW. Where N is batch size, C is the "
|
||||
"number of channels, H and W is the height and width of image.");
|
||||
AddInput(
|
||||
"Filter",
|
||||
"The filter tensor of convolution operator."
|
||||
"The format of the filter tensor is MCHW, where M is the number of "
|
||||
"output image channels, C is the number of input image channels, "
|
||||
"H and W is height and width of filter. "
|
||||
"If the groups attribute is greater than 1, C equal the number of "
|
||||
"input image channels divided by the groups.");
|
||||
AddOutput("Output",
|
||||
"The output tensor of convolution operator."
|
||||
"The format of output tensor is also NCHW.");
|
||||
AddAttr<std::vector<int>>("strides", "strides of convolution operator.")
|
||||
.SetDefault({1, 1});
|
||||
AddAttr<std::vector<int>>("paddings", "paddings of convolution operator.")
|
||||
.SetDefault({0, 0});
|
||||
AddAttr<int>(
|
||||
"groups",
|
||||
"group size of convolution operator. "
|
||||
"Refer to grouped convolution in Alex Krizhevsky's paper: "
|
||||
"when group=2, the first half of the filters are only connected to the "
|
||||
"first half of the input channels, and the second half only connected "
|
||||
"to the second half.")
|
||||
.SetDefault(1);
|
||||
AddComment(R"DOC(
|
||||
The convolution operation calculates the output based on the input, filter
|
||||
and strides, paddings, groups parameters. The size of each dimension of the
|
||||
parameters is checked in the infer-shape.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class Conv2DOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||
auto in = ctx.Input<Tensor>("Input");
|
||||
auto filter = ctx.Input<Tensor>("Filter");
|
||||
auto d_in =
|
||||
ctx.Output<framework::LoDTensor>(framework::GradVarName("Input"));
|
||||
auto d_filter =
|
||||
ctx.Output<framework::LoDTensor>(framework::GradVarName("Filter"));
|
||||
if (d_in) d_in->Resize(in->dims());
|
||||
if (d_filter) d_filter->Resize(filter->dims());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(conv2d, ops::Conv2DOp, ops::Conv2DOpMaker, conv2d_grad,
|
||||
ops::Conv2DOpGrad);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
conv2d, ops::GemmConv2DKernel<paddle::platform::CPUPlace, float>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
conv2d_grad, ops::GemmConvGrad2DKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,22 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/gemm_conv2d_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
conv2d, ops::GemmConv2DKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(
|
||||
conv2d_grad, ops::GemmConvGrad2DKernel<paddle::platform::GPUPlace, float>);
|
@ -0,0 +1,139 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/crop_op.h"
|
||||
#include <boost/lexical_cast.hpp>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
using framework::LoDTensor;
|
||||
|
||||
class CropOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
|
||||
"Input(X) of CropOp should not be null.");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
|
||||
"Output(Out) of CropOp should not be null.");
|
||||
auto x_dim = ctx.Input<LoDTensor>("X")->dims();
|
||||
auto *y = ctx.Input<LoDTensor>("Y");
|
||||
auto *out = ctx.Output<LoDTensor>("Out");
|
||||
if (y == nullptr) {
|
||||
auto shape = Attr<std::vector<int>>("shape");
|
||||
PADDLE_ENFORCE_EQ(
|
||||
int64_t(shape.size()), x_dim.size(),
|
||||
"Shape size should be equal to dimention size of input tensor.");
|
||||
std::vector<int64_t> tensor_shape(shape.size());
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
tensor_shape[i] = static_cast<int64_t>(shape[i]);
|
||||
}
|
||||
out->Resize(framework::make_ddim(tensor_shape));
|
||||
} else {
|
||||
PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y->dims()),
|
||||
"Tensor rank of both CropOp's "
|
||||
"inputs must be same.");
|
||||
out->Resize(y->dims());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class CropOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
CropOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X",
|
||||
"The input of pad op. "
|
||||
"The input should be a k-D tensor(k > 0 and k < 7)");
|
||||
AddInput("Y",
|
||||
"The input used as reference for cropping"
|
||||
" with the same dimension as X. ");
|
||||
AddOutput("Out",
|
||||
"The output of crop op "
|
||||
"with the same dimension as X.");
|
||||
AddAttr<std::vector<int>>("offsets",
|
||||
"A list<int> describing offsets to be cropped."
|
||||
"The size of offsets list should be as same as "
|
||||
"dimension size of input X.");
|
||||
AddAttr<std::vector<int>>("shape",
|
||||
"A list<int> describing the shape of output."
|
||||
"The size of shape list should be as same as "
|
||||
"dimension size of input X.")
|
||||
.SetDefault(std::vector<int>());
|
||||
AddComment(R"DOC(
|
||||
Crop Operator.
|
||||
Crop input into output, as specified by offsets and shape.
|
||||
|
||||
There are two ways to set shape:
|
||||
1. referenc input: crop input X as shape as reference input.
|
||||
The dimension of reference input should
|
||||
be as same as input X.
|
||||
2. shape list: crop input X by shape described by a list<int>.
|
||||
The size of shape list should be as same as
|
||||
dimension size of input X.
|
||||
|
||||
The input should be a k-D tensor(k > 0 and k < 7). As an example:
|
||||
|
||||
Given:
|
||||
|
||||
X = [[0, 1, 2, 0, 0]
|
||||
[0, 3, 4, 0, 0]
|
||||
[0, 0, 0, 0, 0]]
|
||||
|
||||
and
|
||||
|
||||
offsets = [0, 1]
|
||||
|
||||
and
|
||||
|
||||
shape = [2, 2]
|
||||
|
||||
then we get
|
||||
|
||||
Out = [[1, 2],
|
||||
[3, 4]]
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class CropOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
|
||||
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) should not be null");
|
||||
auto x_dims = ctx.Input<LoDTensor>("X")->dims();
|
||||
auto *x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
|
||||
if (x_grad != nullptr) {
|
||||
x_grad->Resize(x_dims);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(crop, ops::CropOp, ops::CropOpMaker, crop_grad, ops::CropOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(crop, ops::CropKernel<float>);
|
||||
REGISTER_OP_CPU_KERNEL(crop_grad,
|
||||
ops::CropGradKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,21 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "paddle/operators/crop_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(crop, ops::CropKernel<float>);
|
||||
REGISTER_OP_GPU_KERNEL(crop_grad,
|
||||
ops::CropGradKernel<paddle::platform::GPUPlace, float>);
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue