summaryrefslogtreecommitdiffstats
path: root/20/py/d07.py
blob: 2658774c711e954cc19fee4ca4f7c4540bfdf16a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python3
import sys
import re


class Node:
    def __init__(self, name, children=[]):
        self.name = name
        self.parents = []
        self.children = children

    def bag_children(self):
        for child in self.children:
            yield child[0]

    def __repr__(self):
        return self.name


def graph(bags):
    print("digraph G {")
    print("rankdir=\"LR\";")
    for bag in bags.values():
        for child in bag.children:
            print(f"\"{bag.name}\" -> \"{child[0].name}\" [ label=\"{child[1]}\" ];")
    print("}")


mem = dict()
def can_contain(bag, target):
    if bag in mem:
        return mem[bag]
    if target in bag.bag_children():
        mem[bag] = True
        return True
    for child in bag.bag_children():
        if can_contain(child, target):
            mem[bag] = True
            return True
    mem[bag] = False
    return False


def pt1(_in):
    rules = {}  # bag: node
    for rule in _in:
        match = re.match(r"(\w+ \w+) bags contain (no other bags|[^\.]*)\.",
                         rule)
        children = []
        if match[2] != "no other bags":
            child_matches = re.findall(r"(\d+) (\w+ \w+) bags?(, )?", match[2])
            for child_match in child_matches:
                children.append((child_match[1], int(child_match[0])))
        rules[match[1]] = Node(match[1], children)
    for name, bag in rules.items():
        rules[name].children = [(rules[bag], amount) for bag, amount in rules[name].children]
    #graph(rules)

    ans = 0
    for _, bag in rules.items():
        if can_contain(bag, rules["shiny gold"]):
            ans += 1
    return ans


children = dict()
def count_children(bag):
    if bag in children:
        #print("mem", bag)
        return mem[bag]
    amount = 0
    for child in bag.children:
        amount += child[1] * count_children(child[0])
    return 1 + amount


def pt2(_in):
    rules = {}
    for rule in _in:
        match = re.match(r"(\w+ \w+) bags contain (no other bags|[^\.]*)\.",
                         rule)
        children = []
        if match[2] != "no other bags":
            child_matches = re.findall(r"(\d+) (\w+ \w+) bags?(, )?", match[2])
            for child_match in child_matches:
                children.append((child_match[1], int(child_match[0])))
        rules[match[1]] = Node(match[1], children)
    for name, bag in rules.items():
        rules[name].children = [(rules[bag], amount) for bag, amount in rules[name].children]

    return count_children(rules["shiny gold"]) - 1


if __name__ == "__main__":
    if len(sys.argv) > 1:
        input = open(sys.argv[1], "r").readlines()
    else:
        input = open("../input/07", "r").readlines()
    #pt1(input)  # for graph

    print(pt1(input))
    print(pt2(input))