Fix the correctness of async mode at distributed training (#18863)
* fix correctness of the communicator * fix a bug in send thread when sending var context is empty, test=develop * add lookup_table_prefetch_op and prefetch optimize, test=develop * remove remote prefetch GPU supported * word2vec force with CPU, test=develop * test dist remote lookup table force with CPU, test=developassert
parent
61389ae5aa
commit
65c7368400
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,166 @@
|
|||||||
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/data_type.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
|
||||||
|
#include "paddle/fluid/operators/math/math_function.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class DistributedLookupTableOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInputs("Ids"),
|
||||||
|
"Input(Ids) of LookupTableOp should not be null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("W"),
|
||||||
|
"Input(W) of LookupTableOp should not be null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutputs("Outputs"),
|
||||||
|
"Output(Outs) of LookupTableOp should not be null.");
|
||||||
|
|
||||||
|
auto ids_dims = ctx->GetInputsDim("Ids");
|
||||||
|
auto table_dims = ctx->GetInputDim("W");
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_EQ(table_dims.size(), 2,
|
||||||
|
"Only 2 dimensions of the 'Embedding' is supported.");
|
||||||
|
|
||||||
|
for (auto &ids_dim : ids_dims) {
|
||||||
|
PADDLE_ENFORCE_EQ(ids_dim.size(), 2,
|
||||||
|
"The dimension of the 'Ids' tensor must be 2.");
|
||||||
|
PADDLE_ENFORCE_EQ(ids_dim[1], 1,
|
||||||
|
"The last dimension of the 'Ids' tensor must be 1.");
|
||||||
|
}
|
||||||
|
|
||||||
|
auto lookup_tables =
|
||||||
|
ctx->Attrs().Get<std::vector<std::string>>("table_names");
|
||||||
|
auto height_sections =
|
||||||
|
ctx->Attrs().Get<std::vector<int64_t>>("height_sections");
|
||||||
|
auto endpoints = ctx->Attrs().Get<std::vector<std::string>>("endpoints");
|
||||||
|
|
||||||
|
PADDLE_ENFORCE(lookup_tables.size() == height_sections.size() &&
|
||||||
|
lookup_tables.size() == endpoints.size() &&
|
||||||
|
lookup_tables.size() != 0,
|
||||||
|
"Attrs lookup_tables/height_sections/endpoints must have "
|
||||||
|
"save size and can not be 0.");
|
||||||
|
|
||||||
|
auto outputs_dims = std::vector<framework::DDim>();
|
||||||
|
|
||||||
|
for (auto &ids_dim : ids_dims) {
|
||||||
|
outputs_dims.push_back(framework::make_ddim({ids_dim[0], table_dims[1]}));
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx->SetOutputsDim("Outputs", outputs_dims);
|
||||||
|
ctx->ShareLoD("Ids", /*->*/ "Outputs");
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext &ctx) const override {
|
||||||
|
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
|
||||||
|
return framework::OpKernelType(data_type, ctx.device_context());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class DistributedLookupTableKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext &context) const override {
|
||||||
|
auto ids_vars = context.MultiInputVar("Ids");
|
||||||
|
auto emb_vars = context.MultiOutput<framework::Tensor>("Embeddings");
|
||||||
|
|
||||||
|
auto id_names = context.Inputs("Ids");
|
||||||
|
auto embedding_name = context.Inputs("W").front();
|
||||||
|
auto out_names = context.Outputs("Outputs");
|
||||||
|
|
||||||
|
auto lookup_tables = context.Attr<std::vector<std::string>>("table_names");
|
||||||
|
auto height_sections =
|
||||||
|
context.Attr<std::vector<int64_t>>("height_sections");
|
||||||
|
auto endpoints = context.Attr<std::vector<std::string>>("endpoints");
|
||||||
|
|
||||||
|
operators::distributed::prefetchs(
|
||||||
|
id_names, out_names, embedding_name, false, lookup_tables, endpoints,
|
||||||
|
height_sections, context, context.scope());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class DistributedLookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("Ids",
|
||||||
|
"(LoDTensor) Ids's type should be LoDTensor"
|
||||||
|
"THe ids to be looked up in W.")
|
||||||
|
.AsDuplicable();
|
||||||
|
|
||||||
|
AddInput("W",
|
||||||
|
"(Tensor) The input represents embedding tensors, "
|
||||||
|
"which is a learnable parameter.");
|
||||||
|
|
||||||
|
AddOutput("Outputs",
|
||||||
|
"(LoDTensor) The lookup results, which have the same type as W.")
|
||||||
|
.AsDuplicable();
|
||||||
|
|
||||||
|
AddAttr<std::vector<std::string>>(
|
||||||
|
"table_names",
|
||||||
|
"(string vector, such as emb_block0, emb_block1)"
|
||||||
|
"Server endpoints in the order of input variables for mapping")
|
||||||
|
.SetDefault({""});
|
||||||
|
|
||||||
|
AddAttr<std::vector<int64_t>>("height_sections",
|
||||||
|
"Height for each output SelectedRows.")
|
||||||
|
.SetDefault(std::vector<int64_t>({}));
|
||||||
|
|
||||||
|
AddAttr<std::vector<std::string>>(
|
||||||
|
"endpoints",
|
||||||
|
"(string vector, default 127.0.0.1:6164)"
|
||||||
|
"Server endpoints in the order of input variables for mapping")
|
||||||
|
.SetDefault({"127.0.0.1:6164"});
|
||||||
|
|
||||||
|
AddAttr<int>("trainer_id", "trainer id from 0 ~ worker_num.").SetDefault(0);
|
||||||
|
|
||||||
|
AddAttr<int64_t>("padding_idx",
|
||||||
|
"(int64, default -1) "
|
||||||
|
"If the value is -1, it makes no effect to lookup. "
|
||||||
|
"Otherwise the given value indicates padding the output "
|
||||||
|
"with zeros whenever lookup encounters it in Ids.")
|
||||||
|
.SetDefault(distributed::kNoPadding);
|
||||||
|
|
||||||
|
AddComment(R"DOC(
|
||||||
|
Lookup Tablel Prefetch Operator.
|
||||||
|
|
||||||
|
This operator is used to perform lookup on parameter W,
|
||||||
|
then concatenated into a sparse tensor.
|
||||||
|
|
||||||
|
The type of Ids(Input) is SelectedRows, the rows of Ids contains
|
||||||
|
the ids to be looked up in W;
|
||||||
|
if the Id is not in the sparse table, this operator will return a
|
||||||
|
random value and set the value into the table for the next looking up.
|
||||||
|
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
|
||||||
|
REGISTER_OPERATOR(distributed_lookup_table, ops::DistributedLookupTableOp,
|
||||||
|
ops::DistributedLookupTableOpMaker);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(distributed_lookup_table,
|
||||||
|
ops::DistributedLookupTableKernel<float>);
|
Loading…
Reference in new issue