support negative index

pull/14145/head
yangwei 4 years ago
parent ac5371b38f
commit 21686def1a

@ -155,7 +155,8 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig a_after_grad = opt::OptPassConfig({ opt::OptPassConfig a_after_grad = opt::OptPassConfig({
irpass.inline_without_move_, irpass.inline_without_move_,
}); });
opt::OptPassConfig a_3 = opt::OptPassConfig({ opt::OptPassConfig a_3 = opt::OptPassConfig(
{
irpass.arithmetic_simplify2_, irpass.arithmetic_simplify2_,
irpass.same_eliminate_, irpass.same_eliminate_,
irpass.check_bprop_eliminate_, irpass.check_bprop_eliminate_,
@ -165,7 +166,8 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.virtual_add_elim_, irpass.virtual_add_elim_,
irpass.row_tensor_add_zeros_like_, irpass.row_tensor_add_zeros_like_,
irpass.mini_step_allgather_replace_, irpass.mini_step_allgather_replace_,
}); },
false, true);
opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_}); opt::OptPassConfig virtual_dataset = opt::OptPassConfig({irpass.virtual_dataset_eliminate_});
opt::irpass::ResolveIRPassLib resolve_irpass; opt::irpass::ResolveIRPassLib resolve_irpass;

@ -113,6 +113,7 @@ def _tuple_getitem_by_tensor(data, tensor_index):
Outputs: Outputs:
Type, is the same as the element type of data. Type, is the same as the element type of data.
""" """
tensor_index = F.select(tensor_index >= 0, tensor_index, tensor_index + len(data))
return _tuple_get_item_tensor(data, tensor_index) return _tuple_get_item_tensor(data, tensor_index)

Loading…
Cancel
Save