|
|
|
@ -27,13 +27,9 @@ namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
namespace ngraphs {
|
|
|
|
|
|
|
|
|
|
void BuildCrossEntropyNode(
|
|
|
|
|
const std::shared_ptr<paddle::framework::OperatorBase>& op,
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
|
|
|
|
|
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
|
|
|
|
|
std::shared_ptr<ngraph::Node> GetCrossEntropy(
|
|
|
|
|
std::shared_ptr<ngraph::Node> x, std::shared_ptr<ngraph::Node> label,
|
|
|
|
|
const bool is_soft_label, int ignore_index) {
|
|
|
|
|
auto label_shape = label->get_shape();
|
|
|
|
|
auto x_shape = x->get_shape();
|
|
|
|
|
auto label_rank = label_shape.size();
|
|
|
|
@ -46,18 +42,16 @@ void BuildCrossEntropyNode(
|
|
|
|
|
label_2d = paddle::platform::NgReshaper(label, label_2d_shape);
|
|
|
|
|
}
|
|
|
|
|
if (x_rank > 2) {
|
|
|
|
|
x_2d_shape = paddle::platform::FlattenTo2d(x_shape, x_rank - 1);
|
|
|
|
|
x_2d = paddle::platform::NgReshaper(x, x_2d_shape);
|
|
|
|
|
x_2d_shape = platform::FlattenTo2d(x_shape, x_rank - 1);
|
|
|
|
|
x_2d = platform::NgReshaper(x, x_2d_shape);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto batch_size = x_2d_shape.at(0);
|
|
|
|
|
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
|
|
|
|
|
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<ngraph::Node> node_1_hot = label_2d;
|
|
|
|
|
if (!is_soft_label) {
|
|
|
|
|
auto label_1d = paddle::platform::NgReshaper(
|
|
|
|
|
label_2d, ngraph::Shape{label_2d_shape.at(0)});
|
|
|
|
|
auto label_1d =
|
|
|
|
|
platform::NgReshaper(label_2d, ngraph::Shape{label_2d_shape.at(0)});
|
|
|
|
|
node_1_hot = std::make_shared<ngraph::op::OneHot>(label_1d, x_2d_shape, 1);
|
|
|
|
|
}
|
|
|
|
|
if (x->get_element_type() != node_1_hot->get_element_type()) {
|
|
|
|
@ -76,11 +70,9 @@ void BuildCrossEntropyNode(
|
|
|
|
|
auto node_sum =
|
|
|
|
|
std::make_shared<ngraph::op::Sum>(node_mul, ngraph::AxisSet{1});
|
|
|
|
|
auto node_neg = std::make_shared<ngraph::op::Negative>(node_sum);
|
|
|
|
|
auto xe =
|
|
|
|
|
paddle::platform::NgReshaper(node_neg, ngraph::Shape{batch_size, 1});
|
|
|
|
|
auto xe = platform::NgReshaper(node_neg, ngraph::Shape{batch_size, 1});
|
|
|
|
|
|
|
|
|
|
if (!is_soft_label) {
|
|
|
|
|
auto ignore_index = op_attrs.Get<int>("ignore_index");
|
|
|
|
|
auto ignore_node = ngraph::op::Constant::create(
|
|
|
|
|
label->get_element_type(), label_2d_shape, {ignore_index});
|
|
|
|
|
auto not_equal_node =
|
|
|
|
@ -89,21 +81,13 @@ void BuildCrossEntropyNode(
|
|
|
|
|
xe->get_element_type());
|
|
|
|
|
xe = xe * mask;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
paddle::platform::SetOutputNode(op, "Y", xe, ngb_node_map);
|
|
|
|
|
return xe;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BuildCrossEntropyGradNode(
|
|
|
|
|
const std::shared_ptr<paddle::framework::OperatorBase>& op,
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
|
|
|
|
|
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
|
|
|
|
|
|
|
|
|
|
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
|
|
|
|
|
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
|
|
|
|
|
auto dy = paddle::platform::GetInputNode(op, "Y@GRAD", ngb_node_map);
|
|
|
|
|
std::shared_ptr<ngraph::Node> GetCrossEntropyGrad(
|
|
|
|
|
std::shared_ptr<ngraph::Node> x, std::shared_ptr<ngraph::Node> label,
|
|
|
|
|
std::shared_ptr<ngraph::Node> dy, const bool is_soft_label,
|
|
|
|
|
int ignore_index) {
|
|
|
|
|
auto x_shape = x->get_shape();
|
|
|
|
|
auto rank = x_shape.size();
|
|
|
|
|
|
|
|
|
@ -111,9 +95,8 @@ void BuildCrossEntropyGradNode(
|
|
|
|
|
if (!is_soft_label) {
|
|
|
|
|
auto label_shape = label->get_shape();
|
|
|
|
|
label_shape.pop_back();
|
|
|
|
|
label = paddle::platform::NgReshaper(label, label_shape);
|
|
|
|
|
label = platform::NgReshaper(label, label_shape);
|
|
|
|
|
|
|
|
|
|
auto ignore_index = op_attrs.Get<int>("ignore_index");
|
|
|
|
|
auto ignore_node = ngraph::op::Constant::create(
|
|
|
|
|
label->get_element_type(), label_shape, {ignore_index});
|
|
|
|
|
auto not_equal_node =
|
|
|
|
@ -128,7 +111,7 @@ void BuildCrossEntropyGradNode(
|
|
|
|
|
|
|
|
|
|
auto dy_shape = dy->get_shape();
|
|
|
|
|
dy_shape.pop_back();
|
|
|
|
|
auto dy_reshape = paddle::platform::NgReshaper(dy, dy_shape);
|
|
|
|
|
auto dy_reshape = platform::NgReshaper(dy, dy_shape);
|
|
|
|
|
auto dy_bcast = std::make_shared<ngraph::op::Broadcast>(
|
|
|
|
|
dy_reshape, x_shape, ngraph::AxisSet{rank - 1});
|
|
|
|
|
if (x->get_element_type() != label->get_element_type()) {
|
|
|
|
@ -140,7 +123,35 @@ void BuildCrossEntropyGradNode(
|
|
|
|
|
if (!is_soft_label) {
|
|
|
|
|
xe_grad = xe_grad * mask;
|
|
|
|
|
}
|
|
|
|
|
return xe_grad;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BuildCrossEntropyNode(
|
|
|
|
|
const std::shared_ptr<paddle::framework::OperatorBase>& op,
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
|
|
|
|
|
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
|
|
|
|
|
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
|
|
|
|
|
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
|
|
|
|
|
int ignore_index = op_attrs.Get<int>("ignore_index");
|
|
|
|
|
auto xe = GetCrossEntropy(x, label, is_soft_label, ignore_index);
|
|
|
|
|
paddle::platform::SetOutputNode(op, "Y", xe, ngb_node_map);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BuildCrossEntropyGradNode(
|
|
|
|
|
const std::shared_ptr<paddle::framework::OperatorBase>& op,
|
|
|
|
|
std::shared_ptr<
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
|
|
|
|
ngb_node_map) {
|
|
|
|
|
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
|
|
|
|
|
const bool is_soft_label = op_attrs.Get<bool>("soft_label");
|
|
|
|
|
int ignore_index = op_attrs.Get<int>("ignore_index");
|
|
|
|
|
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
|
|
|
|
|
auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
|
|
|
|
|
auto dy = paddle::platform::GetInputNode(op, "Y@GRAD", ngb_node_map);
|
|
|
|
|
auto xe_grad = GetCrossEntropyGrad(x, label, dy, is_soft_label, ignore_index);
|
|
|
|
|
paddle::platform::SetOutputNode(op, "X@GRAD", xe_grad, ngb_node_map);
|
|
|
|
|
}
|
|
|
|
|
} // namespace ngraphs
|
|
|
|
|