use unpack graph primitive instead add testcases for all grad interface remove debug log format code remove dumpfuncgraph resolve clang-format resolve reviews resolve cpplint fix reviewpull/62/head
parent
4b702c4c66
commit
d3f733fa25
@ -0,0 +1,94 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "operator/composite/unpack_call.h"
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "./common.h"
|
||||
#include "pipeline/static_analysis/abstract_value.h"
|
||||
#include "pipeline/static_analysis/dshape.h"
|
||||
#include "pipeline/static_analysis/param_validator.h"
|
||||
#include "operator/cc_implementations.h"
|
||||
#include "ir/anf.h"
|
||||
#include "optimizer/opt.h"
|
||||
#include "utils/symbolic.h"
|
||||
#include "pybind_api/api_register.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support composite operators definition
|
||||
namespace prim {
|
||||
using mindspore::abstract::AbstractAttribute;
|
||||
using mindspore::abstract::AbstractBase;
|
||||
using mindspore::abstract::AbstractDictionary;
|
||||
using mindspore::abstract::AbstractDictionaryPtr;
|
||||
using mindspore::abstract::AbstractFunction;
|
||||
using mindspore::abstract::AbstractKeywordArg;
|
||||
using mindspore::abstract::AbstractTuple;
|
||||
using mindspore::abstract::AbstractTuplePtr;
|
||||
|
||||
FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) {
|
||||
// slice a tensor
|
||||
// args: tensor, slice or slice tuple
|
||||
const std::string op_name = std::string("UnpackCall");
|
||||
size_t arg_length = args_spec_list.size();
|
||||
if (arg_length < 2) {
|
||||
MS_LOG(EXCEPTION) << op_name << " requires at least two args, but got " << arg_length << ".";
|
||||
}
|
||||
|
||||
(void)abstract::CheckArg<AbstractFunction>(op_name, args_spec_list, 0);
|
||||
auto ret_graph = std::make_shared<FuncGraph>();
|
||||
ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true);
|
||||
|
||||
AnfNodePtr fnNode = ret_graph->add_parameter();
|
||||
std::vector<AnfNodePtr> elems;
|
||||
elems.push_back(fnNode);
|
||||
for (size_t index = 1; index < arg_length; index++) {
|
||||
MS_EXCEPTION_IF_NULL(args_spec_list[index]);
|
||||
if (args_spec_list[index]->isa<AbstractTuple>()) {
|
||||
auto arg_tuple = args_spec_list[index]->cast<AbstractTuplePtr>();
|
||||
AnfNodePtr para_tuple = ret_graph->add_parameter();
|
||||
for (size_t i = 0; i < arg_tuple->size(); ++i) {
|
||||
elems.push_back(
|
||||
ret_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), para_tuple, NewValueNode(SizeToInt(i))}));
|
||||
}
|
||||
} else if (args_spec_list[index]->isa<AbstractDictionary>()) {
|
||||
AbstractDictionaryPtr arg_dict = args_spec_list[index]->cast<AbstractDictionaryPtr>();
|
||||
AnfNodePtr para_dict = ret_graph->add_parameter();
|
||||
auto dict_elems = arg_dict->elements();
|
||||
(void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems),
|
||||
[ret_graph, para_dict](const AbstractAttribute& item) {
|
||||
auto dict_get_item = ret_graph->NewCNode(
|
||||
{NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)});
|
||||
return ret_graph->NewCNode(
|
||||
{NewValueNode(prim::kPrimMakeKeywordArg), NewValueNode(item.first), dict_get_item});
|
||||
});
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << op_name << " require args should be tuple or dict, but got "
|
||||
<< args_spec_list[index]->ToString();
|
||||
}
|
||||
}
|
||||
ret_graph->set_output(ret_graph->NewCNode(elems));
|
||||
return ret_graph;
|
||||
}
|
||||
|
||||
REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) {
|
||||
(void)py::class_<UnpackCall, MetaFuncGraph, std::shared_ptr<UnpackCall>>(*m, "UnpackCall_")
|
||||
.def(py::init<std::string&>());
|
||||
}));
|
||||
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
@ -0,0 +1,54 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_
|
||||
#define MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
|
||||
#include "pipeline/static_analysis/static_analysis.h"
|
||||
#include "utils/misc.h"
|
||||
#include "utils/any.h"
|
||||
#include "ir/dtype.h"
|
||||
#include "ir/meta_func_graph.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
// namespace to support composite operators definition
|
||||
namespace prim {
|
||||
|
||||
// Expand the tuple and dict parameters generated when parsing the function call,
|
||||
// and generate positional parameters and key-value pairs for function.
|
||||
class UnpackCall : public MetaFuncGraph {
|
||||
public:
|
||||
explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {}
|
||||
~UnpackCall() override = default;
|
||||
MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override;
|
||||
friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; }
|
||||
};
|
||||
using UnpackCallPtr = std::shared_ptr<UnpackCall>;
|
||||
|
||||
} // namespace prim
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_OPERATOR_COMPOSITE_UNPACK_CALL_H_
|
@ -0,0 +1,144 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "optimizer/irpass/grad_var_prepare.h"
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
|
||||
#include "operator/composite/composite.h"
|
||||
#include "operator/ops.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
|
||||
static AnfNodePtr GenerateUnpackGraphNode(std::vector<AnfNodePtr> inputs_y, FuncGraphPtr func_graph,
|
||||
AnfNodePtr func_node, bool is_unpack, bool sens_param) {
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
MS_EXCEPTION_IF_NULL(func_node);
|
||||
std::vector<AnfNodePtr> nodes;
|
||||
AnfNodePtr unpack_graph_node = nullptr;
|
||||
if (is_unpack) {
|
||||
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, true);
|
||||
nodes.push_back(NewValueNode(unpack_graph));
|
||||
nodes.push_back(func_node);
|
||||
// {unpackcall, {GradOperation, ...}, args...}
|
||||
std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes),
|
||||
[](const AnfNodePtr& node) { return node; });
|
||||
unpack_graph_node = func_graph->NewCNode(nodes);
|
||||
} else {
|
||||
auto unpack_graph = std::make_shared<prim::UnpackGraphPrimitive>("unpack_graph", sens_param, false);
|
||||
nodes.push_back(NewValueNode(unpack_graph));
|
||||
nodes.push_back(func_node);
|
||||
// {{GradOperation, ...}, args...}
|
||||
std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes),
|
||||
[](const AnfNodePtr& node) { return node; });
|
||||
unpack_graph_node = func_graph->NewCNode(nodes);
|
||||
}
|
||||
return unpack_graph_node;
|
||||
}
|
||||
|
||||
// get metagraph of value node
|
||||
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) {
|
||||
ValuePtr value;
|
||||
if (IsValueNode<prim::DoSignaturePrimitive>(node)) {
|
||||
value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function();
|
||||
} else {
|
||||
value = GetValueNode(node);
|
||||
}
|
||||
if (value == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return value->cast<MetaFuncGraphPtr>();
|
||||
}
|
||||
|
||||
// check if node is a specific metafuncgraph op
|
||||
bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_graph) {
|
||||
if (node != nullptr) {
|
||||
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
|
||||
if (meta_func_graph_ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (meta_func_graph_ptr->type_name() == meta_func_graph->type_name()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// {{GradOperation, g, w}, Ys}
|
||||
// {UnPackCall, {GradOperation, g, w}, Ys}
|
||||
AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr&, const AnfNodePtr& node) {
|
||||
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// {{...}, Ys}
|
||||
auto inputs_y = node->cast<CNodePtr>()->inputs();
|
||||
std::vector<AnfNodePtr> inputs_x;
|
||||
if (IsCNode(inputs_y[0])) {
|
||||
inputs_x = inputs_y[0]->cast<CNodePtr>()->inputs();
|
||||
} else if (IsMetaFuncGraph(inputs_y[0], unpack_op_) && IsCNode(inputs_y[1])) {
|
||||
inputs_x = inputs_y[1]->cast<CNodePtr>()->inputs();
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// {{...}, Xs}
|
||||
if (inputs_x.size() < 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// {GradOperation, g, w} or {GradOperation, g}
|
||||
if (!IsMetaFuncGraph(inputs_x[0], grad_op_)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]);
|
||||
if (meta_func == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto grad_op_ptr = meta_func->cast<prim::GradOperationPtr>();
|
||||
auto func_node = inputs_x[1];
|
||||
if (!IsValueNode<FuncGraph>(func_node)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AnfNodePtr unpack_graph_node =
|
||||
GenerateUnpackGraphNode(inputs_y, node->cast<CNodePtr>()->func_graph(), func_node,
|
||||
IsMetaFuncGraph(inputs_y[0], unpack_op_), grad_op_ptr->sens_param());
|
||||
// constuct new grad_opration
|
||||
inputs_x[1] = unpack_graph_node;
|
||||
auto grad_op_cnode = node->func_graph()->NewCNode(inputs_x);
|
||||
if (IsMetaFuncGraph(inputs_y[0], unpack_op_)) {
|
||||
inputs_y[1] = grad_op_cnode;
|
||||
} else {
|
||||
inputs_y[0] = grad_op_cnode;
|
||||
}
|
||||
auto cnode = node->func_graph()->NewCNode(inputs_y);
|
||||
return cnode;
|
||||
}
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,55 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_
|
||||
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <unordered_map>
|
||||
#include <memory>
|
||||
|
||||
#include "operator/composite/composite.h"
|
||||
#include "operator/ops.h"
|
||||
#include "optimizer/irpass.h"
|
||||
#include "optimizer/optimizer.h"
|
||||
#include "ir/visitor.h"
|
||||
#include "ir/func_graph.h"
|
||||
#include "ir/func_graph_cloner.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
|
||||
// {{GradOperation, g, w}, Ys}
|
||||
// {UnPackCall, {GradOperation, g, w}, Ys}
|
||||
class GradVarPrepare : public AnfVisitor {
|
||||
public:
|
||||
GradVarPrepare()
|
||||
: grad_op_(std::make_shared<prim::GradOperation>("grad")),
|
||||
unpack_op_(std::make_shared<prim::UnpackCall>("unpack_call")) {}
|
||||
~GradVarPrepare() override = default;
|
||||
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override;
|
||||
|
||||
private:
|
||||
MetaFuncGraphPtr grad_op_;
|
||||
MetaFuncGraphPtr unpack_op_;
|
||||
};
|
||||
} // namespace irpass
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_GRAD_VAR_PREPARE_H_
|
Loading…
Reference in new issue