!1189 Decoupling py default param from Parameter

Merge pull request !1189 from leopz/master
pull/1189/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 04ac611fe8

@ -26,6 +26,7 @@
#include "utils/graph_utils.h"
#include "utils/symbolic.h"
#include "ir/meta_func_graph.h"
#include "ir/param_value_py.h"
#include "pipeline/parse/python_adapter.h"
#include "pipeline/parse/resolve.h"
#include "operator/composite/composite.h"
@ -469,7 +470,8 @@ void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNode
MS_LOG(EXCEPTION) << "Param could not cast to parameter";
}
if (param_ptr->has_default()) {
ofs << " = @" << DumpObject(param_ptr->default_param(), "D");
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_ptr->default_param());
ofs << " = @" << DumpObject(param_value->value(), "D");
}
// output comment
@ -1650,7 +1652,8 @@ class IrParser {
// load parameter default value from serialized file
py::object default_obj = LoadObject(lexer_.GetTokenText());
param->set_default_param(default_obj);
auto param_value_new = std::make_shared<ParamValuePy>(default_obj);
param->set_default_param(param_value_new);
tok = lexer_.GetNextToken();
}

@ -21,12 +21,17 @@
#include <vector>
#include <string>
#include "pybind11/pybind11.h"
#include "ir/meta_func_graph.h"
#include "ir/param_value_py.h"
#include "ir/primitive.h"
#include "utils/graph_utils.h"
#include "utils/utils.h"
#include "operator/composite/composite.h"
#include "ir/meta_tensor.h"
namespace py = pybind11;
namespace mindspore {
// namespace to support debug utils
@ -312,17 +317,21 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) {
for (auto &parameter : key->parameters()) {
buffer_ << "<tr><td>";
buffer_ << parameter->ToString();
auto py_p = dyn_cast<Parameter>(parameter)->default_param();
if (py::hasattr(py_p, "default_input")) {
py_p = py_p.attr("default_input");
if (py::hasattr(py_p, PYTHON_TENSOR_FLAG)) {
std::shared_ptr<tensor::Tensor> m_tensor = py_p.cast<std::shared_ptr<tensor::Tensor>>();
py::tuple shape = m_tensor->GetPyTupleShape();
buffer_ << "[" << std::string(py::str(shape)) << "]";
} else if (py::hasattr(py_p, PYTHON_META_TENSOR_FLAG)) {
std::shared_ptr<tensor::MetaTensor> m_tensor = py_p.cast<std::shared_ptr<tensor::MetaTensor>>();
py::tuple shape = m_tensor->GetPyTupleShape();
buffer_ << "[" << std::string(py::str(shape)) << "]";
auto param = parameter->cast<ParameterPtr>();
if (param->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param->default_param());
auto py_p = param_value->value();
if (py::hasattr(py_p, "default_input")) {
py_p = py_p.attr("default_input");
if (py::hasattr(py_p, PYTHON_TENSOR_FLAG)) {
auto m_tensor = py_p.cast<std::shared_ptr<tensor::Tensor>>();
py::tuple shape = m_tensor->GetPyTupleShape();
buffer_ << "[" << py::str(shape) << "]";
} else if (py::hasattr(py_p, PYTHON_META_TENSOR_FLAG)) {
auto m_tensor = py_p.cast<std::shared_ptr<tensor::MetaTensor>>();
py::tuple shape = m_tensor->GetPyTupleShape();
buffer_ << "[" << py::str(shape) << "]";
}
}
}
buffer_ << "</td></tr>";

@ -18,9 +18,9 @@
#include <utility>
#include <fstream>
#include <sstream>
#include <climits>
#include "ir/anf.h"
#include "pipeline/parse/parse.h"
#include "pipeline/parse/python_adapter.h"
#include "utils/convert_utils.h"
namespace mindspore {
std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) {

@ -24,13 +24,10 @@
#include <utility>
#include <vector>
#include "pybind11/pybind11.h"
#include "ir/base.h"
#include "debug/trace_info.h"
namespace mindspore {
namespace py = pybind11;
// namespace to support intermediate representation definition
enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSourceLineTipInLine = 2 };

@ -21,7 +21,6 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "utils/any.h"
#include "ir/anf.h"
namespace mindspore {

@ -19,8 +19,6 @@
#include <fstream>
#include <sstream>
#include "ir/anf.h"
#include "pipeline/parse/parse.h"
#include "pipeline/parse/python_adapter.h"
namespace mindspore {
std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) {

@ -24,12 +24,9 @@
#include <utility>
#include <vector>
#include "pybind11/pybind11.h"
#include "ir/base.h"
namespace mindspore {
namespace py = pybind11;
class TraceInfo;
using TraceInfoPtr = std::shared_ptr<TraceInfo>;
class Location;

@ -23,21 +23,11 @@
#include <vector>
#include <unordered_map>
#include "ir/visitor.h"
#include "pipeline/static_analysis/static_analysis.h"
#include "operator/ops.h"
#include "parallel/ops_info/ops_utils.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
namespace mindspore {
// namespace to support intermediate representation definition
// Methods of AnfNode
TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildType(); }
BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); }
std::string AnfNode::ToString() const {
return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info());
}
CNode::CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph)
: AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {}
@ -85,66 +75,6 @@ std::string CNode::DebugString(int recursive_level) const {
return buffer.str();
}
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) {
if (operator_info_ != nullptr) {
MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
<< ", using the new one: " << operator_info->name();
auto old_ptr = operator_info_;
operator_info_ = operator_info;
return old_ptr;
}
operator_info_ = operator_info;
return nullptr;
}
std::string CNode::fullname_with_scope() {
// if full name is set, return its name immediately
if (!fullname_with_scope_.empty()) {
return fullname_with_scope_;
}
if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) ||
IsApply(prim::kPrimHistogramSummary)) {
std::string tag = GetValue<std::string>(GetValueNode(input(1)));
if (tag == "") {
MS_LOG(EXCEPTION) << "The tag name is null, should be valid string";
}
std::string name;
if (IsApply(prim::kPrimScalarSummary)) {
name = tag + "[:Scalar]";
} else if (IsApply(prim::kPrimImageSummary)) {
name = tag + "[:Image]";
} else if (IsApply(prim::kPrimHistogramSummary)) {
name = tag + "[:Histogram]";
} else {
name = tag + "[:Tensor]";
}
fullname_with_scope_ = name;
} else {
// cnode input 0 should be primitive ptr
auto value_ptr = input(0)->cast<ValueNodePtr>();
if (value_ptr == nullptr) {
MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << ".";
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
return fullname_with_scope_;
}
auto input_value = value_ptr->value();
if (input_value == nullptr) {
MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr.";
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
return fullname_with_scope_;
}
PrimitivePtr prim = GetValue<PrimitivePtr>(input_value);
MS_EXCEPTION_IF_NULL(scope());
MS_EXCEPTION_IF_NULL(prim);
fullname_with_scope_ =
scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base<CNode>());
}
return fullname_with_scope_;
}
std::string ValueNode::ToString() const {
MS_EXCEPTION_IF_NULL(value_);
if (value_->isa<FuncGraph>()) {
@ -173,10 +103,6 @@ std::string ValueNode::fullname_with_scope() {
return fullname_with_scope_;
}
void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); }
void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); }
void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); }
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();

