|
|
|
@ -486,7 +486,8 @@ void ParameterServer<T>::InitEmbeddingTable(
|
|
|
|
|
|
|
|
|
|
// Init embedding weight
|
|
|
|
|
const std::vector<size_t> &input_shapes = lookup->input_sizes();
|
|
|
|
|
size_t total_dims = std::accumulate(input_shapes.begin(), input_shapes.end(), 1, std::multiplies<size_t>());
|
|
|
|
|
size_t total_dims =
|
|
|
|
|
std::accumulate(input_shapes.begin(), input_shapes.end(), IntToSize(1), std::multiplies<size_t>());
|
|
|
|
|
WeightPtr embedding = std::make_shared<Weight>(total_dims, 0);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(embedding);
|
|
|
|
|
T *embedding_data = embedding->data();
|
|
|
|
@ -732,7 +733,8 @@ void ParameterServer<T>::SyncEmbeddingTables() {
|
|
|
|
|
for (auto embedding_table : embedding_tables_) {
|
|
|
|
|
Key key = embedding_table.first;
|
|
|
|
|
if (embedding_lookup_ops_.count(key) == 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Can't find look up PS kernel for key " << key;
|
|
|
|
|
MS_LOG(WARNING) << "Can't find look up PS kernel for key " << key;
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto lookup = embedding_lookup_ops_[key];
|
|
|
|
|
const std::vector<size_t> &input_shapes = lookup->input_sizes();
|
|
|
|
|