|
|
@ -82,16 +82,23 @@ $$A[i] = T$$
|
|
|
|
class WriteToArrayInferShape : public framework::InferShapeBase {
|
|
|
|
class WriteToArrayInferShape : public framework::InferShapeBase {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void operator()(framework::InferShapeContext *context) const override {
|
|
|
|
void operator()(framework::InferShapeContext *context) const override {
|
|
|
|
PADDLE_ENFORCE(context->HasInput("I"), "Must set the subscript index");
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
if (context->IsRuntime()) {
|
|
|
|
context->HasInput("I"), true,
|
|
|
|
PADDLE_ENFORCE_EQ(framework::product(context->GetInputDim("I")), 1,
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"The number of element of subscript index must be 1");
|
|
|
|
"Read/Write array operation must set the subscript index."));
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// TODO(wangchaochaohu) control flow Op do not support runtime infer shape
|
|
|
|
|
|
|
|
// Later we add [ontext->GetInputDim("I")) == 1] check when it's supported
|
|
|
|
|
|
|
|
|
|
|
|
if (!context->HasInput("X")) {
|
|
|
|
if (!context->HasInput("X")) {
|
|
|
|
return;
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(context->HasOutput("Out"), NotHasOutError());
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
|
|
|
context->HasOutput("Out"), true,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"Read/Write array operation must set the output Tensor "
|
|
|
|
|
|
|
|
"to get the result."));
|
|
|
|
context->SetOutputDim("Out", context->GetInputDim("X"));
|
|
|
|
context->SetOutputDim("Out", context->GetInputDim("X"));
|
|
|
|
|
|
|
|
|
|
|
|
// When compile time, we need to:
|
|
|
|
// When compile time, we need to:
|
|
|
@ -106,13 +113,6 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
|
|
|
|
context->ShareLoD("X", /*->*/ "Out");
|
|
|
|
context->ShareLoD("X", /*->*/ "Out");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
|
|
|
virtual const char *NotHasXError() const { return "Must set the lod tensor"; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
virtual const char *NotHasOutError() const {
|
|
|
|
|
|
|
|
return "Must set the lod tensor array";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class WriteToArrayInferVarType : public framework::VarTypeInference {
|
|
|
|
class WriteToArrayInferVarType : public framework::VarTypeInference {
|
|
|
@ -140,10 +140,15 @@ class ReadFromArrayOp : public ArrayOp {
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
void RunImpl(const framework::Scope &scope,
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
const platform::Place &place) const override {
|
|
|
|
auto *x = scope.FindVar(Input("X"));
|
|
|
|
auto *x = scope.FindVar(Input("X"));
|
|
|
|
PADDLE_ENFORCE(x != nullptr, "X must be set");
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
|
|
|
x,
|
|
|
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"X(Input Variable) must be set when we call read array operation"));
|
|
|
|
auto &x_array = x->Get<framework::LoDTensorArray>();
|
|
|
|
auto &x_array = x->Get<framework::LoDTensorArray>();
|
|
|
|
auto *out = scope.FindVar(Output("Out"));
|
|
|
|
auto *out = scope.FindVar(Output("Out"));
|
|
|
|
PADDLE_ENFORCE(out != nullptr, "Out must be set");
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(out, platform::errors::InvalidArgument(
|
|
|
|
|
|
|
|
"Out(Output Varibale) must be set when we "
|
|
|
|
|
|
|
|
"call read array operation"));
|
|
|
|
size_t offset = GetOffset(scope, place);
|
|
|
|
size_t offset = GetOffset(scope, place);
|
|
|
|
if (offset < x_array.size()) {
|
|
|
|
if (offset < x_array.size()) {
|
|
|
|
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
|
|
|
|
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
|
|
|
@ -199,15 +204,7 @@ $$T = A[i]$$
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class ReadFromArrayInferShape : public WriteToArrayInferShape {
|
|
|
|
class ReadFromArrayInferShape : public WriteToArrayInferShape {};
|
|
|
|
protected:
|
|
|
|
|
|
|
|
const char *NotHasXError() const override {
|
|
|
|
|
|
|
|
return "The input array X must be set";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
const char *NotHasOutError() const override {
|
|
|
|
|
|
|
|
return "The output tensor out must be set";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
class WriteToArrayGradMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
class WriteToArrayGradMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|