import pyutilib.misc
import pyutilib.th as unittest
from pyutilib.misc.visitor import SimpleVisitor, ValueVisitor


class Node(object):

    def __init__(self):
        self.children = []
        self.num = 0

    def __str__(self):      #pragma: no cover
        return str(self.num)

class CountVisitor(SimpleVisitor):

    def __init__(self):
        self.count = 0

    def visit(self, node):
        self.count += 1
        node.num = self.count

    def finalize(self):
        return self.count

class CollectVisitor(SimpleVisitor):

    def __init__(self):
        self.ans = []

    def visit(self, node):
        self.ans.append(node.num)
        
    def finalize(self):
        return self.ans

class SumVisitor(ValueVisitor):

    def __init__(self):
        self.count = 0

    def visit(self, node, values):
        if values is None or len(values) == 0:
            self.count = node.num
        else:
            self.count = node.num + sum(values)
        return self.count
    
    def finalize(self, ans):
        return self.count


class Test(unittest.TestCase):

    def setUp(self):
        root = Node()
        root.children = [Node(), Node(), Node()]
        root.children[0].children = [Node(), Node(), Node()]
        root.children[0].children[0].children = [Node(), Node(), Node()]
        root.children[0].children[1].children = [Node(), Node(), Node()]
        root.children[0].children[2].children = [Node(), Node(), Node()]
        root.children[1].children = [Node(), Node(), Node()]
        root.children[1].children[0].children = [Node(), Node(), Node()]
        root.children[1].children[1].children = [Node(), Node(), Node()]
        root.children[1].children[2].children = [Node(), Node(), Node()]
        root.children[2].children = [Node(), Node(), Node()]
        root.children[2].children[0].children = [Node(), Node(), Node()]
        root.children[2].children[1].children = [Node(), Node(), Node()]
        root.children[2].children[2].children = [Node(), Node(), Node()]

        cvisitor = CountVisitor()
        cvisitor.bfs(root)
        self.root = root
        
    def test_bfs(self):
        visitor = CollectVisitor()
        ans = visitor.bfs(self.root)
        self.assertEqual(ans, [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40])

    def test_dfs_preorder(self):
        visitor = CollectVisitor()
        ans = visitor.dfs(self.root)
        self.assertEqual(ans, [1,2,5,14,15,16,6,17,18,19,7,20,21,22,3,8,23,24,25,9,26,27,28,10,29,30,31,4,11,32,33,34,12,35,36,37,13,38,39,40])

    def test_dfs_inorder(self):
        visitor = CollectVisitor()
        ans = visitor.dfs_inorder(self.root)
        self.assertEqual(ans, [14,5,15,5,16,2,17,6,18,6,19,2,20,7,21,7,22,1,23,8,24,8,25,3,26,9,27,9,28,3,29,10,30,10,31,1,32,11,33,11,34,4,35,12,36,12,37,4,38,13,39,13,40])

    def test_dfs_postorder(self):
        visitor = CollectVisitor()
        ans = visitor.dfs_postorder(self.root)
        self.assertEqual(ans, [14,15,16,5,17,18,19,6,20,21,22,7,2,23,24,25,8,26,27,28,9,29,30,31,10,3,32,33,34,11,35,36,37,12,38,39,40,13,4,1])

    def test_retval_dfs_postorder_tree(self):
        visitor = SumVisitor()
        ans = visitor.dfs_postorder_deque(self.root)
        self.assertEqual(ans, 820)
        visitor = SumVisitor()
        ans = visitor.dfs_postorder_stack(self.root)
        self.assertEqual(ans, 820)

    def test_retval_dfs_postorder_trivial(self):
        root = Node()
        root.num = 1
        visitor = SumVisitor()
        ans = visitor.dfs_postorder_deque(root)
        self.assertEqual(ans, 1)
        visitor = SumVisitor()
        ans = visitor.dfs_postorder_stack(root)
        self.assertEqual(ans, 1)

    def test_count_bfs(self):
        cvisitor = CountVisitor()
        ans = cvisitor.bfs(self.root)
        self.assertEqual(ans,40)
        
    def test_count_xbfs(self):
        cvisitor = CountVisitor()
        ans = cvisitor.xbfs(self.root)
        self.assertEqual(ans,40)
        

if __name__ == "__main__":
    unittest.main()