Merge pull request #3371 from zchen0211/develop
scatter update implementedrevert-3824-remove_grad_op_type
commit
973618b6ab
@ -0,0 +1,92 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <cstring>
|
||||
|
||||
#include "paddle/framework/ddim.h"
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/tensor.h"
|
||||
#include "paddle/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
||||
|
||||
// Implementation of CPU copy
|
||||
template <typename T>
|
||||
void CPUScatterUpdate(const paddle::framework::Tensor* src, const int* index,
|
||||
const size_t index_size,
|
||||
paddle::framework::Tensor* output) {
|
||||
paddle::framework::DDim output_dims = output->dims();
|
||||
|
||||
for (size_t i = 0; i < index_size; ++i) {
|
||||
int index_ = index[i];
|
||||
|
||||
paddle::framework::Tensor src_ = *src;
|
||||
paddle::framework::Tensor output_ = *output;
|
||||
if (index_size > 1) src_ = src->Slice<T>(i, i + 1);
|
||||
if (output_dims[0] > 1) output_ = output->Slice<T>(index_, index_ + 1);
|
||||
|
||||
auto X = EigenVector<T>::Flatten(src_);
|
||||
auto Y = EigenVector<T>::Flatten(output_);
|
||||
|
||||
Y = X + Y;
|
||||
}
|
||||
}
|
||||
|
||||
// Implementation of GPU scatter:
|
||||
template <typename T>
|
||||
void GPUScatterUpdate(const T* src, const int* index, const int slice_size,
|
||||
const int index_size, T* output);
|
||||
|
||||
/**
|
||||
* Return a updated tensor from source tensor, scattered according to index:
|
||||
* dst[i] += src[index[i]]
|
||||
* input[src]: type-T source Tensor
|
||||
* input[index]: type-int index Tensor (1-D)
|
||||
* return: output tensor
|
||||
*/
|
||||
template <typename T>
|
||||
void ScatterUpdate(const platform::Place& place,
|
||||
const paddle::framework::Tensor* src,
|
||||
const paddle::framework::Tensor* index,
|
||||
paddle::framework::Tensor* output) {
|
||||
// check index of shape 1-D
|
||||
PADDLE_ENFORCE(index->dims().size() == 1);
|
||||
int index_size = index->dims()[0];
|
||||
|
||||
auto src_dims = src->dims();
|
||||
auto dst_dims = output->dims();
|
||||
|
||||
// check src shape and dst shape should match
|
||||
for (int i = 1; i < src_dims.size(); i++)
|
||||
PADDLE_ENFORCE(src_dims[i] == dst_dims[i]);
|
||||
|
||||
// slice size
|
||||
size_t slice_size = 1;
|
||||
for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i];
|
||||
|
||||
if (platform::is_cpu_place(place)) {
|
||||
CPUScatterUpdate<T>(src, index->data<int>(), index_size, output);
|
||||
} else {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,52 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/scatter.h"
|
||||
#include "paddle/framework/ddim.h"
|
||||
#include "paddle/framework/tensor.h"
|
||||
#include "paddle/platform/place.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
TEST(scatter, ScatterUpdate) {
|
||||
using namespace paddle::framework;
|
||||
using namespace paddle::platform;
|
||||
using namespace paddle::operators;
|
||||
|
||||
Tensor* src = new Tensor();
|
||||
Tensor* index = new Tensor();
|
||||
Tensor* output = new Tensor();
|
||||
|
||||
float* p_src = nullptr;
|
||||
int* p_index = nullptr;
|
||||
p_src = src->mutable_data<float>(make_ddim({1, 4}), CPUPlace());
|
||||
p_index = index->mutable_data<int>(make_ddim({1}), CPUPlace());
|
||||
|
||||
for (size_t i = 0; i < 4; ++i) p_src[i] = float(i);
|
||||
p_index[0] = 1;
|
||||
|
||||
float* p_output = output->mutable_data<float>(make_ddim({4, 4}), CPUPlace());
|
||||
|
||||
ScatterUpdate<float>(CPUPlace(), src, index, output);
|
||||
|
||||
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0));
|
||||
for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data<float>()[i], float(0));
|
||||
for (size_t i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], float(i - 4));
|
||||
for (size_t i = 4; i < 8; ++i)
|
||||
EXPECT_EQ(output->data<float>()[i], float(i - 4));
|
||||
for (size_t i = 8; i < 16; ++i) EXPECT_EQ(p_output[i], float(0));
|
||||
for (size_t i = 8; i < 16; ++i) EXPECT_EQ(output->data<float>()[i], float(0));
|
||||
}
|
Loading…
Reference in new issue