1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
|
#!/usr/bin/env python3
import aoc20
import sys
import regex as re
import functools
class Node:
def __init__(self, name, children=[]):
self.name = name
self.children = children
def __iter__(self):
yield from (child for child, _ in self.children)
def parse(_in):
nodes = {}
for line in _in:
match = re.match(r"(\w+ \w+) bags contain (no other bags|(((\d+) (\w+ \w+)) bags?(, )?)+)\.", line)
children = [(node_str, int(amount)) for amount, node_str in zip(match.captures(5), match.captures(6))]
nodes[match[1]] = Node(match[1], children)
for node in nodes.values():
node.children = [(nodes[node_str], amount) for node_str, amount in node.children]
return nodes
def pt1(_in):
@functools.cache
def can_contain(node, target):
return target in node or any(can_contain(child, target) for child in node)
nodes = parse(_in)
return len([node for node in nodes.values() if can_contain(node, nodes["shiny gold"])])
def pt2(_in):
@functools.cache
def count_children(node):
return 1 + sum(child[1] * count_children(child[0]) for child in node.children)
nodes = parse(_in)
return count_children(nodes["shiny gold"]) - 1
if __name__ == "__main__":
input = aoc20.read_input(sys.argv[1:], 7)
# graph(parse(input))
print(pt1(input))
print(pt2(input))
|