summaryrefslogtreecommitdiffstats
path: root/20/py/d14.py
blob: 647c1c8179f6fde223ebea03449108234d00af14 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/env python3
import aoc20
import sys


def to_bin(n: int) -> str:
    return bin(n)[2:]  # 0b


def from_bin(b: str) -> int:
    return int(b, 2)


def mask(b: [str], m: [str]) -> [str]:
    # print(f"{b:>36}\n{m}")
    ret = ""
    for mi in range(len(m))[::-1]:
        bi = mi - (len(m) - len(b))
        if m[mi] == "0":
            ret = "0" + ret
        elif m[mi] == "1":
            ret = "1" + ret
        else:
            ret = (str(b[bi]) if bi >= 0 else "0") + ret
        # print(b, b[bi] if bi >= 0 else " ", m[mi], ret)
    # print(f"{ret:>36}")
    # print()
    return ret


def pt1(_in):
    mem = {}
    cur_mask = [-1 for _ in range(36)]
    for inst in _in:
        inst = inst.strip()
        if inst[:4] == "mask":
            cur_mask = inst.split(" = ")[1]
        else:
            key = int(inst[inst.index("[")+1:inst.index("]")])
            bits = to_bin(int(inst.split(" = ")[1]))
            bits = mask(bits, cur_mask)
            mem[key] = bits
    s = 0
    for k, v in mem.items():
        s += from_bin(v)
    return s


def str_set_at(s, c, i):
    return s[:i] + c + s[i+1:]


def mask2(b: [str], m: [str]) -> [[str]]:
    adr = ""
    for mi in range(len(m))[::-1]:
        bi = mi - (len(m) - len(b))
        if m[mi] == "0":
            adr = (str(b[bi]) if bi >= 0 else "0") + adr
        elif m[mi] == "1":
            adr = "1" + adr
        else:
            adr = "X" + adr
    adrs = [adr]
    for bit in range(len(adr)):
        if adr[bit] == "X":
            new_adrs = []
            while adrs:
                adr = adrs.pop()
                new_adrs.append(str_set_at(adr, "0", bit))
                new_adrs.append(str_set_at(adr, "1", bit))
            adrs = new_adrs
    return adrs


def pt2(_in):
    mem = {}
    cur_mask = [-1 for _ in range(36)]
    for inst in _in:
        inst = inst.strip()
        if inst[:4] == "mask":
            cur_mask = inst.split(" = ")[1]
        else:
            key = to_bin(int(inst[inst.index("[")+1:inst.index("]")]))
            keys = mask2(key, cur_mask)
            val = int(inst.split(" = ")[1])
            for key in keys:
                mem[key] = val
    return sum(mem.values())


if __name__ == "__main__":
    input = aoc20.read_input(sys.argv[1:], 14)
    print(pt1(input))
    print(pt2(input))