apply more general dims for multiplex_op

update-doc-pybind
Yibing Liu 7 years ago
parent 089f8e2d37
commit 58ac8f46b8

@ -44,7 +44,8 @@ class MultiplexOp : public framework::OperatorWithKernel {
"one candidate input tensors.");
auto in_dim = ins[0]->dims();
PADDLE_ENFORCE(in_dim.size() == 2, "Candidate tensors must be matrix.");
PADDLE_ENFORCE(in_dim.size() >= 2,
"The rank of candidate tensors must be not less than 2.");
for (size_t i = 1; i < num_ins; i++) {
auto dim = ins[i]->dims();
PADDLE_ENFORCE(in_dim == dim,
@ -65,8 +66,7 @@ class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "The output tensor of multiplex operator.");
AddComment(R"DOC(Multiplex operator
Multiplex multiple tensors according to the index provided by the first
input tensor.
Multiplex multiple tensors according to the index provided by the index tensor.
Ids: the index tensor.
X[0 : N - 1]: the candidate tensors for output (N >= 2).
@ -75,7 +75,7 @@ the (Ids[i])-th tensor.
For i-th row of the output tensor:
y[i][j] = x_{k}[i][j], j = 0,1, ... , (x_{0}.width - 1)
y[i] = x_{k}[i]
where y is the output tensor. `x_{k}` is the k-th input tensor
and `k = Ids[i]`.

@ -30,7 +30,7 @@ class MultiplexGPUKernel : public framework::OpKernel {
out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1];
auto cols = ins[0]->numel() / rows;
// copy index to cpu
Tensor index_t_cpu;
index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace());
@ -67,7 +67,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel {
}
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1];
auto cols = ins[0]->numel() / rows;
// copy index to cpu
Tensor index_t_cpu;
index_t_cpu.CopyFrom<int32_t>(*ids, platform::CPUPlace());

@ -33,7 +33,7 @@ class MultiplexCPUKernel : public framework::OpKernel {
out->mutable_data<T>(ctx.GetPlace());
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1];
auto cols = ins[0]->numel() / rows;
auto index = ids->data<int32_t>();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
@ -65,7 +65,7 @@ class MultiplexGradCPUKernel : public framework::OpKernel {
}
auto rows = ins[0]->dims()[0];
auto cols = ins[0]->dims()[1];
auto cols = ins[0]->numel() / rows;
auto* index = ids->data<int32_t>();
Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {

Loading…
Cancel
Save