| 
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -26,59 +26,82 @@ limitations under the License. */
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					namespace paddle {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					namespace operators {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					namespace ngraphs {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					std::shared_ptr<ngraph::Node> remove_trailing_one(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const std::shared_ptr<ngraph::Node>& input) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto shape = input->get_shape();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (shape.back() == 1) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    shape.pop_back();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    return platform::NgReshaper(input, shape);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  } else {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    return input;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					std::shared_ptr<ngraph::Node> GetCrossEntropy(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    std::shared_ptr<ngraph::Node> x, std::shared_ptr<ngraph::Node> label,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const bool is_soft_label, int ignore_index) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto label_shape = label->get_shape();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto x_shape = x->get_shape();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto label_rank = label_shape.size();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto x_rank = x_shape.size();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  std::shared_ptr<ngraph::Node> x_2d = x, label_2d = label;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto label_2d_shape = label_shape, x_2d_shape = x_shape;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (label_rank > 2) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    label_2d_shape = paddle::platform::FlattenTo2d(label_shape, label_rank - 1);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    label_2d = paddle::platform::NgReshaper(label, label_2d_shape);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					std::shared_ptr<ngraph::Node> flatten_node(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const std::shared_ptr<ngraph::Node>& input) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto shape = input->get_shape();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto rank = shape.size();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto output = input;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (rank > 2) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto shape_2d = paddle::platform::FlattenTo2d(shape, rank - 1);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    output = paddle::platform::NgReshaper(input, shape_2d);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (x_rank > 2) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    x_2d_shape = platform::FlattenTo2d(x_shape, x_rank - 1);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    x_2d = platform::NgReshaper(x, x_2d_shape);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  return output;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					std::shared_ptr<ngraph::Node> convert_to_node_type(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const std::shared_ptr<ngraph::Node>& input,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const std::shared_ptr<ngraph::Node>& ref) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto output = input;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (input->get_element_type() != ref->get_element_type()) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    output =
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        std::make_shared<ngraph::op::Convert>(input, ref->get_element_type());
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  return output;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto batch_size = x_2d_shape.at(0);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					std::shared_ptr<ngraph::Node> create_xe(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const std::shared_ptr<ngraph::Node>& one_hot,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const std::shared_ptr<ngraph::Node>& x) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto node_log = std::make_shared<ngraph::op::Log>(x);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  std::shared_ptr<ngraph::Node> node_1_hot = label_2d;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto node_mul = one_hot * node_log;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto node_sum = std::make_shared<ngraph::op::Sum>(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      node_mul, ngraph::AxisSet{x->get_shape().size() - 1});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto shape = x->get_shape();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  shape.back() = 1;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  return platform::NgReshaper(-node_sum, shape);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					std::shared_ptr<ngraph::Node> create_mask(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const std::shared_ptr<ngraph::Node>& label, int ignore_index) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto ignore_node = paddle::platform::CreateConstant(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      label->get_element_type(), label->get_shape(), {ignore_index});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto not_equal_node =
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      std::make_shared<ngraph::op::NotEqual>(label, ignore_node);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  return not_equal_node;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					std::shared_ptr<ngraph::Node> create_one_hot(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const std::shared_ptr<ngraph::Node>& label,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const std::shared_ptr<ngraph::Node>& x) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto label_shape = label->get_shape();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  return std::make_shared<ngraph::op::OneHot>(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      remove_trailing_one(label), x->get_shape(), x->get_shape().size() - 1);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					std::shared_ptr<ngraph::Node> GetCrossEntropy(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    std::shared_ptr<ngraph::Node> x, std::shared_ptr<ngraph::Node> label,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const bool is_soft_label, int ignore_index) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  std::shared_ptr<ngraph::Node> node_1_hot = label;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (!is_soft_label) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto label_1d =
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        platform::NgReshaper(label_2d, ngraph::Shape{label_2d_shape.at(0)});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    node_1_hot = std::make_shared<ngraph::op::OneHot>(label_1d, x_2d_shape, 1);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (x->get_element_type() != node_1_hot->get_element_type()) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    node_1_hot = std::make_shared<ngraph::op::Convert>(node_1_hot,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                                                       x->get_element_type());
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    node_1_hot = create_one_hot(label, x);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  node_1_hot = convert_to_node_type(node_1_hot, x);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto node_log = std::make_shared<ngraph::op::Log>(x_2d);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto high_clip = ngraph::op::Constant::create(node_log->get_element_type(),
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                                                node_log->get_shape(), {1e20});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto low_clip = ngraph::op::Constant::create(node_log->get_element_type(),
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                                               node_log->get_shape(), {-1e20});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto node_min = std::make_shared<ngraph::op::Minimum>(node_log, high_clip);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto node_max = std::make_shared<ngraph::op::Maximum>(node_min, low_clip);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto node_mul = node_1_hot * node_log;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto node_sum =
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      std::make_shared<ngraph::op::Sum>(node_mul, ngraph::AxisSet{1});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto node_neg = std::make_shared<ngraph::op::Negative>(node_sum);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto xe = platform::NgReshaper(node_neg, ngraph::Shape{batch_size, 1});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto xe = create_xe(node_1_hot, x);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (!is_soft_label) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto ignore_node = ngraph::op::Constant::create(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        label->get_element_type(), label_2d_shape, {ignore_index});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto not_equal_node =
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        std::make_shared<ngraph::op::NotEqual>(label_2d, ignore_node);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto mask = std::make_shared<ngraph::op::Convert>(not_equal_node,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                                                      xe->get_element_type());
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto mask = convert_to_node_type(create_mask(label, ignore_index), xe);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    xe = xe * mask;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  return xe;
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
				
				 | 
				 | 
				
					@ -93,30 +116,17 @@ std::shared_ptr<ngraph::Node> GetCrossEntropyGrad(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  std::shared_ptr<ngraph::Node> mask;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (!is_soft_label) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto label_shape = label->get_shape();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    label_shape.pop_back();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    label = platform::NgReshaper(label, label_shape);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto ignore_node = ngraph::op::Constant::create(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        label->get_element_type(), label_shape, {ignore_index});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    auto not_equal_node =
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        std::make_shared<ngraph::op::NotEqual>(label, ignore_node);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    mask = std::make_shared<ngraph::op::Convert>(not_equal_node,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                                                 x->get_element_type());
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    mask = std::make_shared<ngraph::op::Broadcast>(mask, x_shape,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                                                   ngraph::AxisSet{rank - 1});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    label = std::make_shared<ngraph::op::OneHot>(label, x_shape, rank - 1);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    mask = convert_to_node_type(create_mask(label, ignore_index), x);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    mask = std::make_shared<ngraph::op::Broadcast>(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        remove_trailing_one(mask), x_shape, ngraph::AxisSet{rank - 1});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    label = create_one_hot(label, x);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto dy_shape = dy->get_shape();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  dy_shape.pop_back();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto dy_reshape = platform::NgReshaper(dy, dy_shape);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto dy_reshape = remove_trailing_one(dy);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto dy_bcast = std::make_shared<ngraph::op::Broadcast>(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      dy_reshape, x_shape, ngraph::AxisSet{rank - 1});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  if (x->get_element_type() != label->get_element_type()) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    label = std::make_shared<ngraph::op::Convert>(label, x->get_element_type());
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  }
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  label = convert_to_node_type(label, x);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto xe_grad = -label * dy_bcast / x;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
	
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
				
				 | 
				 | 
				
					@ -154,9 +164,80 @@ void BuildCrossEntropyGradNode(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto xe_grad = GetCrossEntropyGrad(x, label, dy, is_soft_label, ignore_index);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  paddle::platform::SetOutputNode(op, "X@GRAD", xe_grad, ngb_node_map);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					void BuildCrossEntropy2Node(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const std::shared_ptr<paddle::framework::OperatorBase>& op,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    std::shared_ptr<
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        ngb_node_map) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto op_attrs = paddle::framework::AttrReader(op->Attrs());
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  int ignore_index = op_attrs.Get<int>("ignore_index");
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto rank = x->get_shape().size();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto one_hot = convert_to_node_type(create_one_hot(label, x), x);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto xe = create_xe(one_hot, x);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto mask = convert_to_node_type(create_mask(label, ignore_index), xe);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  xe = xe * mask;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  std::shared_ptr<ngraph::Node> node_sum =
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      std::make_shared<ngraph::op::Sum>(one_hot * x, ngraph::AxisSet{rank - 1});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  node_sum = paddle::platform::NgReshaper(node_sum, mask->get_shape());
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto matchx = mask * node_sum;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  paddle::platform::SetOutputNode(op, "MatchX", matchx, ngb_node_map);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  platform::SetOutputNode(op, "XShape", x, ngb_node_map);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  paddle::platform::SetOutputNode(op, "Y", xe, ngb_node_map);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					void BuildCrossEntropyGrad2Node(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    const std::shared_ptr<paddle::framework::OperatorBase>& op,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					    std::shared_ptr<
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					        ngb_node_map) {
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto op_attrs = paddle::framework::AttrReader(op->Attrs());
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  int ignore_index = op_attrs.Get<int>("ignore_index");
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto matchx = paddle::platform::GetInputNode(op, "MatchX", ngb_node_map);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto label = paddle::platform::GetInputNode(op, "Label", ngb_node_map);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto x = paddle::platform::GetInputNode(op, "XShape", ngb_node_map);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto dy = paddle::platform::GetInputNode(op, framework::GradVarName("Y"),
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                                           ngb_node_map);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  matchx = remove_trailing_one(matchx);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  label = remove_trailing_one(label);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  x = remove_trailing_one(x);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  dy = remove_trailing_one(dy);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto x_shape = x->get_shape();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto rank = x_shape.size();
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto one_hot = convert_to_node_type(create_one_hot(label, x), x);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto mask = convert_to_node_type(create_mask(label, ignore_index), x);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto zero = paddle::platform::CreateConstant(matchx->get_element_type(),
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                                               matchx->get_shape(), {0});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto one = paddle::platform::CreateConstant(matchx->get_element_type(),
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                                              matchx->get_shape(), {1});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto is_zero = std::make_shared<ngraph::op::Equal>(matchx, zero);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  matchx = std::make_shared<ngraph::op::Select>(is_zero, one, matchx);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto dy_bcast = std::make_shared<ngraph::op::Broadcast>(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      mask * dy, x_shape, ngraph::AxisSet{rank - 1});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto matchx_bcast = std::make_shared<ngraph::op::Broadcast>(
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					      matchx, x_shape, ngraph::AxisSet{rank - 1});
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  auto xe_grad = -dy_bcast * one_hot / matchx_bcast;
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					  paddle::platform::SetOutputNode(op, framework::GradVarName("X"), xe_grad,
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					                                  ngb_node_map);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}  // namespace ngraphs
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}  // namespace operators
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					}  // namespace paddle
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					REGISTER_NG_OP(cross_entropy, BuildCrossEntropyNode);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					REGISTER_NG_OP(cross_entropy_grad, BuildCrossEntropyGradNode);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					REGISTER_NG_OP(cross_entropy2, BuildCrossEntropy2Node);
 | 
				
			
			
		
	
		
			
				
					 | 
					 | 
				
				 | 
				 | 
				
					REGISTER_NG_OP(cross_entropy_grad2, BuildCrossEntropyGrad2Node);
 | 
				
			
			
		
	
	
		
			
				
					| 
						
						
						
					 | 
				
				 | 
				 | 
				
					
 
 |