@ -52,6 +52,7 @@ class AbstractBase;
} // namespace abstract
using BaseShapePtr = std::shared_ptr<abstract::BaseShape>;
using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>;
using AbstractBasePtrList = std::vector<AbstractBasePtr>;
class ValueNode;
using ValueNodePtr = std::shared_ptr<ValueNode>;
@ -78,6 +79,13 @@ using KernelInfoDevicePtr = std::shared_ptr<KernelInfoDevice>;
class AnfVisitor;
class ParamValue {
public:
ParamValue() = default;
virtual ~ParamValue() = default;
};
using ParamValuePtr = std::shared_ptr<ParamValue>;
// AnfNode is the basic class of the IR definition derived from Base.
// Only two types of nodes are derived: CNode and ANode.
// Methods:
@ -239,11 +247,11 @@ class ANode : public AnfNode {
// Parameter represents the parameter inputs of a function. They have no value.
// Attributes:
// default_param_: used to hold the inputting tensor of the model.
// default_param_value_: used to hold the inputting tensor of the model.
class Parameter : public ANode {
public:
explicit Parameter(const FuncGraphPtr &func_graph)
: ANode(func_graph), name_(""), has_default_(false), default_param_(py::none()), tensor_layout_(nullptr) {}
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), tensor_layout_(nullptr) {}
~Parameter() override = default;
MS_DECLARE_PARENT(Parameter, ANode);
@ -254,12 +262,11 @@ class Parameter : public ANode {
std::string fullname_with_scope() override { return name(); };
bool has_default() const { return has_default_; }
py::object default_param() { return default_param_; }
void set_default_param(const py::object &obj) {
default_param_ = obj;
void set_default_param(ParamValuePtr param) {
default_param_ = param;
has_default_ = true;
}
ParamValuePtr default_param() const { return default_param_; }
std::shared_ptr<parallel::TensorLayout> tensor_layout() const { return tensor_layout_; }
void set_tensor_layout(const std::shared_ptr<parallel::TensorLayout> &tensor_layout) {
@ -280,7 +287,7 @@ class Parameter : public ANode {
private:
std::string name_;
bool has_default_;
py::object default_param_;
ParamValuePtr default_param_;
std::shared_ptr<parallel::TensorLayout> tensor_layout_;
};
using ParameterPtr = std::shared_ptr<Parameter>;

@ -0,0 +1,103 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/anf.h"
#include <algorithm>
#include <sstream>
#include <vector>
#include <unordered_map>
#include "ir/visitor.h"
#include "pipeline/static_analysis/static_analysis.h"
#include "operator/ops.h"
#include "parallel/ops_info/ops_utils.h"
namespace mindspore {
// namespace to support intermediate representation definition
// Methods of AnfNode
TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildType(); }
BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); }
std::string AnfNode::ToString() const {
return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info());
}
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) {
if (operator_info_ != nullptr) {
MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
<< ", using the new one: " << operator_info->name();
auto old_ptr = operator_info_;
operator_info_ = operator_info;
return old_ptr;
}
operator_info_ = operator_info;
return nullptr;
}
std::string CNode::fullname_with_scope() {
// if full name is set, return its name immediately
if (!fullname_with_scope_.empty()) {
return fullname_with_scope_;
}
if (IsApply(prim::kPrimScalarSummary) || IsApply(prim::kPrimTensorSummary) || IsApply(prim::kPrimImageSummary) ||
IsApply(prim::kPrimHistogramSummary)) {
std::string tag = GetValue<std::string>(GetValueNode(input(1)));
if (tag == "") {
MS_LOG(EXCEPTION) << "The tag name is null, should be valid string";
}
std::string name;
if (IsApply(prim::kPrimScalarSummary)) {
name = tag + "[:Scalar]";
} else if (IsApply(prim::kPrimImageSummary)) {
name = tag + "[:Image]";
} else if (IsApply(prim::kPrimHistogramSummary)) {
name = tag + "[:Histogram]";
} else {
name = tag + "[:Tensor]";
}
fullname_with_scope_ = name;
} else {
// cnode input 0 should be primitive ptr
auto value_ptr = input(0)->cast<ValueNodePtr>();
if (value_ptr == nullptr) {
MS_LOG(WARNING) << "Input 0 of cnode is not a value node, its type is " << input(0)->type_name() << ".";
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
return fullname_with_scope_;
}
auto input_value = value_ptr->value();
if (input_value == nullptr) {
MS_LOG(WARNING) << "Value of input 0 of cnode is nullptr.";
fullname_with_scope_ = id_generator::get_id(shared_from_base<CNode>());
return fullname_with_scope_;
}
PrimitivePtr prim = GetValue<PrimitivePtr>(input_value);
MS_EXCEPTION_IF_NULL(scope());
MS_EXCEPTION_IF_NULL(prim);
fullname_with_scope_ =
scope()->name() + "/" + prim->name() + "-op" + id_generator::get_id(shared_from_base<CNode>());
}
return fullname_with_scope_;
}
void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); }
void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); }
void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); }
} // namespace mindspore

