summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--20/README2
-rw-r--r--20/py/d14.py94
2 files changed, 96 insertions, 0 deletions
diff --git a/20/README b/20/README
index 976822f..4de8366 100644
--- a/20/README
+++ b/20/README
@@ -15,6 +15,7 @@ Day Time Ans Time Ans
11 432.343 2338 580.579 2134
12 0.257 1645 0.279 35292
13 0.006 2095 0.018 598411311431841
+ 14 4.357 15018100062885 55.561 5724245857696
------- -------
tot TBD ms TBD ms
@@ -23,6 +24,7 @@ Stats:
-------Part 1-------- -------Part 2--------
Day Time Rank Score Time Rank Score
+ 14 00:43:31 4037 0 01:00:44 2402 0
13 01:16:45 7549 0 02:27:40 3806 0
12 00:09:08 645 0 00:16:03 373 0
11 00:55:39 4799 0 01:20:27 4040 0
diff --git a/20/py/d14.py b/20/py/d14.py
new file mode 100644
index 0000000..647c1c8
--- /dev/null
+++ b/20/py/d14.py
@@ -0,0 +1,94 @@
+#!/usr/bin/env python3
+import aoc20
+import sys
+
+
+def to_bin(n: int) -> str:
+ return bin(n)[2:] # 0b
+
+
+def from_bin(b: str) -> int:
+ return int(b, 2)
+
+
+def mask(b: [str], m: [str]) -> [str]:
+ # print(f"{b:>36}\n{m}")
+ ret = ""
+ for mi in range(len(m))[::-1]:
+ bi = mi - (len(m) - len(b))
+ if m[mi] == "0":
+ ret = "0" + ret
+ elif m[mi] == "1":
+ ret = "1" + ret
+ else:
+ ret = (str(b[bi]) if bi >= 0 else "0") + ret
+ # print(b, b[bi] if bi >= 0 else " ", m[mi], ret)
+ # print(f"{ret:>36}")
+ # print()
+ return ret
+
+
+def pt1(_in):
+ mem = {}
+ cur_mask = [-1 for _ in range(36)]
+ for inst in _in:
+ inst = inst.strip()
+ if inst[:4] == "mask":
+ cur_mask = inst.split(" = ")[1]
+ else:
+ key = int(inst[inst.index("[")+1:inst.index("]")])
+ bits = to_bin(int(inst.split(" = ")[1]))
+ bits = mask(bits, cur_mask)
+ mem[key] = bits
+ s = 0
+ for k, v in mem.items():
+ s += from_bin(v)
+ return s
+
+
+def str_set_at(s, c, i):
+ return s[:i] + c + s[i+1:]
+
+
+def mask2(b: [str], m: [str]) -> [[str]]:
+ adr = ""
+ for mi in range(len(m))[::-1]:
+ bi = mi - (len(m) - len(b))
+ if m[mi] == "0":
+ adr = (str(b[bi]) if bi >= 0 else "0") + adr
+ elif m[mi] == "1":
+ adr = "1" + adr
+ else:
+ adr = "X" + adr
+ adrs = [adr]
+ for bit in range(len(adr)):
+ if adr[bit] == "X":
+ new_adrs = []
+ while adrs:
+ adr = adrs.pop()
+ new_adrs.append(str_set_at(adr, "0", bit))
+ new_adrs.append(str_set_at(adr, "1", bit))
+ adrs = new_adrs
+ return adrs
+
+
+def pt2(_in):
+ mem = {}
+ cur_mask = [-1 for _ in range(36)]
+ for inst in _in:
+ inst = inst.strip()
+ if inst[:4] == "mask":
+ cur_mask = inst.split(" = ")[1]
+ else:
+ key = to_bin(int(inst[inst.index("[")+1:inst.index("]")]))
+ keys = mask2(key, cur_mask)
+ val = int(inst.split(" = ")[1])
+ for key in keys:
+ mem[key] = val
+ return sum(mem.values())
+
+
+if __name__ == "__main__":
+ input = aoc20.read_input(sys.argv[1:], 14)
+ print(pt1(input))
+ print(pt2(input))