parent
fe7ed285d1
commit
5368e50d84
@ -1,2 +1,3 @@
|
||||
cc_library(var_handle SRCS var_handle.cc DEPS place)
|
||||
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context)
|
||||
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
|
||||
|
@ -0,0 +1,47 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
|
||||
platform::Place place)
|
||||
: coeff_(static_cast<float>(1.0 / num_dev)), scope_(scope), place_(place) {}
|
||||
|
||||
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
|
||||
|
||||
void ScaleLossGradOpHandle::RunImpl() {
|
||||
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
|
||||
|
||||
float *tmp =
|
||||
scope_->FindVar(var_name)->GetMutable<LoDTensor>()->mutable_data<float>(
|
||||
make_ddim({1}), place_);
|
||||
|
||||
if (platform::is_cpu_place(place_)) {
|
||||
*tmp = coeff_;
|
||||
} else {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
auto stream =
|
||||
static_cast<platform::CUDADeviceContext *>(this->dev_ctx_[place_])
|
||||
->stream();
|
||||
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
|
||||
platform::CPUPlace(), &coeff_, sizeof(float), stream);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,39 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
struct ScaleLossGradOpHandle : public OpHandleBase {
|
||||
float coeff_;
|
||||
Scope *scope_;
|
||||
platform::Place place_;
|
||||
|
||||
ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place);
|
||||
|
||||
~ScaleLossGradOpHandle() final;
|
||||
|
||||
protected:
|
||||
void RunImpl() override;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Loading…
Reference in new issue