|
|
|
@ -26,59 +26,82 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
namespace ngraphs {
|
|
|
|
|
std::shared_ptr<ngraph::Node> remove_trailing_one(
|
|
|
|
|
const std::shared_ptr<ngraph::Node>& input) {
|
|
|
|
|
auto shape = input->get_shape();
|
|
|
|
|
if (shape.back() == 1) {
|
|
|
|
|
shape.pop_back();
|
|
|
|
|
return platform::NgReshaper(input, shape);
|
|
|
|
|
} else {
|
|
|
|
|
return input;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<ngraph::Node> GetCrossEntropy(
|
|
|
|
|
std::shared_ptr<ngraph::Node> x, std::shared_ptr<ngraph::Node> label,
|
|
|
|
|
const bool is_soft_label, int ignore_index) {
|
|
|
|
|
auto label_shape = label->get_shape();
|
|
|
|
|
auto x_shape = x->get_shape();
|
|
|
|
|
auto label_rank = label_shape.size();
|
|
|
|
|
auto x_rank = x_shape.size();
|
|
|
|
|
std::shared_ptr<ngraph::Node> x_2d = x, label_2d = label;
|
|
|
|
|
auto label_2d_shape = label_shape, x_2d_shape = x_shape;
|
|
|
|
|
|
|
|
|
|
if (label_rank > 2) {
|
|
|
|
|
label_2d_shape = paddle::platform::FlattenTo2d(label_shape, label_rank - 1);
|
|
|
|
|
label_2d = paddle::platform::NgReshaper(label, label_2d_shape);
|
|
|
|
|
std::shared_ptr<ngraph::Node> flatten_node(
|
|
|
|
|
const std::shared_ptr<ngraph::Node>& input) {
|
|
|
|
|
auto shape = input->get_shape();
|
|
|
|
|
auto rank = shape.size();
|
|
|
|
|
auto output = input;
|
|
|
|
|
if (rank > 2) {
|
|
|
|
|
auto shape_2d = paddle::platform::FlattenTo2d(shape, rank - 1);
|
|
|
|
|
output = paddle::platform::NgReshaper(input, shape_2d);
|
|
|
|
|
}
|
|
|
|
|
if (x_rank > 2) {
|
|
|
|
|
x_2d_shape = platform::FlattenTo2d(x_shape, x_rank - 1);
|
|
|
|
|
x_2d = platform::NgReshaper(x, x_2d_shape);
|
|
|
|
|
return output;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<ngraph::Node> convert_to_node_type(
|
|
|
|
|
const std::shared_ptr<ngraph::Node>& input,
|
|
|
|
|
const std::shared_ptr<ngraph::Node>& ref) {
|
|
|
|
|
auto output = input;
|
|
|
|
|
if (input->get_element_type() != ref->get_element_type()) {
|
|
|
|
|
output =
|
|
|
|
|
std::make_shared<ngraph::op::Convert>(input, ref->get_element_type());
|
|
|
|
|
}
|
|
|
|
|
return output;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto batch_size = x_2d_shape.at(0);
|
|
|
|
|
std::shared_ptr<ngraph::Node> create_xe(
|
|
|
|
|
const std::shared_ptr<ngraph::Node>& one_hot,
|
|
|
|
|
const std::shared_ptr<ngraph::Node>& x) {
|
|
|
|
|
auto node_log = std::make_shared<ngraph::op::Log>(x);
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<ngraph::Node> node_1_hot = label_2d;
|
|
|
|
|
auto node_mul = one_hot * node_log;
|
|
|
|
|
auto node_sum = std::make_shared<ngraph::op::Sum>(
|
|
|
|
|
node_mul, ngraph::AxisSet{x->get_shape().size() - 1});
|
|
|
|
|
|
|
|
|
|
auto shape = x->get_shape();
|
|
|
|
|
shape.back() = 1;
|
|
|
|
|
return platform::NgReshaper(-node_sum, shape);
|
|
|
|
|
}
|
|
|
|
|
std::shared_ptr<ngraph::Node> create_mask(
|
|
|
|
|
const std::shared_ptr<ngraph::Node>& label, int ignore_index) {
|
|
|
|
|
auto ignore_node = paddle::platform::CreateConstant(
|
|
|
|
|
label->get_element_type(), label->get_shape(), {ignore_index});
|
|
|
|
|
auto not_equal_node =
|
|
|
|
|
std::make_shared<ngraph::op::NotEqual>(label, ignore_node);
|
|
|
|
|
return not_equal_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<ngraph::Node> create_one_hot(
|
|
|
|
|
const std::shared_ptr<ngraph::Node>& label,
|
|
|
|
|
const std::shared_ptr<ngraph::Node>& x) {
|
|
|
|
|
auto label_shape = label->get_shape();
|
|
|
|
|
return std::make_shared<ngraph::op::OneHot>(
|
|
|
|
|
remove_trailing_one(label), x->get_shape(), x->get_shape().size() - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<ngraph::Node> GetCrossEntropy(
|
|
|
|
|
std::shared_ptr<ngraph::Node> x, std::shared_ptr<ngraph::Node> label,
|
|
|
|
|
const bool is_soft_label, int ignore_index) {
|
|
|
|
|
std::shared_ptr<ngraph::Node> node_1_hot = label;
|
|
|
|
|
if (!is_soft_label) {
|
|
|
|
|
auto label_1d =
|
|
|
|
|
platform::NgReshaper(label_2d, ngraph::Shape{label_2d_shape.at(0)});
|
|
|
|
|
node_1_hot = std::make_shared<ngraph::op::OneHot>(label_1d, x_2d_shape, 1);
|
|
|
|
|
}
|
|
|
|
|
if (x->get_element_type() != node_1_hot->get_element_type()) {
|
|
|
|
|
node_1_hot = std::make_shared<ngraph::op::Convert>(node_1_hot,
|
|
|
|
|
x->get_element_type());
|
|
|
|
|
node_1_hot = create_one_hot(label, x);
|
|
|
|
|
}
|
|
|
|
|
node_1_hot = convert_to_node_type(node_1_hot, x);
|
|
|
|
|
|
|
|
|
|
auto node_log = std::make_shared<ngraph::op::Log>(x_2d);
|
|
|
|
|
auto high_clip = ngraph::op::Constant::create(node_log->get_element_type(),
|
|
|
|
|
node_log->get_shape(), {1e20});
|
|
|
|
|
auto low_clip = ngraph::op::Constant::create(node_log->get_element_type(),
|
|
|
|
|
node_log->get_shape(), {-1e20});
|
|
|
|
|
auto node_min = std::make_shared<ngraph::op::Minimum>(node_log, high_clip);
|
|
|
|
|
auto node_max = std::make_shared<ngraph::op::Maximum>(node_min, low_clip);
|
|
|
|
|
auto node_mul = node_1_hot * node_log;
|
|
|
|
|
auto node_sum =
|
|
|
|
|
std::make_shared<ngraph::op::Sum>(node_mul, ngraph::AxisSet{1});
|
|
|
|
|
auto node_neg = std::make_shared<ngraph::op::Negative>(node_sum);
|
|
|
|
|
auto xe = platform::NgReshaper(node_neg, ngraph::Shape{batch_size, 1});
|
|
|
|
|
|
|
|
|
|
auto xe = create_xe(node_1_hot, x);
|
|
|
|
|
if (!is_soft_label) {
|
|
|
|
|
auto ignore_node = ngraph::op::Constant::create(
|
|
|
|
|
label->get_element_type(), label_2d_shape, {ignore_index});
|
|
|
|
|
auto not_equal_node =
|
|
|
|
|
std::make_shared<ngraph::op::NotEqual>(label_2d, ignore_node);
|
|
|
|
|
auto mask = std::make_shared<ngraph::op::Convert>(not_equal_node,
|
|
|
|
|
xe->get_element_type());
|
|
|
|
|
auto mask = convert_to_node_type(create_mask(label, ignore_index), xe);
|
|
|
|
|
xe = xe * mask;
|
|
|
|
|
}
|
|
|
|
|
return xe;
|
|
|
|
@ -93,30 +116,17 @@ std::shared_ptr<ngraph::Node> GetCrossEntropyGrad(
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<ngraph::Node> mask;
|
|
|
|
|
if (!is_soft_label) {
|
|
|
|
|
auto label_shape = label->get_shape();
|
|
|
|
|
label_shape.pop_back();
|
|
|
|
|
label = platform::NgReshaper(label, label_shape);
|
|
|
|
|
|
|
|
|
|
auto ignore_node = ngraph::op::Constant::create(
|
|
|
|
|
label->get_element_type(), label_shape, {ignore_index});
|
|
|
|
|
auto not_equal_node =
|
|
|
|
|
std::make_shared<ngraph::op::NotEqual>(label, ignore_node);
|
|
|
|
|
mask = std::make_shared<ngraph::op::Convert>(not_equal_node,
|
|
|
|
|
x->get_element_type());
|
|
|
|
|
mask = std::make_shared<ngraph::op::Broadcast>(mask, x_shape,
|
|
|
|
|
ngraph::AxisSet{rank - 1});
|
|
|
|
|
|
|
|
|
|
label = std::make_shared<ngraph::op::OneHot>(label, x_shape, rank - 1);
|
|
|
|
|
mask = convert_to_node_type(create_mask(label, ignore_index), x);
|
|
|
|
|
mask = std::make_shared<ngraph::op::Broadcast>(
|
|
|
|
|
remove_trailing_one(mask), x_shape, ngraph::AxisSet{rank - 1});
|
|
|
|
|
label = create_one_hot(label, x);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto dy_shape = dy->get_shape();
|
|
|
|
|
dy_shape.pop_back();
|
|
|
|
|
auto dy_reshape = platform::NgReshaper(dy, dy_shape);
|
|
|
|
|
auto dy_reshape = remove_trailing_one(dy);
|
|
|
|
|
auto dy_bcast = std::make_shared<ngraph::op::Broadcast>(
|
|
|
|
|
dy_reshape, x_shape, ngraph::AxisSet{rank - 1});
|
|
|
|
|
if (x->get_element_type() != label->get_element_type()) {
|
|
|
|
|
label = std::make_shared<ngraph::op::Convert>(label, x->get_element_type());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
label = convert_to_node_type(label, x);
|
|
|
|
|
|
|
|
|
|
auto xe_grad = -label * dy_bcast / x;
|
|
|
|
|
|
|
|
|
@ -154,9 +164,80 @@ void BuildCrossEntropyGradNode(
|
|
|
|
|
auto xe_grad = GetCrossEntropyGrad(x, label, dy, is_soft_label, ignore_index);
|
|
|
|
|
paddle::platform::SetOutputNode(op, "X@GRAD", xe_grad, ngb_node_map);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BuildCrossEntropy2Node(
|
|
|
|
|
const std::shared_ptr<paddle::framework::OperatorBase>& op,
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
|
|
|
|
|
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
|
|
|
|
|
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
|
|
|
|
|
int ignore_index = op_attrs.Get<int>("ignore_index");
|
|
|
|
|
|
|
|
|
|
auto rank = x->get_shape().size();
|
|
|
|
|
|
|
|
|
|
auto one_hot = convert_to_node_type(create_one_hot(label, x), x);
|
|
|
|
|
auto xe = create_xe(one_hot, x);
|
|
|
|
|
auto mask = convert_to_node_type(create_mask(label, ignore_index), xe);
|
|
|
|
|
|
|
|
|
|
xe = xe * mask;
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<ngraph::Node> node_sum =
|
|
|
|
|
std::make_shared<ngraph::op::Sum>(one_hot * x, ngraph::AxisSet{rank - 1});
|
|
|
|
|
node_sum = paddle::platform::NgReshaper(node_sum, mask->get_shape());
|
|
|
|
|
auto matchx = mask * node_sum;
|
|
|
|
|
|
|
|
|
|
paddle::platform::SetOutputNode(op, "MatchX", matchx, ngb_node_map);
|
|
|
|
|
platform::SetOutputNode(op, "XShape", x, ngb_node_map);
|
|
|
|
|
paddle::platform::SetOutputNode(op, "Y", xe, ngb_node_map);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BuildCrossEntropyGrad2Node(
|
|
|
|
|
const std::shared_ptr<paddle::framework::OperatorBase>& op,
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
|
|
|
|
|
int ignore_index = op_attrs.Get<int>("ignore_index");
|
|
|
|
|
auto matchx = paddle::platform::GetInputNode(op, "MatchX", ngb_node_map);
|
|
|
|
|
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
|
|
|
|
|
auto x = paddle::platform::GetInputNode(op, "XShape", ngb_node_map);
|
|
|
|
|
auto dy = paddle::platform::GetInputNode(op, framework::GradVarName("Y"),
|
|
|
|
|
ngb_node_map);
|
|
|
|
|
|
|
|
|
|
matchx = remove_trailing_one(matchx);
|
|
|
|
|
label = remove_trailing_one(label);
|
|
|
|
|
x = remove_trailing_one(x);
|
|
|
|
|
dy = remove_trailing_one(dy);
|
|
|
|
|
|
|
|
|
|
auto x_shape = x->get_shape();
|
|
|
|
|
auto rank = x_shape.size();
|
|
|
|
|
|
|
|
|
|
auto one_hot = convert_to_node_type(create_one_hot(label, x), x);
|
|
|
|
|
auto mask = convert_to_node_type(create_mask(label, ignore_index), x);
|
|
|
|
|
|
|
|
|
|
auto zero = paddle::platform::CreateConstant(matchx->get_element_type(),
|
|
|
|
|
matchx->get_shape(), {0});
|
|
|
|
|
auto one = paddle::platform::CreateConstant(matchx->get_element_type(),
|
|
|
|
|
matchx->get_shape(), {1});
|
|
|
|
|
auto is_zero = std::make_shared<ngraph::op::Equal>(matchx, zero);
|
|
|
|
|
matchx = std::make_shared<ngraph::op::Select>(is_zero, one, matchx);
|
|
|
|
|
|
|
|
|
|
auto dy_bcast = std::make_shared<ngraph::op::Broadcast>(
|
|
|
|
|
mask * dy, x_shape, ngraph::AxisSet{rank - 1});
|
|
|
|
|
auto matchx_bcast = std::make_shared<ngraph::op::Broadcast>(
|
|
|
|
|
matchx, x_shape, ngraph::AxisSet{rank - 1});
|
|
|
|
|
|
|
|
|
|
auto xe_grad = -dy_bcast * one_hot / matchx_bcast;
|
|
|
|
|
paddle::platform::SetOutputNode(op, framework::GradVarName("X"), xe_grad,
|
|
|
|
|
ngb_node_map);
|
|
|
|
|
}
|
|
|
|
|
} // namespace ngraphs
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
REGISTER_NG_OP(cross_entropy, BuildCrossEntropyNode);
|
|
|
|
|
REGISTER_NG_OP(cross_entropy_grad, BuildCrossEntropyGradNode);
|
|
|
|
|
REGISTER_NG_OP(cross_entropy2, BuildCrossEntropy2Node);
|
|
|
|
|
REGISTER_NG_OP(cross_entropy_grad2, BuildCrossEntropyGrad2Node);
|
|
|
|
|