Skip to content

Commit

Permalink
add draft for specific fuse
Browse files Browse the repository at this point in the history
  • Loading branch information
effrey-liu committed Sep 8, 2024
1 parent d75332c commit 140eb41
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 1 deletion.
2 changes: 1 addition & 1 deletion frontend/Python/graph/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
# ===---------------------------------------------------------------------------

from .fuse_ops import simply_fuse, my_fuse_ops_test
from .useless_op_eliminate import maxpool2d_simplify
from .useless_op_eliminate import maxpool2d_simplify, varmean_simpplify
45 changes: 45 additions & 0 deletions frontend/Python/graph/transform/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,32 @@ def fcond0(kind, issink):
graph_node = graph.post_dfs_order[i]
dom_node = tree.tree_nodes[i]
group_node = tree.groups[i]


# if group_node.pattern == OpType.Unfusable:
# continue
# if dom_node.parent == None:
# continue
# dom_parent_gindex = dom_node.parent_gnode.index
# if phase == 2:
# if group_node.pattern > OpType.ElementwiseType:
# continue
# dom_parent_group = tree.groups[dom_parent_gindex]
# dom_root_group = dom_parent_group.FindRoot()
# if dom_root_group.pattern == OpType.GetItemType:
# continue
# if dom_parent_group.pattern == OpType.GetItemType and dom_root_group.pattern == OpType.ElementwiseType:
# def fcond1(kind, is_sink):
# return kind.value <= OpType.ElementwiseType.value
# if self.CheckPath(graph_node, dom_node.parent_gnode, fcond=fcond1, tree=tree):
# self.CommitFuse(graph_node, dom_node.parent_gnode, tree)
# continue
# if tree.groups[dom_parent_gindex] != None and group_node.FindRoot() == tree.groups[dom_parent_gindex].FindRoot():
# continue
# if tree.groups[dom_parent_gindex].pattern == OpType.GetItemType:
# continue


if dom_node != None and group_node.pattern == OpType.ReduceType:
if phase != 0:
continue
Expand All @@ -381,6 +407,25 @@ def fcond0(kind, issink):
self.CommitFuse(
graph_node, dom_node.parent_gnode, tree
)
# elif group_node.pattern.value <= OpType.BroadcastType.value:
# if dom_node.parent != None and (dom_node.pattern.value <= OpType.ElementwiseType.value or dom_node.pattern == OpType.ReduceType):
# def fcond2(kind, is_sink):
# if is_sink is False:
# return kind.value <= OpType.ElementwiseType.value
# else:
# return (kind.value <= OpType.BroadcastType.value or kind == OpType.ReduceType or kind == OpType.ElementwiseType or kind == OpType.ReduceType)
# if self.CheckPath(graph_node, dom_node.parent_gnode, fcond2, tree):
# self.CommitFuse(graph_node, dom_node.parent_gnode, tree)
# elif group_node.pattern == OpType.ElementwiseType or group_node.pattern == OpType.GetItemType:
# if phase != 1:
# continue
# def fcond3(kind, is_sink):
# return kind.value <= OpType.ElementwiseType.value
# if self.CheckPath(graph_node, dom_node.parent_gnode, fcond3, tree):
# self.CommitFuse(graph_node, dom_node.parent_gnode, tree)
# else:
# pass

for node in tree.groups:
if node.master_ref is not None:
logger.info(
Expand Down
25 changes: 25 additions & 0 deletions frontend/Python/graph/transform/useless_op_eliminate.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,28 @@ def maxpool2d_simplify(graph: Graph):
if op == getitem_node:
graph.body[j] = new_node
break

def varmean_simpplify(graph: Graph):
"""
Fuse the varmean op and getitem op to simpllify graph.
Args:
graph (Graph): The Graph to be simplified.
"""
keys_to_remove = []

for i, key in enumerate(list(graph.op_groups.keys())):
if key.startswith("var_mean"):
# getitem_key1 = f"getitem_{int(key.split('var_mean')[1].split('_')[-1]) * 2}" if '_' in key.split('var_mean')[1] else "getitem"
# getitem_key2 = f"getitem_{int(key.split('var_mean')[1].split('_')[-1]) * 2 + 1}" if '_' in key.split('var_mean')[1] else "getitem_1"
getitem_key1 = list(graph.op_groups.keys())[i + 1]
getitem_key2 = list(graph.op_groups.keys())[i + 2]

if getitem_key1 in graph.op_groups and getitem_key2 in graph.op_groups:
graph.op_groups[key].extend(graph.op_groups[getitem_key1])
graph.op_groups[key].extend(graph.op_groups[getitem_key2])

keys_to_remove.extend([getitem_key1, getitem_key2])

for key in keys_to_remove:
del graph.op_groups[key]

0 comments on commit 140eb41

Please sign in to comment.