diff --git a/paddle/fluid/operators/array_to_lod_tensor_op.cc b/paddle/fluid/operators/array_to_lod_tensor_op.cc index 149226e92d..bc725e5357 100644 --- a/paddle/fluid/operators/array_to_lod_tensor_op.cc +++ b/paddle/fluid/operators/array_to_lod_tensor_op.cc @@ -11,6 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include "paddle/fluid/framework/lod_rank_table.h" @@ -24,6 +25,50 @@ namespace operators { using LoD = framework::LoD; +class ArrayToLoDFunctor; +template +struct ArrayToLoDFunctorImpl { + const ArrayToLoDFunctor *prev_functor_; + DeviceContext *dev_ctx_; + + template + void apply(); +}; + +struct ArrayToLoDFunctor : public boost::static_visitor { + std::vector in; + mutable framework::Tensor *out; + + template + void operator()(Place place) const { + auto &pool = platform::DeviceContextPool::Instance(); + if (std::is_same::value) { + Apply(static_cast(pool.Get(place))); + } else { +#ifdef PADDLE_WITH_CUDA + Apply(static_cast(pool.Get(place))); +#else + PADDLE_THROW("Fluid is not compiled with CUDA"); +#endif + } + } + + template + void Apply(DeviceContext *dev_ctx) const { + ArrayToLoDFunctorImpl functor; + functor.dev_ctx_ = dev_ctx; + functor.prev_functor_ = this; + framework::VisitDataType(framework::ToDataType(out->type()), functor); + } +}; + +template +template +void ArrayToLoDFunctorImpl::apply() { + math::ConcatFunctor func; + func(*dev_ctx_, prev_functor_->in, 0, prev_functor_->out); +} + class ArrayToLoDTensorOp : public framework::OperatorBase { public: ArrayToLoDTensorOp(const std::string &type, @@ -47,14 +92,18 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { int rank = x[0].dims().size(); platform::Place place = x[0].place(); std::type_index data_type = x[0].type(); - framework::DDim ins_dims = framework::slice_ddim(x[0].dims(), 1, rank); int64_t batch_size = x[0].dims()[0]; + framework::DDim ins_dims = rank > 1 + ? framework::slice_ddim(x[0].dims(), 1, rank) + : framework::make_ddim({0}); for (size_t i = 1; i < x.size(); ++i) { - PADDLE_ENFORCE_EQ(framework::slice_ddim(x[i].dims(), 1, rank), ins_dims, + auto ins_i_dims = rank > 1 ? framework::slice_ddim(x[i].dims(), 1, rank) + : framework::make_ddim({0}); + PADDLE_ENFORCE_EQ(ins_i_dims, ins_dims, "The dimension of the %zu'th element in LoDTensorArray " "differs from previous ones.", i); - PADDLE_ENFORCE(platform::places_are_same_class(x[i].place(), place), + PADDLE_ENFORCE(x[i].place() == place, "The place class of the %zu'th element in LoDTensorArray " "differs from previous ones.", i); @@ -82,13 +131,14 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { // Build LoDTensor `out` framework::LoD *out_lod = out->mutable_lod(); out_lod->clear(); - size_t out_offset = 0; auto prefix_lod = rank_table.coarse_lod(); prefix_lod.emplace_back(); auto &cur_level_lod = prefix_lod.back(); cur_level_lod.push_back(0); + ArrayToLoDFunctor functor; for (size_t idx : table_item_idx) { cur_level_lod.push_back(cur_level_lod.back() + table_items[idx].length); + PADDLE_ENFORCE_LE(table_items[idx].length, x.size()); for (size_t x_idx = 0; x_idx < table_items[idx].length; ++x_idx) { auto lod_and_offset = framework::GetSubLoDAndAbsoluteOffset( x[x_idx].lod(), idx, idx + 1, 0); @@ -106,17 +156,11 @@ class ArrayToLoDTensorOp : public framework::OperatorBase { if (len == 0) { continue; } - auto slice = out->Slice(out_offset, out_offset + len); - - platform::DeviceContextPool &pool = - platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); - - framework::TensorCopy(x[x_idx].Slice(start_offset, end_offset), place, - dev_ctx, &slice); - out_offset += len; + functor.in.emplace_back(x[x_idx].Slice(start_offset, end_offset)); } } + functor.out = out; + platform::VisitPlace(place, functor); out_lod->insert(out_lod->begin(), prefix_lod.begin(), prefix_lod.end()); } }; diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index c724055937..bc58612f9d 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -95,6 +95,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); + ctx->ShareLoD("X", framework::GradVarName("X")); } }; diff --git a/paddle/fluid/operators/concat_op.h b/paddle/fluid/operators/concat_op.h index 78be2e1e1f..b2c6495c44 100644 --- a/paddle/fluid/operators/concat_op.h +++ b/paddle/fluid/operators/concat_op.h @@ -109,8 +109,9 @@ class ConcatGradKernel : public framework::OpKernel { auto& dev_ctx = ctx.template device_context(); paddle::operators::math::ConcatGradFunctor concat_grad_functor; - concat_grad_functor(dev_ctx, *out_grad, ins, static_cast(axis), - &outputs); + concat_grad_functor(dev_ctx, *out_grad, + ctx.MultiInput("X"), + static_cast(axis), &outputs); } } }; diff --git a/paddle/fluid/operators/lod_tensor_to_array_op.cc b/paddle/fluid/operators/lod_tensor_to_array_op.cc index b3f7e0c009..8eab83fcd2 100644 --- a/paddle/fluid/operators/lod_tensor_to_array_op.cc +++ b/paddle/fluid/operators/lod_tensor_to_array_op.cc @@ -11,10 +11,13 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include +#include #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/port.h" @@ -26,6 +29,61 @@ struct CopyRange { size_t end; }; +struct LoDTensorToArrayFunctor; + +template +struct LoDTensorToArrayFunctorImpl { + const LoDTensorToArrayFunctor *prev_functor_; + DeviceContext *dev_ctx_; + template + void apply(); +}; + +struct LoDTensorToArrayFunctor : public boost::static_visitor { + std::vector ref_inputs_; + mutable std::vector outputs_; + const framework::Tensor &input_; + + explicit LoDTensorToArrayFunctor(const framework::Tensor &input) + : input_(input) {} + + void AddOutput(framework::Tensor *t) { + outputs_.emplace_back(t); + ref_inputs_.emplace_back(t); + } + + template + void operator()(Place place) const { + auto &pool = platform::DeviceContextPool::Instance(); + auto *dev_ctx = pool.Get(place); + if (std::is_same::value) { + Apply(static_cast(dev_ctx)); + } else { +#ifdef PADDLE_WITH_CUDA + Apply(static_cast(dev_ctx)); +#else + PADDLE_THROW("Not compiled with cuda"); +#endif + } + } + + template + void Apply(DeviceContext *dev_ctx) const { + LoDTensorToArrayFunctorImpl func; + func.prev_functor_ = this; + func.dev_ctx_ = dev_ctx; + framework::VisitDataType(framework::ToDataType(input_.type()), func); + } +}; + +template +template +void LoDTensorToArrayFunctorImpl::apply() { + math::ConcatGradFunctor func; + func(*dev_ctx_, prev_functor_->input_, prev_functor_->ref_inputs_, 0, + &prev_functor_->outputs_); +} + class LoDTensorToArrayOp : public framework::OperatorBase { public: LoDTensorToArrayOp(const std::string &type, @@ -72,6 +130,11 @@ class LoDTensorToArrayOp : public framework::OperatorBase { copy_ranges[t].emplace_back(CopyRange{start_offset, end_offset}); } } + + auto &outputs = *const_cast(scope) + .Var() + ->GetMutable>(); + for (size_t i = 0; i < max_seq_len; ++i) { auto &ranges = copy_ranges[i]; size_t height = std::accumulate( @@ -90,17 +153,16 @@ class LoDTensorToArrayOp : public framework::OperatorBase { // out[i][offset: offset+len] = x[each_range.begin: each_range.end] auto slice = out[i].Slice(static_cast(offset), static_cast(offset + len)); - - platform::DeviceContextPool &pool = - platform::DeviceContextPool::Instance(); - auto &dev_ctx = *pool.Get(place); - - framework::TensorCopy(x.Slice(static_cast(each_range.begin), - static_cast(each_range.end)), - x.place(), dev_ctx, &slice); + outputs.insert({each_range.begin, slice}); offset += len; } } + + LoDTensorToArrayFunctor functor(x); + for (auto &out_pair : outputs) { + functor.AddOutput(&out_pair.second); + } + platform::VisitPlace(place, functor); } }; diff --git a/paddle/fluid/operators/math/concat.cc b/paddle/fluid/operators/math/concat.cc index c3c5c160db..7b79f10e33 100644 --- a/paddle/fluid/operators/math/concat.cc +++ b/paddle/fluid/operators/math/concat.cc @@ -27,7 +27,7 @@ template class ConcatFunctor { public: void operator()(const platform::CPUDeviceContext& context, - const std::vector& input, const int axis, + const std::vector& input, int axis, framework::Tensor* output) { // TODO(zcd): Add input data validity checking int num = input.size(); @@ -71,7 +71,7 @@ class ConcatGradFunctor { public: void operator()(const platform::CPUDeviceContext& context, const framework::Tensor& input, - const std::vector& ref_inputs, + const std::vector& ref_inputs, const int axis, std::vector* outputs) { // TODO(zcd): Add input data validity checking size_t num = outputs->size(); @@ -109,16 +109,11 @@ class ConcatGradFunctor { } } }; +#define DEFINE_FUNCTOR(type) \ + template class ConcatFunctor; \ + template class ConcatGradFunctor; -template class ConcatFunctor; -template class ConcatFunctor; -template class ConcatFunctor; -template class ConcatFunctor; - -template class ConcatGradFunctor; -template class ConcatGradFunctor; -template class ConcatGradFunctor; -template class ConcatGradFunctor; +FOR_ALL_TYPES(DEFINE_FUNCTOR); } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/concat.cu b/paddle/fluid/operators/math/concat.cu index 342379268b..b59d86e661 100644 --- a/paddle/fluid/operators/math/concat.cu +++ b/paddle/fluid/operators/math/concat.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/operators/math/concat.h" #include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -118,7 +119,7 @@ template class ConcatFunctor { public: void operator()(const platform::CUDADeviceContext& context, - const std::vector& input, const int axis, + const std::vector& input, int axis, framework::Tensor* output) { // TODO(zcd): Add input data validity checking int in_num = input.size(); @@ -192,8 +193,8 @@ class ConcatGradFunctor { public: void operator()(const platform::CUDADeviceContext& context, const framework::Tensor& input, - const std::vector& ref_inputs, - const int axis, std::vector* outputs) { + const std::vector& ref_inputs, + int axis, std::vector* outputs) { // TODO(zcd): Add input data validity checking int o_num = outputs->size(); int out_row = 1; @@ -261,15 +262,11 @@ class ConcatGradFunctor { } }; -template class ConcatFunctor; -template class ConcatFunctor; -template class ConcatFunctor; -template class ConcatFunctor; +#define DEFINE_FUNCTOR(type) \ + template class ConcatFunctor; \ + template class ConcatGradFunctor -template class ConcatGradFunctor; -template class ConcatGradFunctor; -template class ConcatGradFunctor; -template class ConcatGradFunctor; +FOR_ALL_TYPES(DEFINE_FUNCTOR); } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/concat.h b/paddle/fluid/operators/math/concat.h index e5d7d860b3..867a84fa87 100644 --- a/paddle/fluid/operators/math/concat.h +++ b/paddle/fluid/operators/math/concat.h @@ -37,7 +37,7 @@ template class ConcatFunctor { public: void operator()(const DeviceContext& context, - const std::vector& input, const int axis, + const std::vector& input, int axis, framework::Tensor* output); }; @@ -57,10 +57,21 @@ template class ConcatGradFunctor { public: void operator()(const DeviceContext& context, const framework::Tensor& input, - const std::vector& ref_inputs, - const int axis, std::vector* outputs); + const std::vector& ref_inputs, + int axis, std::vector* outputs); }; } // namespace math } // namespace operators } // namespace paddle + +#define FOR_ALL_TYPES(macro) \ + macro(int); \ + macro(float); \ + macro(double); \ + macro(bool); \ + macro(int64_t); \ + macro(int16_t); \ + macro(uint8_t); \ + macro(int8_t); \ + macro(::paddle::platform::float16)