2022-11-10 10:49:38 +08:00
|
|
|
"""
|
|
|
|
Author : Alexander Pantyukhin
|
|
|
|
Date : November 7, 2022
|
|
|
|
|
|
|
|
Task:
|
|
|
|
You are given a tree root of a binary tree with n nodes, where each node has
|
|
|
|
node.data coins. There are exactly n coins in whole tree.
|
|
|
|
|
|
|
|
In one move, we may choose two adjacent nodes and move one coin from one node
|
|
|
|
to another. A move may be from parent to child, or from child to parent.
|
|
|
|
|
|
|
|
Return the minimum number of moves required to make every node have exactly one coin.
|
|
|
|
|
|
|
|
Example 1:
|
|
|
|
|
|
|
|
3
|
|
|
|
/ \
|
|
|
|
0 0
|
|
|
|
|
|
|
|
Result: 2
|
|
|
|
|
|
|
|
Example 2:
|
|
|
|
|
|
|
|
0
|
|
|
|
/ \
|
|
|
|
3 0
|
|
|
|
|
|
|
|
Result 3
|
|
|
|
|
|
|
|
leetcode: https://leetcode.com/problems/distribute-coins-in-binary-tree/
|
|
|
|
|
|
|
|
Implementation notes:
|
|
|
|
User depth-first search approach.
|
|
|
|
|
|
|
|
Let n is the number of nodes in tree
|
|
|
|
Runtime: O(n)
|
|
|
|
Space: O(1)
|
|
|
|
"""
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
from dataclasses import dataclass
|
2023-08-09 15:55:30 +08:00
|
|
|
from typing import NamedTuple
|
2022-11-10 10:49:38 +08:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class TreeNode:
|
|
|
|
data: int
|
|
|
|
left: TreeNode | None = None
|
|
|
|
right: TreeNode | None = None
|
|
|
|
|
|
|
|
|
2023-08-09 15:55:30 +08:00
|
|
|
class CoinsDistribResult(NamedTuple):
|
|
|
|
moves: int
|
|
|
|
excess: int
|
2022-11-10 10:49:38 +08:00
|
|
|
|
|
|
|
|
|
|
|
def distribute_coins(root: TreeNode | None) -> int:
|
|
|
|
"""
|
|
|
|
>>> distribute_coins(TreeNode(3, TreeNode(0), TreeNode(0)))
|
|
|
|
2
|
|
|
|
>>> distribute_coins(TreeNode(0, TreeNode(3), TreeNode(0)))
|
|
|
|
3
|
|
|
|
>>> distribute_coins(TreeNode(0, TreeNode(0), TreeNode(3)))
|
|
|
|
3
|
|
|
|
>>> distribute_coins(None)
|
|
|
|
0
|
|
|
|
>>> distribute_coins(TreeNode(0, TreeNode(0), TreeNode(0)))
|
|
|
|
Traceback (most recent call last):
|
|
|
|
...
|
|
|
|
ValueError: The nodes number should be same as the number of coins
|
|
|
|
>>> distribute_coins(TreeNode(0, TreeNode(1), TreeNode(1)))
|
|
|
|
Traceback (most recent call last):
|
|
|
|
...
|
|
|
|
ValueError: The nodes number should be same as the number of coins
|
|
|
|
"""
|
|
|
|
|
|
|
|
if root is None:
|
|
|
|
return 0
|
|
|
|
|
|
|
|
# Validation
|
|
|
|
def count_nodes(node: TreeNode | None) -> int:
|
|
|
|
"""
|
2023-08-09 15:55:30 +08:00
|
|
|
>>> count_nodes(None)
|
2022-11-10 10:49:38 +08:00
|
|
|
0
|
|
|
|
"""
|
|
|
|
if node is None:
|
|
|
|
return 0
|
|
|
|
|
|
|
|
return count_nodes(node.left) + count_nodes(node.right) + 1
|
|
|
|
|
|
|
|
def count_coins(node: TreeNode | None) -> int:
|
|
|
|
"""
|
2023-08-09 15:55:30 +08:00
|
|
|
>>> count_coins(None)
|
2022-11-10 10:49:38 +08:00
|
|
|
0
|
|
|
|
"""
|
|
|
|
if node is None:
|
|
|
|
return 0
|
|
|
|
|
|
|
|
return count_coins(node.left) + count_coins(node.right) + node.data
|
|
|
|
|
|
|
|
if count_nodes(root) != count_coins(root):
|
|
|
|
raise ValueError("The nodes number should be same as the number of coins")
|
|
|
|
|
|
|
|
# Main calculation
|
|
|
|
def get_distrib(node: TreeNode | None) -> CoinsDistribResult:
|
|
|
|
"""
|
|
|
|
>>> get_distrib(None)
|
|
|
|
namedtuple("CoinsDistribResult", "0 2")
|
|
|
|
"""
|
|
|
|
|
|
|
|
if node is None:
|
|
|
|
return CoinsDistribResult(0, 1)
|
|
|
|
|
|
|
|
left_distrib_moves, left_distrib_excess = get_distrib(node.left)
|
|
|
|
right_distrib_moves, right_distrib_excess = get_distrib(node.right)
|
|
|
|
|
|
|
|
coins_to_left = 1 - left_distrib_excess
|
|
|
|
coins_to_right = 1 - right_distrib_excess
|
|
|
|
|
|
|
|
result_moves = (
|
|
|
|
left_distrib_moves
|
|
|
|
+ right_distrib_moves
|
|
|
|
+ abs(coins_to_left)
|
|
|
|
+ abs(coins_to_right)
|
|
|
|
)
|
|
|
|
result_excess = node.data - coins_to_left - coins_to_right
|
|
|
|
|
|
|
|
return CoinsDistribResult(result_moves, result_excess)
|
|
|
|
|
|
|
|
return get_distrib(root)[0]
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
import doctest
|
|
|
|
|
|
|
|
doctest.testmod()
|