add cache pass

pull/11314/head
fangzehua 4 years ago
parent dfa6daaa57
commit f97e19f23f

@ -19,55 +19,20 @@
#include <memory>
#include <vector>
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/cache_embedding_hashmap_struct.h"
namespace mindspore {
namespace kernel {
template <typename T>
struct HashmapEntry {
T key;
T value;
T step;
T tag;
bool IsEmpty() {
if (this->tag == NULLTAG)
return true;
else
return false;
}
bool IsUsing(const T &train_step) {
if (this->step >= (train_step - 1))
return true;
else
return false;
}
bool IsKey(const T &emb_idx) {
if (this->key == emb_idx)
return true;
else
return false;
}
void SetEmpty() { this->tag = NULLTAG; }
};
template <typename T>
T HashFunc(const T &key, const size_t &m) {
return (T)(((0.6180339 * key) - floor(0.6180339 * key)) * m);
}
template <typename T>
int Compress(HashmapEntry<T> *entry_p, const size_t &length, T entry) {
T i = (entry + 1) % length, off = 1;
int compress_count = 0;
for (; !entry_p[i].IsEmpty(); i = (i + 1) % length, off++) {
if (entry_p[i].tag > off) {
entry_p[entry].key = entry_p[i].key;
entry_p[entry].value = entry_p[i].value;
entry_p[entry].step = entry_p[i].step;
entry_p[entry].tag = entry_p[i].tag - off;
if (entry_p[i].tag_ > off) {
entry_p[entry].key_ = entry_p[i].key_;
entry_p[entry].value_ = entry_p[i].value_;
entry_p[entry].step_ = entry_p[i].step_;
entry_p[entry].tag_ = entry_p[i].tag_ - off;
entry_p[i].SetEmpty();
off = 0;
entry = i;
@ -127,6 +92,7 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
float total_count = 0;
int count_size = 0;
float hit_count = 0;
// search_cache_idx
for (size_t i = 0; i < batch_size_; ++i) {
T key = input_indices[i] - offset;
@ -140,7 +106,7 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
while ((!hashmap[tmp_entry].IsEmpty() && !hashmap[tmp_entry].IsKey(key))) {
tmp_entry = (tmp_entry + 1) % hashmap_length_;
if (count > hashmap_length_) {
MS_LOG(ERROR) << "Hashmap is full, search cache idx failed!";
MS_LOG(EXCEPTION) << "Hashmap is full, search cache idx failed, please set a larger vocab_cache_size!";
break;
}
count += 1;
@ -153,8 +119,8 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
miss_count++;
} else {
hit_count += 1;
output_cache_idx[i] = hashmap[tmp_entry].value;
hashmap[tmp_entry].step = step_[0];
output_cache_idx[i] = hashmap[tmp_entry].value_;
hashmap[tmp_entry].step_ = step_[0];
}
}
if (miss_count != 0) {
@ -175,27 +141,27 @@ void MapCacheIdxCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
while (!hashmap[entry].IsEmpty()) {
entry = (entry + 1) % hashmap_length_;
if (tag_count > hashmap_length_) {
MS_LOG(ERROR) << "Hashmap is full, insert new key failed!";
MS_LOG(EXCEPTION) << "Hashmap is full, insert new key failed, please set a larger vocab_cache_size!";
break;
}
tag_count++;
}
hashmap[entry].key = emb_idx;
hashmap[entry].step = step_[0];
hashmap[entry].tag = tag_count;
hashmap[entry].key_ = emb_idx;
hashmap[entry].step_ = step_[0];
hashmap[entry].tag_ = tag_count;
T tmp_entry = (entry + 1) % hashmap_length_;
size_t delete_count = 1;
while (hashmap[tmp_entry].IsEmpty() || hashmap[tmp_entry].IsUsing(step_[0])) {
tmp_entry = (tmp_entry + 1) % hashmap_length_;
if (delete_count > hashmap_length_) {
MS_LOG(ERROR) << "Hashmap is full, delete old key failed!";
MS_LOG(EXCEPTION) << "Hashmap is full, delete old key failed, please set a larger vocab_cache_size!";
break;
}
delete_count++;
}
output_swap_cache_idx[i] = hashmap[tmp_entry].value;
output_old_emb_idx[i] = hashmap[tmp_entry].key;
hashmap[entry].value = output_swap_cache_idx[i];
output_swap_cache_idx[i] = hashmap[tmp_entry].value_;
output_old_emb_idx[i] = hashmap[tmp_entry].key_;
hashmap[entry].value_ = output_swap_cache_idx[i];
hashmap[tmp_entry].SetEmpty();
int compress_count = Compress(hashmap, hashmap_length_, tmp_entry);
total_delete_count += (compress_count + delete_count);

@ -23,8 +23,6 @@
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#define NULLTAG 0
namespace mindspore {
namespace kernel {
class MapCacheIdxCPUKernel : public CPUKernel {

@ -188,12 +188,18 @@ void ReplaceOldNode(std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusi
MS_EXCEPTION_IF_NULL(manager);
auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id];
if (buffer_fusion_info.outputs_list.size() == 1) { // single output
if (kernel_graph != nullptr) {
kernel_graph->FrontBackendlMapUpdate(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel);
}
(void)manager->Replace(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel);
ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[0],
buffer_fusion_kernel);
} else { // multiple output
for (size_t index = 0; index < buffer_fusion_info.outputs_list.size(); ++index) {
auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, index);
if (kernel_graph != nullptr) {
kernel_graph->FrontBackendlMapUpdate(buffer_fusion_info.outputs_list[index], tuple_item);
}
(void)manager->Replace(buffer_fusion_info.outputs_list[index], tuple_item);
ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[index],
tuple_item);

@ -274,6 +274,10 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
bool IsNopNode(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto target = GetCNodeTarget(node);
if (target == kCPUDevice) {
return false;
}
if (context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kAscendDevice &&
context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
return false;

@ -0,0 +1,28 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_
#include "ir/anf.h"
namespace mindspore {
namespace parallel {
// Automatically adding control depend based on effect order and side effect analysis.
void AddCacheEmbedding(const FuncGraphPtr &graph);
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_CACHE_EMBEDDING_CACHE_EMBEDDING_H_

@ -36,11 +36,14 @@
#include "frontend/optimizer/graph_transform.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/step_auto_parallel.h"
#include "frontend/parallel/cache_embedding/cache_embedding.h"
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
#include "frontend/optimizer/recompute.h"
#include "utils/log_adapter.h"
#include "pipeline/jit/pipeline_split.h"
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
#include "ps/util.h"
#endif
namespace mindspore {
namespace pipeline {
using OptPassGroupMap = opt::OptPassGroupMap;
@ -391,6 +394,26 @@ bool AddRecomputationPass(const ResourcePtr &res) {
return true;
}
bool AddCacheEmbeddingPass(const ResourcePtr &res) {
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
if (ps::Util::IsParamServerMode()) {
return true;
}
#endif
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
parallel::AddCacheEmbedding(func_graph);
if (func_graph->has_flag(GRAPH_FLAG_CACHE_ENABLE)) {
auto params = func_graph->parameters();
AbstractBasePtrList args_spec_list;
std::for_each(params.begin(), params.end(),
[&args_spec_list](const AnfNodePtr &node) { args_spec_list.push_back(node->abstract()); });
func_graph = pipeline::Renormalize(res, func_graph, args_spec_list);
}
return true;
}
bool MergeDupGraphPass(const ResourcePtr &res) {
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
@ -500,6 +523,7 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
{"tuple_transform", OptPassTransformGraphGroup},
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_cache_embedding", AddCacheEmbeddingPass},
{"add_control_depend", AddControlDependPass},
{"add_recomputation", AddRecomputationPass}};

@ -37,6 +37,7 @@ bool PipelineSplitPass(const ResourcePtr &res);
bool ValidatePass(const ResourcePtr &res);
bool ConvertPrepareAdapt(const ResourcePtr &res);
bool AddControlDependPass(const ResourcePtr &res);
bool AddCacheEmbeddingPass(const ResourcePtr &res);
bool InferenceOptPreparePass(const ResourcePtr &res);
void ReclaimOptimizer();
bool PynativeOptPass(const ResourcePtr &res);

@ -32,6 +32,8 @@ REGISTER_PYBIND_DEFINE(ParamInfo, ([](const py::module *m) {
.def_property("parallel_optimizer", &ParamInfo::parallel_optimizer,
&ParamInfo::set_parallel_optimizer)
.def_property("comm_fusion", &ParamInfo::comm_fusion, &ParamInfo::set_comm_fusion)
.def_property("cache_enable", &ParamInfo::cache_enable, &ParamInfo::set_cache_enable)
.def_property("cache_shape", &ParamInfo::cache_shape, &ParamInfo::set_cache_shape)
.def(py::pickle(
[](const ParamInfo &p) { // __getstate__
return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel());

@ -24,6 +24,7 @@
#include "pybind_api/api_register.h"
#include "abstract/abstract_value.h"
#include "utils/shape_utils.h"
#include "utils/cache_embedding_hashmap_struct.h"
namespace mindspore {
namespace tensor {
@ -272,6 +273,68 @@ py::int_ TensorPy::GetPyItemSize(const Tensor &tensor) { return tensor.data().it
py::int_ TensorPy::GetPyNBytes(const Tensor &tensor) { return tensor.data().nbytes(); }
template <typename T>
void MemCopyFromCacheToHost(void *hashmap_addr, void *host_addr, void *cache_addr, size_t host_max, size_t cache_max,
size_t hashmap_size, size_t col_size) {
auto host_data = static_cast<char *>(host_addr);
auto cache_data = static_cast<char *>(cache_addr);
auto hashmap_data = static_cast<HashmapEntry<T> *>(hashmap_addr);
// default param type float
size_t param_type_size = 4;
size_t single_col_bytes = param_type_size * col_size;
for (size_t i = 0; i < hashmap_size; ++i) {
if (!hashmap_data[i].IsEmpty()) {
size_t host_offset = single_col_bytes * hashmap_data[i].key_;
size_t cache_offset = single_col_bytes * hashmap_data[i].value_;
if (cache_offset + single_col_bytes <= cache_max) {
auto ret =
memcpy_s(host_data + host_offset, host_max - host_offset, cache_data + cache_offset, single_col_bytes);
if (ret != 0) {
MS_LOG(EXCEPTION) << "Memcpy failed.";
}
}
}
}
MS_LOG(INFO) << "Memcpy from cache to host success!";
}
void TensorPy::FlushFromCache(const Tensor &tensor) {
py::gil_scoped_release gil_release;
if (tensor.NeedWait()) {
tensor.Wait();
}
tensor.data_sync();
if (tensor.cache_enable()) {
MS_LOG(INFO) << tensor.ToString() << " is cache enable.";
auto hashmap_tensor_ptr = tensor.hashmap_tensor_ptr();
auto cache_tensor_ptr = tensor.cache_tensor_ptr();
if (hashmap_tensor_ptr != nullptr && cache_tensor_ptr != nullptr) {
hashmap_tensor_ptr->data_sync();
cache_tensor_ptr->data_sync();
auto hashmap_size = hashmap_tensor_ptr->shape_c()[0];
auto host_shape = tensor.shape_c();
auto cache_shape = cache_tensor_ptr->shape_c();
if (host_shape.size() != 2 && host_shape.size() != 2 && host_shape[1] != cache_shape[1]) {
MS_LOG(EXCEPTION) << "Got host shape and cache shape invalid."
<< "host shape:" << host_shape << ", cache shape:" << cache_shape;
}
auto host_data_max_size = tensor.Size();
auto cache_data_max_size = cache_tensor_ptr->Size();
auto hashmap_data_type = hashmap_tensor_ptr->data_type();
if (hashmap_data_type == TypeId::kNumberTypeInt32) {
MemCopyFromCacheToHost<int32_t>(hashmap_tensor_ptr->data_c(), tensor.data_c(), cache_tensor_ptr->data_c(),
host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]);
} else if (hashmap_data_type == TypeId::kNumberTypeInt64) {
MemCopyFromCacheToHost<int32_t>(hashmap_tensor_ptr->data_c(), tensor.data_c(), cache_tensor_ptr->data_c(),
host_data_max_size, cache_data_max_size, hashmap_size, host_shape[1]);
} else {
MS_LOG(ERROR) << "Hashmap dtype only suppotr int32, in64.";
}
}
}
}
py::array TensorPy::SyncAsNumpy(const Tensor &tensor) {
{
py::gil_scoped_release gil_release;
@ -457,6 +520,16 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
array([[1., 1., 1.],
[1., 1., 1.]])
)mydelimiter")
.def("_flush_from_cache", TensorPy::FlushFromCache, R"mydelimiter(
Flush Cache data to Host if tensor is cache enable.
Returns:
None.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 3)))
>>> data._flush_from_cache()
)mydelimiter")
.def("is_init", &Tensor::is_init, R"mydelimiter(
Get tensor init_flag.

@ -115,6 +115,8 @@ class TensorPy {
static py::int_ GetPyItemSize(const Tensor &tensor);
static py::int_ GetPyNBytes(const Tensor &tensor);
static void FlushFromCache(const Tensor &tensor);
};
} // namespace tensor
} // namespace mindspore

@ -268,7 +268,7 @@ void CPUKernelRuntime::CreateOutputTensors(session::KernelGraph *kernel_graph,
bound_addresses_.clear();
auto output_nodes = kernel_graph->outputs();
for (const auto &item : output_nodes) {
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, false);
auto out = CreatTensorForOutput(kernel_graph, item_with_index, tensor_to_node);
outputs->push_back(std::move(out));
}

@ -0,0 +1,51 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_UTIL_CACHE_EMBBEDDING_HASHMAP_STRUCT_H_
#define MINDSPORE_CCSRC_UTIL_CACHE_EMBBEDDING_HASHMAP_STRUCT_H_
#include <math.h>
namespace mindspore {
const int64_t kNullTag = 0;
const int64_t kInitStep = -5;
const int64_t kEmptyRate = 4;
const double kGoldenRatio = 0.6180339;
template <typename T>
struct HashmapEntry {
T key_;
T value_;
T step_;
T tag_;
bool IsEmpty() { return tag_ == kNullTag; }
bool IsUsing(const T train_step) { return step_ >= (train_step - 1); }
bool IsKey(const T emb_idx) { return key_ == emb_idx; }
void SetEmpty() { tag_ = kNullTag; }
};
template <typename T>
T HashFunc(const T key, const size_t m) {
return (T)(((kGoldenRatio * key) - floor(kGoldenRatio * key)) * m);
}
} // namespace mindspore
#endif // MINDSPORE_CCSRC_UTIL_CACHE_EMBBEDDING_HASHMAP_STRUCT_H_

@ -350,6 +350,7 @@ constexpr auto kAttrPrimitiveTarget = "primitive_target";
constexpr auto kAttrUseLocking = "use_locking";
constexpr auto kAttrReduceScatterFlag = "reduce_scatter_flag";
constexpr auto kAttrOffset = "offset";
constexpr auto kAttrCacheEnable = "cache_enable";
constexpr auto kAttrPsKey = "ps_key";
constexpr auto kAttrOptimizerType = "optim_type";
constexpr auto kAttrChildGraph = "child_graph";

@ -131,7 +131,7 @@ class Parameter(Tensor_):
if self.init_mode is not None:
data = self.init_mode
else:
# cast to break deep infinit loop while deepcopy
# cast to break deep infinite loop while deepcopy
data = Tensor(self)
return (
Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
@ -348,6 +348,8 @@ class Parameter(Tensor_):
x.is_param_ps = self.is_param_ps
x.init_in_server = self.init_in_server
x.cache_enable = self.cache_enable
if self.cache_shape:
x.cache_shape = self.cache_shape
if init != 'same':
shape = self.shape
dtype = self.dtype
@ -375,6 +377,28 @@ class Parameter(Tensor_):
raise TypeError("`parallel_optimizer` parameter must be bool type")
self._param_info.parallel_optimizer = value
@property
def cache_enable(self):
"""Return whether the parameter is cache enable."""
return self._param_info.cache_enable
@cache_enable.setter
def cache_enable(self, value=True):
if not isinstance(value, bool):
raise TypeError("`cache_enable` parameter must be bool type")
self._param_info.cache_enable = value
@property
def cache_shape(self):
"""Return the cache shape corresponding to the parameter if use cache."""
return self._param_info.cache_shape
@cache_shape.setter
def cache_shape(self, value):
if not isinstance(value, (tuple, list)):
raise TypeError("`cache_shape` parameter must be tuple or list type")
self._param_info.cache_shape = value
@property
def requires_grad(self):
"""Return whether the parameter requires gradient."""

@ -308,6 +308,10 @@ class Tensor(Tensor_):
"""Convert tensor to numpy array."""
return Tensor_.asnumpy(self)
def _flush_from_cache(self):
"""Flush cache data to host if tensor is cache enable."""
Tensor_._flush_from_cache(self)
def all(self, axis=(), keep_dims=False):
"""
Check all array elements along a given axis evaluate to True.

@ -60,6 +60,7 @@ using ValueNodePtr = std::shared_ptr<ValueNode>;
class CNode;
using CNodePtr = std::shared_ptr<CNode>;
using CNodePtrList = std::vector<CNodePtr>;
class FuncGraph;
using FuncGraphSet = OrderedSet<FuncGraphPtr>;
@ -88,7 +89,7 @@ using ParamInfoPtr = std::shared_ptr<ParamInfo>;
// intermediate_abstract: return the cached inferring abstract value.
// Type/Shape: return the related info of this AnfNode. When this AnfNode is an
// input of other CNodes, you can get the related info by this method.
// debug_info: return the information retrived from parser. Set it using set_debug_info.
// debug_info: return the information retrieved from parser. Set it using set_debug_info.
// fullname_with_scope: return the detailed debug info.
class AnfNode : public Base {
public:

@ -167,7 +167,6 @@ class MetaTensor : public Value {
// Get tensor's param_info info.
ParamInfoPtr param_info() const { return param_info_; }
bool is_parameter() const { return is_parameter_; }
// Set tensor's param_info info.
void set_param_info(const ParamInfoPtr &param_info) {
is_parameter_ = true;

@ -81,6 +81,12 @@ class ParamInfo {
bool parallel_optimizer() const { return parallel_optimizer_; }
void set_parallel_optimizer(bool parallel_optimizer) { parallel_optimizer_ = parallel_optimizer; }
bool cache_enable() const { return cache_enable_; }
void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; }
std::vector<int64_t> cache_shape() const { return cache_shape_; }
void set_cache_shape(const std::vector<int64_t> &cache_shape) { cache_shape_ = cache_shape; }
private:
std::string name_{"Parameter"};
bool requires_grad_{true};
@ -92,6 +98,8 @@ class ParamInfo {
int32_t cloned_index_{0};
int32_t fusion_type_{1};
bool parallel_optimizer_{true};
bool cache_enable_{false};
std::vector<int64_t> cache_shape_;
};
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_PARAM_INFO_H_

@ -449,6 +449,9 @@ Tensor::Tensor(const Tensor &tensor)
event_(tensor.event_),
sync_status_(tensor.sync_status_),
device_sync_(tensor.device_sync_),
cache_enable_(tensor.cache_enable_),
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
padding_type_(tensor.padding_type()) {}
Tensor::Tensor(const Tensor &tensor, TypeId data_type)
@ -459,6 +462,9 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
event_(tensor.event_),
sync_status_(tensor.sync_status_),
device_sync_(tensor.device_sync_),
cache_enable_(tensor.cache_enable_),
cache_tensor_ptr_(tensor.cache_tensor_ptr_),
hashmap_tensor_ptr_(tensor.hashmap_tensor_ptr_),
padding_type_(tensor.padding_type()) {}
Tensor::Tensor(TypeId data_type, const ShapeVector &shape, TensorDataPtr data)
@ -511,7 +517,7 @@ bool Tensor::ValueEqual(const Tensor &tensor) const {
return (&tensor == this || (MetaTensor::operator==(tensor) && data_->equals(*tensor.data_)));
}
// assgin value to this tensor
// assign value to this tensor
Tensor &Tensor::AssignValue(const Tensor &tensor) {
if (this != &tensor) {
MetaTensor::operator=(tensor);

@ -206,7 +206,7 @@ class Tensor : public MetaTensor {
// it do real value comparison.
bool ValueEqual(const Tensor &tensor) const;
// assgin value to this tensor
// assign value to this tensor
Tensor &AssignValue(const Tensor &tensor);
bool operator==(const Value &other) const override {
@ -291,6 +291,18 @@ class Tensor : public MetaTensor {
TypePtr cast_dtype() { return cast_dtype_; }
void set_cast_dtype(TypePtr dtype = nullptr) { cast_dtype_ = dtype; }
// used if cache_enable, in order to update tensor from cache to host
bool cache_enable() const { return cache_enable_; }
void set_cache_enable(bool cache_enable = true) { cache_enable_ = cache_enable; }
std::shared_ptr<Tensor> hashmap_tensor_ptr() const { return hashmap_tensor_ptr_; }
void set_hashmap_tensor_ptr(std::shared_ptr<Tensor> hashmap_tensor_ptr = nullptr) {
hashmap_tensor_ptr_ = hashmap_tensor_ptr;
}
std::shared_ptr<Tensor> cache_tensor_ptr() const { return cache_tensor_ptr_; }
void set_cache_tensor_ptr(std::shared_ptr<Tensor> cache_tensor_ptr = nullptr) {
cache_tensor_ptr_ = cache_tensor_ptr;
}
void SetNeedWait(bool need_wait) {
if (event_ != nullptr) {
event_->set_need_wait(need_wait);
@ -335,6 +347,9 @@ class Tensor : public MetaTensor {
mutable TensorSyncStatus sync_status_{kNeedSyncHostToDevice};
bool graph_output_{false};
DeviceSyncPtr device_sync_{nullptr};
bool cache_enable_{false};
std::shared_ptr<Tensor> cache_tensor_ptr_{nullptr};
std::shared_ptr<Tensor> hashmap_tensor_ptr_{nullptr};
std::vector<Axis> padding_type_;
TypePtr cast_dtype_{nullptr};
};

@ -21,6 +21,7 @@ const char GRAPH_FLAG_MIX_PRECISION_FP16[] = "fp16";
const char GRAPH_FLAG_MIX_PRECISION_FP32[] = "fp32";
const char GRAPH_FLAG_HAS_EFFECT[] = "has_effect";
const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[] = "_effect_patial_order";
const char GRAPH_FLAG_CACHE_ENABLE[] = "cache_enable";
const char GRAPH_FLAG_RANDOM_EFFECT[] = "_random_effect";
const char GRAPH_FLAG_SIDE_EFFECT[] = "_side_effect";

@ -21,6 +21,7 @@ extern const char GRAPH_FLAG_MIX_PRECISION_FP16[];
extern const char GRAPH_FLAG_MIX_PRECISION_FP32[];
extern const char GRAPH_FLAG_HAS_EFFECT[];
extern const char GRAPH_FLAG_EFFECT_PATIAL_ORDER[];
extern const char GRAPH_FLAG_CACHE_ENABLE[];
extern const char GRAPH_FLAG_RANDOM_EFFECT[];
extern const char GRAPH_FLAG_SIDE_EFFECT[];

@ -172,8 +172,8 @@ class EmbeddingLookup(Cell):
or None. Default: None
sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True.
vocab_cache_size (int): Cache size of the dictionary of embeddings. Default: 0. It is valid only in
parameter server trainning mode and 'DEVICE' target. And the moment parameter of corresponding
optimizer will also be set to the cache size. In addition, it should be noted that it will cost the 'DEVICE'
'DEVICE' target. And the moment parameter of corresponding optimizer will also be set to the cache size.
In addition, it should be noted that it will cost the 'DEVICE'
memory, so suggests setting a reasonable value to avoid insufficient memory.
Inputs:
@ -205,7 +205,12 @@ class EmbeddingLookup(Cell):
max_norm=None, sparse=True, vocab_cache_size=0):
super(EmbeddingLookup, self).__init__()
validator.check_value_type('sparse', sparse, [bool], self.cls_name)
self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
self.target = target
self.sparse = sparse
self.cache_enable = self.vocab_cache_size > 0
self.forward_unique = False
if target not in ('CPU', 'DEVICE'):
raise ValueError('Attr \'target\' of \'EmbeddingLookup\' Op passed '
+ str(target) + ', should be one of values in \'CPU\', \'DEVICE\'.')
@ -216,21 +221,23 @@ class EmbeddingLookup(Cell):
else:
self.gatherv2 = P.GatherV2()
self.embeddinglookup = P.EmbeddingLookup().add_prim_attr('primitive_target', 'CPU')
self.vocab_size = validator.check_positive_int(vocab_size, 'vocab_size')
self.vocab_cache_size = validator.check_non_negative_int(vocab_cache_size, 'vocab_cache_size')
self._process_vocab_cache(slice_mode)
enable_ps = _get_ps_context("enable_ps")
if enable_ps:
self._process_vocab_cache(slice_mode)
self.embedding_size = validator.check_positive_int(embedding_size, 'embedding_size')
self.embedding_table = Parameter(initializer(param_init, [self.vocab_size, self.embedding_size]),
name='embedding_table')
if self.cache_enable:
self._set_voacb_cache_enable(vocab_cache_size, embedding_size, vocab_size)
if self.cache_enable and enable_ps:
self._set_voacb_cache_enable_for_ps(vocab_cache_size, embedding_size, vocab_size)
parallel_mode = _get_parallel_mode()
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
self.forward_unique = False
self.gather_revert = P.GatherV2()
self.unique = P.Unique().shard(((1,),))
self.reshape_first = P.Reshape()
self.reshape = P.Reshape()
self.unique = P.Unique()
self.shape = P.Shape()
if is_auto_parallel:
self.unique = P.Unique().shard(((1,),))
indices_shape_size = 2
if slice_mode == "field_slice" and is_auto_parallel:
if not manual_shapes:
@ -270,12 +277,34 @@ class EmbeddingLookup(Cell):
if is_auto_parallel:
raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get "
+ str(slice_mode))
if self.cache_enable and not enable_ps:
if is_auto_parallel:
raise ValueError("parallel mode haven't supported cache enable yet.")
self._set_cache_enable()
self.embedding_table.unique = self.forward_unique
self.max_norm = max_norm
if self.max_norm is not None:
self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name)
self.max_norm = Tensor(self.max_norm, dtype=mstype.float32)
def _set_cache_enable(self):
"""EmbeddingLookup cache check for not ps env."""
if self.target != 'DEVICE':
logger.warning("The configuration of 'vocab_cache_size' is valid only in 'DEVICE' target, "
"so it will be ignored.")
return
if not self.sparse:
logger.warning("The configuration of 'vocab_cache_size' is valid only 'sparse' is true, "
"so it will be ignored.")
return
logger.info("EmbeddingLookup cache enable takes effect.")
self.forward_unique = True
self.unique = P.Unique().add_prim_attr('primitive_target', 'CPU')
self.unique.add_prim_attr('cache_enable', True)
self.embedding_table.cache_enable = self.cache_enable
self.embedding_table.cache_shape = (self.vocab_cache_size, self.embedding_size)
self.reshape_first = P.Reshape().add_prim_attr('primitive_target', 'CPU')
def _process_vocab_cache(self, slice_mode):
"""PS embeddingLookup cache check and process."""
self.cache_enable = False
@ -302,7 +331,7 @@ class EmbeddingLookup(Cell):
if _is_role_worker():
self.vocab_size = self.vocab_cache_size
def _set_voacb_cache_enable(self, vocab_cache_size, embedding_size, vocab_size):
def _set_voacb_cache_enable_for_ps(self, vocab_cache_size, embedding_size, vocab_size):
"""PS embeddingLookup cache enable set."""
self.embedding_table.cache_enable = True
self.embedding_table.is_param_ps = True
@ -316,7 +345,7 @@ class EmbeddingLookup(Cell):
else:
if self.forward_unique:
shp = self.shape(indices) + (self.embedding_size,)
indices_flatten = self.reshape(indices, (-1,))
indices_flatten = self.reshape_first(indices, (-1,))
unique_id, unique_idx = self.unique(indices_flatten)
weight_unique = self.gatherv2(self.embedding_table, unique_id, 0)
weight_flatten = self.gather_revert(weight_unique, unique_idx, 0)

@ -156,8 +156,8 @@ class Optimizer(Cell):
break
ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
ps_cache_filter = lambda x: x.cache_enable
self.cache_enable = tuple(ps_cache_filter(x) for x in self.parameters)
cache_filter = lambda x: x.cache_enable
self.cache_enable = tuple(cache_filter(x) for x in self.parameters)
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
self.need_scale = loss_scale != 1.0
self.global_step_increase_tensor = Tensor(1, mstype.int32)

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save