From 531ad4df70d75b0dbdb85bc9665e2900c658874c Mon Sep 17 00:00:00 2001 From: lirongzhen1 Date: Mon, 24 Aug 2020 17:03:49 +0800 Subject: [PATCH] prepare to support int64 --- mindspore/ccsrc/common/trans.cc | 4 +-- mindspore/ccsrc/common/trans.h | 3 +- mindspore/ccsrc/frontend/parallel/ps/worker.h | 21 ++++++------ mindspore/ccsrc/pipeline/jit/pipeline.cc | 12 +++---- .../pipeline/jit/static_analysis/prim.cc | 5 +-- .../device/ascend/ascend_device_address.cc | 14 ++++---- .../device/ascend/ascend_device_address.h | 14 ++++---- .../device/ascend/ascend_kernel_runtime.cc | 9 ++--- .../runtime/device/cpu/cpu_device_address.cc | 5 ++- .../runtime/device/cpu/cpu_device_address.h | 5 +-- .../runtime/device/cpu/cpu_kernel_runtime.cc | 9 ++--- .../ccsrc/runtime/device/device_address.h | 5 +-- .../runtime/device/gpu/gpu_device_address.cc | 8 ++--- .../runtime/device/gpu/gpu_device_address.h | 8 +++-- .../runtime/device/gpu/gpu_kernel_runtime.cc | 11 ++++--- .../ccsrc/runtime/device/kernel_adjust.cc | 5 +-- .../ccsrc/runtime/device/kernel_runtime.cc | 3 +- mindspore/ccsrc/transform/graph_ir/util.cc | 22 ++++++------- mindspore/ccsrc/transform/graph_ir/util.h | 27 ++++++++------- mindspore/ccsrc/utils/callbacks_ge.cc | 7 ++-- mindspore/ccsrc/utils/convert_utils.cc | 11 ++++--- .../ccsrc/utils/load_onnx/anf_model_parser.cc | 9 ++--- mindspore/ccsrc/utils/tensorprint_utils.cc | 7 ++-- mindspore/core/abstract/abstract_value.h | 10 +++--- mindspore/core/abstract/dshape.h | 17 +++++----- mindspore/core/abstract/prim_arrays.cc | 23 ++++++------- mindspore/core/abstract/prim_nn.cc | 7 ++-- mindspore/core/abstract/prim_others.cc | 5 +-- mindspore/core/abstract/utils.cc | 11 ++++--- mindspore/core/ir/device_sync.h | 6 ++-- mindspore/core/ir/meta_tensor.cc | 4 +-- mindspore/core/ir/meta_tensor.h | 14 ++++---- mindspore/core/ir/pattern_matcher.h | 9 ++--- mindspore/core/ir/tensor.cc | 33 +++++++++---------- mindspore/core/ir/tensor.h | 23 ++++++------- mindspore/core/utils/shape_utils.h | 23 +++++++++++++ 36 files changed, 226 insertions(+), 183 deletions(-) create mode 100644 mindspore/core/utils/shape_utils.h diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index f8f5e90d62..2ea8d7ffcb 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -371,9 +371,9 @@ bool IsNeedPadding(const std::string &format, const size_t shape_size) { return false; } -std::vector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { +ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index) { MS_EXCEPTION_IF_NULL(node); - std::vector shape; + ShapeVector shape; std::vector host_shape; if (node->isa()) { auto value_node = node->cast(); diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index c815fbc31f..702dda0450 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -26,6 +26,7 @@ #include "ir/dtype.h" #include "backend/kernel_compiler/kernel.h" #include "ir/dtype/type.h" +#include "utils/shape_utils.h" namespace mindspore { namespace trans { @@ -52,7 +53,7 @@ size_t ShapeSize(const std::vector &shape); size_t CubeSizeByType(const TypeId data_type); std::vector PaddingShapeTo4d(const std::vector &shape, const std::vector &padding_axis = {}); -std::vector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index); +ShapeVector GetRuntimePaddingShape(const AnfNodePtr &node, size_t index); bool IsNeedPadding(const std::string &format, const size_t shape_size); std::vector TransShapeToDevice(const std::vector &shape, const std::string &format); bool TransDataType(const TypeIdArgs &args, void *result); diff --git a/mindspore/ccsrc/frontend/parallel/ps/worker.h b/mindspore/ccsrc/frontend/parallel/ps/worker.h index 7c96938773..00b588a612 100644 --- a/mindspore/ccsrc/frontend/parallel/ps/worker.h +++ b/mindspore/ccsrc/frontend/parallel/ps/worker.h @@ -28,6 +28,7 @@ #include "frontend/parallel/ps/util.h" #include "frontend/parallel/ps/common.h" #include "frontend/parallel/ps/worker_proxy.h" +#include "utils/shape_utils.h" namespace mindspore { namespace parallel { @@ -41,15 +42,15 @@ class Worker { } void Run(); - void Push(const std::vector &keys, std::vector addrs, const std::vector &sizes); + void Push(const std::vector &keys, std::vector addrs, const ShapeVector &sizes); void Pull(const size_t key, void *dev_addr, const size_t size); size_t SetParamKey(const std::string ¶m_name); void SetParamInitInServer(const std::string ¶m_name, bool init_in_server); bool GetParamInitInServer(const std::string ¶m_name); void SetKeyOptimId(size_t key, const std::string &optimizer_name); - void SetOptimInputShapes(size_t key, const std::vector &shape); + void SetOptimInputShapes(size_t key, const ShapeVector &shape); void AddEmbeddingTable(const ::ps::Key &key, const size_t &row_count); - void InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, const std::vector &sizes); + void InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, const ShapeVector &sizes); void InitPSParamAndOptim(const std::string ¶m_name, tensor::TensorPtr tensor); void DoPSEmbeddingLookup(const ::ps::SArray<::ps::Key> &keys, const ::ps::SArray &lookup_ids, const ::ps::SArray &lens, ::ps::SArray *lookup_result, int cmd); @@ -75,7 +76,7 @@ class Worker { std::map param_to_key_; std::map init_keys_; std::map key_to_optimId_; - std::map>> key_to_optim_shapes_; + std::map> key_to_optim_shapes_; std::map param_to_init_in_server_; }; @@ -94,7 +95,7 @@ void Worker::Run() { } template -void Worker::Push(const std::vector &keys, std::vector addrs, const std::vector &sizes) { +void Worker::Push(const std::vector &keys, std::vector addrs, const ShapeVector &sizes) { size_t total_size = 0; for (auto size : sizes) { total_size += size; @@ -154,7 +155,7 @@ void Worker::InitPSParamData(const std::vector &keys, void *origin_ad } template -void Worker::SetOptimInputShapes(size_t key, const std::vector &shape) { +void Worker::SetOptimInputShapes(size_t key, const ShapeVector &shape) { if (key_to_optim_shapes_.find(key) == key_to_optim_shapes_.end()) { key_to_optim_shapes_[key] = {shape}; } else { @@ -167,7 +168,7 @@ void Worker::InitPSOptimInputShapes(const size_t key) { ::ps::SArray<::ps::Key> keys; ::ps::SArray shape_len; ::ps::SArray all_shape; - std::vector> shapes = key_to_optim_shapes_[key]; + std::vector shapes = key_to_optim_shapes_[key]; for (auto shape : shapes) { keys.push_back(key); if (shape.size() == 0) { @@ -255,7 +256,7 @@ void Worker::InitPSOptimId(const size_t param_key) { template void Worker::InitPSEmbeddingTable(const std::vector &keys, std::vector shapes, - const std::vector &sizes) { + const ShapeVector &sizes) { bool has_init = IsKeyInit(keys[0]); if (has_init) { MS_LOG(DEBUG) << "The key embedding table of key " << keys[0] << " is initialized."; @@ -272,7 +273,7 @@ template void Worker::InitPSParamAndOptim(const std::string ¶m_name, tensor::TensorPtr tensor) { void *param_data = tensor->data_c(); size_t param_size = LongToSize(tensor->data().nbytes()); - std::vector param_shape = tensor->shape_c(); + ShapeVector param_shape = tensor->shape_c(); size_t param_key = GetParamKey(param_name); if (param_key == kInvalidKey) { @@ -280,7 +281,7 @@ void Worker::InitPSParamAndOptim(const std::string ¶m_name, tensor::Tenso return; } bool init_in_server = false; - std::vector shape_init_in_server = {1}; + ShapeVector shape_init_in_server = {1}; if (param_shape == shape_init_in_server) { init_in_server = true; } diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index bbd8db1e1b..062dff7255 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -42,7 +42,7 @@ #include "frontend/optimizer/py_pass_manager.h" #include "pybind_api/pybind_patch.h" #include "backend/kernel_compiler/cpu/random_op_cpu_kernel.h" - +#include "utils/shape_utils.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "frontend/parallel/ps/common.h" #include "frontend/parallel/ps/util.h" @@ -136,10 +136,10 @@ py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple i return false; } std::shared_ptr sig = input_signature[count].cast>(); - std::vector sig_shape = sig->shape(); + ShapeVector sig_shape = sig->shape(); TypePtr sig_type = sig->Dtype(); - std::vector tensor_shape = m_tensor->shape_c(); + ShapeVector tensor_shape = m_tensor->shape_c(); if (tensor_shape != sig_shape) { MS_LOG(ERROR) << "Python input shape is incompatible with input_signature"; return false; @@ -849,13 +849,13 @@ bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batc const std::vector &types, const std::vector> &shapes, const std::vector &input_indexes, bool need_run) { MS_LOG(INFO) << "Start InitDataSet Entry"; - std::vector int_input_indexes; + ShapeVector int_input_indexes; (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes), [](int64_t item) { return static_cast(item); }); - std::vector> int_shapes; + std::vector int_shapes; (void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(int_shapes), [](const std::vector &item) { - std::vector vector_item; + ShapeVector vector_item; (void)std::transform(item.begin(), item.end(), std::back_inserter(vector_item), [](int64_t inner_item) { return static_cast(inner_item); }); return vector_item; diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index b25fcdd38b..784323a67b 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -39,6 +39,7 @@ #include "abstract/primitive_infer_map.h" #include "abstract/param_validator.h" #include "utils/ms_utils.h" +#include "utils/shape_utils.h" namespace mindspore { namespace abstract { @@ -309,13 +310,13 @@ py::dict ConvertAbstractToPython(const AbstractBasePtr &abs_base) { dic["dtype"] = arg->BuildType(); dic["value"] = BuildValue(arg->BuildValue()); } else if (abs_base->isa() || abs_base->isa() || abs_base->isa()) { - std::vector shape; + ShapeVector shape; dic["shape"] = shape; dic["dtype"] = abs_base->BuildType(); dic["value"] = BuildValue(abs_base->BuildValue()); } else if (abs_base->isa()) { auto arg_slice = dyn_cast(abs_base); - std::vector shape; + ShapeVector shape; dic["shape"] = shape; dic["dtype"] = arg_slice->BuildType(); dic["value"] = BuildValue(arg_slice->BuildValue()); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc index 0503039ec6..90f15b9e47 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.cc @@ -275,7 +275,7 @@ void AscendDeviceAddress::SyncStream() const { MS_LOG(INFO) << "Finish!"; } -bool AscendDeviceAddress::SyncDeviceToHost(const std::vector &shape, size_t size, mindspore::TypeId type, +bool AscendDeviceAddress::SyncDeviceToHost(const ShapeVector &shape, size_t size, mindspore::TypeId type, void *host_ptr) const { MS_LOG(INFO) << "SyncDeviceToHost, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; @@ -462,7 +462,7 @@ std::vector AscendDeviceAddress::GetDeviceShape(std::vector *hos return device_shape; } -bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, +bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const ShapeVector &shape, size_t size, mindspore::TypeId type, void *host_ptr) const { MS_LOG(INFO) << "SyncDeviceToHostAndConvertFormat, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; @@ -513,7 +513,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, mindspore::TypeId type, +bool AscendDeviceAddress::SyncHostToDevice(const ShapeVector &shape, size_t size, mindspore::TypeId type, const void *host_ptr) const { MS_LOG(INFO) << "SyncHostToDevice, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) << ", size:" << size_ << "), Host(type_id:" << TypeIdLabel(type) << ", size:" << size << ")"; @@ -557,7 +557,7 @@ bool AscendDeviceAddress::SyncHostToDevice(const std::vector &shape, size_t return sync_ok; } -bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector &shape, size_t size, +bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const ShapeVector &shape, size_t size, mindspore::TypeId type, const void *host_ptr) const { bool sync_ok = false; MS_LOG(INFO) << "ConvertFormatAndSyncHostToDevice, Device(format:" << format_ << ", type_id:" << TypeIdLabel(type_id_) @@ -622,7 +622,7 @@ AscendDeviceAddress::~AscendDeviceAddress() { #ifdef ENABLE_DUMP_E2E bool AscendDeviceAddress::DumpMemToFile(bool trans_flag, const std::string &filepath, const std::string &host_fmt, - const std::vector &host_shape, TypeId host_type) const { + const ShapeVector &host_shape, TypeId host_type) const { bool ret = false; if (filepath.empty()) { MS_LOG(ERROR) << "Dump file path is null!"; @@ -666,8 +666,8 @@ bool AscendDeviceAddress::DumpMemToFile(bool trans_flag, const std::string &file #ifdef ENABLE_DEBUGGER bool AscendDeviceAddress::LoadMemToHost(bool trans_flag, const std::string &tensor_name, int execution_order, - const std::string &host_fmt, const std::vector &host_shape, - TypeId host_type, size_t slot, Debugger *debugger, bool keep_prev) const { + const std::string &host_fmt, const ShapeVector &host_shape, TypeId host_type, + size_t slot, Debugger *debugger, bool keep_prev) const { bool ret = false; DebugServices *debug_services = debugger->debug_services(); TensorLoader *tensor_loader = debug_services->tensor_loader(); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h index 944d4bce7c..95a3a15abf 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_device_address.h @@ -25,6 +25,7 @@ #include "runtime/device/ascend/ascend_memory_pool.h" #include "ir/dtype.h" #include "backend/kernel_compiler/kernel.h" +#include "utils/shape_utils.h" namespace mindspore { #ifdef ENABLE_DEBUGGER @@ -38,23 +39,22 @@ class AscendDeviceAddress : public DeviceAddress { explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id) : DeviceAddress(ptr, size, format, type_id) {} ~AscendDeviceAddress() override; - bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; - bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; + bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override; + bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr) const override; DeviceAddressType DeviceType() const override { return DeviceAddressType::kAscend; } #ifdef ENABLE_DUMP_E2E bool DumpMemToFile(bool dump_mode, const std::string &filepath, const std::string &host_fmt, - const std::vector &host_shape, TypeId host_type) const; + const ShapeVector &host_shape, TypeId host_type) const; #endif #ifdef ENABLE_DEBUGGER bool LoadMemToHost(bool dump_mode, const std::string &tensor_name, int execution_order, const std::string &host_fmt, - const std::vector &host_shape, TypeId host_type, size_t slot, Debugger *debugger, + const ShapeVector &host_shape, TypeId host_type, size_t slot, Debugger *debugger, bool keep_prev) const; #endif private: - bool SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const; - bool ConvertFormatAndSyncHostToDevice(const std::vector &shape, size_t size, TypeId type, - const void *host_ptr) const; + bool SyncDeviceToHostAndConvertFormat(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const; + bool ConvertFormatAndSyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr) const; bool SyncDeviceToHostAndConvertFormatBasedOnTransData(const std::vector &host_shape, const std::vector &device_shape, size_t size, mindspore::TypeId type, void *host_ptr) const; diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index 38d03ced80..7640d276fa 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -39,6 +39,7 @@ #include "backend/kernel_compiler/tbe/tbe_utils.h" #include "runtime/device/ascend/ascend_memory_manager.h" #include "debug/tensor_load.h" +#include "utils/shape_utils.h" #ifdef MEM_REUSE_DEBUG #include "backend/optimizer/mem_reuse/mem_reuse_checker.h" #endif @@ -231,7 +232,7 @@ void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path, auto output_size = AnfAlgo::GetOutputTensorNum(node); for (size_t j = 0; j < output_size; ++j) { auto addr = AnfAlgo::GetOutputAddr(node, j); - std::vector int_shapes; + ShapeVector int_shapes; if (trans_flag) { int_shapes = trans::GetRuntimePaddingShape(node, j); } else { @@ -266,7 +267,7 @@ void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_p continue; } auto addr = AnfAlgo::GetOutputAddr(item, PRAMATER_OUTPUT_INDEX); - std::vector int_shapes; + ShapeVector int_shapes; if (trans_flag) { int_shapes = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX); } else { @@ -351,7 +352,7 @@ void LoadOutput(mindspore::session::KernelGraph *graph, Debugger *debugger) { auto format = kOpFormat_DEFAULT; string tensor_name = kernel_name + ':' + std::to_string(j); auto ascend_addr = dynamic_cast(addr); - std::vector int_shapes; + ShapeVector int_shapes; if (trans_flag) { int_shapes = trans::GetRuntimePaddingShape(node, j); } else { @@ -387,7 +388,7 @@ void LoadParameters(mindspore::session::KernelGraph *graph, Debugger *debugger) auto format = kOpFormat_DEFAULT; string tensor_name = parameter_name + ':' + "0"; auto ascend_addr = dynamic_cast(addr); - std::vector int_shapes; + ShapeVector int_shapes; if (trans_flag) { int_shapes = trans::GetRuntimePaddingShape(item, PRAMATER_OUTPUT_INDEX); } else { diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc index c2131a541e..98523acc74 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.cc @@ -20,8 +20,7 @@ namespace mindspore { namespace device { namespace cpu { -bool CPUDeviceAddress::SyncDeviceToHost(const std::vector & /*shape*/, size_t size, TypeId type, - void *host_ptr) const { +bool CPUDeviceAddress::SyncDeviceToHost(const ShapeVector & /*shape*/, size_t size, TypeId type, void *host_ptr) const { if (ptr_ == nullptr) { MS_LOG(ERROR) << "The pointer ptr_ is null!"; return false; @@ -50,7 +49,7 @@ bool CPUDeviceAddress::SyncDeviceToHost(const std::vector & /*shape*/, size return true; } -bool CPUDeviceAddress::SyncHostToDevice(const std::vector & /*shape*/, size_t size, TypeId type, +bool CPUDeviceAddress::SyncHostToDevice(const ShapeVector & /*shape*/, size_t size, TypeId type, const void *host_ptr) const { if (host_ptr == ptr_) { MS_LOG(DEBUG) << "host_ptr is equal to ptr_, request ignored."; diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h index c06e2915e0..d73804c324 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_device_address.h @@ -19,6 +19,7 @@ #include #include #include "runtime/device/device_address.h" +#include "utils/shape_utils.h" namespace mindspore { namespace device { @@ -32,8 +33,8 @@ class CPUDeviceAddress : public DeviceAddress { ~CPUDeviceAddress() override = default; - bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; - bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; + bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override; + bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr) const override; DeviceAddressType DeviceType() const override { return DeviceAddressType::kCPU; } }; } // namespace cpu diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index 114c20ae65..f9a318be23 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -25,6 +25,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "backend/session/session_basic.h" #include "frontend/operator/ops.h" +#include "utils/shape_utils.h" namespace mindspore { namespace device { @@ -52,7 +53,7 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph } auto tensor = node_value->cast(); MS_EXCEPTION_IF_NULL(tensor); - std::vector data_shape = tensor->shape(); + ShapeVector data_shape = tensor->shape(); size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies()); DeviceAddressPtr address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeFloat32); MS_EXCEPTION_IF_NULL(address); @@ -135,7 +136,7 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k tensor::TensorPtr tensor = kernel_graph->GetInternalOutputTensor(node, index); if (tensor == nullptr) { auto shape = AnfAlgo::GetOutputInferShape(node, index); - std::vector temp_shape; + ShapeVector temp_shape; (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end()); tensor = std::make_shared(infer_type_id, temp_shape); bool is_internal_output = kernel_graph->IsInternalOutput(node, index); @@ -149,7 +150,7 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *k } else { if (infer_type_id != device_type_id) { size_t type_size = GetTypeByte(TypeIdToType(device_type_id)); - std::vector data_shape = tensor->shape(); + ShapeVector data_shape = tensor->shape(); size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), type_size, std::multiplies()); address->ptr_ = resource_manager_.MemMalloc(tensor_size); need_sync_outputs->emplace_back(tensor); @@ -224,7 +225,7 @@ void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const tensor->data_type() == kNumberTypeInt32) { address->ptr_ = tensor->data_c(); } else { - std::vector data_shape = tensor->shape(); + ShapeVector data_shape = tensor->shape(); size_t tensor_size = std::accumulate(data_shape.begin(), data_shape.end(), sizeof(float), std::multiplies()); address->ptr_ = resource_manager_.MemMalloc(tensor_size); diff --git a/mindspore/ccsrc/runtime/device/device_address.h b/mindspore/ccsrc/runtime/device/device_address.h index 4775319f51..bf6aa99504 100644 --- a/mindspore/ccsrc/runtime/device/device_address.h +++ b/mindspore/ccsrc/runtime/device/device_address.h @@ -22,6 +22,7 @@ #include #include "ir/dtype.h" #include "ir/device_sync.h" +#include "utils/shape_utils.h" namespace mindspore { namespace device { @@ -60,7 +61,7 @@ class DeviceAddress : public mindspore::DeviceSync { size_t GetSize() const { return size_; } std::string format() const { return format_; } TypeId type_id() const { return type_id_; } - void set_host_shape(const std::vector &shape) { host_shape_ = shape; } + void set_host_shape(const ShapeVector &shape) { host_shape_ = shape; } virtual void set_status(DeviceAddressStatus status) {} virtual DeviceAddressStatus status() const { return DeviceAddressStatus::kInDevice; } virtual DeviceAddressType DeviceType() const { return DeviceAddressType::kUnknown; } @@ -77,7 +78,7 @@ class DeviceAddress : public mindspore::DeviceSync { TypeId type_id_{kNumberTypeFloat16}; bool from_mem_pool_{false}; uint8_t *communication_ptr_{nullptr}; - std::vector host_shape_{}; + ShapeVector host_shape_{}; friend class KernelRuntime; friend class MemoryManager; friend class mindspore::device::ascend::tasksink::TaskGenerator; diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc index 521e280f90..ffd522fbee 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.cc @@ -30,7 +30,7 @@ namespace mindspore { namespace device { namespace gpu { -bool GPUDeviceAddress::SyncDeviceToHost(const std::vector &, size_t size, TypeId, void *host_ptr) const { +bool GPUDeviceAddress::SyncDeviceToHost(const ShapeVector &, size_t size, TypeId, void *host_ptr) const { MS_EXCEPTION_IF_NULL(host_ptr); bool need_sync = (size != 0) && (size_ != 0); if (!need_sync) { @@ -50,7 +50,7 @@ bool GPUDeviceAddress::SyncDeviceToHost(const std::vector &, size_t size, T return GPUDeviceManager::GetInstance().CopyDeviceMemToHost(host_ptr, ptr_, size); } -bool GPUDeviceAddress::SyncHostToDevice(const std::vector &, size_t size, TypeId, const void *host_ptr) const { +bool GPUDeviceAddress::SyncHostToDevice(const ShapeVector &, size_t size, TypeId, const void *host_ptr) const { MS_EXCEPTION_IF_NULL(host_ptr); bool need_sync = (size != 0) && (size_ != 0); if (!need_sync) { @@ -80,8 +80,8 @@ GPUDeviceAddress::~GPUDeviceAddress() { } #ifdef ENABLE_DEBUGGER bool GPUDeviceAddress::LoadMemToHost(const std::string &tensor_name, int execution_order, const std::string &host_fmt, - const std::vector &host_shape, TypeId host_type, size_t slot, - Debugger *debugger, bool keep_prev) const { + const ShapeVector &host_shape, TypeId host_type, size_t slot, Debugger *debugger, + bool keep_prev) const { bool ret = false; if (size_ == 0) { return true; diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h index 8a3baccb61..c68108b9de 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_device_address.h @@ -21,6 +21,8 @@ #include #include "runtime/device/device_address.h" +using ShapeVecotr = std::vector; + namespace mindspore { #ifdef ENABLE_DEBUGGER class Debugger; @@ -34,15 +36,15 @@ class GPUDeviceAddress : public DeviceAddress { : DeviceAddress(ptr, size, format, type_id) {} ~GPUDeviceAddress() override; - bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const override; - bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, const void *host_ptr) const override; + bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override; + bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr) const override; void set_status(DeviceAddressStatus status) { status_ = status; } DeviceAddressStatus status() const { return status_; } DeviceAddressType DeviceType() const override { return DeviceAddressType::kGPU; } #ifdef ENABLE_DEBUGGER bool LoadMemToHost(const std::string &tensor_name, int execution_order, const std::string &host_fmt, - const std::vector &host_shape, TypeId host_type, size_t slot, Debugger *debugger, + const ShapeVector &host_shape, TypeId host_type, size_t slot, Debugger *debugger, bool keep_prev) const; #endif private: diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index 255d7b46d8..4406c3cb68 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -32,6 +32,7 @@ #include "common/trans.h" #include "ir/dtype.h" #include "profiler/device/gpu/gpu_profiling.h" +#include "utils/shape_utils.h" #ifdef ENABLE_DEBUGGER #include "debug/debug_services.h" #endif @@ -107,7 +108,7 @@ void DumpOutput(mindspore::session::KernelGraph *graph, const string &dump_path, auto addr = AnfAlgo::GetOutputAddr(node, j); TypeId addr_type_id = addr->type_id(); std::string addr_format = addr->format(); - std::vector int_shapes; + ShapeVector int_shapes; if (trans_flag) { int_shapes = trans::GetRuntimePaddingShape(node, j); } else { @@ -153,7 +154,7 @@ void DumpParameters(mindspore::session::KernelGraph *graph, const string &dump_p auto addr = AnfAlgo::GetOutputAddr(item, PARAMETER_OUTPUT_INDEX); TypeId addr_type_id = addr->type_id(); std::string addr_format = addr->format(); - std::vector int_shapes; + ShapeVector int_shapes; if (trans_flag) { int_shapes = trans::GetRuntimePaddingShape(item, PARAMETER_OUTPUT_INDEX); } else { @@ -251,7 +252,7 @@ void LoadKernelData(Debugger *debugger, const CNodePtr &kernel, auto format = kOpFormat_DEFAULT; auto gpu_addr = std::make_unique(addr->addr, addr->size, format, type); string input_tensor_name = input_kernel_name + ':' + "0"; - std::vector int_shapes; + ShapeVector int_shapes; auto shape = AnfAlgo::GetOutputDeviceShape(input_kernel, PARAMETER_OUTPUT_INDEX); (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), [](size_t inner_item) { return SizeToInt(inner_item); }); @@ -270,7 +271,7 @@ void LoadKernelData(Debugger *debugger, const CNodePtr &kernel, auto format = kOpFormat_DEFAULT; auto gpu_addr = std::make_unique(addr->addr, addr->size, format, type); string tensor_name = kernel_name + ':' + std::to_string(j); - std::vector int_shapes; + ShapeVector int_shapes; auto shape = AnfAlgo::GetOutputDeviceShape(kernel, j); (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), [](size_t inner_item) { return SizeToInt(inner_item); }); @@ -310,7 +311,7 @@ void LoadParameters(const session::KernelGraph *graph, Debugger *debugger, bool auto format = kOpFormat_DEFAULT; string tensor_name = parameter_name + ':' + "0"; auto gpu_addr = dynamic_cast(addr); - std::vector int_shapes; + ShapeVector int_shapes; auto shape = AnfAlgo::GetOutputDeviceShape(item, PARAMETER_OUTPUT_INDEX); (void)std::transform(shape.begin(), shape.end(), std::back_inserter(int_shapes), [](size_t inner_item) { return SizeToInt(inner_item); }); diff --git a/mindspore/ccsrc/runtime/device/kernel_adjust.cc b/mindspore/ccsrc/runtime/device/kernel_adjust.cc index 7d4d8aea5c..d0df4f1141 100644 --- a/mindspore/ccsrc/runtime/device/kernel_adjust.cc +++ b/mindspore/ccsrc/runtime/device/kernel_adjust.cc @@ -31,6 +31,7 @@ #include "runtime/device/ascend/profiling/profiling_manager.h" #include "runtime/base.h" #include "runtime/device/ascend/ascend_stream_assign.h" +#include "utils/shape_utils.h" namespace { constexpr auto kProfilingGraphId = "PROFILING_GRAPH_ID"; @@ -320,7 +321,7 @@ void KernelAdjust::CreateSwitchOpParameters(const std::shared_ptr *switch_loop_input) { MS_EXCEPTION_IF_NULL(kernel_graph_ptr); MS_EXCEPTION_IF_NULL(switch_loop_input); - std::vector shp = {1}; + ShapeVector shp = {1}; tensor::TensorPtr tensor_ptr = std::make_shared(kInt32->type_id(), shp); MS_EXCEPTION_IF_NULL(tensor_ptr); mindspore::abstract::AbstractBasePtr paremeter_abstract_ptr = tensor_ptr->ToAbstract(); @@ -559,7 +560,7 @@ void KernelAdjust::LoadSwitchInputs(std::vector *inputs) { MS_LOG(INFO) << "---------------- LoadSwitchInputs---"; MS_EXCEPTION_IF_NULL(inputs); // current loop count - std::vector shp = {1}; + ShapeVector shp = {1}; tensor::TensorPtr cur_loop_count = std::make_shared(kInt32->type_id(), shp); MS_EXCEPTION_IF_NULL(cur_loop_count); int32_t *val = nullptr; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 97258edde7..9d57e64caa 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -27,6 +27,7 @@ #include "backend/session/anf_runtime_algorithm.h" #include "backend/optimizer/common/helper.h" #include "ir/value.h" +#include "utils/shape_utils.h" using mindspore::kernel::Address; using mindspore::kernel::AddressPtr; @@ -681,7 +682,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { MS_LOG(EXCEPTION) << "Cannot alloc address when flag is: " << kStaticMem << ", tensor size is: " << tensor_size; } AnfAlgo::SetOutputAddr(address, 0, value_node.get()); - std::vector shape = {1, SizeToInt(tensor_size)}; + ShapeVector shape = {1, SizeToInt(tensor_size)}; if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) { MS_LOG(EXCEPTION) << "kValueNode SyncHostToDevice fail!"; } diff --git a/mindspore/ccsrc/transform/graph_ir/util.cc b/mindspore/ccsrc/transform/graph_ir/util.cc index 4c653b3c80..7cfd6f76a3 100644 --- a/mindspore/ccsrc/transform/graph_ir/util.cc +++ b/mindspore/ccsrc/transform/graph_ir/util.cc @@ -94,8 +94,8 @@ GeFormat TransformUtil::ConvertFormat(const string &format) { static int64_t IntegerCastFunc(size_t temp) { return static_cast(temp); } -std::shared_ptr TransformUtil::GetGeTensorDesc(const std::vector &me_shape, - const MeDataType &me_type, const std::string &format) { +std::shared_ptr TransformUtil::GetGeTensorDesc(const ShapeVector &me_shape, const MeDataType &me_type, + const std::string &format) { // convert me shape to ge shape std::vector ge_shape; @@ -196,7 +196,7 @@ GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr &tensor, const std::s } std::vector TransformUtil::ConvertGeTensors(const std::vector &ge_tensors, - const std::vector> &request_dims) { + const std::vector &request_dims) { std::vector outputs; for (size_t index = 0; index < ge_tensors.size(); index++) { @@ -204,7 +204,7 @@ std::vector TransformUtil::ConvertGeTensors(const std::vector empty_shape; + ShapeVector empty_shape; me_tensor_ptr = ConvertGeTensor(ge_tensors[index], empty_shape); } @@ -270,7 +270,7 @@ MeDataType TransformUtil::ConvertGeDataType(const GeDataType &type) { } namespace { -bool IsGeShapeCompatible(const GeShape &ge_shape, const std::vector &request_dims) { +bool IsGeShapeCompatible(const GeShape &ge_shape, const ShapeVector &request_dims) { MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims()); MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims); @@ -307,20 +307,20 @@ bool IsGeShapeCompatible(const GeShape &ge_shape, const std::vector &reques } } // namespace -GeShape TransformUtil::ConvertMeShape(const std::vector &me_dims) { +GeShape TransformUtil::ConvertMeShape(const ShapeVector &me_dims) { std::vector ge_dims; (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims)); return GeShape(ge_dims); } -std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape) { - std::vector me_dims; +ShapeVector TransformUtil::ConvertGeShape(const GeShape &ge_shape) { + ShapeVector me_dims; std::vector ge_dims = ge_shape.GetDims(); (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims)); return me_dims; } -std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims) { +ShapeVector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const ShapeVector &request_dims) { vector ret; if (ge_shape.GetDimNum() == 0) { MS_LOG(DEBUG) << "GeTensor's shape is scalar"; @@ -336,7 +336,7 @@ std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const st return ret; } -MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, +MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const ShapeVector &me_dims, const TypeId &me_type) { MeTensor me_tensor(me_type, me_dims); @@ -380,7 +380,7 @@ MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) { } // if request_dims is empty, use ge tensor's shape,otherwise convert to request shape -MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const std::vector &request_dims) { +MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const ShapeVector &request_dims) { MS_EXCEPTION_IF_NULL(ge_tensor); GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape(); vector me_dims = ConvertGeShape(ge_shape, request_dims); diff --git a/mindspore/ccsrc/transform/graph_ir/util.h b/mindspore/ccsrc/transform/graph_ir/util.h index 14f251a26e..6333e2f99f 100644 --- a/mindspore/ccsrc/transform/graph_ir/util.h +++ b/mindspore/ccsrc/transform/graph_ir/util.h @@ -26,8 +26,8 @@ #include "ir/dtype.h" #include "ir/tensor.h" #include "transform/graph_ir/types.h" - #include "graph/tensor.h" +#include "utils/shape_utils.h" namespace mindspore { namespace transform { @@ -73,7 +73,7 @@ class TransformUtil { * Return: * [shared_ptr] the shared pointer of ge tensor description * */ - static std::shared_ptr GetGeTensorDesc(const std::vector &shape, const MeDataType &me_type, + static std::shared_ptr GetGeTensorDesc(const ShapeVector &shape, const MeDataType &me_type, const std::string &format); /* @@ -107,20 +107,20 @@ class TransformUtil { /* * Parameters: * tensor: [GeTensor] the data tensor in GE - * request_dims [std::vector] the output Me tensors must adjust to this shapes + * request_dims [ShapeVector] the output Me tensors must adjust to this shapes * Return: * [MeTensor] the data tensor in ME * */ - static MeTensorPtr ConvertGeTensor(GeTensorPtr ge_tensor, const std::vector &request_dims); + static MeTensorPtr ConvertGeTensor(GeTensorPtr ge_tensor, const ShapeVector &request_dims); /* * Parameters: * ge_tensors: [std::vector] the data tensor in GE - * request_dims [std::vector>] the output Me tensors must adjust to this shapes + * request_dims [std::vector] the output Me tensors must adjust to this shapes * Return: * [std::vector] the data tensor in ME * */ static std::vector ConvertGeTensors(const std::vector &ge_tensors, - const std::vector> &request_dims); + const std::vector &request_dims); /* * Parameters: * ge_tensors: [std::vector] the data tensor in GE @@ -131,13 +131,12 @@ class TransformUtil { /* * Parameters: * ge_tensor: [GeTensor] the data tensor in GE - * me_dims: [std::vector] the shape of created Me tensor + * me_dims: [ShapeVector] the shape of created Me tensor * me_type: [TypeId] the type of created Me tensor * Return: * [MeTensor] the data tensor in ME * */ - static MeTensorPtr GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, - const TypeId &me_type); + static MeTensorPtr GenerateMeTensor(const GeTensorPtr &ge_tensor, const ShapeVector &me_dims, const TypeId &me_type); /* * Parameters: * type: [GeDataType] the ge tensor data type @@ -148,11 +147,11 @@ class TransformUtil { /* * Parameters: - * me_dims: [std::vector] the me shape + * me_dims: [ShapeVector] the me shape * Return: * [GeShape] the ge shape * */ - static GeShape ConvertMeShape(const std::vector &me_dims); + static GeShape ConvertMeShape(const ShapeVector &me_dims); /* * Parameters: @@ -160,7 +159,7 @@ class TransformUtil { * Return: * [vector] the me shape * */ - static std::vector ConvertGeShape(const GeShape &ge_shape); + static ShapeVector ConvertGeShape(const GeShape &ge_shape); /* Function: * Convert GeShape to Me request shape, Support pattern: @@ -176,11 +175,11 @@ class TransformUtil { * Return: * [vector] the me shape * */ - static std::vector ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims); + static ShapeVector ConvertGeShape(const GeShape &ge_shape, const ShapeVector &request_dims); /* * Parameters: - * vec: [std::vector] the vector to print + * vec: [ShapeVector] the vector to print * Return: * [string] value string * */ diff --git a/mindspore/ccsrc/utils/callbacks_ge.cc b/mindspore/ccsrc/utils/callbacks_ge.cc index 8f5718b0a2..808edcb5e0 100644 --- a/mindspore/ccsrc/utils/callbacks_ge.cc +++ b/mindspore/ccsrc/utils/callbacks_ge.cc @@ -22,6 +22,7 @@ #include "pipeline/jit/parse/data_converter.h" #include "pipeline/jit/parse/python_adapter.h" #include "utils/visible.h" +#include "utils/shape_utils.h" namespace mindspore { namespace callbacks { @@ -36,7 +37,7 @@ using mindspore::transform::Status; using mindspore::transform::TransformUtil; bool GetParameterShape(const FuncGraphPtr &graph, const std::string ¶m_name, - const std::shared_ptr> &shape) { + const std::shared_ptr &shape) { if (graph == nullptr) { MS_LOG(ERROR) << "Graph is null, can not get graph parameter"; return false; @@ -74,7 +75,7 @@ static TensorPtr GetMeTensorTransformed(uint32_t graph_id, const std::string &pa return nullptr; } - std::shared_ptr> parameter_shape_ptr = std::make_shared>(); + std::shared_ptr parameter_shape_ptr = std::make_shared(); if (!GetParameterShape(anf_graph, parameter_name, parameter_shape_ptr)) { MS_LOG(ERROR) << "Can not get parameter shape during callback"; return nullptr; @@ -133,7 +134,7 @@ static TensorPtr GetMeTensorForSummary(const std::string &name, const std::share // process the scalar type summary // Because the ge tensor is dim = 4, so set the (1,1,1,1)-->(1,) // We do the (1,) shape is scalar - auto shape = std::vector({ONE_SHAPE}); + auto shape = ShapeVector({ONE_SHAPE}); return TransformUtil::ConvertGeTensor(ge_tensor_ptr, shape); } if (tname == "[:Tensor]" || tname == "[:Histogram]") { diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index 0b41330d58..5fed864b4d 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -33,6 +33,7 @@ #include "ir/param_info.h" #include "utils/base_ref_extends.h" #include "utils/ms_context.h" +#include "utils/shape_utils.h" namespace mindspore { py::object BuiltinsToPyData(const Any &value); @@ -374,7 +375,7 @@ py::object VectorRefToPyData(const VectorRef &value_list) { AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py::object &type_obj, const py::object &min_shape, const py::object &max_shape) { if ((py::isinstance(shape_obj) || py::isinstance(shape_obj)) && py::isinstance(type_obj)) { - auto ret_vec = shape_obj.cast>(); + auto ret_vec = shape_obj.cast(); auto ret_dtype = type_obj.cast(); MS_EXCEPTION_IF_NULL(ret_dtype); // if the size of shape list is empty, return an scalar abstract @@ -383,13 +384,13 @@ AbstractBasePtr PyListDtype2AbstractTensor(const py::object &shape_obj, const py return abs_scalar; } AbstractBasePtr tensor = nullptr; - std::vector min_shape_vec; - std::vector max_shape_vec; + ShapeVector min_shape_vec; + ShapeVector max_shape_vec; if (!min_shape.is_none()) { - min_shape_vec = min_shape.cast>(); + min_shape_vec = min_shape.cast(); } if (!max_shape.is_none()) { - max_shape_vec = max_shape.cast>(); + max_shape_vec = max_shape.cast(); } auto ret_shape = std::make_shared(ret_vec, min_shape_vec, max_shape_vec); if (ret_dtype->isa()) { diff --git a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc index 4125675bd5..fe16f88be6 100644 --- a/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc +++ b/mindspore/ccsrc/utils/load_onnx/anf_model_parser.cc @@ -26,6 +26,7 @@ #include "abstract/abstract_value.h" #include "proto/onnx.pb.h" #include "utils/log_adapter.h" +#include "utils/shape_utils.h" using std::string; @@ -96,7 +97,7 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, cons return false; } const onnx::TensorShapeProto &tensor_shape = tensor_typeproto.shape(); - std::vector shape; + ShapeVector shape; for (int i = 0; i < tensor_shape.dim_size(); ++i) { shape.push_back(tensor_shape.dim(i).dim_value()); } @@ -241,7 +242,7 @@ bool MSANFModelParser::GetAttrValueForCNode(const PrimitivePtr &prim, const onnx bool MSANFModelParser::ObtainValueNodeInTensorForm(const std::string &value_node_name, const onnx::TensorProto &attr_tensor) { const int attr_tensor_type = attr_tensor.data_type(); - std::vector shape; + ShapeVector shape; for (int i = 0; i < attr_tensor.dims_size(); ++i) { shape.push_back(attr_tensor.dims(i)); } @@ -355,7 +356,7 @@ bool MSANFModelParser::BuildValueNodeForFuncGraph(const onnx::NodeProto &node_pr } AbstractBasePtr MSANFModelParser::GetAbstractForCNode(const onnx::AttributeProto &attr_proto) { - std::vector shape_vec; + ShapeVector shape_vec; const onnx::TensorProto &attr_tensor = attr_proto.t(); for (int i = 0; i < attr_tensor.dims_size(); ++i) { shape_vec.push_back(attr_tensor.dims(i)); @@ -471,7 +472,7 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra const onnx::ValueInfoProto &output_node = importProto.output(0); const onnx::TypeProto &output_typeproto = output_node.type(); int output_type = output_typeproto.tensor_type().elem_type(); - std::vector output_shape; + ShapeVector output_shape; for (int i = 0; i < output_typeproto.tensor_type().shape().dim_size(); ++i) { output_shape.push_back(output_typeproto.tensor_type().shape().dim(i).dim_value()); } diff --git a/mindspore/ccsrc/utils/tensorprint_utils.cc b/mindspore/ccsrc/utils/tensorprint_utils.cc index f427a382c1..7a92633762 100644 --- a/mindspore/ccsrc/utils/tensorprint_utils.cc +++ b/mindspore/ccsrc/utils/tensorprint_utils.cc @@ -22,6 +22,7 @@ #include "ir/tensor.h" #include "pybind11/pybind11.h" #include "utils/ms_utils.h" +#include "utils/shape_utils.h" #ifndef NO_DLIB #include "tdt/tsd_client.h" #include "tdt/tdt_host_interface.h" @@ -59,7 +60,7 @@ std::string GetParseType(const std::string &tensorType_) { return type_iter->second; } -bool ParseTensorShape(const std::string &input_shape_str, std::vector *const tensor_shape, size_t *dims) { +bool ParseTensorShape(const std::string &input_shape_str, ShapeVector *const tensor_shape, size_t *dims) { if (tensor_shape == nullptr) { return false; } @@ -189,7 +190,7 @@ bool ConvertDataItem2Tensor(const std::vector &items) { continue; } - std::vector tensor_shape; + ShapeVector tensor_shape; size_t totaldims = 1; if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) { MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_; @@ -235,7 +236,7 @@ bool SaveDataItem2File(const std::vector &items, const std::strin } } - std::vector tensor_shape; + ShapeVector tensor_shape; size_t totaldims = 1; if (!ParseTensorShape(item.tensorShape_, &tensor_shape, &totaldims)) { MS_LOG(ERROR) << "Tensor print can not parse tensor shape, receive info" << item.tensorShape_; diff --git a/mindspore/core/abstract/abstract_value.h b/mindspore/core/abstract/abstract_value.h index fa768addc6..2eab2757bb 100644 --- a/mindspore/core/abstract/abstract_value.h +++ b/mindspore/core/abstract/abstract_value.h @@ -33,6 +33,7 @@ #include "ir/value.h" #include "ir/tensor.h" #include "abstract/dshape.h" +#include "utils/shape_utils.h" namespace mindspore { namespace abstract { @@ -250,7 +251,7 @@ class AbstractUndetermined : public AbstractBase { } set_shape(shape); } - AbstractUndetermined(const TypePtr &element_type, const std::vector &shape) + AbstractUndetermined(const TypePtr &element_type, const ShapeVector &shape) : AbstractBase(kAnyValue), element_(std::make_shared(kAnyValue, element_type)) { if (element_type == nullptr) { MS_LOG(EXCEPTION) << "element_type is nullptr"; @@ -273,8 +274,7 @@ class AbstractTensor : public AbstractUndetermined { // only element_ and value, shape track are valid member, type track are unknown. explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) : AbstractUndetermined(element, shape) {} - AbstractTensor(const TypePtr &element_type, const std::vector &shape) - : AbstractUndetermined(element_type, shape) {} + AbstractTensor(const TypePtr &element_type, const ShapeVector &shape) : AbstractUndetermined(element_type, shape) {} explicit AbstractTensor(const tensor::TensorPtr &tensor) : AbstractUndetermined(tensor->Dtype(), tensor->shape()) {} ~AbstractTensor() override = default; MS_DECLARE_PARENT(AbstractTensor, AbstractUndetermined) @@ -632,7 +632,7 @@ class AbstractRowTensor : public AbstractUndetermined { public: explicit AbstractRowTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) : AbstractUndetermined(element, shape) {} - AbstractRowTensor(const TypePtr &element_type, const std::vector &shape) + AbstractRowTensor(const TypePtr &element_type, const ShapeVector &shape) : AbstractUndetermined(element_type, shape) {} ~AbstractRowTensor() override = default; MS_DECLARE_PARENT(AbstractRowTensor, AbstractUndetermined) @@ -661,7 +661,7 @@ class AbstractSparseTensor : public AbstractUndetermined { public: explicit AbstractSparseTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared()) : AbstractUndetermined(element, shape) {} - AbstractSparseTensor(const TypePtr &element_type, const std::vector &shape) + AbstractSparseTensor(const TypePtr &element_type, const ShapeVector &shape) : AbstractUndetermined(element_type, shape) {} ~AbstractSparseTensor() override = default; MS_DECLARE_PARENT(AbstractSparseTensor, AbstractUndetermined) diff --git a/mindspore/core/abstract/dshape.h b/mindspore/core/abstract/dshape.h index 4197f73ac0..01aa314e22 100644 --- a/mindspore/core/abstract/dshape.h +++ b/mindspore/core/abstract/dshape.h @@ -29,6 +29,7 @@ #include "utils/log_adapter.h" #include "base/base.h" +#include "utils/shape_utils.h" namespace mindspore { namespace abstract { @@ -69,12 +70,12 @@ class Shape : public BaseShape { (void)std::transform(list_in.begin(), list_in.end(), std::back_inserter(shape_), [](const int64_t &value) { return static_cast(value); }); } - explicit Shape(const std::vector &list) : shape_(list) {} + explicit Shape(const ShapeVector &list) : shape_(list) {} explicit Shape(const std::vector &list) { (void)std::transform(list.begin(), list.end(), std::back_inserter(shape_), [](const int64_t &value) { return static_cast(value); }); } - Shape(const std::vector &list, const std::vector &min_shape, const std::vector &max_shape) + Shape(const ShapeVector &list, const ShapeVector &min_shape, const ShapeVector &max_shape) : shape_(list), min_shape_(min_shape), max_shape_(max_shape) {} ~Shape() override = default; MS_DECLARE_PARENT(Shape, BaseShape) @@ -83,13 +84,13 @@ class Shape : public BaseShape { bool operator==(const BaseShape &other) const override; BaseShapePtr Clone() const override { return std::make_shared(shape_, min_shape_, max_shape_); } void Broaden() override; - std::vector &shape() { return shape_; } - std::vector &min_shape() { return min_shape_; } - std::vector &max_shape() { return max_shape_; } + ShapeVector &shape() { return shape_; } + ShapeVector &min_shape() { return min_shape_; } + ShapeVector &max_shape() { return max_shape_; } - std::vector shape_; // use SHP_ANY to implement the any shape in python - std::vector min_shape_; // record mininum length for each dynamic dimention - std::vector max_shape_; // record maximum length for each dynamic dimention + ShapeVector shape_; // use SHP_ANY to implement the any shape in python + ShapeVector min_shape_; // record mininum length for each dynamic dimention + ShapeVector max_shape_; // record maximum length for each dynamic dimention }; using ShapePtr = std::shared_ptr; using ShapePtrList = std::vector; diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 40dfbc02fe..136a7b5075 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -17,11 +17,12 @@ #include "abstract/infer_functions.h" #include "abstract/utils.h" #include "abstract/param_validator.h" +#include "utils/shape_utils.h" namespace mindspore { namespace abstract { namespace { -std::vector BroadcastShape(std::vector shpx, std::vector shpy) { +ShapeVector BroadcastShape(ShapeVector shpx, ShapeVector shpy) { int dlen = SizeToInt(shpx.size()) - SizeToInt(shpy.size()); if (dlen < 0) { for (int i = 0; i < -dlen; ++i) { @@ -35,7 +36,7 @@ std::vector BroadcastShape(std::vector shpx, std::vector shpy) { if (shpx.size() != shpy.size()) { MS_LOG(EXCEPTION) << "Failure: shpx.size() != shpy.size()."; } - std::vector shp; + ShapeVector shp; for (size_t i = 0; i < shpx.size(); i++) { auto a = shpx[i]; auto b = shpy[i]; @@ -50,7 +51,7 @@ std::vector BroadcastShape(std::vector shpx, std::vector shpy) { } else if (a == b) { shp.push_back(a); } else { - return std::vector(); + return ShapeVector(); } } return shp; @@ -89,18 +90,18 @@ AbstractBasePtr InferImplBroadCastShape(const AnalysisEnginePtr &, const Primiti auto value_tuple_x = xs->BuildValue()->cast(); MS_EXCEPTION_IF_NULL(value_tuple_x); auto shp_tuple_x = value_tuple_x->value(); - std::vector shp_x; + ShapeVector shp_x; (void)std::transform(std::begin(shp_tuple_x), std::end(shp_tuple_x), std::back_inserter(shp_x), [](const ValuePtr &e) -> int { return GetValue(e); }); auto value_tuple_y = ys->BuildValue()->cast(); MS_EXCEPTION_IF_NULL(value_tuple_y); auto shp_tuple_y = value_tuple_y->value(); - std::vector shp_y; + ShapeVector shp_y; (void)std::transform(std::begin(shp_tuple_y), std::end(shp_tuple_y), std::back_inserter(shp_y), [](const ValuePtr &e) -> int { return GetValue(e); }); - std::vector res = BroadcastShape(shp_x, shp_y); + ShapeVector res = BroadcastShape(shp_x, shp_y); if (res.empty()) { MS_LOG(EXCEPTION) << "BroadcastShape fail: " << args_spec_list[0]->ToString() << "," << args_spec_list[1]->ToString(); @@ -130,7 +131,7 @@ AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &pri MS_LOG(EXCEPTION) << "shape's data field can't be anything: " << args_spec_list[1]->ToString(); } - std::vector mul_shp; + ShapeVector mul_shp; auto value_tuple_mul = mul_shp_value->cast(); auto mul_shp_data = value_tuple_mul->value(); (void)std::transform(std::begin(mul_shp_data), std::end(mul_shp_data), std::back_inserter(mul_shp), @@ -140,7 +141,7 @@ AbstractBasePtr InferImplTile(const AnalysisEnginePtr &, const PrimitivePtr &pri << input_shape->shape().size() << ", value size is: " << mul_shp_data.size() << "."; } - std::vector result_shp; + ShapeVector result_shp; for (size_t i = 0; i < mul_shp_data.size(); ++i) { result_shp.push_back(input_shape->shape()[i] * mul_shp[i]); } @@ -195,9 +196,9 @@ AbstractBasePtr InferImplUnique(const AnalysisEnginePtr &, const PrimitivePtr &p if (shape->shape().size() != 1) { MS_LOG(EXCEPTION) << "Rank of " << op_name << "'s input must be 1."; } - std::vector ids_shape = {Shape::SHP_ANY}; - std::vector min_shape = {1}; - std::vector max_shape = shape->shape(); + ShapeVector ids_shape = {Shape::SHP_ANY}; + ShapeVector min_shape = {1}; + ShapeVector max_shape = shape->shape(); auto ids = std::make_shared(input->element(), std::make_shared(ids_shape, min_shape, max_shape)); auto ids_idx = std::make_shared(std::make_shared(32), shape->shape()); diff --git a/mindspore/core/abstract/prim_nn.cc b/mindspore/core/abstract/prim_nn.cc index ac0228caa7..625c07fb90 100644 --- a/mindspore/core/abstract/prim_nn.cc +++ b/mindspore/core/abstract/prim_nn.cc @@ -18,6 +18,7 @@ #include "abstract/utils.h" #include "abstract/param_validator.h" #include "utils/check_convert_utils.h" +#include "utils/shape_utils.h" namespace mindspore { namespace abstract { @@ -82,7 +83,7 @@ AbstractBasePtr InferImplPooling(const AnalysisEnginePtr &, const PrimitivePtr & int h_out = ((h_input + 2 * padding - (window - 1) - 1) / stride) + 1; int w_out = ((w_input + 2 * padding - (window - 1) - 1) / stride) + 1; - std::vector shape_out = {input_shape->shape()[0], input_shape->shape()[1], h_out, w_out}; + ShapeVector shape_out = {input_shape->shape()[0], input_shape->shape()[1], h_out, w_out}; AbstractBasePtr ret = input_tensor->Broaden(); ret->set_shape(std::make_shared(shape_out)); return ret; @@ -271,11 +272,11 @@ AbstractBasePtr InferImplBiasAddGrad(const AnalysisEnginePtr &, const PrimitiveP MS_EXCEPTION_IF_NULL(args_spec_list[0]); ShapePtr shape_y = dyn_cast(args_spec_list[0]->GetShapeTrack()); MS_EXCEPTION_IF_NULL(shape_y); - std::vector y_dims = shape_y->shape(); + ShapeVector y_dims = shape_y->shape(); if (y_dims.size() < 2) { MS_LOG(EXCEPTION) << primitive->name() << " input y backprop, dim should >= 2, while " << y_dims.size() << "."; } - std::vector bias_dims = {y_dims[1]}; + ShapeVector bias_dims = {y_dims[1]}; ShapePtr ret_shape = std::make_shared(bias_dims); AbstractBasePtr ret = args_spec_list[0]->Broaden(); ret->set_shape(ret_shape); diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 358ed75849..813d91943d 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -25,6 +25,7 @@ #include "abstract/utils.h" #include "utils/ms_context.h" #include "utils/symbolic.h" +#include "utils/shape_utils.h" namespace mindspore { namespace abstract { @@ -224,7 +225,7 @@ AbstractBasePtr InferImplMakeRowTensor(const AnalysisEnginePtr &, const Primitiv auto dense_shape_value = dense_shape->BuildValue()->cast(); MS_EXCEPTION_IF_NULL(dense_shape_value); auto shp = dense_shape_value->value(); - std::vector dense_shape_vec; + ShapeVector dense_shape_vec; (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec), [](const ValuePtr &e) -> int { auto elem = GetValue(e); @@ -318,7 +319,7 @@ AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const Primi auto dense_shape_value = dense_shape->BuildValue()->cast(); MS_EXCEPTION_IF_NULL(dense_shape_value); auto shp = dense_shape_value->value(); - std::vector dense_shape_vec; + ShapeVector dense_shape_vec; (void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec), [](const ValuePtr &e) -> int { auto elem = GetValue(e); diff --git a/mindspore/core/abstract/utils.cc b/mindspore/core/abstract/utils.cc index 20eeab0de5..d4366d2c3d 100644 --- a/mindspore/core/abstract/utils.cc +++ b/mindspore/core/abstract/utils.cc @@ -23,6 +23,7 @@ #include #include "utils/symbolic.h" #include "abstract/param_validator.h" +#include "utils/shape_utils.h" namespace mindspore { namespace abstract { @@ -54,7 +55,7 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) { MS_LOG(WARNING) << "Unsupported shape join. shape1 = " << shape1->ToString() << ", shape2 = " << shape2->ToString(); return shape1; } - std::vector dims; + ShapeVector dims; bool has_dynamic_shape = false; dims.resize(shape1->shape().size()); for (std::size_t i = 0; i < shape1->shape().size(); i++) { @@ -72,8 +73,8 @@ ShapePtr ShapeJoin(const ShapePtr &shape1, const ShapePtr &shape2) { return std::make_shared(dims); } // calculate dynamic shape - std::vector min_dims(dims.size()); - std::vector max_dims(dims.size()); + ShapeVector min_dims(dims.size()); + ShapeVector max_dims(dims.size()); for (size_t i = 0; i < dims.size(); ++i) { if (dims[i] != Shape::SHP_ANY) { min_dims[i] = max_dims[i] = dims[i]; @@ -205,7 +206,7 @@ int GetPositiveAxis(int axis_value, size_t increment) { // Return if two shapes can be broadcast. // Broadcast shape is placed in broadcast_output_shape. -std::vector RealBroadcast(const std::string &op, std::vector x_shape, std::vector y_shape) { +ShapeVector RealBroadcast(const std::string &op, ShapeVector x_shape, ShapeVector y_shape) { std::reverse(x_shape.begin(), x_shape.end()); std::reverse(y_shape.begin(), y_shape.end()); // Fill a placeholder value 1 which will be replaced later. @@ -213,7 +214,7 @@ std::vector RealBroadcast(const std::string &op, std::vector x_shape, y_shape.resize(std_len, 1); x_shape.resize(std_len, 1); - std::vector broadcast_shape; + ShapeVector broadcast_shape; for (size_t i = 0; i < std_len; i++) { int x_i = x_shape[i]; // i-th dimension of x int y_i = y_shape[i]; // i-th dimension of y diff --git a/mindspore/core/ir/device_sync.h b/mindspore/core/ir/device_sync.h index d8a0079814..2cf7ecd38e 100644 --- a/mindspore/core/ir/device_sync.h +++ b/mindspore/core/ir/device_sync.h @@ -22,6 +22,7 @@ #include #include "ir/dtype/type.h" +#include "utils/shape_utils.h" using std::string; @@ -29,9 +30,8 @@ namespace mindspore { // Interface for data synchornize between device and host. class DeviceSync { public: - virtual bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const = 0; - virtual bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, - const void *host_ptr) const = 0; + virtual bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const = 0; + virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr) const = 0; virtual void *GetMutablePtr() const = 0; }; using DeviceSyncPtr = std::shared_ptr; diff --git a/mindspore/core/ir/meta_tensor.cc b/mindspore/core/ir/meta_tensor.cc index 41b069b770..bd7c839e30 100644 --- a/mindspore/core/ir/meta_tensor.cc +++ b/mindspore/core/ir/meta_tensor.cc @@ -27,9 +27,9 @@ namespace tensor { // MetaTensor has default type_id_ which is TypeId::kTypeUnknown. MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {} -MetaTensor::MetaTensor(const TypeId data_type, const std::vector &shape) : data_type_(data_type), shape_(shape) {} +MetaTensor::MetaTensor(const TypeId data_type, const ShapeVector &shape) : data_type_(data_type), shape_(shape) {} -MetaTensor::MetaTensor(const TypePtr &type_ptr, const std::vector &shape) { +MetaTensor::MetaTensor(const TypePtr &type_ptr, const ShapeVector &shape) { TypeId data_type = TypeId::kTypeUnknown; if (type_ptr != nullptr) { data_type = type_ptr->type_id(); diff --git a/mindspore/core/ir/meta_tensor.h b/mindspore/core/ir/meta_tensor.h index 100c3cc59e..1bb8ce9533 100644 --- a/mindspore/core/ir/meta_tensor.h +++ b/mindspore/core/ir/meta_tensor.h @@ -26,6 +26,7 @@ #include "ir/dtype.h" #include "utils/convert_utils_base.h" #include "utils/hashing.h" +#include "utils/shape_utils.h" // brief mindspore namespace. // @@ -37,7 +38,6 @@ namespace mindspore { // // A sub namespace in ME to support tensor related definition. namespace tensor { - // brief Device info of Tensor // // Includes the format and data type of a tensor. @@ -63,9 +63,9 @@ class MetaTensor : public Value { // information of a Tensor. The following codes will create a 2x3 float // param data_type The data type of the tensor. // param shape The shape of the tensor. - MetaTensor(const TypeId data_type, const std::vector &shape); + MetaTensor(const TypeId data_type, const ShapeVector &shape); - MetaTensor(const TypePtr &type_ptr, const std::vector &shape); + MetaTensor(const TypePtr &type_ptr, const ShapeVector &shape); // brief Constructs a MetaTensor object from an existing MetaTensor instance. // // The constructed MetaTensor object will have the same data type and shape as the @@ -115,7 +115,7 @@ class MetaTensor : public Value { // order it represents. // // return A const vector which represents the shape of the tensor. - const std::vector &shape() const { return shape_; } + const ShapeVector &shape() const { return shape_; } // brief Sets the shape of a tensor. // @@ -126,7 +126,7 @@ class MetaTensor : public Value { // // param shape The shape of the tensor. // return The shape's size. - size_t set_shape(const std::vector &shape) { + size_t set_shape(const ShapeVector &shape) { this->shape_ = shape; return shape_.size(); } @@ -174,11 +174,11 @@ class MetaTensor : public Value { // brief Shape of the tensor. // - // A std::vector container is used to store the shape of a tensor. + // A ShapeVector container is used to store the shape of a tensor. // Each element of the vector represents the size of a dimension of the tensor. // The order of each element in the vector is as same as the the dimension's // order it represents. If the dimension size is not set, its value will be -1. - std::vector shape_; + ShapeVector shape_; // brief Device info of Tensor // diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h index 3c5c3122c9..0eb31fc422 100644 --- a/mindspore/core/ir/pattern_matcher.h +++ b/mindspore/core/ir/pattern_matcher.h @@ -24,6 +24,7 @@ #include "ir/visitor.h" #include "base/core_ops.h" +#include "utils/shape_utils.h" namespace mindspore { /// @@ -599,7 +600,7 @@ class PConstant : public PBase > { auto tensor_abstract = node->abstract()->cast(); TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); - std::vector tensor_shape = tensor_abstract->shape()->shape(); + ShapeVector tensor_shape = tensor_abstract->shape()->shape(); auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); @@ -619,7 +620,7 @@ class PConstant : public PBase > { return nullptr; } auto x_abstract = x->abstract()->cast(); - std::vector x_shape = x_abstract->shape()->shape(); + ShapeVector x_shape = x_abstract->shape()->shape(); if (x_shape != tensor_shape) { return nullptr; } @@ -664,7 +665,7 @@ class PConstant : public PBase > { auto tensor_abstract = node->abstract()->cast(); TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType(); - std::vector tensor_shape = tensor_abstract->shape()->shape(); + ShapeVector tensor_shape = tensor_abstract->shape()->shape(); auto new_tensor_ptr = std::make_shared(tensor_type_ptr->type_id(), tensor_shape); size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum()); @@ -744,7 +745,7 @@ class PConstant : public PBase > { return nullptr; } - std::vector tensor_out_shape = tensor_3_abstract->shape()->shape(); + ShapeVector tensor_out_shape = tensor_3_abstract->shape()->shape(); int data_out_size = std::accumulate(tensor_out_shape.begin(), tensor_out_shape.end(), 1, std::multiplies()); if ((tensor_ptr_1->DataSize() > 1) && (tensor_ptr_1->DataSize() != data_out_size)) { return nullptr; diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index efffc2ba26..442bf76527 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -50,11 +50,11 @@ static TypeId TypeIdOf(const TypePtr &data_type, TypeId defaultTypeId) { return data_type ? data_type->type_id() : defaultTypeId; } -static size_t SizeOf(const std::vector &shape) { +static size_t SizeOf(const ShapeVector &shape) { return std::accumulate(shape.begin(), shape.end(), size_t(1), std::multiplies()); } -static std::string ShapeToString(const std::vector &shape) { +static std::string ShapeToString(const ShapeVector &shape) { std::string str = "["; const size_t count = shape.size(); for (size_t i = 0; i < count; ++i) { @@ -93,7 +93,7 @@ std::unique_ptr NewData(Scalar scalar) { } template -std::unique_ptr CopyData(const std::vector &shape, void *const data, TypeId data_type) { +std::unique_ptr CopyData(const ShapeVector &shape, void *const data, TypeId data_type) { const size_t size = SizeOf(shape); switch (data_type) { case kNumberTypeBool: { @@ -151,7 +151,7 @@ std::unique_ptr CopyData(const std::vector &shape, void *const data, T } template -std::unique_ptr CopyData(const std::vector &shape, void *const data, size_t data_len) { +std::unique_ptr CopyData(const ShapeVector &shape, void *const data, size_t data_len) { size_t size = SizeOf(shape); if (size * sizeof(T) != data_len) { MS_LOG(EXCEPTION) << "Incorrect tensor input data length " << data_len << ", expect " << size * sizeof(T) @@ -165,21 +165,21 @@ std::unique_ptr CopyData(const std::vector &shape, void *const data, s template class TensorDataImpl : public TensorData { public: - explicit TensorDataImpl(const std::vector &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {} + explicit TensorDataImpl(const ShapeVector &shape) : ndim_(shape.size()), data_size_(SizeOf(shape)) {} ~TensorDataImpl() = default; - TensorDataImpl(const std::vector &shape, void *data, size_t data_len) + TensorDataImpl(const ShapeVector &shape, void *data, size_t data_len) : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData(shape, data, data_len)) {} - TensorDataImpl(const std::vector &shape, void *data, TypeId data_type) + TensorDataImpl(const ShapeVector &shape, void *data, TypeId data_type) : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(CopyData(shape, data, data_type)) {} template - TensorDataImpl(const std::vector &shape, const U *input, size_t size) + TensorDataImpl(const ShapeVector &shape, const U *input, size_t size) : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(NewData(input, size)) {} template - TensorDataImpl(const std::vector &shape, Scalar scalar) + TensorDataImpl(const ShapeVector &shape, Scalar scalar) : ndim_(shape.size()), data_size_(SizeOf(shape)), data_(NewData(scalar)) {} ssize_t size() const override { return static_cast(data_size_); } @@ -213,7 +213,7 @@ class TensorDataImpl : public TensorData { std::equal(data_.get(), data_.get() + data_size_, ptr->data_.get()); } - std::string ToString(const TypeId type, const std::vector &shape) const override { + std::string ToString(const TypeId type, const ShapeVector &shape) const override { constexpr auto valid = std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value || @@ -301,8 +301,7 @@ class TensorDataImpl : public TensorData { } } - void SummaryStringRecursive(std::ostringstream &ss, const std::vector &shape, ssize_t *cursor, - ssize_t depth) const { + void SummaryStringRecursive(std::ostringstream &ss, const ShapeVector &shape, ssize_t *cursor, ssize_t depth) const { if (depth >= static_cast(ndim_)) { return; } @@ -360,7 +359,7 @@ class TensorDataImpl : public TensorData { }; template -TensorDataPtr MakeTensorData(TypeId data_type, const std::vector &shape, const Args... args) { +TensorDataPtr MakeTensorData(TypeId data_type, const ShapeVector &shape, const Args... args) { switch (data_type) { case kNumberTypeBool: return std::make_shared>(shape, args...); @@ -410,16 +409,16 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type) device_sync_(tensor.device_sync_), padding_type_(tensor.padding_type()) {} -Tensor::Tensor(TypeId data_type, const std::vector &shape, TensorDataPtr data) +Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data) : MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {} -Tensor::Tensor(TypeId data_type, const std::vector &shape) +Tensor::Tensor(TypeId data_type, const ShapeVector &shape) : Tensor(data_type, shape, MakeTensorData(data_type, shape)) {} -Tensor::Tensor(TypeId data_type, const std::vector &shape, void *data, size_t data_len) +Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len) : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, data_len)) {} -Tensor::Tensor(TypeId data_type, const std::vector &shape, void *data, TypeId src_data_type) +Tensor::Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type) : Tensor(data_type, shape, MakeTensorData(data_type, shape, data, src_data_type)) {} Tensor::Tensor(const std::vector &input, const TypePtr &data_type) diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index 752edb50da..8bf499a9b4 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -26,6 +26,7 @@ #include "ir/meta_tensor.h" #include "utils/log_adapter.h" #include "base/float16.h" +#include "utils/shape_utils.h" // brief mindspore namespace. // @@ -52,7 +53,7 @@ class TensorData { /// Is data equals. virtual bool equals(const TensorData &other) const = 0; /// To string. - virtual std::string ToString(const TypeId type, const std::vector &shape) const = 0; + virtual std::string ToString(const TypeId type, const ShapeVector &shape) const = 0; }; using TensorDataPtr = std::shared_ptr; @@ -76,31 +77,31 @@ class Tensor : public MetaTensor { // brief Create tensor with the given shared tensor data. // // param data_type [TypeId] Data type of the tensor. - // param shape The shape represented by std::vector of the tensor. + // param shape The shape represented by ShapeVector of the tensor. // param data The shared tensor data. - Tensor(TypeId data_type, const std::vector &shape, TensorDataPtr data); + Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data); // brief Create a lazy allocated tensor. // // param data_type [TypeId] Data type of the tensor. - // param shape The shape represented by std::vector of the tensor. - Tensor(TypeId data_type, const std::vector &shape); + // param shape The shape represented by ShapeVector of the tensor. + Tensor(TypeId data_type, const ShapeVector &shape); // brief Create a tensor with input data buffer. // // param data_type [TypeId] Data type of the tensor. - // param shape The shape represented by std::vector of the tensor. + // param shape The shape represented by ShapeVector of the tensor. // param data The input data to be copied into tensor. // param data_len The length of data in bytes. - Tensor(TypeId data_type, const std::vector &shape, void *data, size_t data_len); + Tensor(TypeId data_type, const ShapeVector &shape, void *data, size_t data_len); // brief Create a tensor with input data buffer and given source data type. // // param data_type [TypeId] Data type of the tensor. - // param shape The shape represented by std::vector of the tensor. + // param shape The shape represented by ShapeVector of the tensor. // param data The input data to be copied into tensor. // param src_data_type The source data type. - Tensor(TypeId data_type, const std::vector &shape, void *data, TypeId src_data_type); + Tensor(TypeId data_type, const ShapeVector &shape, void *data, TypeId src_data_type); // brief Create 1 dimension tensor from an int vector. // @@ -170,8 +171,8 @@ class Tensor : public MetaTensor { // brief Get the tensor's shape for C++ // - // return [std::vector] - std::vector shape_c(void) const { return shape(); } + // return [ShapeVector] + ShapeVector shape_c(void) const { return shape(); } // brief Get Tensor data pointer for c++ type // diff --git a/mindspore/core/utils/shape_utils.h b/mindspore/core/utils/shape_utils.h new file mode 100644 index 0000000000..a2e43a9930 --- /dev/null +++ b/mindspore/core/utils/shape_utils.h @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_SHAPE_UTILS_INFO_H_ +#define MINDSPORE_SHAPE_UTILS_INFO_H_ + +#include +using ShapeVector = std::vector; + +#endif // MINDSPORE_SHAPE_UTILS_INFO_H_