from typing import Optional from functools import cmp_to_key def chart_region(plots: list[list[Optional[str]]], current_pos: tuple[int, int], charted_positions: list[tuple[int, int]], fence_positions: list[list[tuple[int, int]]]) -> tuple[int, int]: charted_positions.append(current_pos) current_plot = plots[current_pos[0]][current_pos[1]] plot_count = 1 fence_count = 0 row_cnt = len(data) col_cnt = len(data[0]) for dir_idx, direct in enumerate(((-1, 0), (1, 0), (0, -1), (0, 1))): new_pos = (current_pos[0] + direct[0], current_pos[1] + direct[1]) if 0 <= new_pos[0] < row_cnt and 0 <= new_pos[1] < col_cnt and plots[new_pos[0]][new_pos[1]] == current_plot: if new_pos not in charted_positions: result = chart_region(plots, new_pos, charted_positions, fence_positions) plot_count += result[0] fence_count += result[1] else: fence_positions[dir_idx].append(current_pos) fence_count += 1 return plot_count, fence_count def compare_row_first(item1: tuple[int, int], item2: tuple[int, int]): row_diff = item1[0] - item2[0] if row_diff != 0: return row_diff return item1[1] - item2[1] def compare_col_first(item1: tuple[int, int], item2: tuple[int, int]): col_diff = item1[1] - item2[1] if col_diff != 0: return col_diff return item1[0] - item2[0] def calculate_discount(fence_positions: list[list[tuple[int, int]]]) -> int: adjacents = 0 for charted_positions in fence_positions: charted_positions.sort(key=cmp_to_key(compare_row_first)) last_pos = None for current_pos in charted_positions: if last_pos is not None and last_pos[0] == current_pos[0] and last_pos[1] + 1 == current_pos[1]: adjacents += 1 last_pos = current_pos charted_positions.sort(key=cmp_to_key(compare_col_first)) last_pos = None for current_pos in charted_positions: if last_pos is not None and last_pos[1] == current_pos[1] and last_pos[0] + 1 == current_pos[0]: adjacents += 1 last_pos = current_pos return adjacents use_discount = True # for part two data = [] with open("input") as f: for line in f: data.append(list(line.strip())) row_count = len(data) col_count = len(data[0]) final_sum = 0 for row in range(row_count): for col in range(col_count): if data[row][col] is not None: charted_pos: list[tuple[int, int]] = [] fence_pos: list[list[tuple[int, int]]] = [[], [], [], []] chart_result = chart_region(data, (row, col), charted_pos, fence_pos) fences = chart_result[1] if use_discount: fences -= calculate_discount(fence_pos) final_sum += chart_result[0] * fences for pos in charted_pos: data[pos[0]][pos[1]] = None print(final_sum)