!475 bug fix for lstm

Merge pull request !475 from liubuyu/lstm
pull/7893/head
mindspore-ci-bot 5 years ago committed by Gitee
commit 830f97de33

@ -446,8 +446,8 @@ std::vector<size_t> TransShapeToDevice(const std::vector<size_t> &shape, const s
return device_shape;
} else if (format == kOpFormat_FRACTAL_ZN_LSTM) {
const size_t c0 = 4;
const size_t h = shape.at(kC) / c0;
const size_t i = shape.at(kN) - h;
const size_t h = shape.at(kN) / c0;
const size_t i = shape.at(kC) - h;
const size_t first = DivCeil(i, kCubeSize) + DivCeil(h, kCubeSize);
const size_t second = c0 * DivCeil(h, kCubeSize);
device_shape.push_back(first);

@ -30,7 +30,7 @@ basic_lstm_cell_op_info = TBERegOp("BasicLSTMCell") \
.input(0, "x", False, "required", "all") \
.input(1, "h", False, "required", "all") \
.input(2, "c", False, "required", "all") \
.input(3, "w", False, "required", "all", reshape_type="NC") \
.input(3, "w", False, "required", "all", reshape_type="CN") \
.input(4, "b", False, "required", "all") \
.input(5, "mask", False, "optional", "all") \
.output(0, "ct", False, "required", "all") \

Loading…
Cancel
Save