summaryrefslogtreecommitdiffstats
path: root/20/py/d07.py
blob: fc39ad0b896057c4871a017552184218ea900400 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env python3
import sys
import regex as re
import functools


class Node:
    def __init__(self, name, children=[]):
        self.name = name
        self.children = children

    def __iter__(self):
        for child in self.children:
            yield child[0]

    def __repr__(self):
        return self.name


def graph(nodes):
    print("digraph G {")
    print("rankdir=\"LR\";")
    for node in nodes.values():
        for child in node.children:
            print(f"\"{node.name}\" -> \"{child[0].name}\" [ label=\"{child[1]}\" ];")
    print("}")


def parse(_in):
    nodes = {}
    for line in _in:
        match = re.match(r"(\w+ \w+) bags contain (no other bags|(((\d+) (\w+ \w+)) bags?(, )?)+)\.", line)
        children = [(node_str, int(amount)) for amount, node_str in zip(match.captures(5), match.captures(6))]
        nodes[match[1]] = Node(match[1], children)
    for node in nodes.values():
        node.children = [(nodes[node_str], amount) for node_str, amount in node.children]
    return nodes


def pt1(_in):
    @functools.cache
    def can_contain(node, target):
        return target in node or any(can_contain(child, target) for child in node)

    nodes = parse(_in)
    return len([node for node in nodes.values() if can_contain(node, nodes["shiny gold"])])


def pt2(_in):
    @functools.cache
    def count_children(node):
        return 1 + sum([child[1] * count_children(child[0]) for child in node.children])

    nodes = parse(_in)
    return count_children(nodes["shiny gold"]) - 1


if __name__ == "__main__":
    if len(sys.argv) > 1:
        input = open(sys.argv[1], "r").readlines()
    else:
        input = open("../input/07", "r").readlines()
    # graph(parse(input))

    print(pt1(input))
    print(pt2(input))