|
|
|
@ -60,34 +60,45 @@ template <typename DeviceContext, typename T>
|
|
|
|
|
class ConcatGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const {
|
|
|
|
|
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto* out_grad =
|
|
|
|
|
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
auto ins = ctx.MultiInput<framework::Tensor>("X");
|
|
|
|
|
auto out_var_names = ctx.Outputs(framework::GradVarName("X"));
|
|
|
|
|
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
|
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
|
|
|
|
|
|
|
|
|
|
// get output tensor that the name is not kEmptyVarName
|
|
|
|
|
std::vector<framework::Tensor*> outputs;
|
|
|
|
|
for (size_t j = 0; j < outs.size(); ++j) {
|
|
|
|
|
if (out_var_names[j] != framework::kEmptyVarName) {
|
|
|
|
|
outs[j]->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
outputs.push_back(outs[j]);
|
|
|
|
|
} else {
|
|
|
|
|
outputs.push_back(nullptr);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Sometimes direct copies will be faster, this maybe need deeply analysis.
|
|
|
|
|
if (axis == 0 && outs.size() < 10) {
|
|
|
|
|
size_t input_offset = 0;
|
|
|
|
|
auto in_stride = framework::stride_numel(in->dims());
|
|
|
|
|
const auto in_stride = framework::stride_numel(out_grad->dims());
|
|
|
|
|
|
|
|
|
|
for (auto& out : outs) {
|
|
|
|
|
out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
auto out_stride = framework::stride_numel(out->dims());
|
|
|
|
|
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
|
|
|
|
|
out_stride, in->data<T>() + input_offset,
|
|
|
|
|
in_stride, out_stride[axis]);
|
|
|
|
|
for (size_t i = 0; i < outs.size(); ++i) {
|
|
|
|
|
auto out_stride = framework::stride_numel(ins[i]->dims());
|
|
|
|
|
auto* out = outputs[i];
|
|
|
|
|
if (out != nullptr) {
|
|
|
|
|
StridedNumelCopyWithAxis<T>(
|
|
|
|
|
ctx.device_context(), axis, out->data<T>(), out_stride,
|
|
|
|
|
out_grad->data<T>() + input_offset, in_stride, out_stride[axis]);
|
|
|
|
|
}
|
|
|
|
|
input_offset += out_stride[axis];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
std::vector<framework::Tensor> outputs(outs.size());
|
|
|
|
|
for (size_t j = 0; j < outs.size(); ++j) {
|
|
|
|
|
outs[j]->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
outputs[j] = *outs[j];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
|
|
|
|
paddle::operators::math::ConcatGradFunctor<DeviceContext, T>
|
|
|
|
|
concat_grad_functor;
|
|
|
|
|
concat_grad_functor(dev_ctx, *in, static_cast<int>(axis), &outputs);
|
|
|
|
|
concat_grad_functor(dev_ctx, *out_grad, ins, static_cast<int>(axis),
|
|
|
|
|
&outputs);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|