commit
e34e12931e
@ -0,0 +1,213 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
from paddle.trainer_config_helpers import *
|
||||||
|
|
||||||
|
height = 224
|
||||||
|
width = 224
|
||||||
|
num_class = 1000
|
||||||
|
batch_size = get_config_arg('batch_size', int, 64)
|
||||||
|
layer_num = get_config_arg("layer_num", int, 50)
|
||||||
|
is_test = get_config_arg("is_test", bool, False)
|
||||||
|
|
||||||
|
args = {'height': height, 'width': width, 'color': True, 'num_class': num_class}
|
||||||
|
define_py_data_sources2(
|
||||||
|
"train.list", None, module="provider", obj="process", args=args)
|
||||||
|
|
||||||
|
settings(
|
||||||
|
batch_size=batch_size,
|
||||||
|
learning_rate=0.01 / batch_size,
|
||||||
|
learning_method=MomentumOptimizer(0.9),
|
||||||
|
regularization=L2Regularization(0.0005 * batch_size))
|
||||||
|
|
||||||
|
|
||||||
|
#######################Network Configuration #############
|
||||||
|
def conv_bn_layer(name,
|
||||||
|
input,
|
||||||
|
filter_size,
|
||||||
|
num_filters,
|
||||||
|
stride,
|
||||||
|
padding,
|
||||||
|
channels=None,
|
||||||
|
active_type=ReluActivation()):
|
||||||
|
"""
|
||||||
|
A wrapper for conv layer with batch normalization layers.
|
||||||
|
Note:
|
||||||
|
conv layer has no activation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
tmp = img_conv_layer(
|
||||||
|
name=name + "_conv",
|
||||||
|
input=input,
|
||||||
|
filter_size=filter_size,
|
||||||
|
num_channels=channels,
|
||||||
|
num_filters=num_filters,
|
||||||
|
stride=stride,
|
||||||
|
padding=padding,
|
||||||
|
act=LinearActivation(),
|
||||||
|
bias_attr=False)
|
||||||
|
return batch_norm_layer(
|
||||||
|
name=name + "_bn", input=tmp, act=active_type, use_global_stats=is_test)
|
||||||
|
|
||||||
|
|
||||||
|
def bottleneck_block(name, input, num_filters1, num_filters2):
|
||||||
|
"""
|
||||||
|
A wrapper for bottlenect building block in ResNet.
|
||||||
|
Last conv_bn_layer has no activation.
|
||||||
|
Addto layer has activation of relu.
|
||||||
|
"""
|
||||||
|
last_name = conv_bn_layer(
|
||||||
|
name=name + '_branch2a',
|
||||||
|
input=input,
|
||||||
|
filter_size=1,
|
||||||
|
num_filters=num_filters1,
|
||||||
|
stride=1,
|
||||||
|
padding=0)
|
||||||
|
last_name = conv_bn_layer(
|
||||||
|
name=name + '_branch2b',
|
||||||
|
input=last_name,
|
||||||
|
filter_size=3,
|
||||||
|
num_filters=num_filters1,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
last_name = conv_bn_layer(
|
||||||
|
name=name + '_branch2c',
|
||||||
|
input=last_name,
|
||||||
|
filter_size=1,
|
||||||
|
num_filters=num_filters2,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
active_type=LinearActivation())
|
||||||
|
|
||||||
|
return addto_layer(
|
||||||
|
name=name + "_addto", input=[input, last_name], act=ReluActivation())
|
||||||
|
|
||||||
|
|
||||||
|
def mid_projection(name, input, num_filters1, num_filters2, stride=2):
|
||||||
|
"""
|
||||||
|
A wrapper for middile projection in ResNet.
|
||||||
|
projection shortcuts are used for increasing dimensions,
|
||||||
|
and other shortcuts are identity
|
||||||
|
branch1: projection shortcuts are used for increasing
|
||||||
|
dimensions, has no activation.
|
||||||
|
branch2x: bottleneck building block, shortcuts are identity.
|
||||||
|
"""
|
||||||
|
# stride = 2
|
||||||
|
branch1 = conv_bn_layer(
|
||||||
|
name=name + '_branch1',
|
||||||
|
input=input,
|
||||||
|
filter_size=1,
|
||||||
|
num_filters=num_filters2,
|
||||||
|
stride=stride,
|
||||||
|
padding=0,
|
||||||
|
active_type=LinearActivation())
|
||||||
|
|
||||||
|
last_name = conv_bn_layer(
|
||||||
|
name=name + '_branch2a',
|
||||||
|
input=input,
|
||||||
|
filter_size=1,
|
||||||
|
num_filters=num_filters1,
|
||||||
|
stride=stride,
|
||||||
|
padding=0)
|
||||||
|
last_name = conv_bn_layer(
|
||||||
|
name=name + '_branch2b',
|
||||||
|
input=last_name,
|
||||||
|
filter_size=3,
|
||||||
|
num_filters=num_filters1,
|
||||||
|
stride=1,
|
||||||
|
padding=1)
|
||||||
|
|
||||||
|
last_name = conv_bn_layer(
|
||||||
|
name=name + '_branch2c',
|
||||||
|
input=last_name,
|
||||||
|
filter_size=1,
|
||||||
|
num_filters=num_filters2,
|
||||||
|
stride=1,
|
||||||
|
padding=0,
|
||||||
|
active_type=LinearActivation())
|
||||||
|
|
||||||
|
return addto_layer(
|
||||||
|
name=name + "_addto", input=[branch1, last_name], act=ReluActivation())
|
||||||
|
|
||||||
|
|
||||||
|
img = data_layer(name='image', size=height * width * 3)
|
||||||
|
|
||||||
|
|
||||||
|
def deep_res_net(res2_num=3, res3_num=4, res4_num=6, res5_num=3):
|
||||||
|
"""
|
||||||
|
A wrapper for 50,101,152 layers of ResNet.
|
||||||
|
res2_num: number of blocks stacked in conv2_x
|
||||||
|
res3_num: number of blocks stacked in conv3_x
|
||||||
|
res4_num: number of blocks stacked in conv4_x
|
||||||
|
res5_num: number of blocks stacked in conv5_x
|
||||||
|
"""
|
||||||
|
# For ImageNet
|
||||||
|
# conv1: 112x112
|
||||||
|
tmp = conv_bn_layer(
|
||||||
|
"conv1",
|
||||||
|
input=img,
|
||||||
|
filter_size=7,
|
||||||
|
channels=3,
|
||||||
|
num_filters=64,
|
||||||
|
stride=2,
|
||||||
|
padding=3)
|
||||||
|
tmp = img_pool_layer(name="pool1", input=tmp, pool_size=3, stride=2)
|
||||||
|
|
||||||
|
# conv2_x: 56x56
|
||||||
|
tmp = mid_projection(
|
||||||
|
name="res2_1", input=tmp, num_filters1=64, num_filters2=256, stride=1)
|
||||||
|
for i in xrange(2, res2_num + 1, 1):
|
||||||
|
tmp = bottleneck_block(
|
||||||
|
name="res2_" + str(i), input=tmp, num_filters1=64, num_filters2=256)
|
||||||
|
|
||||||
|
# conv3_x: 28x28
|
||||||
|
tmp = mid_projection(
|
||||||
|
name="res3_1", input=tmp, num_filters1=128, num_filters2=512)
|
||||||
|
for i in xrange(2, res3_num + 1, 1):
|
||||||
|
tmp = bottleneck_block(
|
||||||
|
name="res3_" + str(i),
|
||||||
|
input=tmp,
|
||||||
|
num_filters1=128,
|
||||||
|
num_filters2=512)
|
||||||
|
|
||||||
|
# conv4_x: 14x14
|
||||||
|
tmp = mid_projection(
|
||||||
|
name="res4_1", input=tmp, num_filters1=256, num_filters2=1024)
|
||||||
|
for i in xrange(2, res4_num + 1, 1):
|
||||||
|
tmp = bottleneck_block(
|
||||||
|
name="res4_" + str(i),
|
||||||
|
input=tmp,
|
||||||
|
num_filters1=256,
|
||||||
|
num_filters2=1024)
|
||||||
|
|
||||||
|
# conv5_x: 7x7
|
||||||
|
tmp = mid_projection(
|
||||||
|
name="res5_1", input=tmp, num_filters1=512, num_filters2=2048)
|
||||||
|
for i in xrange(2, res5_num + 1, 1):
|
||||||
|
tmp = bottleneck_block(
|
||||||
|
name="res5_" + str(i),
|
||||||
|
input=tmp,
|
||||||
|
num_filters1=512,
|
||||||
|
num_filters2=2048)
|
||||||
|
|
||||||
|
tmp = img_pool_layer(
|
||||||
|
name='avgpool',
|
||||||
|
input=tmp,
|
||||||
|
pool_size=7,
|
||||||
|
stride=1,
|
||||||
|
pool_type=AvgPooling())
|
||||||
|
|
||||||
|
return fc_layer(input=tmp, size=num_class, act=SoftmaxActivation())
|
||||||
|
|
||||||
|
|
||||||
|
if layer_num == 50:
|
||||||
|
resnet = deep_res_net(3, 4, 6, 3)
|
||||||
|
elif layer_num == 101:
|
||||||
|
resnet = deep_res_net(3, 4, 23, 3)
|
||||||
|
elif layer_num == 152:
|
||||||
|
resnet = deep_res_net(3, 8, 36, 3)
|
||||||
|
else:
|
||||||
|
print("Wrong layer number.")
|
||||||
|
|
||||||
|
lbl = data_layer(name="label", size=num_class)
|
||||||
|
loss = cross_entropy(name='loss', input=resnet, label=lbl)
|
||||||
|
inputs(img, lbl)
|
||||||
|
outputs(loss)
|
@ -0,0 +1,152 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
#include <numeric>
|
||||||
|
#include "paddle/framework/lod_rank_table.h"
|
||||||
|
#include "paddle/framework/lod_tensor_array.h"
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
#include "paddle/memory/memcpy.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
using LoD = framework::LoD;
|
||||||
|
|
||||||
|
class ArrayToLoDTensorOp : public framework::OperatorBase {
|
||||||
|
public:
|
||||||
|
ArrayToLoDTensorOp(const std::string &type,
|
||||||
|
const framework::VariableNameMap &inputs,
|
||||||
|
const framework::VariableNameMap &outputs,
|
||||||
|
const framework::AttributeMap &attrs)
|
||||||
|
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||||
|
void Run(const framework::Scope &scope,
|
||||||
|
const platform::DeviceContext &dev_ctx) const override {
|
||||||
|
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensorArray>();
|
||||||
|
auto &rank_table =
|
||||||
|
scope.FindVar(Input("RankTable"))->Get<framework::LoDRankTable>();
|
||||||
|
auto *out =
|
||||||
|
scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
|
||||||
|
|
||||||
|
// Check dims, place and data type of input's elements and infer output's
|
||||||
|
// dim
|
||||||
|
PADDLE_ENFORCE(!x.empty(), "There's no element in the input array.");
|
||||||
|
int rank = x[0].dims().size();
|
||||||
|
platform::Place place = x[0].place();
|
||||||
|
std::type_index data_type = x[0].type();
|
||||||
|
framework::DDim ins_dims = framework::slice_ddim(x[0].dims(), 1, rank);
|
||||||
|
int64_t batch_size = x[0].dims()[0];
|
||||||
|
for (size_t i = 1; i < x.size(); ++i) {
|
||||||
|
PADDLE_ENFORCE_EQ(framework::slice_ddim(x[i].dims(), 1, rank), ins_dims,
|
||||||
|
"The dimension of the %zu'th element in LoDTensorArray "
|
||||||
|
"differs from previous ones.",
|
||||||
|
i);
|
||||||
|
PADDLE_ENFORCE(platform::places_are_same_class(x[i].place(), place),
|
||||||
|
"The place class of the %zu'th element in LoDTensorArray "
|
||||||
|
"differs from previous ones.",
|
||||||
|
i);
|
||||||
|
PADDLE_ENFORCE(x[i].type() == data_type,
|
||||||
|
"The date type of the %zu'th element in LoDTensorArray "
|
||||||
|
"differs from previous ones.",
|
||||||
|
i);
|
||||||
|
batch_size += x[i].dims()[0];
|
||||||
|
}
|
||||||
|
auto ins_dim_vec = framework::vectorize(ins_dims);
|
||||||
|
ins_dim_vec.insert(ins_dim_vec.begin(), batch_size);
|
||||||
|
framework::DDim out_dims = framework::make_ddim(ins_dim_vec);
|
||||||
|
out->Resize(out_dims);
|
||||||
|
out->mutable_data(place, data_type);
|
||||||
|
|
||||||
|
auto &table_items = rank_table.items();
|
||||||
|
std::vector<size_t> table_item_idx(table_items.size());
|
||||||
|
// table_item_idx = range(table_items_idx.size())
|
||||||
|
std::iota(table_item_idx.begin(), table_item_idx.end(), 0);
|
||||||
|
std::sort(table_item_idx.begin(), table_item_idx.end(),
|
||||||
|
[&](size_t a, size_t b) {
|
||||||
|
return table_items[a].index < table_items[b].index;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Build LoDTensor `out`
|
||||||
|
framework::LoD *out_lod = out->mutable_lod();
|
||||||
|
out_lod->clear();
|
||||||
|
size_t out_offset = 0;
|
||||||
|
auto prefix_lod = rank_table.coarse_lod();
|
||||||
|
prefix_lod.emplace_back();
|
||||||
|
auto &cur_level_lod = prefix_lod.back();
|
||||||
|
cur_level_lod.push_back(0);
|
||||||
|
for (size_t idx : table_item_idx) {
|
||||||
|
cur_level_lod.push_back(cur_level_lod.back() + table_items[idx].length);
|
||||||
|
for (size_t x_idx = 0; x_idx < table_items[idx].length; ++x_idx) {
|
||||||
|
auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset(
|
||||||
|
x[x_idx].lod(), idx, idx + 1, 0);
|
||||||
|
|
||||||
|
auto &lod_length = lod_and_offset.first;
|
||||||
|
framework::AppendLoD(out_lod, lod_length);
|
||||||
|
|
||||||
|
size_t start_offset = lod_and_offset.second.first;
|
||||||
|
size_t end_offset = lod_and_offset.second.second;
|
||||||
|
VLOG(10) << "idx=" << idx << " x_idx=" << x_idx << " ["
|
||||||
|
<< ", " << end_offset << "]";
|
||||||
|
// Copy data
|
||||||
|
PADDLE_ENFORCE_GE(end_offset, start_offset);
|
||||||
|
size_t len = end_offset - start_offset;
|
||||||
|
if (len == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
out->Slice(out_offset, out_offset + len)
|
||||||
|
.CopyFrom(x[x_idx].Slice(start_offset, end_offset), place, dev_ctx);
|
||||||
|
out_offset += len;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out_lod->insert(out_lod->begin(), prefix_lod.begin(), prefix_lod.end());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ArrayToLoDTensorOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
ArrayToLoDTensorOpProtoMaker(framework::OpProto *proto,
|
||||||
|
framework::OpAttrChecker *op_checker)
|
||||||
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||||
|
AddInput("X",
|
||||||
|
"(std::vector<LodTensor>) A vector of tensors that is going to "
|
||||||
|
"be casted to a big LoDTensor.");
|
||||||
|
AddInput("RankTable",
|
||||||
|
"(LoDRankTable) RankTable provides the coarse lod infomation to "
|
||||||
|
"build the output LoDTensor. See "
|
||||||
|
"'paddle/framework/lod_rank_table.h' for more details.");
|
||||||
|
AddOutput("Out", "(LoDTensor) The LoDTensor formed by input tensor array.");
|
||||||
|
AddComment(
|
||||||
|
R"DOC(This Op build a big LoDTensor from a std::vector<LoDTensor>
|
||||||
|
and a LoDRankTable. It is supposed to be used in getting dynamic RNN's
|
||||||
|
outputs back to a normal LoDTensor. The std::vector<LoDTensor>
|
||||||
|
would be the output of RNN Op and the LoDRankTable would be build
|
||||||
|
with RNN's input.)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class ArrayToLoDTensorInferShape : public framework::InferShapeBase {
|
||||||
|
public:
|
||||||
|
void operator()(framework::InferShapeContext *context) const override {
|
||||||
|
PADDLE_ENFORCE(context->HasInput("X"),
|
||||||
|
"ArrayToLoDTensorOp must has input X.");
|
||||||
|
PADDLE_ENFORCE(context->HasInput("RankTable"),
|
||||||
|
"ArrayToLoDTensorOp must has input RankTable.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
REGISTER_OPERATOR(array_to_lod_tensor, ops::ArrayToLoDTensorOp,
|
||||||
|
ops::ArrayToLoDTensorOpProtoMaker,
|
||||||
|
ops::ArrayToLoDTensorInferShape);
|
@ -0,0 +1,82 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/compare_op.h"
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
template <typename OpComment>
|
||||||
|
class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
CompareOpProtoMaker(framework::OpProto *proto,
|
||||||
|
framework::OpAttrChecker *op_checker)
|
||||||
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||||
|
OpComment comment;
|
||||||
|
AddInput("X",
|
||||||
|
string::Sprintf("(LoDTensor) the left hand operand of %s operator",
|
||||||
|
comment.type));
|
||||||
|
AddInput("Y", string::Sprintf(
|
||||||
|
"(LoDTensor) the right hand operand of %s operator",
|
||||||
|
comment.type));
|
||||||
|
AddOutput("Out", string::Sprintf(
|
||||||
|
"(LoDTensor) n-dim bool tensor. Each element is %s",
|
||||||
|
comment.equation));
|
||||||
|
AddComment(string::Sprintf(R"DOC(%s Operator
|
||||||
|
|
||||||
|
It operates element-wise on X and Y, and returns the Out. Each of them is a
|
||||||
|
N-dim tensor. X and Y could be any type. The each element of the Out tensor is
|
||||||
|
calculated by %s
|
||||||
|
)DOC",
|
||||||
|
comment.type, comment.equation));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename OpComment>
|
||||||
|
class CompareOpInferShape : public framework::InferShapeBase {
|
||||||
|
public:
|
||||||
|
void operator()(framework::InferShapeContext *context) const override {
|
||||||
|
OpComment comment;
|
||||||
|
PADDLE_ENFORCE(context->HasInput("X"), "%s operator must has input X",
|
||||||
|
comment.type);
|
||||||
|
PADDLE_ENFORCE(context->HasInput("Y"), "%s operator must has input Y",
|
||||||
|
comment.type);
|
||||||
|
auto dim_x = context->GetInputDim("X");
|
||||||
|
auto dim_y = context->GetInputDim("Y");
|
||||||
|
PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y),
|
||||||
|
"The number of elements in X and Y should be same");
|
||||||
|
|
||||||
|
context->SetOutputDim("Out", context->GetInputDim("X"));
|
||||||
|
context->ShareLoD("X", "Out");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
#define REGISTER_LOGICAL_OP(op_type, _equation) \
|
||||||
|
struct _##op_type##Comment { \
|
||||||
|
static char type[]; \
|
||||||
|
static char equation[]; \
|
||||||
|
}; \
|
||||||
|
char _##op_type##Comment::type[]{#op_type}; \
|
||||||
|
char _##op_type##Comment::equation[]{_equation}; \
|
||||||
|
REGISTER_OP_WITH_KERNEL( \
|
||||||
|
op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
|
||||||
|
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
|
||||||
|
::paddle::framework::EmptyGradOpMaker);
|
||||||
|
|
||||||
|
REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
|
||||||
|
REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
|
||||||
|
REGISTER_LOGICAL_OP(equal, "Out = X == Y");
|
||||||
|
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
|
@ -0,0 +1,18 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/operators/compare_op.h"
|
||||||
|
|
||||||
|
REGISTER_LOGICAL_KERNEL(less_than, GPU, paddle::operators::LessThanFunctor);
|
||||||
|
REGISTER_LOGICAL_KERNEL(equal, GPU, paddle::operators::EqualFunctor);
|
@ -0,0 +1,74 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
#include <math.h>
|
||||||
|
#include <type_traits>
|
||||||
|
#include "paddle/framework/op_registry.h"
|
||||||
|
#include "paddle/platform/transform.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct LessThanFunctor {
|
||||||
|
using ELEM_TYPE = T;
|
||||||
|
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct EqualFunctor {
|
||||||
|
using ELEM_TYPE = T;
|
||||||
|
HOSTDEVICE bool operator()(const T& a, const T& b) const {
|
||||||
|
if (std::is_floating_point<T>::value) {
|
||||||
|
// This branch will be optimized while compiling if T is integer. It is
|
||||||
|
// safe to cast a and b to double.
|
||||||
|
return fabs(static_cast<double>(a - b)) < 1e-8;
|
||||||
|
} else {
|
||||||
|
return (a == b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename Place, typename Functor>
|
||||||
|
class CompareOpKernel
|
||||||
|
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& context) const override {
|
||||||
|
using T = typename Functor::ELEM_TYPE;
|
||||||
|
auto* x = context.Input<framework::Tensor>("X");
|
||||||
|
auto* y = context.Input<framework::Tensor>("Y");
|
||||||
|
auto* out = context.Output<framework::Tensor>("Out");
|
||||||
|
Functor binary_func;
|
||||||
|
platform::Transform<Place> trans;
|
||||||
|
trans(context.device_context(), x->data<T>(), x->data<T>() + x->numel(),
|
||||||
|
y->data<T>(), out->mutable_data<bool>(context.GetPlace()),
|
||||||
|
binary_func);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
#define REGISTER_LOGICAL_KERNEL(op_type, dev, functor) \
|
||||||
|
REGISTER_OP_##dev##_KERNEL( \
|
||||||
|
op_type, \
|
||||||
|
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
|
||||||
|
functor<int>>, \
|
||||||
|
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
|
||||||
|
functor<int64_t>>, \
|
||||||
|
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
|
||||||
|
functor<float>>, \
|
||||||
|
::paddle::operators::CompareOpKernel<::paddle::platform::dev##Place, \
|
||||||
|
functor<double>>);
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue