commit
8ba177593b
@ -0,0 +1,100 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/memory/allocation/allocator_facade.h"
|
||||
#include <gflags/gflags.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
DECLARE_double(fraction_of_gpu_memory_to_use);
|
||||
DECLARE_double(fraction_of_cuda_pinned_memory_to_use);
|
||||
DECLARE_uint64(initial_gpu_memory_in_mb);
|
||||
DECLARE_uint64(reallocate_gpu_memory_in_mb);
|
||||
DECLARE_int64(gpu_allocator_retry_time);
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace memory {
|
||||
namespace allocation {
|
||||
|
||||
//! Run allocate test cases for different places
|
||||
void AllocateTestCases() {
|
||||
auto &instance = AllocatorFacade::Instance();
|
||||
platform::Place place;
|
||||
size_t size = 1024;
|
||||
|
||||
{
|
||||
place = platform::CPUPlace();
|
||||
size = 1024;
|
||||
auto cpu_allocation = instance.Alloc(place, size);
|
||||
ASSERT_NE(cpu_allocation, nullptr);
|
||||
ASSERT_NE(cpu_allocation->ptr(), nullptr);
|
||||
ASSERT_EQ(cpu_allocation->place(), place);
|
||||
ASSERT_EQ(cpu_allocation->size(), size);
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
{
|
||||
place = platform::CUDAPlace(0);
|
||||
size = 1024;
|
||||
auto gpu_allocation = instance.Alloc(place, size);
|
||||
ASSERT_NE(gpu_allocation, nullptr);
|
||||
ASSERT_NE(gpu_allocation->ptr(), nullptr);
|
||||
ASSERT_EQ(gpu_allocation->place(), place);
|
||||
ASSERT_GE(gpu_allocation->size(), size);
|
||||
}
|
||||
|
||||
{
|
||||
// Allocate 2GB gpu memory
|
||||
place = platform::CUDAPlace(0);
|
||||
size = 2 * static_cast<size_t>(1 << 30);
|
||||
auto gpu_allocation = instance.Alloc(place, size);
|
||||
ASSERT_NE(gpu_allocation, nullptr);
|
||||
ASSERT_NE(gpu_allocation->ptr(), nullptr);
|
||||
ASSERT_EQ(gpu_allocation->place(), place);
|
||||
ASSERT_GE(gpu_allocation->size(), size);
|
||||
}
|
||||
|
||||
{
|
||||
place = platform::CUDAPinnedPlace();
|
||||
size = (1 << 20);
|
||||
auto cuda_pinned_allocation =
|
||||
instance.Alloc(platform::CUDAPinnedPlace(), 1 << 20);
|
||||
ASSERT_NE(cuda_pinned_allocation, nullptr);
|
||||
ASSERT_NE(cuda_pinned_allocation->ptr(), nullptr);
|
||||
ASSERT_EQ(cuda_pinned_allocation->place(), place);
|
||||
ASSERT_GE(cuda_pinned_allocation->size(), size);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TEST(Allocator, SpecifyGpuMemory) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
// Set to 0.0 to test FLAGS_initial_gpu_memory_in_mb and
|
||||
// FLAGS_reallocate_gpu_memory_in_mb
|
||||
FLAGS_fraction_of_gpu_memory_to_use = 0.0;
|
||||
// 512 MB
|
||||
FLAGS_initial_gpu_memory_in_mb = 512;
|
||||
// 4 MB
|
||||
FLAGS_reallocate_gpu_memory_in_mb = 4;
|
||||
FLAGS_gpu_allocator_retry_time = 500;
|
||||
FLAGS_fraction_of_cuda_pinned_memory_to_use = 0.5;
|
||||
#endif
|
||||
|
||||
AllocateTestCases();
|
||||
}
|
||||
|
||||
} // namespace allocation
|
||||
} // namespace memory
|
||||
} // namespace paddle
|
@ -0,0 +1,133 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/memory/detail/buddy_allocator.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "gflags/gflags.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/fluid/memory/detail/system_allocator.h"
|
||||
#include "paddle/fluid/platform/gpu_info.h"
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
DECLARE_double(fraction_of_gpu_memory_to_use);
|
||||
DECLARE_uint64(initial_gpu_memory_in_mb);
|
||||
DECLARE_uint64(reallocate_gpu_memory_in_mb);
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace memory {
|
||||
namespace detail {
|
||||
|
||||
constexpr static int test_gpu_id = 0;
|
||||
|
||||
void TestBuddyAllocator(BuddyAllocator* allocator, size_t size_bytes) {
|
||||
bool freed = false;
|
||||
size_t used_bytes = allocator->Used();
|
||||
|
||||
if (size_bytes > 0) {
|
||||
void* p = allocator->Alloc(size_bytes);
|
||||
|
||||
EXPECT_NE(p, nullptr);
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (size_bytes < platform::GpuMaxChunkSize()) {
|
||||
#else
|
||||
if (size_bytes < platform::CpuMaxChunkSize()) {
|
||||
#endif
|
||||
// Not allocate from SystemAllocator
|
||||
EXPECT_GE(allocator->Used(), used_bytes + size_bytes);
|
||||
} else {
|
||||
// Allocate from SystemAllocator doesn't count in Used()
|
||||
EXPECT_EQ(allocator->Used(), used_bytes);
|
||||
}
|
||||
|
||||
int* intp = static_cast<int*>(p);
|
||||
std::shared_ptr<int> ptr(intp, [&](void* p) {
|
||||
allocator->Free(intp);
|
||||
freed = true;
|
||||
});
|
||||
} else {
|
||||
freed = true;
|
||||
}
|
||||
|
||||
EXPECT_EQ(used_bytes, allocator->Used());
|
||||
EXPECT_TRUE(freed);
|
||||
}
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
TEST(BuddyAllocator, GpuFraction) {
|
||||
FLAGS_fraction_of_gpu_memory_to_use = 0.01;
|
||||
|
||||
BuddyAllocator buddy_allocator(
|
||||
std::unique_ptr<SystemAllocator>(new GPUAllocator(test_gpu_id)),
|
||||
platform::GpuMinChunkSize(), platform::GpuMaxChunkSize());
|
||||
|
||||
TestBuddyAllocator(&buddy_allocator, 10);
|
||||
TestBuddyAllocator(&buddy_allocator, 10 << 10);
|
||||
TestBuddyAllocator(&buddy_allocator, 10 << 20);
|
||||
TestBuddyAllocator(&buddy_allocator, 2 * static_cast<size_t>(1 << 30));
|
||||
}
|
||||
|
||||
TEST(BuddyAllocator, InitRealloc) {
|
||||
FLAGS_initial_gpu_memory_in_mb = 100;
|
||||
FLAGS_reallocate_gpu_memory_in_mb = 50;
|
||||
|
||||
EXPECT_EQ(platform::GpuMaxChunkSize(), static_cast<size_t>(100 << 20));
|
||||
|
||||
BuddyAllocator buddy_allocator(
|
||||
std::unique_ptr<SystemAllocator>(new GPUAllocator(test_gpu_id)),
|
||||
platform::GpuMinChunkSize(), platform::GpuMaxChunkSize());
|
||||
|
||||
// Less then initial size and reallocate size
|
||||
TestBuddyAllocator(&buddy_allocator, 10 << 20);
|
||||
// Between initial size and reallocate size and not exceed pool
|
||||
TestBuddyAllocator(&buddy_allocator, 80 << 20);
|
||||
// Less then reallocate size and exceed pool
|
||||
TestBuddyAllocator(&buddy_allocator, 40 << 20);
|
||||
// Greater then reallocate size and exceed pool
|
||||
TestBuddyAllocator(&buddy_allocator, 80 << 20);
|
||||
// Greater then initial size and reallocate size
|
||||
TestBuddyAllocator(&buddy_allocator, 2 * static_cast<size_t>(1 << 30));
|
||||
}
|
||||
|
||||
TEST(BuddyAllocator, ReallocSizeGreaterThanInit) {
|
||||
FLAGS_initial_gpu_memory_in_mb = 5;
|
||||
FLAGS_reallocate_gpu_memory_in_mb = 10;
|
||||
|
||||
EXPECT_EQ(platform::GpuMaxChunkSize(), static_cast<size_t>(10 << 20));
|
||||
|
||||
BuddyAllocator buddy_allocator(
|
||||
std::unique_ptr<SystemAllocator>(new GPUAllocator(test_gpu_id)),
|
||||
platform::GpuMinChunkSize(), platform::GpuMaxChunkSize());
|
||||
|
||||
// Less then initial size and reallocate size
|
||||
TestBuddyAllocator(&buddy_allocator, 1 << 20);
|
||||
// Between initial size and reallocate size and not exceed pool
|
||||
TestBuddyAllocator(&buddy_allocator, 3 << 20);
|
||||
// Less then initial size and exceed pool
|
||||
TestBuddyAllocator(&buddy_allocator, 3 << 20);
|
||||
// Less then reallocate size and not exceed pool (now pool is 15 MB, used 7
|
||||
// MB)
|
||||
TestBuddyAllocator(&buddy_allocator, 7 << 20);
|
||||
// Less then reallocate size and exceed pool
|
||||
TestBuddyAllocator(&buddy_allocator, 8 << 20);
|
||||
// Greater then initial size and reallocate size
|
||||
TestBuddyAllocator(&buddy_allocator, 2 * static_cast<size_t>(1 << 30));
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
} // namespace memory
|
||||
} // namespace paddle
|
@ -0,0 +1,128 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fsp_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class FSPOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of FSPOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of FSPOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of FSPOp should not be null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto y_dims = ctx->GetInputDim("Y");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
x_dims.size() == 4,
|
||||
"The Input(X) must have shape [batch_size, channel, height, width].");
|
||||
PADDLE_ENFORCE(
|
||||
y_dims.size() == 4,
|
||||
"The Input(Y) must have shape [batch_size, channel, height, width].");
|
||||
PADDLE_ENFORCE(
|
||||
(x_dims[2] == y_dims[2]) && (x_dims[3] == y_dims[3]),
|
||||
"The Input(X) and Input(Y) should have the same height and width.");
|
||||
|
||||
ctx->SetOutputDim("Out", {x_dims[0], x_dims[1], y_dims[1]});
|
||||
ctx->ShareLoD("X", "Out");
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
framework::LibraryType library_{framework::LibraryType::kPlain};
|
||||
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
|
||||
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
||||
ctx.device_context(), layout_, library_);
|
||||
}
|
||||
};
|
||||
|
||||
class FSPOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(Tensor) The input of FSP op with shape [batch_size, x_channel, "
|
||||
"height, width]");
|
||||
AddInput("Y",
|
||||
"(Tensor) The input of FSP op with shape"
|
||||
"[batch_size, y_channel, height, width]."
|
||||
"The y_channel can be different with the x_channel of Input(X)"
|
||||
" while the other dimensions must be the same with Input(X)'s.");
|
||||
AddOutput(
|
||||
"Out",
|
||||
"(Tensor) The output of FSP op with shape "
|
||||
"[batch_size, x_channel, y_channel]. The x_channel is the channel "
|
||||
"of Input(X) and the y_channel is the channel of Input(Y).");
|
||||
AddComment(R"DOC(
|
||||
This op is used to calculate the flow of solution procedure (FSP) matrix of two feature maps.
|
||||
Given feature map x with shape [x_channel, h, w] and feature map y with shape
|
||||
[y_channel, h, w], we can get the fsp matrix of x and y in two steps:
|
||||
|
||||
step 1: reshape x into matrix with shape [x_channel, h * w] and reshape and
|
||||
transpose y into matrix with shape [h * w, y_channel]
|
||||
step 2: multiply x and y to get fsp matrix with shape [x_channel, y_channel]
|
||||
|
||||
The output is a batch of fsp matrices.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class FSPOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) should not be null");
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto y_dims = ctx->GetInputDim("Y");
|
||||
auto x_grad_name = framework::GradVarName("X");
|
||||
auto y_grad_name = framework::GradVarName("Y");
|
||||
if (ctx->HasOutput(x_grad_name)) {
|
||||
ctx->SetOutputDim(x_grad_name, x_dims);
|
||||
}
|
||||
if (ctx->HasOutput(y_grad_name)) {
|
||||
ctx->SetOutputDim(y_grad_name, y_dims);
|
||||
}
|
||||
}
|
||||
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(fsp, ops::FSPOp, ops::FSPOpMaker,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
REGISTER_OPERATOR(fsp_grad, ops::FSPOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
fsp, ops::FSPOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::FSPOpKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
fsp_grad, ops::FSPGradOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::FSPGradOpKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,24 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/fsp_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
REGISTER_OP_CUDA_KERNEL(fsp, ops::FSPOpKernel<plat::CUDADeviceContext, float>,
|
||||
ops::FSPOpKernel<plat::CUDADeviceContext, double>);
|
||||
REGISTER_OP_CUDA_KERNEL(fsp_grad,
|
||||
ops::FSPGradOpKernel<plat::CUDADeviceContext, float>,
|
||||
ops::FSPGradOpKernel<plat::CUDADeviceContext, double>);
|
@ -0,0 +1,136 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class FSPOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* x = context.Input<Tensor>("X");
|
||||
auto* y = context.Input<Tensor>("Y");
|
||||
auto* output = context.Output<Tensor>("Out");
|
||||
output->mutable_data<T>(context.GetPlace());
|
||||
auto x_dims = x->dims();
|
||||
auto y_dims = y->dims();
|
||||
|
||||
auto batch_size = x_dims[0];
|
||||
auto x_channel = x_dims[1];
|
||||
auto y_channel = y_dims[1];
|
||||
auto height = x_dims[2];
|
||||
auto width = x_dims[3];
|
||||
|
||||
auto blas = math::GetBlas<DeviceContext, T>(context);
|
||||
|
||||
math::MatDescriptor x_mat_desc;
|
||||
x_mat_desc.height_ = x_channel;
|
||||
x_mat_desc.width_ = height * width;
|
||||
x_mat_desc.batch_size_ = batch_size;
|
||||
x_mat_desc.stride_ = x_channel * height * width;
|
||||
|
||||
math::MatDescriptor y_mat_desc;
|
||||
y_mat_desc.height_ = height * width;
|
||||
y_mat_desc.width_ = y_channel;
|
||||
y_mat_desc.batch_size_ = batch_size;
|
||||
y_mat_desc.stride_ = y_channel * height * width;
|
||||
y_mat_desc.trans_ = true;
|
||||
|
||||
blas.MatMul(*x, x_mat_desc, *y, y_mat_desc,
|
||||
static_cast<T>(1.0 / (height * width)), output,
|
||||
static_cast<T>(0.0));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class FSPGradOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* d_y = context.Output<Tensor>(framework::GradVarName("Y"));
|
||||
if (d_x == nullptr && d_y == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto d_out_dims = d_out->dims();
|
||||
auto batch_size = d_out_dims[0];
|
||||
auto x_channel = d_out_dims[1];
|
||||
auto y_channel = d_out_dims[2];
|
||||
int64_t h = 0;
|
||||
int64_t w = 0;
|
||||
|
||||
auto blas = math::GetBlas<DeviceContext, T>(context);
|
||||
math::SetConstant<DeviceContext, T> set_zero;
|
||||
if (d_x != nullptr) {
|
||||
d_x->mutable_data<T>(context.GetPlace());
|
||||
set_zero(context.template device_context<DeviceContext>(), d_x,
|
||||
static_cast<T>(0));
|
||||
auto* y = context.Input<Tensor>("Y");
|
||||
auto y_dims = y->dims();
|
||||
h = y_dims[2];
|
||||
w = y_dims[3];
|
||||
|
||||
math::MatDescriptor d_out_mat_desc;
|
||||
d_out_mat_desc.height_ = x_channel;
|
||||
d_out_mat_desc.width_ = y_channel;
|
||||
d_out_mat_desc.batch_size_ = batch_size;
|
||||
d_out_mat_desc.stride_ = x_channel * y_channel;
|
||||
|
||||
math::MatDescriptor y_mat_desc;
|
||||
y_mat_desc.height_ = y_channel;
|
||||
y_mat_desc.width_ = h * w;
|
||||
y_mat_desc.batch_size_ = batch_size;
|
||||
y_mat_desc.stride_ = y_channel * h * w;
|
||||
|
||||
blas.MatMul(*d_out, d_out_mat_desc, *y, y_mat_desc,
|
||||
static_cast<T>(1.0 / (h * w)), d_x, static_cast<T>(0.0));
|
||||
}
|
||||
|
||||
if (d_y != nullptr) {
|
||||
d_y->mutable_data<T>(context.GetPlace());
|
||||
set_zero(context.template device_context<DeviceContext>(), d_y,
|
||||
static_cast<T>(0));
|
||||
auto* x = context.Input<Tensor>("X");
|
||||
auto x_dims = x->dims();
|
||||
h = x_dims[2];
|
||||
w = x_dims[3];
|
||||
|
||||
math::MatDescriptor d_out_mat_desc;
|
||||
d_out_mat_desc.height_ = y_channel;
|
||||
d_out_mat_desc.width_ = x_channel;
|
||||
d_out_mat_desc.batch_size_ = batch_size;
|
||||
d_out_mat_desc.stride_ = x_channel * y_channel;
|
||||
d_out_mat_desc.trans_ = true;
|
||||
|
||||
math::MatDescriptor x_mat_desc;
|
||||
x_mat_desc.height_ = x_channel;
|
||||
x_mat_desc.width_ = h * w;
|
||||
x_mat_desc.batch_size_ = batch_size;
|
||||
x_mat_desc.stride_ = x_channel * h * w;
|
||||
|
||||
blas.MatMul(*d_out, d_out_mat_desc, *x, x_mat_desc,
|
||||
static_cast<T>(1.0 / (h * w)), d_y, static_cast<T>(0.0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,42 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/reader/py_reader.h"
|
||||
#include <memory>
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace reader {
|
||||
|
||||
PyReader::PyReader(const std::shared_ptr<LoDTensorBlockingQueue>& queue)
|
||||
: framework::FileReader() {
|
||||
PADDLE_ENFORCE(queue != nullptr, "LoDTensorBlockingQueue must not be null");
|
||||
queue_ = queue;
|
||||
}
|
||||
|
||||
void PyReader::ReadNext(std::vector<framework::LoDTensor>* out) {
|
||||
bool success;
|
||||
*out = queue_->Pop(&success);
|
||||
if (!success) out->clear();
|
||||
}
|
||||
|
||||
PyReader::~PyReader() { queue_->Close(); }
|
||||
|
||||
void PyReader::Shutdown() { queue_->Close(); }
|
||||
|
||||
void PyReader::Start() { queue_->ReOpen(); }
|
||||
|
||||
} // namespace reader
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue