def plot_points(points, ax, **kwargs):
merged = dict(color="black", markersize=7, alpha=0.5)
merged.update(kwargs)
ax.plot(points[:, 0], points[:, 1], "o", **merged)
def plot_bounds(lower, upper, ax, **kwargs):
merged = dict(fill=False, edgecolor="black", linestyle="--", linewidth=1, alpha=0.5)
merged.update(kwargs)
extent = upper - lower
rect = plt.Rectangle(lower, extent[0], extent[1], **merged)
ax.add_patch(rect)
ax.update_datalim(rect.get_path().vertices)
ax.autoscale_view()
def plot_region(points, codes, mask, value, ax):
plot_points(points, ax, markersize=3)
lower, upper = np.min(points, axis=0), np.max(points, axis=0)
plot_bounds(lower, upper, ax)
mask = (codes & mask) == value
matched_points = points[mask]
if matched_points.shape[0] > 0:
plot_points(matched_points, ax, color="crimson", markersize=6, markerfacecolor="none", markeredgewidth=1)
def generate_regions(mask, width=8):
if mask == 0:
return []
# (~mask + 1) is used instead of -mask to avoid overflow warnings with np.uint32 (equivalent in two's complement).
# pos = int(mask & -mask).bit_length() - 1
pos = int(mask & (~mask + 1)).bit_length() - 1
length = mask >> pos
n_bits = int(length).bit_length()
n_values = length + 1
results = []
for i in range(n_values):
val = np.uint32(i << pos)
val_str = f"{val:0{width}b}"
label = val_str[: width - pos] + "x" * pos
results.append((val, label))
return results
def plot_regeions(points, codes, mask, cols=None):
regions = generate_regions(mask)
num_regions = len(regions)
if cols is None:
cols = math.ceil(math.sqrt(num_regions))
else:
cols = min(cols, num_regions)
rows = math.ceil(num_regions / cols)
fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows), constrained_layout=True)
axes = np.array(axes).ravel()
for i, region in enumerate(regions):
ax = axes[i]
plot_region(points, codes, mask, region[0], ax)
ax.set_title(region[1])
ax.set_xlim(-1, 16)
ax.set_ylim(-1, 16)
ax.set_xticks(np.arange(16))
ax.set_yticks(np.arange(16))
ax.tick_params(axis="both", labelsize=8)
ax.set_aspect("equal", adjustable="box")
ax.grid(True)
plt.show()