refine tensor_array_write_read (#14643)

test=develop
local_add_cudnn_lstm
chengduo 7 years ago committed by GitHub
parent dfd4a111cb
commit 6776e92846
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -81,13 +81,35 @@ class CompileTimeInferShapeContext : public InferShapeContext {
"The %s[%d] is @EMPTY@", out, j);
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
if (in_var->GetType() != proto::VarType::LOD_TENSOR) {
VLOG(3) << "input " << in << " is not LodTensor";
if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
VLOG(3) << "input " << in << " is not LodTensor or LodTensorArray.";
return;
}
out_var->SetLoDLevel(in_var->GetLoDLevel());
}
void DecreaseLoDLevel(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const override {
PADDLE_ENFORCE_LT(i, Inputs(in).size());
PADDLE_ENFORCE_LT(j, Outputs(out).size());
PADDLE_ENFORCE(Inputs(in)[i] != framework::kEmptyVarName,
"The %s[%d] is @EMPTY@", in, i);
PADDLE_ENFORCE(Outputs(out)[j] != framework::kEmptyVarName,
"The %s[%d] is @EMPTY@", out, j);
auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
PADDLE_ENFORCE(out_var->GetType() == proto::VarType::LOD_TENSOR_ARRAY ||
out_var->GetType() == proto::VarType::LOD_TENSOR,
"The input %s should be LodTensorArray or LodTensor.",
out_var->Name());
PADDLE_ENFORCE(in_var->GetType() == proto::VarType::LOD_TENSOR,
"The input %s should be LodTensor.", in_var->Name());
if (in_var->GetLoDLevel() > 0) {
out_var->SetLoDLevel(in_var->GetLoDLevel() - 1);
}
}
bool IsRuntime() const override;
protected:

@ -623,6 +623,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
out_tensor->set_layout(in_tensor.layout());
}
void DecreaseLoDLevel(const std::string& in, const std::string& out,
size_t i = 0, size_t j = 0) const override {
PADDLE_THROW("DecreaseLoDLevel is only used in compile time.");
}
bool IsRuntime() const override { return true; }
protected:

@ -62,6 +62,9 @@ class InferShapeContext {
virtual void ShareLoD(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0;
virtual void DecreaseLoDLevel(const std::string &in, const std::string &out,
size_t i = 0, size_t j = 0) const = 0;
virtual bool IsRuntime() const = 0;
std::vector<InferShapeVarPtr> GetInputVarPtrs(const std::string &name);

@ -167,6 +167,19 @@ $$T = A[i]$$
};
class ReadFromArrayInferShape : public WriteToArrayInferShape {
public:
void operator()(framework::InferShapeContext *context) const override {
WriteToArrayInferShape::operator()(context);
if (!context->HasInput("X")) {
return;
}
// FIXME: just for compile time.
if (!context->IsRuntime()) {
context->ShareLoD("X", /*->*/ "Out");
}
}
protected:
const char *NotHasXError() const override {
return "The input array X must be set";

@ -192,6 +192,10 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
// The first dim of each LoDTensor in Output can only be set at run-time.;
// We still have to Resize each LoDTensor in Output.
context->SetOutputDim("Out", x_dim);
// The lod level should be passed to out in compile time.
if (!context->IsRuntime()) {
context->DecreaseLoDLevel("X", /*->*/ "Out");
}
}
};

@ -201,6 +201,9 @@ class IdentityInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
context->SetOutputDim("Out", context->GetInputDim("X"));
if (!context->IsRuntime()) {
context->ShareLoD("X", /*->*/ "Out");
}
}
};

@ -100,6 +100,9 @@ class ShrinkRNNMemoryInferShape : public framework::InferShapeBase {
PADDLE_ENFORCE(context->HasInput("I"));
PADDLE_ENFORCE(context->HasInput("RankTable"));
context->SetOutputDim("Out", context->GetInputDim("X"));
if (!context->IsRuntime()) {
context->DecreaseLoDLevel("X", /*->*/ "Out");
}
}
};

@ -172,6 +172,7 @@ class TestDynRNN(unittest.TestCase):
rnn = fluid.layers.DynamicRNN()
with rnn.block():
in_ = rnn.step_input(sentence)
assert in_.lod_level == 1, "the lod level of in_ should be 1"
sent_emb = fluid.layers.embedding(
input=in_, size=[len(word_dict), 32], dtype='float32')
out_ = fluid.layers.fc(input=sent_emb, size=100, act='tanh')
@ -179,6 +180,7 @@ class TestDynRNN(unittest.TestCase):
rnn1 = fluid.layers.DynamicRNN()
with rnn1.block():
in_1 = rnn1.step_input(out_)
assert in_1.lod_level == 0, "the lod level of in_1 should be 0"
out_1 = fluid.layers.fc(input=[in_1], size=100, act='tanh')
rnn1.output(out_1)

Loading…
Cancel
Save