Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_api_reference_docs
commit
a83b792ada
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,128 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/merge_ids_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class MergeIdsOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Ids", "(LoDTensor) the input ids with shape{batch_num, 1}");
|
||||
AddInput(
|
||||
"X",
|
||||
"(LoDTensors) multi input tensor with shape{batch_num, N}, N is the "
|
||||
"size of embedding table")
|
||||
.AsDuplicable();
|
||||
AddOutput("Out", "(LoDTensor) The merged outputs of the input tensors.");
|
||||
|
||||
AddComment(R"DOC(
|
||||
Merge multi LoDTensor's into one according to Ids's shard num.
|
||||
|
||||
|
||||
split_ids_op -> prefetch_op -> merge_ids_op
|
||||
|
||||
|
||||
merge_ids_op should be used after split_ids_op and prefetch_op, split_ids_op
|
||||
will split input Ids into multiple tensors according to Id's shard number.
|
||||
prefetch_op will send them to parameter server to prefetch embedding value
|
||||
back. During split, the order of ids is disordered. In merge_ids_op we use
|
||||
the original Ids to restore the order of the fetched embedding value and
|
||||
also pass the lod information to the merged output.
|
||||
|
||||
|
||||
Example:
|
||||
|
||||
Ids = [1,2,3,4,5,6] # 3 shared
|
||||
|
||||
split_ids_op ->
|
||||
|
||||
Id0 = [3, 6] # id % 3 == 0
|
||||
Id1 = [1, 4] # id % 3 == 1
|
||||
Id2 = [2, 5] # id % 3 == 2
|
||||
|
||||
prefetch_op ->
|
||||
|
||||
X0 = [[0.3 0.3] # 3
|
||||
[0.6 0.6]] # 6
|
||||
X1 = [[0.1 0.1] # 1
|
||||
[0.4 0.4]] # 4
|
||||
X2 = [[0.2 0.2] # 2
|
||||
[0.5 0.5]] # 5
|
||||
|
||||
merge_ids_op ->
|
||||
|
||||
Out = [[0.1 0.1] # 1
|
||||
[0.2 0.2] # 2
|
||||
[0.3 0.3] # 3
|
||||
[0.4 0.4] # 4
|
||||
[0.5 0.5] # 5
|
||||
[0.6 0.6]] # 6
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class MergeIdsOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Ids"), "MergeIdsOp must has input Ids.");
|
||||
PADDLE_ENFORCE(ctx->HasInputs("X"), "MergeIdsOp must has input X.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"), "MergeIdsOp must has output Out.");
|
||||
|
||||
auto ids_var_type = ctx->GetInputsVarType("Ids").front();
|
||||
auto ids_dims = ctx->GetInputDim("Ids");
|
||||
if (ids_var_type == framework::proto::VarType::LOD_TENSOR) {
|
||||
PADDLE_ENFORCE_EQ(ids_dims.size(), 2);
|
||||
PADDLE_ENFORCE_EQ(ids_dims[1], 1);
|
||||
}
|
||||
auto x_var_type = ctx->GetInputsVarType("X");
|
||||
for (auto &var_type : x_var_type) {
|
||||
PADDLE_ENFORCE_EQ(var_type, framework::proto::VarType::LOD_TENSOR,
|
||||
"input X only support lod tensors");
|
||||
}
|
||||
ctx->ShareLoD("Ids", "Out");
|
||||
}
|
||||
|
||||
private:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(
|
||||
ctx.MultiInput<framework::Tensor>("X").front()->type()),
|
||||
ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
class MergeIdsOpInferVarType : public framework::VarTypeInference {
|
||||
public:
|
||||
void operator()(const framework::OpDesc &op_desc,
|
||||
framework::BlockDesc *block) const override {
|
||||
auto *input_var = block->Var(op_desc.Input("Ids")[0]);
|
||||
for (auto &out_var : op_desc.Output("Out")) {
|
||||
block->Var(out_var)->SetType(input_var->GetType());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(merge_ids, ops::MergeIdsOp, ops::MergeIdsOpMaker,
|
||||
ops::MergeIdsOpInferVarType);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
merge_ids, ops::MergeIdsOpKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,92 @@
|
||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/tensor_util.h"
|
||||
#include "paddle/fluid/operators/math/selected_rows_functor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class MergeIdsOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &ctx) const override {
|
||||
auto place = ctx.GetPlace();
|
||||
if (!platform::is_cpu_place(place)) {
|
||||
PADDLE_THROW("MergeIds do not support GPU kernel");
|
||||
}
|
||||
VLOG(3) << "run in MergeIdsOpKernel";
|
||||
|
||||
const auto *ids_var = ctx.InputVar("Ids");
|
||||
PADDLE_ENFORCE(ids_var->IsType<framework::LoDTensor>(),
|
||||
"only support to merge Ids of LoDTensor");
|
||||
|
||||
const auto &ids_tensor = ids_var->Get<framework::LoDTensor>();
|
||||
const auto &ids_dims = ids_tensor.dims();
|
||||
const int64_t *ids = ids_tensor.data<int64_t>();
|
||||
|
||||
auto x_tensors = ctx.MultiInput<framework::LoDTensor>("X");
|
||||
|
||||
auto *out = ctx.Output<framework::LoDTensor>("Out");
|
||||
|
||||
int batch_size = 0;
|
||||
int embedding_size = 0;
|
||||
for (auto &input : x_tensors) {
|
||||
if (framework::product(input->dims()) != 0) {
|
||||
if (embedding_size == 0) {
|
||||
embedding_size = input->dims()[1];
|
||||
}
|
||||
PADDLE_ENFORCE_EQ(embedding_size, input->dims()[1],
|
||||
"embedding size of all input should be the same");
|
||||
batch_size += input->dims()[0];
|
||||
}
|
||||
}
|
||||
PADDLE_ENFORCE_EQ(
|
||||
batch_size, ids_dims[0],
|
||||
"the batch size of ids and merged embedding value should be the same");
|
||||
|
||||
const size_t shard_num = x_tensors.size();
|
||||
|
||||
if (shard_num == 1) {
|
||||
VLOG(3) << "only one shard, we can copy the data directly";
|
||||
TensorCopy(*x_tensors[0], place, out);
|
||||
} else {
|
||||
std::vector<int> in_indexs(shard_num, 0);
|
||||
auto *out_data = out->mutable_data<T>(
|
||||
framework::make_ddim({batch_size, embedding_size}), place);
|
||||
// copy data from ins[shard_num] to out.
|
||||
for (int i = 0; i < ids_dims[0]; ++i) {
|
||||
int64_t id = ids[i];
|
||||
size_t shard_id = static_cast<size_t>(id) % shard_num;
|
||||
int index = in_indexs[shard_id];
|
||||
memcpy(out_data + embedding_size * i,
|
||||
x_tensors[shard_id]->data<T>() + index * embedding_size,
|
||||
sizeof(T) * embedding_size);
|
||||
in_indexs[shard_id] += 1;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < shard_num; ++i) {
|
||||
PADDLE_ENFORCE_EQ(in_indexs[i], x_tensors[i]->dims()[0],
|
||||
"after merge, all data in x_tensor should be used");
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue