diff --git a/frontend/Python/graph/transform/__init__.py b/frontend/Python/graph/transform/__init__.py index 265f05501..fb2760592 100644 --- a/frontend/Python/graph/transform/__init__.py +++ b/frontend/Python/graph/transform/__init__.py @@ -19,4 +19,4 @@ # ===--------------------------------------------------------------------------- from .fuse_ops import simply_fuse, my_fuse_ops_test -from .useless_op_eliminate import maxpool2d_simplify +from .useless_op_eliminate import maxpool2d_simplify, varmean_simpplify diff --git a/frontend/Python/graph/transform/fuse_ops.py b/frontend/Python/graph/transform/fuse_ops.py index c9ddb7ff7..1b24798f9 100644 --- a/frontend/Python/graph/transform/fuse_ops.py +++ b/frontend/Python/graph/transform/fuse_ops.py @@ -368,6 +368,32 @@ def fcond0(kind, issink): graph_node = graph.post_dfs_order[i] dom_node = tree.tree_nodes[i] group_node = tree.groups[i] + + + # if group_node.pattern == OpType.Unfusable: + # continue + # if dom_node.parent == None: + # continue + # dom_parent_gindex = dom_node.parent_gnode.index + # if phase == 2: + # if group_node.pattern > OpType.ElementwiseType: + # continue + # dom_parent_group = tree.groups[dom_parent_gindex] + # dom_root_group = dom_parent_group.FindRoot() + # if dom_root_group.pattern == OpType.GetItemType: + # continue + # if dom_parent_group.pattern == OpType.GetItemType and dom_root_group.pattern == OpType.ElementwiseType: + # def fcond1(kind, is_sink): + # return kind.value <= OpType.ElementwiseType.value + # if self.CheckPath(graph_node, dom_node.parent_gnode, fcond=fcond1, tree=tree): + # self.CommitFuse(graph_node, dom_node.parent_gnode, tree) + # continue + # if tree.groups[dom_parent_gindex] != None and group_node.FindRoot() == tree.groups[dom_parent_gindex].FindRoot(): + # continue + # if tree.groups[dom_parent_gindex].pattern == OpType.GetItemType: + # continue + + if dom_node != None and group_node.pattern == OpType.ReduceType: if phase != 0: continue @@ -381,6 +407,25 @@ def fcond0(kind, issink): self.CommitFuse( graph_node, dom_node.parent_gnode, tree ) + # elif group_node.pattern.value <= OpType.BroadcastType.value: + # if dom_node.parent != None and (dom_node.pattern.value <= OpType.ElementwiseType.value or dom_node.pattern == OpType.ReduceType): + # def fcond2(kind, is_sink): + # if is_sink is False: + # return kind.value <= OpType.ElementwiseType.value + # else: + # return (kind.value <= OpType.BroadcastType.value or kind == OpType.ReduceType or kind == OpType.ElementwiseType or kind == OpType.ReduceType) + # if self.CheckPath(graph_node, dom_node.parent_gnode, fcond2, tree): + # self.CommitFuse(graph_node, dom_node.parent_gnode, tree) + # elif group_node.pattern == OpType.ElementwiseType or group_node.pattern == OpType.GetItemType: + # if phase != 1: + # continue + # def fcond3(kind, is_sink): + # return kind.value <= OpType.ElementwiseType.value + # if self.CheckPath(graph_node, dom_node.parent_gnode, fcond3, tree): + # self.CommitFuse(graph_node, dom_node.parent_gnode, tree) + # else: + # pass + for node in tree.groups: if node.master_ref is not None: logger.info( diff --git a/frontend/Python/graph/transform/useless_op_eliminate.py b/frontend/Python/graph/transform/useless_op_eliminate.py index a99dbe02c..25376a892 100644 --- a/frontend/Python/graph/transform/useless_op_eliminate.py +++ b/frontend/Python/graph/transform/useless_op_eliminate.py @@ -64,3 +64,28 @@ def maxpool2d_simplify(graph: Graph): if op == getitem_node: graph.body[j] = new_node break + +def varmean_simpplify(graph: Graph): + """ + Fuse the varmean op and getitem op to simpllify graph. + + Args: + graph (Graph): The Graph to be simplified. + """ + keys_to_remove = [] + + for i, key in enumerate(list(graph.op_groups.keys())): + if key.startswith("var_mean"): + # getitem_key1 = f"getitem_{int(key.split('var_mean')[1].split('_')[-1]) * 2}" if '_' in key.split('var_mean')[1] else "getitem" + # getitem_key2 = f"getitem_{int(key.split('var_mean')[1].split('_')[-1]) * 2 + 1}" if '_' in key.split('var_mean')[1] else "getitem_1" + getitem_key1 = list(graph.op_groups.keys())[i + 1] + getitem_key2 = list(graph.op_groups.keys())[i + 2] + + if getitem_key1 in graph.op_groups and getitem_key2 in graph.op_groups: + graph.op_groups[key].extend(graph.op_groups[getitem_key1]) + graph.op_groups[key].extend(graph.op_groups[getitem_key2]) + + keys_to_remove.extend([getitem_key1, getitem_key2]) + + for key in keys_to_remove: + del graph.op_groups[key]