Set weight size limit for PS mode.

pull/6807/head
ZPaC 4 years ago
parent 39874d133f
commit 15be0cc819

@ -29,6 +29,10 @@ void EmbeddingLookUpProxyKernel::InitKernel(const CNodePtr &kernel_node) {
for (auto dim : input_shape) {
input_dims_ *= dim;
}
if (input_dims_ * sizeof(float) > INT_MAX) {
MS_LOG(EXCEPTION) << "PS mode embedding lookup max embedding table size is " << INT_MAX << ", current shape "
<< input_shape << " is too large.";
}
if (mindspore::ps::Util::IsRoleOfWorker()) {
key_ = AnfAlgo::GetNodeAttr<size_t>(kernel_node, kAttrPsKey);

@ -486,7 +486,8 @@ void ParameterServer<T>::InitEmbeddingTable(
// Init embedding weight
const std::vector<size_t> &input_shapes = lookup->input_sizes();
size_t total_dims = std::accumulate(input_shapes.begin(), input_shapes.end(), 1, std::multiplies<size_t>());
size_t total_dims =
std::accumulate(input_shapes.begin(), input_shapes.end(), IntToSize(1), std::multiplies<size_t>());
WeightPtr embedding = std::make_shared<Weight>(total_dims, 0);
MS_EXCEPTION_IF_NULL(embedding);
T *embedding_data = embedding->data();
@ -732,7 +733,8 @@ void ParameterServer<T>::SyncEmbeddingTables() {
for (auto embedding_table : embedding_tables_) {
Key key = embedding_table.first;
if (embedding_lookup_ops_.count(key) == 0) {
MS_LOG(EXCEPTION) << "Can't find look up PS kernel for key " << key;
MS_LOG(WARNING) << "Can't find look up PS kernel for key " << key;
continue;
}
auto lookup = embedding_lookup_ops_[key];
const std::vector<size_t> &input_shapes = lookup->input_sizes();

@ -325,6 +325,9 @@ void Worker<T>::InitPSParamAndOptim(const std::string &param_name, const tensor:
MS_EXCEPTION_IF_NULL(tensor);
void *param_data = tensor->data_c();
size_t param_size = LongToSize(tensor->data().nbytes());
if (param_size > INT_MAX) {
MS_LOG(EXCEPTION) << "PS mode max weight size is " << INT_MAX << ", " << param_name << " size is " << param_size;
}
ShapeVector param_shape = tensor->shape_c();
size_t param_key = GetParamKey(param_name);

Loading…
Cancel
Save