summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--20/py/d07.py88
1 files changed, 27 insertions, 61 deletions
diff --git a/20/py/d07.py b/20/py/d07.py
index 2658774..0252e6e 100644
--- a/20/py/d07.py
+++ b/20/py/d07.py
@@ -1,15 +1,15 @@
#!/usr/bin/env python3
import sys
-import re
+import regex as re
+import functools
class Node:
def __init__(self, name, children=[]):
self.name = name
- self.parents = []
self.children = children
- def bag_children(self):
+ def __iter__(self):
for child in self.children:
yield child[0]
@@ -26,69 +26,35 @@ def graph(bags):
print("}")
-mem = dict()
-def can_contain(bag, target):
- if bag in mem:
- return mem[bag]
- if target in bag.bag_children():
- mem[bag] = True
- return True
- for child in bag.bag_children():
- if can_contain(child, target):
- mem[bag] = True
- return True
- mem[bag] = False
- return False
+def parse(_in):
+ bags = {} # bag: node
+ for bag in _in:
+ match = re.match(r"(\w+ \w+) bags contain (no other bags|(((\d+) (\w+ \w+)) bags?(, )?)+)\.",
+ bag)
+ children = [(kind, int(count)) for count, kind in zip(match.captures(5),
+ match.captures(6))]
+ bags[match[1]] = Node(match[1], children)
+ for name, bag in bags.items():
+ bags[name].children = [(bags[bag], amount) for bag, amount in bags[name].children]
+ return bags
def pt1(_in):
- rules = {} # bag: node
- for rule in _in:
- match = re.match(r"(\w+ \w+) bags contain (no other bags|[^\.]*)\.",
- rule)
- children = []
- if match[2] != "no other bags":
- child_matches = re.findall(r"(\d+) (\w+ \w+) bags?(, )?", match[2])
- for child_match in child_matches:
- children.append((child_match[1], int(child_match[0])))
- rules[match[1]] = Node(match[1], children)
- for name, bag in rules.items():
- rules[name].children = [(rules[bag], amount) for bag, amount in rules[name].children]
- #graph(rules)
-
- ans = 0
- for _, bag in rules.items():
- if can_contain(bag, rules["shiny gold"]):
- ans += 1
- return ans
-
-
-children = dict()
-def count_children(bag):
- if bag in children:
- #print("mem", bag)
- return mem[bag]
- amount = 0
- for child in bag.children:
- amount += child[1] * count_children(child[0])
- return 1 + amount
+ @functools.cache
+ def can_contain(bag, target):
+ return target in bag or any(can_contain(child, target) for child in bag)
+
+ bags = parse(_in)
+ return len([bag for bag in bags.values() if can_contain(bag, bags["shiny gold"])])
def pt2(_in):
- rules = {}
- for rule in _in:
- match = re.match(r"(\w+ \w+) bags contain (no other bags|[^\.]*)\.",
- rule)
- children = []
- if match[2] != "no other bags":
- child_matches = re.findall(r"(\d+) (\w+ \w+) bags?(, )?", match[2])
- for child_match in child_matches:
- children.append((child_match[1], int(child_match[0])))
- rules[match[1]] = Node(match[1], children)
- for name, bag in rules.items():
- rules[name].children = [(rules[bag], amount) for bag, amount in rules[name].children]
-
- return count_children(rules["shiny gold"]) - 1
+ @functools.cache
+ def count_children(bag):
+ return 1 + sum([child[1] * count_children(child[0]) for child in bag.children])
+
+ bags = parse(_in)
+ return count_children(bags["shiny gold"]) - 1
if __name__ == "__main__":
@@ -96,7 +62,7 @@ if __name__ == "__main__":
input = open(sys.argv[1], "r").readlines()
else:
input = open("../input/07", "r").readlines()
- #pt1(input) # for graph
+ # graph(parse(input))
print(pt1(input))
print(pt2(input))