File diff suppressed because it is too large Load Diff

@ -19,9 +19,6 @@
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
namespace mindspore {
static std::string DumpTypeVector(const std::vector<TypePtr> &elements, bool is_dumptext) {

@ -19,9 +19,6 @@
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
namespace mindspore {
bool Number::operator==(const Type &other) const {

@ -19,9 +19,6 @@
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
namespace mindspore {
TypePtr RefType::DeepCopy() const {

@ -21,9 +21,8 @@
#include <cstdlib>
#include <algorithm>
#include "utils/log_adapter.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
#include "ir/dtype/number.h"
#include "utils/convert_utils.h"
namespace mindspore {
TypeId IntBitsToTypeId(const int nbits) {
@ -227,11 +226,6 @@ bool Type::operator==(const Value &other) const {
}
}
abstract::AbstractBasePtr Type::ToAbstract() {
abstract::AbstractBasePtr ptr = std::make_shared<abstract::AbstractType>(shared_from_base<Type>());
return ptr;
}
std::ostream &operator<<(std::ostream &os, const Type &type) {
os << type.ToString();
return os;

@ -0,0 +1,25 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/dtype/type.h"
#include "pipeline/static_analysis/abstract_value.h"
namespace mindspore {
abstract::AbstractBasePtr Type::ToAbstract() {
auto ptr = std::make_shared<abstract::AbstractType>(shared_from_base<Type>());
return ptr;
}
} // namespace mindspore

File diff suppressed because it is too large Load Diff

@ -19,6 +19,7 @@
#include <algorithm>
#include "ir/manager.h"
#include "ir/param_value_py.h"
#include "operator/ops.h"
#include "utils/log_adapter.h"
#include "utils/profile.h"
@ -69,7 +70,9 @@ void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target,
new_param->set_abstract(old_param->abstract());
new_param->set_name(old_param->name());
if (old_param->has_default()) {
new_param->set_default_param(old_param->default_param());
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(old_param->default_param());
auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
new_param->set_default_param(param_value_new);
}
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_param->set_scope(scope);
@ -248,7 +251,9 @@ void Cloner::CloneParameter(const ParameterPtr &param, const AnfNodePtr &node) {
if (node->isa<Parameter>()) {
ParameterPtr old_param = dyn_cast<Parameter>(node);
if (old_param->has_default()) {
param->set_default_param(old_param->default_param());
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(old_param->default_param());
auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
param->set_default_param(param_value_new);
}
param->set_name(old_param->name());
}

@ -28,6 +28,7 @@
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/manager.h"
namespace mindspore {
class Cloner;

@ -31,6 +31,7 @@
#include "ir/dtype.h"
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/signature.h"
#include "pipeline/static_analysis/abstract_value.h"
namespace py = pybind11;

@ -21,7 +21,6 @@
#include <memory>
#include <functional>
#include "ir/base.h"
#include "ir/anf.h"
namespace mindspore {

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_MINNIE_H_
#define MINDSPORE_CCSRC_IR_PARAM_VALUE_MINNIE_H_
#include <memory>
#include "ir/anf.h"
namespace mindspore {
class ParamValueMinnie : public ParamValue {
public:
ParamValueMinnie() : tensor_addr_(nullptr), tensor_size_(0) {}
virtual ~ParamValueMinnie() = default;
size_t tensor_size() const { return tensor_size_; }
void set_tensor_size(size_t size) { tensor_size_ = size; }
void *tensor_addr() const { return tensor_addr_; }
void set_tensor_addr(void *addr) { tensor_addr_ = addr; }
private:
void *tensor_addr_;
size_t tensor_size_;
};
using ParamValueMinniePtr = std::shared_ptr<ParamValueMinnie>;
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_MINNIE_H_

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
#define MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
#include <memory>
#include "ir/anf.h"
#include "pybind11/pybind11.h"
namespace mindspore {
namespace py = pybind11;
class ParamValuePy : public ParamValue {
public:
ParamValuePy() : value_(py::none()) {}
explicit ParamValuePy(py::object value) : value_(value) {}
virtual ~ParamValuePy() = default;
py::object value() { return value_; }
void set_value(const py::object &obj) { value_ = obj; }
private:
py::object value_;
};
using ParamValuePyPtr = std::shared_ptr<ParamValuePy>;
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_

@ -17,20 +17,23 @@
#ifndef MINDSPORE_CCSRC_IR_SCALAR_H_
#define MINDSPORE_CCSRC_IR_SCALAR_H_
namespace mindspore {
/* namespace to support inference engine */
#include <type_traits>
#include <algorithm>
#include <cmath>
#include <vector>
#include <string>
#include <memory>
#include <sstream>
#include <utility>
#include <cfloat>
#include "ir/base.h"
#include "ir/dtype.h"
#include "ir/dtype/number.h"
using std::fabs;
namespace mindspore {
class Scalar : public Value {
public:
Scalar() = default;

@ -19,9 +19,7 @@
#include <memory>
#include <cmath>
#include <cfloat>
#include "pybind_api/api_register.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "utils/convert_utils.h"
namespace mindspore {
const ValuePtr ValueSequeue::operator[](const std::size_t &dim) const {
@ -208,41 +206,6 @@ bool AnyValue::operator==(const Value &other) const {
}
}
const ValuePtr kAnyValue = std::make_shared<AnyValue>();
using ContextPtr = abstract::AnalysisContextPtr;
abstract::AbstractBasePtr Scalar::ToAbstract() {
return std::make_shared<abstract::AbstractScalar>(shared_from_base<Value>());
}
abstract::AbstractBasePtr StringImm::ToAbstract() {
return std::make_shared<abstract::AbstractScalar>(shared_from_base<Value>(), std::make_shared<String>());
}
abstract::AbstractBasePtr RefKey::ToAbstract() {
auto refkey = std::make_shared<abstract::AbstractRefKey>();
refkey->set_value(shared_from_base<Value>());
return refkey;
}
abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_shared<abstract::AbstractScalar>(); }
abstract::AbstractBasePtr ValueTuple::ToAbstract() {
abstract::AbstractBasePtrList a_list;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
MS_EXCEPTION_IF_NULL(ele);
return ele->ToAbstract();
});
return std::make_shared<abstract::AbstractTuple>(a_list);
}
abstract::AbstractBasePtr ValueList::ToAbstract() {
abstract::AbstractBasePtrList a_list;
(void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) {
MS_EXCEPTION_IF_NULL(ele);
return ele->ToAbstract();
});
return std::make_shared<abstract::AbstractList>(a_list);
}
std::size_t ValueSlice::hash() const {
MS_EXCEPTION_IF_NULL(start_);
@ -280,16 +243,6 @@ std::string ValueSlice::ToString() const {
return buffer.str();
}
abstract::AbstractBasePtr ValueSlice::ToAbstract() {
MS_EXCEPTION_IF_NULL(start_);
MS_EXCEPTION_IF_NULL(stop_);
MS_EXCEPTION_IF_NULL(step_);
abstract::AbstractBasePtr start = start_->ToAbstract();
abstract::AbstractBasePtr end = stop_->ToAbstract();
abstract::AbstractBasePtr step = step_->ToAbstract();
return std::make_shared<abstract::AbstractSlice>(start, end, step);
}
std::size_t KeywordArg::hash() const {
MS_EXCEPTION_IF_NULL(value_);
return hash_combine({tid(), std::hash<std::string>{}(key_), value_->hash()});
@ -316,12 +269,6 @@ std::string KeywordArg::ToString() const {
return buffer.str();
}
abstract::AbstractBasePtr KeywordArg::ToAbstract() {
MS_EXCEPTION_IF_NULL(value_);
abstract::AbstractBasePtr argument = value_->ToAbstract();
return std::make_shared<abstract::AbstractKeywordArg>(key_, argument);
}
const ValuePtr ValueDictionary::operator[](const std::string &key) const {
auto it = std::find_if(key_values_.begin(), key_values_.end(),
[key](const std::pair<std::string, ValuePtr> &item) { return item.first == key; });
@ -354,17 +301,4 @@ bool ValueDictionary::operator==(const ValueDictionary &other) const {
}
return true;
}
abstract::AbstractBasePtr ValueDictionary::ToAbstract() {
std::vector<std::pair<std::string, abstract::AbstractBasePtr>> kv;
(void)std::transform(
key_values_.begin(), key_values_.end(), std::back_inserter(kv),
[](const std::pair<std::string, ValuePtr> &item) { return std::make_pair(item.first, item.second->ToAbstract()); });
return std::make_shared<abstract::AbstractDictionary>(kv);
}
REGISTER_PYBIND_DEFINE(
RefKey, ([](const py::module *m) {
(void)py::class_<RefKey, std::shared_ptr<RefKey>>(*m, "RefKey").def(py::init<std::string>(), py::arg("tag"));
}));
} // namespace mindspore

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save