Add Support for SelectedRows for Transpose OP and Fix a Bug That SelectedRows Cannot be Supported in SimNet (#25536)

This PR fixes a bug that SelectedRows cannot be supported in SimNet. The reason of this bug is that dygraph basic_engine didn't copy var's type when the var needs to be accumulated during backward. So when a var is SelectedRows and needs to be accumulated, like SimNet which calls net for two times, the var's type will be changed to default LoDTensor thus bug happens. To fix it, we just also copy the type.

Without this PR, the accumulated SelectedRows parameters in dygraph will be changed into LoDTensor. So when we fixed the bug of supporting SelectedRows in SimNet, we found `test_imperative_lod_tensor_to_selected_rows` failed and threw the error that SelectedRows was not supported for Transpose OP. To fix it, too, this PR also added support for SelectedRows for Transpose OP.
fix_copy_if_different
Huihuang Zheng 5 years ago committed by GitHub
parent 0f8dc611c8
commit d8fe517bf8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -205,7 +205,9 @@ void BasicEngine::Execute() {
continue;
}
var = std::make_shared<VariableWrapper>(var->Name());
auto tmp_var = std::make_shared<VariableWrapper>(var->Name());
tmp_var->SetType(var->Type());
var = tmp_var;
need_accu_var_list_.emplace_back(iter->second.get(), var);
}
}

@ -660,19 +660,26 @@ template <typename DeviceContext, typename T>
class TransposeGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
if (out->numel() == 0) {
auto* x = context.InputVar("X");
auto* out = context.OutputVar("Out");
const framework::Tensor* x_tensor =
GetLoDTensorOrSelectedRowsValueFromVar(*x);
framework::Tensor* out_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(out);
out_tensor->mutable_data<T>(context.GetPlace());
if (out_tensor->numel() == 0) {
return;
}
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size();
const auto& dev_ctx = context.template device_context<DeviceContext>();
auto ret = TransposeSimple<T>::run(dev_ctx, *x, axis, out);
auto ret = TransposeSimple<T>::run(dev_ctx, *x_tensor, axis, out_tensor);
if (!ret) {
TransCompute<DeviceContext, T>(ndims, dev_ctx, *x, out, axis);
TransCompute<DeviceContext, T>(ndims, dev_ctx, *x_tensor, out_tensor,
axis);
}
}
};
@ -680,14 +687,19 @@ template <typename DeviceContext, typename T>
class TransposeGradGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
if (!x_grad) return;
x_grad->mutable_data<T>(context.GetPlace());
if (x_grad->numel() == 0) {
auto* out_grad = context.InputVar(framework::GradVarName("Out"));
auto* x_grad = context.OutputVar(framework::GradVarName("X"));
if (!x_grad) {
return;
}
const framework::Tensor* out_grad_tensor =
GetLoDTensorOrSelectedRowsValueFromVar(*out_grad);
framework::Tensor* x_grad_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(x_grad);
x_grad_tensor->mutable_data<T>(context.GetPlace());
if (x_grad_tensor->numel() == 0) {
return;
}
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
@ -699,11 +711,11 @@ class TransposeGradGPUKernel : public framework::OpKernel<T> {
int ndims = axis.size();
const auto& dev_ctx = context.template device_context<DeviceContext>();
auto ret =
TransposeSimple<T>::run(dev_ctx, *out_grad, reversed_axis, x_grad);
auto ret = TransposeSimple<T>::run(dev_ctx, *out_grad_tensor, reversed_axis,
x_grad_tensor);
if (!ret) {
TransCompute<DeviceContext, T>(ndims, dev_ctx, *out_grad, x_grad,
reversed_axis);
TransCompute<DeviceContext, T>(ndims, dev_ctx, *out_grad_tensor,
x_grad_tensor, reversed_axis);
}
}
};

@ -64,16 +64,23 @@ template <typename DeviceContext, typename T>
class TransposeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
if (out->numel() == 0) {
auto* x = context.InputVar("X");
auto* out = context.OutputVar("Out");
const framework::Tensor* x_tensor =
GetLoDTensorOrSelectedRowsValueFromVar(*x);
framework::Tensor* out_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(out);
out_tensor->mutable_data<T>(context.GetPlace());
if (out_tensor->numel() == 0) {
return;
}
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size();
auto& dev_ctx = context.template device_context<DeviceContext>();
TransCompute<DeviceContext, T>(ndims, dev_ctx, *x, out, axis);
TransCompute<DeviceContext, T>(ndims, dev_ctx, *x_tensor, out_tensor, axis);
}
};
@ -81,14 +88,19 @@ template <typename DeviceContext, typename T>
class TransposeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
if (!x_grad) return;
x_grad->mutable_data<T>(context.GetPlace());
if (x_grad->numel() == 0) {
auto* out_grad = context.InputVar(framework::GradVarName("Out"));
auto* x_grad = context.OutputVar(framework::GradVarName("X"));
if (!x_grad) {
return;
}
const framework::Tensor* out_grad_tensor =
GetLoDTensorOrSelectedRowsValueFromVar(*out_grad);
framework::Tensor* x_grad_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(x_grad);
x_grad_tensor->mutable_data<T>(context.GetPlace());
if (x_grad_tensor->numel() == 0) {
return;
}
@ -101,8 +113,8 @@ class TransposeGradKernel : public framework::OpKernel<T> {
int ndims = axis.size();
auto& dev_ctx = context.template device_context<DeviceContext>();
TransCompute<DeviceContext, T>(ndims, dev_ctx, *out_grad, x_grad,
reversed_axis);
TransCompute<DeviceContext, T>(ndims, dev_ctx, *out_grad_tensor,
x_grad_tensor, reversed_axis);
}
};

@ -42,7 +42,7 @@ class EmbeddingLayer(object):
# causes crush in dy2stat. Set it to True after fixing it.
emb = Embedding(
size=[self.dict_size, self.emb_dim],
is_sparse=False,
is_sparse=True,
padding_idx=self.padding_idx,
param_attr=attr.ParamAttr(
name=self.name, initializer=fluid.initializer.Xavier()))

@ -149,7 +149,6 @@ def train(conf_dict, to_static):
pred = pos_score
_, neg_score = net(left, neg_right)
avg_cost = loss.compute(pos_score, neg_score)
#avg_cost = loss.compute(pos_score, pos_score)
losses.append(np.mean(avg_cost.numpy()))
avg_cost.backward()
optimizer.minimize(avg_cost)

@ -186,7 +186,8 @@ class TestDygraphSimpleNet(unittest.TestCase):
k - 1]] = out[k]
self.assertTrue(
np.array_equal(static_loss_value, dy_loss_value))
np.allclose(
static_loss_value, dy_loss_value, rtol=1e-3))
for key, value in six.iteritems(static_param_init):
self.assertTrue(np.array_equal(value, dy_param_init[key]))
for key, value in six.iteritems(static_param_updated):

Loading…
Cancel
Save