You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/operators/ngraph/ops/matmul_op.h

249 lines
8.5 KiB

/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "ngraph/ngraph.hpp"
#include "paddle/fluid/operators/ngraph/ops/elementwise_scalar_op.h"
#include "paddle/fluid/operators/ngraph/ops/op_bridge.h"
#include "paddle/fluid/platform/ngraph_helper.h"
namespace paddle {
namespace operators {
namespace ngraphs {
std::shared_ptr<ngraph::Node> transposeAndFlat3D(
const std::shared_ptr<ngraph::Node>& input, const bool transpose,
bool x = true) {
auto shape = input->get_shape();
size_t n = shape.size();
std::shared_ptr<ngraph::Node> output;
if (n >= 3) {
std::vector<size_t> order(n);
std::iota(std::begin(order), std::end(order), 0);
size_t outer = 1;
for (size_t i = 0; i < n - 2; i++) {
outer = outer * shape[i];
}
std::vector<size_t> reshape{outer, shape[n - 2], shape[n - 1]};
if (transpose == true) {
order[n - 2] = n - 1;
order[n - 1] = n - 2;
reshape[2] = shape[n - 2];
reshape[1] = shape[n - 1];
}
output = std::make_shared<ngraph::op::Reshape>(
input, ngraph::AxisVector(order), ngraph::Shape(reshape));
} else {
std::shared_ptr<ngraph::Node> temp;
if (n == 1 && x == true) {
temp = std::make_shared<ngraph::op::Reshape>(input, ngraph::AxisVector{0},
ngraph::Shape{1, shape[0]});
} else if (n == 1 && x == false) {
temp = std::make_shared<ngraph::op::Reshape>(input, ngraph::AxisVector{0},
ngraph::Shape{shape[0], 1});
} else {
temp = input;
}
auto temp_shape = temp->get_shape();
if (transpose == true) {
output = std::make_shared<ngraph::op::Reshape>(
temp, ngraph::AxisVector{1, 0},
ngraph::Shape{temp_shape[1], temp_shape[0]});
} else {
output = temp;
}
}
return output;
}
std::shared_ptr<ngraph::Node> broadcast3D(
const std::shared_ptr<ngraph::Node>& input, size_t axis0) {
auto shape = input->get_shape();
size_t n = shape.size();
if (n == 2) {
auto output = std::make_shared<ngraph::op::Broadcast>(
input, ngraph::Shape{axis0, shape[0], shape[1]}, ngraph::AxisSet{0});
return output;
}
return input;
}
std::shared_ptr<ngraph::Node> dotOp(const std::shared_ptr<ngraph::Node>& a,
const std::shared_ptr<ngraph::Node>& b) {
std::shared_ptr<ngraph::Node> out;
auto a_shape = a->get_shape();
auto na = a_shape.size();
auto b_shape = b->get_shape();
auto nb = b_shape.size();
if (na > 2 && nb > 2) {
out = std::make_shared<ngraph::op::BatchMatMul>(a, b);
} else {
out = std::make_shared<ngraph::op::Dot>(a, b);
}
return out;
}
std::shared_ptr<ngraph::Node> reshapeToOriginal(
std::shared_ptr<ngraph::Node> input, const ngraph::Shape& shape) {
auto input_shape = input->get_shape();
std::vector<size_t> axis(input_shape.size());
std::iota(axis.begin(), axis.end(), 0);
auto out = std::make_shared<ngraph::op::Reshape>(input, axis, shape);
return out;
}
void BuildMatMulNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
auto y = paddle::platform::GetInputNode(op, "Y", ngb_node_map);
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
bool transpose_x = op_attrs.Get<bool>("transpose_X");
bool transpose_y = op_attrs.Get<bool>("transpose_Y");
float alpha = op_attrs.Get<float>("alpha");
std::shared_ptr<ngraph::Node> out;
auto x_shape = x->get_shape();
auto y_shape = y->get_shape();
size_t nx = x_shape.size();
size_t ny = y_shape.size();
x = transposeAndFlat3D(x, transpose_x, true);
y = transposeAndFlat3D(y, transpose_y, false);
auto y_shape3 = y->get_shape();
auto x_shape3 = x->get_shape();
if (nx > 2 || ny > 2) {
ngraph::Shape out_shape = x_shape;
if (nx != 3) {
x = broadcast3D(x, y_shape3[0]);
out_shape = y_shape;
}
if (ny != 3) {
y = broadcast3D(y, x_shape3[0]);
out_shape = x_shape;
}
auto nout = out_shape.size();
auto out3 = std::make_shared<ngraph::op::BatchMatMul>(x, y);
auto out3_shape = out3->get_shape();
out_shape[nout - 1] = out3_shape[2];
out_shape[nout - 2] = out3_shape[1];
out = std::make_shared<ngraph::op::Reshape>(
out3, ngraph::AxisVector{0, 1, 2}, out_shape);
} else {
out = std::make_shared<ngraph::op::Dot>(x, y);
}
auto out_shape = out->get_shape();
std::vector<size_t> axis(out_shape.size());
std::iota(axis.begin(), axis.end(), 0);
for (size_t i = out_shape.size() - 1; i > 0; i--) {
if (out_shape[i] == 1) {
out_shape.erase(out_shape.begin() + i);
}
}
auto out_ = std::make_shared<ngraph::op::Reshape>(
out, ngraph::AxisVector(axis), out_shape);
auto out_alpha = ElementwiseScalar<ngraph::op::Multiply>(alpha, out_);
paddle::platform::SetOutputNode(op, "Out", out_alpha, ngb_node_map);
}
void BuildMatMulGradNode(
const std::shared_ptr<paddle::framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
auto y = paddle::platform::GetInputNode(op, "Y", ngb_node_map);
auto x = paddle::platform::GetInputNode(op, "X", ngb_node_map);
bool is_dx = paddle::platform::HasOutput(op, "X@GRAD") ? true : false;
bool is_dy = paddle::platform::HasOutput(op, "Y@GRAD") ? true : false;
bool transpose_x = op_attrs.Get<bool>("transpose_X");
bool transpose_y = op_attrs.Get<bool>("transpose_Y");
float alpha = op_attrs.Get<float>("alpha");
auto dout_shape = dout->get_shape();
auto x_shape = x->get_shape();
auto y_shape = y->get_shape();
size_t nx = x_shape.size();
size_t ny = y_shape.size();
size_t ndout = dout_shape.size();
std::shared_ptr<ngraph::Node> x2, y2;
std::shared_ptr<ngraph::Node> dout2;
x2 = transposeAndFlat3D(x, false);
y2 = transposeAndFlat3D(y, false, false);
dout2 = transposeAndFlat3D(dout, false);
auto x2_shape = x2->get_shape();
auto y2_shape = y2->get_shape();
if (nx >= 3 || ny >= 3) {
std::shared_ptr<ngraph::Node> dout_temp;
if (ndout == 2) {
dout_temp = std::make_shared<ngraph::op::Reshape>(
dout, ngraph::AxisVector{0, 1},
ngraph::Shape{dout_shape[0], dout_shape[1], 1});
if (ny < 3) {
dout2 = dout_temp;
} else {
dout2 = transposeAndFlat3D(dout_temp, true);
}
}
x2 = broadcast3D(x2, y_shape[0]);
y2 = broadcast3D(y2, x_shape[0]);
} else {
dout2 = transposeAndFlat3D(dout, false, nx == 1 && transpose_x == false);
}
if (transpose_y == false) {
y2 = transposeAndFlat3D(y2, true);
}
if (transpose_x == false) {
x2 = transposeAndFlat3D(x2, true);
}
auto dx = dotOp(dout2, y2);
auto dy = dotOp(x2, dout2);
if (transpose_x == true) {
dx = transposeAndFlat3D(dx, true);
}
if (transpose_y == true) {
dy = transposeAndFlat3D(dy, true);
}
if (nx < 3 && ny >= 3) {
dx = std::make_shared<ngraph::op::Sum>(dx, ngraph::AxisSet{0});
}
if (ny < 3 && nx >= 3) {
dy = std::make_shared<ngraph::op::Sum>(dy, ngraph::AxisSet{0});
}
auto dx_t = reshapeToOriginal(dx, x_shape);
auto dy_t = reshapeToOriginal(dy, y_shape);
auto dx_scale = ElementwiseScalar<ngraph::op::Multiply>(1 / alpha, dx_t);
auto dy_scale = ElementwiseScalar<ngraph::op::Multiply>(1 / alpha, dy_t);
if (is_dx)
paddle::platform::SetOutputNode(op, "X@GRAD", dx_scale, ngb_node_map);
if (is_dy)
paddle::platform::SetOutputNode(op, "Y@GRAD", dy_scale, ngb_node_map);
}
} // namespace ngraphs
} // namespace operators
} // namespace paddle
REGISTER_NG_OP(matmul, BuildMatMulNode);
REGISTER_NG_OP(matmul_grad, BuildMatMulGradNode);