diff options
| -rw-r--r-- | 20/py/d07.py | 88 |
1 files changed, 27 insertions, 61 deletions
diff --git a/20/py/d07.py b/20/py/d07.py index 2658774..0252e6e 100644 --- a/20/py/d07.py +++ b/20/py/d07.py @@ -1,15 +1,15 @@ #!/usr/bin/env python3 import sys -import re +import regex as re +import functools class Node: def __init__(self, name, children=[]): self.name = name - self.parents = [] self.children = children - def bag_children(self): + def __iter__(self): for child in self.children: yield child[0] @@ -26,69 +26,35 @@ def graph(bags): print("}") -mem = dict() -def can_contain(bag, target): - if bag in mem: - return mem[bag] - if target in bag.bag_children(): - mem[bag] = True - return True - for child in bag.bag_children(): - if can_contain(child, target): - mem[bag] = True - return True - mem[bag] = False - return False +def parse(_in): + bags = {} # bag: node + for bag in _in: + match = re.match(r"(\w+ \w+) bags contain (no other bags|(((\d+) (\w+ \w+)) bags?(, )?)+)\.", + bag) + children = [(kind, int(count)) for count, kind in zip(match.captures(5), + match.captures(6))] + bags[match[1]] = Node(match[1], children) + for name, bag in bags.items(): + bags[name].children = [(bags[bag], amount) for bag, amount in bags[name].children] + return bags def pt1(_in): - rules = {} # bag: node - for rule in _in: - match = re.match(r"(\w+ \w+) bags contain (no other bags|[^\.]*)\.", - rule) - children = [] - if match[2] != "no other bags": - child_matches = re.findall(r"(\d+) (\w+ \w+) bags?(, )?", match[2]) - for child_match in child_matches: - children.append((child_match[1], int(child_match[0]))) - rules[match[1]] = Node(match[1], children) - for name, bag in rules.items(): - rules[name].children = [(rules[bag], amount) for bag, amount in rules[name].children] - #graph(rules) - - ans = 0 - for _, bag in rules.items(): - if can_contain(bag, rules["shiny gold"]): - ans += 1 - return ans - - -children = dict() -def count_children(bag): - if bag in children: - #print("mem", bag) - return mem[bag] - amount = 0 - for child in bag.children: - amount += child[1] * count_children(child[0]) - return 1 + amount + @functools.cache + def can_contain(bag, target): + return target in bag or any(can_contain(child, target) for child in bag) + + bags = parse(_in) + return len([bag for bag in bags.values() if can_contain(bag, bags["shiny gold"])]) def pt2(_in): - rules = {} - for rule in _in: - match = re.match(r"(\w+ \w+) bags contain (no other bags|[^\.]*)\.", - rule) - children = [] - if match[2] != "no other bags": - child_matches = re.findall(r"(\d+) (\w+ \w+) bags?(, )?", match[2]) - for child_match in child_matches: - children.append((child_match[1], int(child_match[0]))) - rules[match[1]] = Node(match[1], children) - for name, bag in rules.items(): - rules[name].children = [(rules[bag], amount) for bag, amount in rules[name].children] - - return count_children(rules["shiny gold"]) - 1 + @functools.cache + def count_children(bag): + return 1 + sum([child[1] * count_children(child[0]) for child in bag.children]) + + bags = parse(_in) + return count_children(bags["shiny gold"]) - 1 if __name__ == "__main__": @@ -96,7 +62,7 @@ if __name__ == "__main__": input = open(sys.argv[1], "r").readlines() else: input = open("../input/07", "r").readlines() - #pt1(input) # for graph + # graph(parse(input)) print(pt1(input)) print(pt2(input)) |
