"""LeNet-style (feature-map stack) renderer.
This renderer aims to reproduce the classic "stack of feature maps" diagrams
commonly used to illustrate convolutional neural networks, while still being
able to represent dense (vector) layers and other non-convolutional components.
It is intentionally separate from ``functional_view``: it renders a mostly
sequential pipeline left-to-right (skipping ignored layers) and draws
connections based on the destination layer's kernel/pool parameters.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Tuple, Union
import hashlib
import random
import os
import math
import aggdraw
from PIL import Image, ImageDraw, ImageFont
from .layer_utils import get_layers, extract_primary_shape
from .options import LenetOptions, LENET_PRESETS
from .utils import get_rgba_tuple, fade_color, resolve_style
# ---------------------------------------------------------------------------
# Shape helpers (copied/trimmed from layered.py for version-robustness)
# ---------------------------------------------------------------------------
def _shape_to_tuple(shape: Any) -> Any:
if shape is None:
return None
if isinstance(shape, tuple):
return shape
if hasattr(shape, "as_list"):
try:
return tuple(shape.as_list())
except Exception: # noqa: BLE001
pass
if isinstance(shape, list):
return tuple(shape)
return shape
def _resolve_layer_output_shape(layer: Any) -> Any:
shape = getattr(layer, "output_shape", None)
if shape is not None:
return _shape_to_tuple(shape)
output = getattr(layer, "output", None)
tensor_shape = getattr(output, "shape", None)
if tensor_shape is not None:
return _shape_to_tuple(tensor_shape)
compute_output_shape = getattr(layer, "compute_output_shape", None)
if callable(compute_output_shape):
input_shape = getattr(layer, "input_shape", None)
if input_shape is not None:
try:
return _shape_to_tuple(compute_output_shape(input_shape))
except Exception: # noqa: BLE001
pass
return None
def _clamp_int(value: float, low: int, high: int) -> int:
return int(max(low, min(high, round(value))))
def _as_tuple2(value: Any) -> Tuple[int, int]:
if value is None:
return (1, 1)
if isinstance(value, int):
return (value, value)
if isinstance(value, (tuple, list)) and len(value) >= 2:
return (int(value[0]), int(value[1]))
if isinstance(value, (tuple, list)) and len(value) == 1:
return (1, int(value[0]))
return (1, 1)
def _clamp_rect_to_face(
center_x: float,
center_y: float,
rect_w: float,
rect_h: float,
face_rect: Tuple[float, float, float, float],
*,
margin: float = 1.0,
) -> Tuple[float, float, float, float]:
"""Clamp a rectangle (defined by center + size) to lie within a face rect."""
fx1, fy1, fx2, fy2 = face_rect
# available size inside margins
avail_w = max(0.0, (fx2 - fx1) - 2.0 * margin)
avail_h = max(0.0, (fy2 - fy1) - 2.0 * margin)
rw = min(float(rect_w), avail_w) if avail_w > 0 else 0.0
rh = min(float(rect_h), avail_h) if avail_h > 0 else 0.0
half_w = rw / 2.0
half_h = rh / 2.0
if avail_w <= 0 or avail_h <= 0:
# Face too small; collapse to center
cx = (fx1 + fx2) / 2.0
cy = (fy1 + fy2) / 2.0
return (cx, cy, cx, cy)
cx = min(max(center_x, fx1 + margin + half_w), fx2 - margin - half_w)
cy = min(max(center_y, fy1 + margin + half_h), fy2 - margin - half_h)
return (cx - half_w, cy - half_h, cx + half_w, cy + half_h)
def _with_alpha(rgba: Tuple[int, int, int, int], alpha: int) -> Tuple[int, int, int, int]:
return (rgba[0], rgba[1], rgba[2], int(max(0, min(255, alpha))))
def _effective_patch_alpha_for_layer(
style: Mapping[str, Any],
*,
has_face_image: bool,
base_alpha: int,
default_on_image: int,
) -> int:
"""Return an alpha (0-255) for patch fill on this layer.
Rules:
- If the style explicitly sets patch alpha (patch_fill_alpha/patch_alpha), use it.
- Otherwise, if the layer has a face image configured, default to a semi-transparent
alpha (min(base_alpha, default_on_image)) so the underlying image remains visible.
- Otherwise, keep the base alpha from the patch fill color.
The level of opacity can be controlled per-layer via style keys:
- patch_fill_alpha (preferred) or patch_alpha
- patch_fill_alpha_on_image / patch_alpha_on_image (only when a face image is present)
"""
def _clamp(v: Any) -> Optional[int]:
try:
return int(max(0, min(255, int(v))))
except Exception: # noqa: BLE001
return None
# Explicit per-layer alpha always wins.
for key in ("patch_fill_alpha", "patch_alpha"):
if key in style:
vv = _clamp(style.get(key))
if vv is not None:
return vv
if has_face_image:
for key in ("patch_fill_alpha_on_image", "patch_alpha_on_image"):
if key in style:
vv = _clamp(style.get(key))
if vv is not None:
return vv
return int(min(int(base_alpha), int(default_on_image)))
return int(base_alpha)
def _stable_seed(base_seed: Optional[int], *parts: Any) -> int:
"""Create a stable 64-bit seed from an optional base seed and arbitrary parts."""
h = hashlib.sha256()
if base_seed is not None:
h.update(str(int(base_seed)).encode("utf-8"))
for p in parts:
h.update(b"|")
h.update(str(p).encode("utf-8"))
return int.from_bytes(h.digest()[:8], "big", signed=False)
# ---------------------------------------------------------------------------
# Face image helpers (front-face textures)
# ---------------------------------------------------------------------------
_FACE_IMAGE_CACHE: Dict[str, Image.Image] = {}
def _load_face_image(spec: Any) -> Optional[Image.Image]:
"""Load a face image from a path or return a provided PIL image.
Returns an RGBA image or None on failure.
"""
if spec is None:
return None
if isinstance(spec, Image.Image):
try:
return spec.convert("RGBA")
except Exception: # noqa: BLE001
return None
if isinstance(spec, str):
path = os.path.expanduser(spec)
# Cache the decoded image; callers get a copy to avoid mutation issues.
if path in _FACE_IMAGE_CACHE:
return _FACE_IMAGE_CACHE[path].copy()
try:
img = Image.open(path).convert("RGBA")
except Exception: # noqa: BLE001
return None
_FACE_IMAGE_CACHE[path] = img
return img.copy()
return None
def _parse_face_image_style(style: Mapping[str, Any]) -> Tuple[Optional[Any], str, int, Optional[int]]:
"""Extract face-image parameters from a style dict.
Supports either:
- face_image = "/path/to.png"
- face_image = {"path": "...", "fit": "...", "alpha": 200, "inset": 1}
and optional top-level overrides:
- face_image_fit, face_image_alpha, face_image_inset
"""
spec = style.get("face_image", None)
fit = str(style.get("face_image_fit", "cover")).strip().lower()
alpha = int(style.get("face_image_alpha", 255))
inset = style.get("face_image_inset", None)
if inset is not None:
try:
inset = int(inset)
except Exception: # noqa: BLE001
inset = None
if isinstance(spec, Mapping):
# Nested spec dict for convenience
spec_map: Mapping[str, Any] = spec
spec_path = spec_map.get("path", spec_map.get("src", spec_map.get("file", None)))
if spec_path is not None:
spec = spec_path
if "fit" in spec_map:
fit = str(spec_map.get("fit")).strip().lower()
if "alpha" in spec_map:
try:
alpha = int(spec_map.get("alpha"))
except Exception: # noqa: BLE001
pass
if "inset" in spec_map:
try:
inset = int(spec_map.get("inset"))
except Exception: # noqa: BLE001
pass
# Normalize + clamp
if fit not in {"cover", "contain", "match_aspect", "fill"}:
fit = "cover"
alpha = int(max(0, min(255, alpha)))
return spec, fit, alpha, inset
def _adjust_wh_for_image_aspect(
w_px: int,
h_px: int,
img: Image.Image,
*,
min_xy: int,
max_xy: int,
) -> Tuple[int, int]:
"""Adjust (w,h) so the face matches the image aspect ratio (best-effort)."""
try:
iw, ih = img.size
except Exception: # noqa: BLE001
return (w_px, h_px)
if iw <= 0 or ih <= 0:
return (w_px, h_px)
aspect = float(iw) / float(ih)
def _cl(v: float) -> int:
return int(max(min_xy, min(max_xy, round(v))))
# Prefer keeping height stable (less vertical jitter) when possible.
w1 = _cl(h_px * aspect)
if min_xy <= w1 <= max_xy:
return (w1, _cl(h_px))
h1 = _cl(w_px / aspect)
if min_xy <= h1 <= max_xy:
return (_cl(w_px), h1)
# Fallback: clamp both with minimal distortion.
return (_cl(w1), _cl(h1))
def _fit_image_to_rect(
img: Image.Image,
w: int,
h: int,
*,
fit: str,
background: Tuple[int, int, int, int],
) -> Image.Image:
"""Fit an image into a (w,h) rect using the requested mode."""
fit = (fit or "cover").strip().lower()
if fit == "match_aspect":
# After aspect-matching, 'contain' shows the full image with no crop.
fit = "contain"
if w <= 0 or h <= 0:
return Image.new("RGBA", (max(1, w), max(1, h)), background)
try:
iw, ih = img.size
except Exception: # noqa: BLE001
iw, ih = (0, 0)
resample = getattr(getattr(Image, "Resampling", Image), "LANCZOS", getattr(Image, "LANCZOS", Image.BICUBIC))
if iw <= 0 or ih <= 0:
return Image.new("RGBA", (w, h), background)
if fit == "fill":
return img.resize((w, h), resample=resample)
if fit == "contain":
scale = min(float(w) / float(iw), float(h) / float(ih))
nw = max(1, int(round(iw * scale)))
nh = max(1, int(round(ih * scale)))
im2 = img.resize((nw, nh), resample=resample)
canvas = Image.new("RGBA", (w, h), background)
canvas.paste(im2, ((w - nw) // 2, (h - nh) // 2), im2)
return canvas
# cover (default)
scale = max(float(w) / float(iw), float(h) / float(ih))
nw = max(1, int(round(iw * scale)))
nh = max(1, int(round(ih * scale)))
im2 = img.resize((nw, nh), resample=resample)
left = max(0, (nw - w) // 2)
top = max(0, (nh - h) // 2)
return im2.crop((left, top, left + w, top + h))
def _apply_face_images(img: Image.Image, stacks: Sequence[Dict[str, Any]]) -> None:
"""Paste configured face images onto the front face of stacks."""
for obj in stacks:
style = obj.get("style", {}) or {}
st: FeatureMapStack = obj.get("stack")
if st is None:
continue
spec, fit, alpha, inset = _parse_face_image_style(style)
if not spec:
continue
face_img = style.get("_face_image_obj", None)
if not isinstance(face_img, Image.Image):
face_img = _load_face_image(spec)
if face_img is None:
continue
# Save for later reuse in the same render call.
try:
style["_face_image_obj"] = face_img
except Exception: # noqa: BLE001
pass
x1f, y1f, x2f, y2f = st.front_rect()
ix1 = int(round(x1f))
iy1 = int(round(y1f))
ix2 = int(round(x2f))
iy2 = int(round(y2f))
# Inset so we don't paint over the outline
inset_px = st.line_width if inset is None else int(max(0, inset))
ix1 += inset_px
iy1 += inset_px
ix2 -= inset_px
iy2 -= inset_px
w = max(1, ix2 - ix1)
h = max(1, iy2 - iy1)
bg = get_rgba_tuple(style.get("fill", st.fill))
fitted = _fit_image_to_rect(face_img, w, h, fit=fit, background=bg)
if alpha < 255:
# Multiply alpha channel
r, g, b, a = fitted.split()
a = a.point(lambda v: int(v * (alpha / 255.0)))
fitted = Image.merge("RGBA", (r, g, b, a))
# Composite onto the base image
try:
img.alpha_composite(fitted, (ix1, iy1))
except Exception: # noqa: BLE001
# Fallback
img.paste(fitted, (ix1, iy1), fitted)
def _sample_patch_ratios(rng: random.Random, *, x_lo: float, x_hi: float, y_lo: float = 0.12, y_hi: float = 0.88) -> Tuple[float, float]:
"""Sample (rx, ry) in [0,1] normalized coordinates with bounded ranges."""
rx = x_lo + (x_hi - x_lo) * rng.random()
ry = y_lo + (y_hi - y_lo) * rng.random()
# clamp just in case
rx = max(0.0, min(1.0, rx))
ry = max(0.0, min(1.0, ry))
return rx, ry
def _place_patch_by_ratio(
face: Tuple[float, float, float, float],
patch_w: float,
patch_h: float,
rx: float,
ry: float,
*,
margin: float = 1.0,
) -> Tuple[float, float, float, float]:
"""Place a patch within a face using normalized ratios, clamped to the face."""
fx1, fy1, fx2, fy2 = face
fw = max(0.0, fx2 - fx1)
fh = max(0.0, fy2 - fy1)
# Available travel range for the patch's top-left (after margins)
avail_w = max(0.0, fw - 2 * margin - patch_w)
avail_h = max(0.0, fh - 2 * margin - patch_h)
# If the patch is too large, fall back to center ratios.
rx_eff = rx if avail_w > 0 else 0.5
ry_eff = ry if avail_h > 0 else 0.5
cx = fx1 + margin + patch_w / 2.0 + rx_eff * avail_w
cy = fy1 + margin + patch_h / 2.0 + ry_eff * avail_h
return _clamp_rect_to_face(cx, cy, patch_w, patch_h, face, margin=margin)
def _rects_overlap_1d(a1: float, a2: float, b1: float, b2: float, *, eps: float = 0.0) -> bool:
"""Return True if [a1,a2] overlaps/touches [b1,b2] within eps."""
return (a1 <= b2 + eps) and (a2 >= b1 - eps)
def _clamp_rect_topleft(
rect: Tuple[float, float, float, float],
face: Tuple[float, float, float, float],
*,
margin: float = 1.0,
) -> Tuple[float, float, float, float]:
"""Clamp rect (x1,y1,x2,y2) inside face, preserving its size."""
x1, y1, x2, y2 = rect
fx1, fy1, fx2, fy2 = face
w = max(0.0, x2 - x1)
h = max(0.0, y2 - y1)
# protect against negative available space
max_x1 = fx2 - margin - w
max_y1 = fy2 - margin - h
x1 = min(max(x1, fx1 + margin), max_x1)
y1 = min(max(y1, fy1 + margin), max_y1)
return (x1, y1, x1 + w, y1 + h)
def _set_rect_x1(rect: Tuple[float, float, float, float], new_x1: float) -> Tuple[float, float, float, float]:
x1, y1, x2, y2 = rect
w = x2 - x1
return (new_x1, y1, new_x1 + w, y2)
def _set_rect_y1(rect: Tuple[float, float, float, float], new_y1: float) -> Tuple[float, float, float, float]:
x1, y1, x2, y2 = rect
h = y2 - y1
return (x1, new_y1, x2, new_y1 + h)
def _choose_y_from_ranges(
rng: random.Random,
prefer: float,
ranges: Sequence[Tuple[float, float]],
) -> Optional[float]:
if not ranges:
return None
candidates = []
for lo, hi in ranges:
y = max(lo, min(prefer, hi))
candidates.append((abs(prefer - y), y, lo, hi))
mind = min(c[0] for c in candidates)
best = [c for c in candidates if abs(c[0] - mind) <= 1e-6]
_, y, lo, hi = rng.choice(best)
# tiny jitter for variety while staying close
if hi - lo >= 1.0:
jitter = (rng.random() - 0.5) * min(6.0, hi - lo)
y = max(lo, min(hi, y + jitter))
return y
def _enforce_in_out_patch_separation(
face: Tuple[float, float, float, float],
incoming: Tuple[float, float, float, float],
outgoing: Tuple[float, float, float, float],
*,
rng: random.Random,
x_eps: float = 1.0,
v_gap: float = 2.0,
margin: float = 1.0,
) -> Tuple[Tuple[float, float, float, float], Tuple[float, float, float, float]]:
"""Ensure ordering + avoid 'touching' for incoming/outgoing patches on the same face.
Requirements:
- The outgoing patch (to the next layer) should not have a left edge further left than the incoming patch.
- If patches are horizontally close (overlap/touch), enforce enough vertical separation so they don't touch.
"""
inc = _clamp_rect_topleft(incoming, face, margin=margin)
out = _clamp_rect_topleft(outgoing, face, margin=margin)
# Enforce ordering by left edge: out.x1 >= inc.x1
if out[0] < inc[0]:
fx1, fy1, fx2, fy2 = face
out_w = out[2] - out[0]
max_out_x1 = fx2 - margin - out_w
out = _set_rect_x1(out, min(max_out_x1, inc[0]))
out = _clamp_rect_topleft(out, face, margin=margin)
if out[0] < inc[0]:
# If still violated due to clamping, move incoming left as needed.
inc_w = inc[2] - inc[0]
inc = _set_rect_x1(inc, out[0])
inc = _clamp_rect_topleft(inc, face, margin=margin)
# If horizontally close, enforce vertical separation (disallow overlap/touch)
x_close = out[0] <= inc[2] + x_eps # outgoing is on the right (or close)
if x_close and _rects_overlap_1d(inc[1], inc[3], out[1], out[3], eps=v_gap):
fx1, fy1, fx2, fy2 = face
# First, try moving outgoing away from incoming (above or below).
out_h = out[3] - out[1]
min_y = fy1 + margin
max_y = fy2 - margin - out_h
above_max = min(max_y, inc[1] - v_gap - out_h)
below_min = max(min_y, inc[3] + v_gap)
ranges: list[Tuple[float, float]] = []
if min_y <= above_max:
ranges.append((min_y, above_max))
if below_min <= max_y:
ranges.append((below_min, max_y))
new_y = _choose_y_from_ranges(rng, out[1], ranges)
if new_y is not None:
out = _set_rect_y1(out, new_y)
out = _clamp_rect_topleft(out, face, margin=margin)
# If still overlapping (very tight face), try moving incoming instead.
if _rects_overlap_1d(inc[1], inc[3], out[1], out[3], eps=v_gap):
inc_h = inc[3] - inc[1]
min_y2 = fy1 + margin
max_y2 = fy2 - margin - inc_h
above_max2 = min(max_y2, out[1] - v_gap - inc_h)
below_min2 = max(min_y2, out[3] + v_gap)
ranges2: list[Tuple[float, float]] = []
if min_y2 <= above_max2:
ranges2.append((min_y2, above_max2))
if below_min2 <= max_y2:
ranges2.append((below_min2, max_y2))
new_y2 = _choose_y_from_ranges(rng, inc[1], ranges2)
if new_y2 is not None:
inc = _set_rect_y1(inc, new_y2)
inc = _clamp_rect_topleft(inc, face, margin=margin)
return inc, out
@dataclass
class RenderShape:
"""Canonical shape record used by the LeNet renderer.
The renderer normalizes layer outputs into this structure so that spatial
feature maps and vector-like outputs can be handled through one drawing
pipeline.
"""
kind: str # "spatial" or "vector"
h: int
w: int
c: int
# Original-ish dims (best effort; used for kernel->patch ratio)
h_dim: int
w_dim: int
c_dim: int
def _canonicalize_shape(layer: Any, shape: Any) -> RenderShape:
"""Return a canonical (H,W,C) representation + kind."""
if shape is None:
return RenderShape("vector", 1, 1, 1, 1, 1, 1)
# Multi-output: best effort pick first output.
if isinstance(shape, (list, tuple)) and shape and isinstance(shape[0], (list, tuple)):
shape = shape[0]
shape = _shape_to_tuple(shape)
if not isinstance(shape, (list, tuple)):
return RenderShape("vector", 1, 1, 1, 1, 1, 1)
# Drop batch dim if present.
dims = list(shape[1:]) if len(shape) > 1 else list(shape)
rank = len(dims)
data_format = getattr(layer, "data_format", None)
def num(x: Any) -> int:
try:
return int(x) if x is not None else 1
except Exception: # noqa: BLE001
return 1
# 2D conv/pool typical: (H,W,C) or (C,H,W)
if rank == 3:
if data_format == "channels_first":
c_dim, h_dim, w_dim = (num(dims[0]), num(dims[1]), num(dims[2]))
else:
h_dim, w_dim, c_dim = (num(dims[0]), num(dims[1]), num(dims[2]))
return RenderShape("spatial", h_dim, w_dim, c_dim, h_dim, w_dim, c_dim)
# 1D conv typical: (L,C) after batch removal
if rank == 2:
# Treat as vector for most cases, but Conv1D/Pooling1D looks nicer as 1×L×C.
layer_name = type(layer).__name__.lower()
if "conv1d" in layer_name or "pool1d" in layer_name:
w_dim, c_dim = (num(dims[0]), num(dims[1]))
return RenderShape("spatial", 1, w_dim, c_dim, 1, w_dim, c_dim)
units = 1
for d in dims:
units *= num(d)
return RenderShape("vector", 1, 1, units, 1, 1, units)
# rank 1 or unknown: vector
units = 1
for d in dims:
units *= num(d)
return RenderShape("vector", 1, 1, units, 1, 1, units)
# ---------------------------------------------------------------------------
# Drawing primitives
# ---------------------------------------------------------------------------
class FeatureMapStack:
"""Draw a LeNet-style stack of offset feature-map rectangles."""
def __init__(
self,
x: float,
y: float,
width: int,
height: int,
channels: int,
*,
map_spacing: int,
max_visual_channels: int,
fill: Any,
outline: Any,
line_width: int = 1,
shade_step: int = 6,
) -> None:
self.x = float(x)
self.y = float(y)
self.width = int(width)
self.height = int(height)
self.channels = int(max(1, channels))
self.map_spacing = int(map_spacing)
self.max_visual_channels = int(max(1, max_visual_channels))
self.fill = fill
self.outline = outline
self.line_width = int(max(1, line_width))
self.shade_step = int(max(0, shade_step))
@property
def visible_count(self) -> int:
return min(self.channels, self.max_visual_channels)
@property
def offset(self) -> int:
return (self.visible_count - 1) * self.map_spacing if self.visible_count > 0 else 0
def bounds(self) -> Tuple[float, float, float, float]:
left = self.x - self.offset
top = self.y - self.offset
right = self.x + self.width
bottom = self.y + self.height
return (left, top, right, bottom)
def front_rect(self) -> Tuple[float, float, float, float]:
return (self.x, self.y, self.x + self.width, self.y + self.height)
def front_anchor(self) -> Tuple[float, float]:
x1, y1, x2, y2 = self.front_rect()
return ((x1 + x2) / 2.0, (y1 + y2) / 2.0)
def left_mid(self) -> Tuple[float, float]:
x1, y1, x2, y2 = self.front_rect()
return (x1, (y1 + y2) / 2.0)
def right_mid(self) -> Tuple[float, float]:
x1, y1, x2, y2 = self.front_rect()
return (x2, (y1 + y2) / 2.0)
def draw(self, draw: aggdraw.Draw) -> None:
pen = aggdraw.Pen(get_rgba_tuple(self.outline), self.line_width)
base_fill = get_rgba_tuple(self.fill)
# Draw from back to front.
for i in range(self.visible_count - 1, -1, -1):
ox = i * self.map_spacing
oy = i * self.map_spacing
x1 = self.x - ox
y1 = self.y - oy
x2 = x1 + self.width
y2 = y1 + self.height
if self.shade_step > 0:
brush_color = fade_color(base_fill, (self.visible_count - 1 - i) * self.shade_step)
else:
brush_color = base_fill
brush = aggdraw.Brush(brush_color)
draw.rectangle([x1, y1, x2, y2], pen, brush)
class PyramidConnection:
"""Draw a receptive-field style connection between two spatial stacks."""
def __init__(
self,
src: FeatureMapStack,
dst: FeatureMapStack,
*,
src_patch: Tuple[float, float, float, float],
dst_patch: Tuple[float, float, float, float],
connector_fill: Any,
connector_width: int,
src_patch_fill: Any,
dst_patch_fill: Any,
patch_outline: Any,
draw_patches: bool = True,
polygon_alpha: int = 90,
) -> None:
self.src = src
self.dst = dst
self.src_patch = src_patch
self.dst_patch = dst_patch
self.connector_fill = connector_fill
self.connector_width = int(max(1, connector_width))
self.src_patch_fill = src_patch_fill
self.dst_patch_fill = dst_patch_fill
self.patch_outline = patch_outline
self.draw_patches = bool(draw_patches)
self.polygon_alpha = int(max(0, min(255, polygon_alpha)))
def draw(self, draw: aggdraw.Draw) -> None:
pen = aggdraw.Pen(get_rgba_tuple(self.connector_fill), self.connector_width)
# Wedge polygon between patch edges
sx1, sy1, sx2, sy2 = self.src_patch
dx1, dy1, dx2, dy2 = self.dst_patch
poly = [sx2, sy1, dx1, dy1, dx1, dy2, sx2, sy2]
brush = aggdraw.Brush(_with_alpha(get_rgba_tuple(self.connector_fill), self.polygon_alpha))
draw.polygon(poly, pen, brush)
# Edge lines (helps readability)
draw.line([sx2, sy1, dx1, dy1], pen)
draw.line([sx2, sy2, dx1, dy2], pen)
if self.draw_patches:
ppen = aggdraw.Pen(get_rgba_tuple(self.patch_outline), 1)
pbrush_src = aggdraw.Brush(get_rgba_tuple(self.src_patch_fill))
pbrush_dst = aggdraw.Brush(get_rgba_tuple(self.dst_patch_fill))
draw.rectangle(list(self.src_patch), ppen, pbrush_src)
draw.rectangle(list(self.dst_patch), ppen, pbrush_dst)
class FunnelConnection:
"""Draw a funnel-style connector from a spatial stack to a vector stack."""
def __init__(
self,
src: FeatureMapStack,
dst: FeatureMapStack,
*,
connector_fill: Any,
connector_width: int,
polygon_alpha: int = 70,
) -> None:
self.src = src
self.dst = dst
self.connector_fill = connector_fill
self.connector_width = int(max(1, connector_width))
self.polygon_alpha = int(max(0, min(255, polygon_alpha)))
def draw(self, draw: aggdraw.Draw) -> None:
pen = aggdraw.Pen(get_rgba_tuple(self.connector_fill), self.connector_width)
brush = aggdraw.Brush(_with_alpha(get_rgba_tuple(self.connector_fill), self.polygon_alpha))
sx1, sy1, sx2, sy2 = self.src.front_rect()
dx1, dy1, dx2, dy2 = self.dst.front_rect()
poly = [sx2, sy1, dx1, dy1, dx1, dy2, sx2, sy2]
draw.polygon(poly, pen, brush)
draw.line([sx2, sy1, dx1, dy1], pen)
draw.line([sx2, sy2, dx1, dy2], pen)
class FullConnection:
"""Draw a dense-style connector between two vector-like stacks."""
def __init__(
self,
src: FeatureMapStack,
dst: FeatureMapStack,
*,
connector_fill: Any,
connector_width: int,
polygon_alpha: int = 55,
) -> None:
self.src = src
self.dst = dst
self.connector_fill = connector_fill
self.connector_width = int(max(1, connector_width))
self.polygon_alpha = int(max(0, min(255, polygon_alpha)))
def draw(self, draw: aggdraw.Draw) -> None:
pen = aggdraw.Pen(get_rgba_tuple(self.connector_fill), self.connector_width)
brush = aggdraw.Brush(_with_alpha(get_rgba_tuple(self.connector_fill), self.polygon_alpha))
sx1, sy1, sx2, sy2 = self.src.front_rect()
dx1, dy1, dx2, dy2 = self.dst.front_rect()
# Narrower trapezoid looks cleaner for vectors.
sy_top = sy1 + (sy2 - sy1) * 0.2
sy_bot = sy2 - (sy2 - sy1) * 0.2
dy_top = dy1 + (dy2 - dy1) * 0.2
dy_bot = dy2 - (dy2 - dy1) * 0.2
poly = [sx2, sy_top, dx1, dy_top, dx1, dy_bot, sx2, sy_bot]
draw.polygon(poly, pen, brush)
draw.line([sx2, sy_top, dx1, dy_top], pen)
draw.line([sx2, sy_bot, dx1, dy_bot], pen)
# ---------------------------------------------------------------------------
# Text helpers
# ---------------------------------------------------------------------------
def _get_text_size(draw: ImageDraw.ImageDraw, text: str, font: ImageFont.ImageFont) -> Tuple[int, int]:
"""Robust text sizing for newer and older Pillow versions."""
if hasattr(draw, "textbbox"):
bbox = draw.textbbox((0, 0), text, font=font)
return (bbox[2] - bbox[0], bbox[3] - bbox[1])
if hasattr(font, "getbbox"):
bbox = font.getbbox(text)
return (bbox[2] - bbox[0], bbox[3] - bbox[1])
return draw.textsize(text, font=font)
def _get_multiline_text_size(
draw: ImageDraw.ImageDraw,
text: str,
font: ImageFont.ImageFont,
*,
spacing: int = 2,
) -> Tuple[int, int]:
"""Measure multiline text (\n-delimited) in a Pillow-version-safe way."""
if not text:
return (0, 0)
parts = str(text).splitlines()
widths: list[int] = []
heights: list[int] = []
for part in parts:
w, h = _get_text_size(draw, part, font)
widths.append(int(w))
heights.append(int(h))
total_h = int(sum(heights) + spacing * max(0, len(parts) - 1))
return (int(max(widths) if widths else 0), total_h)
def _default_top_label(layer: Any, rshape: RenderShape) -> str:
return type(layer).__name__
def _default_bottom_label(layer: Any, rshape: RenderShape) -> str:
if rshape.kind == "spatial":
return f"{rshape.c_dim}@{rshape.h_dim}×{rshape.w_dim}"
return f"{rshape.c_dim}"
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
[docs]
def lenet_view(
model: Any,
to_file: Optional[str] = None,
min_xy: int = 20,
max_xy: int = 220,
scale_xy: float = 4.0,
type_ignore: Optional[Sequence[type]] = None,
index_ignore: Optional[Sequence[int]] = None,
color_map: Optional[Mapping[type, Mapping[str, Any]]] = None,
background_fill: Any = "black",
padding: int = 20,
layer_spacing: int = 40,
map_spacing: int = 4,
max_visual_channels: int = 12,
connector_fill: Any = "gray",
connector_width: int = 1,
patch_fill: Any = "#7db7ff",
patch_outline: Any = "black",
patch_scale: float = 1.0,
patch_alpha_on_image: int = 140,
seed: Optional[int] = None,
draw_connections: bool = True,
draw_patches: bool = True,
font: Optional[ImageFont.ImageFont] = None,
font_color: Any = "white",
top_label_callable: Optional[Callable[[Any, RenderShape], Optional[str]]] = None,
bottom_label_callable: Optional[Callable[[Any, RenderShape], Optional[str]]] = None,
top_label: bool = True,
bottom_label: bool = True,
top_label_padding: int = 6,
bottom_label_padding: int = 6,
styles: Optional[Mapping[Union[str, type], Dict[str, Any]]] = None,
*,
options: Union[LenetOptions, Mapping[str, Any], None] = None,
preset: Optional[str] = None,
) -> Image.Image:
"""Render a Keras model using a LeNet-style feature-map diagram.
This renderer emphasizes stacked feature maps and left-to-right progression.
It is especially useful for CNN-focused figures, publication graphics, and
teaching material that should resemble classic LeNet-style architecture
diagrams.
Parameters
----------
model : Any
Keras model instance to visualize.
LeNet view works best for sequential or mostly sequential models where
channel progression and stage-by-stage feature-map flow are the main
story.
to_file : str, optional
Path to save the rendered image. The image format is inferred from the
file extension.
The rendered ``PIL.Image`` is returned whether or not this value is
supplied. Use this when you want to save the figure and keep working
with the in-memory image.
min_xy : int, default=20
Minimum rendered width and height for a feature-map face in pixels.
This prevents small feature maps from becoming visually insignificant in
diagrams that also contain much larger spatial layers.
max_xy : int, default=220
Maximum rendered width and height for a feature-map face in pixels.
Use this to keep very large early feature maps from dominating the
overall composition.
scale_xy : float, default=4.0
Multiplier applied to feature-map width and height before clamping.
This is one of the main controls for the overall apparent scale of the
stacks in the figure.
type_ignore : sequence of type, optional
Layer classes to exclude from rendering.
This is useful for hiding utility layers that add noise to the diagram
without changing the main architectural story.
index_ignore : sequence of int, optional
Layer indices to exclude from rendering.
Use this when you need precise control over individual layers rather
than excluding every instance of a class.
color_map : mapping, optional
Mapping from layer class to broad style values such as ``fill`` and
``outline``.
This is the quickest way to define a consistent color language by layer
type. Use ``styles`` when you need per-layer overrides.
background_fill : Any, default='black'
Background color for the final image.
Dark backgrounds often work well in LeNet mode because they increase
contrast with the stacked feature maps and connectors.
padding : int, default=20
Outer padding around the full diagram in pixels.
Increase this when labels or wide stacks feel too close to the image
boundary.
layer_spacing : int, default=40
Horizontal spacing between successive stacks.
This is the main control for how compact or open the left-to-right flow
feels.
map_spacing : int, default=4
Offset between visible feature maps within a stack.
Larger values emphasize the layered stack effect. Smaller values create
a tighter and more compact look.
max_visual_channels : int, default=12
Maximum number of feature maps to draw for a single layer.
This keeps high-channel layers readable. Additional channels are still
represented conceptually, but they are not all drawn individually.
connector_fill : Any, default='gray'
Color used for connections between stacks.
Neutral connector colors are usually best because the stack fills and
labels already carry most of the semantic styling.
connector_width : int, default=1
Line width used for connections between stacks.
Increase this for large exported figures or diagrams intended for
presentation screens.
patch_fill : Any, default='#7db7ff'
Default fill color used for receptive-field or projection patches.
Patches help explain how one stack maps into the next. Choose a color
that remains visible against both the stack fill and the background.
patch_outline : Any, default='black'
Outline color used for receptive-field or projection patches.
A clear outline helps patches remain visible even when the fill color is
partially transparent.
patch_scale : float, default=1.0
Relative size multiplier applied to connector patches.
Increase this when patches should read more prominently. Reduce it when
they distract from the stacks themselves.
patch_alpha_on_image : int, default=140
Alpha value used when a patch is drawn over an embedded face image.
Lower values let the underlying image remain more visible. Higher values
make the patch read more strongly.
seed : int, optional
Seed used for deterministic placement of randomized patch elements.
Set this when you want repeated renders to remain visually consistent.
draw_connections : bool, default=True
If ``True``, draw connections between successive stacks.
Disable this when you want a cleaner figure that focuses only on the
stacks and labels.
draw_patches : bool, default=True
If ``True``, draw receptive-field and projection patches where
applicable.
Patches are often useful in explanatory figures, but turning them off
can simplify the diagram considerably.
font : PIL.ImageFont.ImageFont, optional
Font used for top and bottom labels.
A custom font is useful when the figure needs to match an existing
document or slide style.
font_color : Any, default='white'
Text color used for labels.
This should contrast clearly with ``background_fill`` and remain legible
against the rendered stacks.
top_label_callable : callable, optional
Callable receiving ``(layer, render_shape)`` and returning the label to
place above a stack.
This is the main hook for custom top annotations such as tensor sizes or
stage identifiers.
bottom_label_callable : callable, optional
Callable receiving ``(layer, render_shape)`` and returning the label to
place below a stack.
This is often used for layer names or layer types while the top label
communicates tensor shape.
top_label : bool, default=True
If ``True``, render top labels.
Disable this when the figure should stay compact or when only one label
position is needed.
bottom_label : bool, default=True
If ``True``, render bottom labels.
This is often paired with ``top_label`` to split different kinds of
information across two lines of annotation.
top_label_padding : int, default=6
Vertical padding between a stack and its top label.
Increase this when multiline labels or large fonts feel cramped.
bottom_label_padding : int, default=6
Vertical padding between a stack and its bottom label.
Increase this when bottom labels collide with other elements or need
more breathing room.
styles : mapping, optional
Fine-grained per-layer style overrides keyed by layer name or layer
class.
Use this for embedded images, per-layer patch settings, stack spacing
overrides, and other local adjustments that are too specific for
``color_map``. Supported keys include ``face_image``,
``face_image_fit``, ``face_image_alpha``, and ``face_image_inset``.
options : LenetOptions or mapping, optional
Configuration bundle applied after ``preset`` and before explicit
keyword arguments.
Use this when you want to reuse a LeNet-style configuration across
multiple models or examples.
preset : str, optional
Name of a preset from ``visualkeras.LENET_PRESETS``. LeNet mode
currently provides ``default``, ``compact``, and ``presentation``.
Presets are intended as convenient starting points rather than fixed
modes. They can be refined further with ``options`` and explicit
overrides.
Returns
-------
PIL.Image.Image
Rendered LeNet-style diagram.
Notes
-----
Configuration precedence is ``preset`` followed by ``options`` followed by
explicit keyword arguments.
Full documentation:
https://visualkeras.readthedocs.io/en/latest/api/lenet_style.html
"""
# --- preset/options resolution (match layered/graph/functional behavior) ---
if preset is not None or options is not None:
defaults = LenetOptions().to_kwargs()
defaults["color_map"] = None
defaults["styles"] = None
defaults["font"] = None
defaults["top_label_callable"] = None # not in LenetOptions but allow compare
defaults["bottom_label_callable"] = None
resolved = dict(defaults)
if preset is not None:
try:
resolved.update(LENET_PRESETS[preset].to_kwargs())
except KeyError as exc:
available = ", ".join(sorted(LENET_PRESETS.keys()))
raise ValueError(
f"Unknown lenet preset '{preset}'. Available presets: {available}"
) from exc
if options is not None:
if isinstance(options, LenetOptions):
option_values = options.to_kwargs()
elif isinstance(options, Mapping):
option_values = dict(options)
else:
raise TypeError("options must be a LenetOptions instance or a mapping of keyword arguments.")
resolved.update(option_values)
explicit_values: Dict[str, Any] = {
"to_file": to_file,
"min_xy": min_xy,
"max_xy": max_xy,
"scale_xy": scale_xy,
"type_ignore": type_ignore,
"index_ignore": index_ignore,
"color_map": color_map,
"background_fill": background_fill,
"padding": padding,
"layer_spacing": layer_spacing,
"map_spacing": map_spacing,
"max_visual_channels": max_visual_channels,
"connector_fill": connector_fill,
"connector_width": connector_width,
"patch_fill": patch_fill,
"patch_outline": patch_outline,
"patch_scale": patch_scale,
"patch_alpha_on_image": patch_alpha_on_image,
"seed": seed,
"draw_connections": draw_connections,
"draw_patches": draw_patches,
"font": font,
"font_color": font_color,
"top_label_padding": top_label_padding,
"bottom_label_padding": bottom_label_padding,
"top_label": top_label,
"bottom_label": bottom_label,
"styles": styles,
}
for key, value in explicit_values.items():
if key not in defaults:
continue
if value != defaults[key]:
resolved[key] = value
# write back
to_file = resolved["to_file"]
min_xy = resolved["min_xy"]
max_xy = resolved["max_xy"]
scale_xy = resolved["scale_xy"]
type_ignore = resolved["type_ignore"]
index_ignore = resolved["index_ignore"]
color_map = resolved["color_map"]
background_fill = resolved["background_fill"]
padding = resolved["padding"]
layer_spacing = resolved["layer_spacing"]
map_spacing = resolved["map_spacing"]
max_visual_channels = resolved["max_visual_channels"]
connector_fill = resolved["connector_fill"]
connector_width = resolved["connector_width"]
patch_fill = resolved["patch_fill"]
patch_outline = resolved["patch_outline"]
patch_scale = resolved["patch_scale"]
patch_alpha_on_image = resolved.get("patch_alpha_on_image", patch_alpha_on_image)
seed = resolved.get("seed", None)
draw_connections = resolved["draw_connections"]
draw_patches = resolved["draw_patches"]
font = resolved["font"]
font_color = resolved["font_color"]
top_label_padding = resolved.get("top_label_padding", 6)
bottom_label_padding = resolved.get("bottom_label_padding", 6)
top_label = resolved["top_label"]
bottom_label = resolved["bottom_label"]
styles = resolved["styles"]
if color_map is None:
color_map = {}
if styles is None:
styles = {}
if top_label_callable is None:
top_label_callable = _default_top_label
if bottom_label_callable is None:
bottom_label_callable = _default_bottom_label
type_ignore = set(type_ignore or [])
index_ignore = set(index_ignore or [])
layers = list(get_layers(model))
# Build renderable layers list
render_layers: list[Tuple[int, Any, RenderShape, Dict[str, Any]]] = []
global_defaults = {
"fill": "#d9d9d9",
"outline": "black",
"line_width": 1,
"shade_step": 6,
# allow per-layer overrides
"map_spacing": map_spacing,
"max_visual_channels": max_visual_channels,
"connector_fill": connector_fill,
"connector_width": connector_width,
"patch_fill": patch_fill,
"patch_outline": patch_outline,
"patch_scale": patch_scale,
"face_image": None,
"face_image_fit": "cover",
"face_image_alpha": 255,
"face_image_inset": None,
"top_label_padding": top_label_padding,
"bottom_label_padding": bottom_label_padding,
}
for idx, layer in enumerate(layers):
if idx in index_ignore:
continue
if type(layer) in type_ignore:
continue
name = getattr(layer, "name", f"layer_{idx}")
legacy_color = color_map.get(type(layer), {})
current_defaults = dict(global_defaults)
current_defaults.update(legacy_color)
style = resolve_style(layer, name, styles, current_defaults)
raw_shape = _resolve_layer_output_shape(layer)
primary = extract_primary_shape(raw_shape, name)
rshape = _canonicalize_shape(layer, primary)
render_layers.append((idx, layer, rshape, style))
if not render_layers:
# empty canvas
img = Image.new("RGBA", (max(1, padding * 2), max(1, padding * 2)), get_rgba_tuple(background_fill))
if to_file:
img.save(to_file)
return img
# First pass: compute stack sizes (front face)
stacks: list[Dict[str, Any]] = []
max_total_h = 0
for idx, layer, rshape, style in render_layers:
w_px = _clamp_int(rshape.w * float(scale_xy), min_xy, max_xy)
h_px = _clamp_int(rshape.h * float(scale_xy), min_xy, max_xy)
# vectors: keep a pleasant aspect; small square works best
if rshape.kind == "vector":
w_px = min(w_px, max_xy // 2) if max_xy > 0 else w_px
h_px = min(h_px, max_xy // 2) if max_xy > 0 else h_px
w_px = max(min_xy, min(w_px, max_xy))
h_px = max(min_xy, min(h_px, max_xy))
# Optional: texture the front face with an image.
# If fit == 'match_aspect', adjust the face dimensions to match the image aspect ratio.
spec, fit_mode, _, _inset = _parse_face_image_style(style)
if spec:
face_img = style.get("_face_image_obj", None)
if not isinstance(face_img, Image.Image):
face_img = _load_face_image(spec)
if face_img is not None:
style["_face_image_obj"] = face_img
if isinstance(face_img, Image.Image) and fit_mode == "match_aspect":
w_px, h_px = _adjust_wh_for_image_aspect(w_px, h_px, face_img, min_xy=min_xy, max_xy=max_xy)
ms = int(style.get("map_spacing", map_spacing))
mvc = int(style.get("max_visual_channels", max_visual_channels))
temp_stack = FeatureMapStack(
0, 0, w_px, h_px, rshape.c,
map_spacing=ms,
max_visual_channels=mvc,
fill=style.get("fill", global_defaults["fill"]),
outline=style.get("outline", global_defaults["outline"]),
line_width=int(style.get("line_width", 1)),
shade_step=int(style.get("shade_step", 6)),
)
_, top, _, bottom = temp_stack.bounds()
total_h = bottom - top
max_total_h = max(max_total_h, int(math.ceil(total_h)))
stacks.append({
"layer": layer,
"layer_index": idx,
"rshape": rshape,
"style": style,
"stack": temp_stack,
"has_face_image": bool(spec),
})
# Second pass: assign positions using bounding boxes
cursor = float(padding)
global_top = float(padding)
for obj in stacks:
stack: FeatureMapStack = obj["stack"]
total_w = (stack.width + stack.offset)
total_h = (stack.height + stack.offset)
# Place so that the bounding-left starts at cursor
stack.x = cursor + stack.offset
# Center vertically: set bounding-top = global_top + (max_total_h - total_h)/2
stack.y = global_top + (max_total_h - total_h) / 2.0 + stack.offset
cursor += total_w + float(layer_spacing)
# Connections are built after final shifting (so patches align to faces).
connections: list[Any] = []
# Determine image bounds (include optional caption text)
min_x = float('inf')
min_y = float('inf')
max_x = float('-inf')
max_y = float('-inf')
# Resolve a font early so we can size/allocate space for captions.
if (top_label or bottom_label) and font is None:
try:
font = ImageFont.load_default()
except Exception: # noqa: BLE001
font = None
measure_draw = None
if font is not None and (top_label or bottom_label):
_dummy = Image.new('RGB', (10, 10))
measure_draw = ImageDraw.Draw(_dummy)
for obj in stacks:
st: FeatureMapStack = obj['stack']
l, t, r, b = st.bounds()
min_x = min(min_x, l)
min_y = min(min_y, t)
max_x = max(max_x, r)
max_y = max(max_y, b)
if measure_draw is None:
continue
layer = obj['layer']
rshape: RenderShape = obj['rshape']
b_left, _b_top, b_right, _b_bottom = st.bounds()
cx = (b_left + b_right) / 2.0
x1, y1, x2, y2 = st.front_rect()
top_pad = int(obj.get("style", {}).get("top_label_padding", top_label_padding))
bottom_pad = int(obj.get("style", {}).get("bottom_label_padding", bottom_label_padding))
if bottom_label:
text = bottom_label_callable(layer, rshape)
if text:
tw, th = _get_multiline_text_size(measure_draw, str(text), font)
tx0 = cx - tw / 2.0
ty0 = y2 + bottom_pad
tx1 = tx0 + tw
ty1 = ty0 + th
min_x = min(min_x, tx0)
min_y = min(min_y, ty0)
max_x = max(max_x, tx1)
max_y = max(max_y, ty1)
if top_label:
text = top_label_callable(layer, rshape)
if text:
tw, th = _get_multiline_text_size(measure_draw, str(text), font)
vis_top = y1 - st.offset
tx0 = cx - tw / 2.0
ty0 = vis_top - th - top_pad
tx1 = tx0 + tw
ty1 = ty0 + th
min_x = min(min_x, tx0)
min_y = min(min_y, ty0)
max_x = max(max_x, tx1)
max_y = max(max_y, ty1)
# Pad (additional safety, but avoid negative)
img_w = int(max(1, math.ceil(max_x - min_x + padding)))
img_h = int(max(1, math.ceil(max_y - min_y + padding)))
# Shift everything if needed so min coords are inside padding/2
shift_x = float(padding) / 2.0 - min_x
shift_y = float(padding) / 2.0 - min_y
for obj in stacks:
st: FeatureMapStack = obj['stack']
st.x += shift_x
st.y += shift_y
# Pre-sample per-layer patch placement ratios so boxes appear randomly placed on faces,
# while keeping outgoing boxes on a layer to the right of incoming boxes.
patch_ratios_in: Dict[int, Tuple[float, float]] = {}
patch_ratios_out: Dict[int, Tuple[float, float]] = {}
for li, obj in enumerate(stacks):
lname = getattr(obj.get("layer"), "name", f"layer_{li}")
if li > 0:
rng_in = random.Random(_stable_seed(seed, "in", li, lname))
patch_ratios_in[li] = _sample_patch_ratios(rng_in, x_lo=0.06, x_hi=0.44)
if li < len(stacks) - 1:
rng_out = random.Random(_stable_seed(seed, "out", li, lname))
patch_ratios_out[li] = _sample_patch_ratios(rng_out, x_lo=0.56, x_hi=0.94)
# Build connections between consecutive rendered layers (after shift)
connections = []
if draw_connections and len(stacks) >= 2:
# First pass: compute edge definitions (patch rectangles can be post-processed per-layer).
edge_defs: list[Dict[str, Any]] = []
for i in range(len(stacks) - 1):
src_obj = stacks[i]
dst_obj = stacks[i + 1]
src_stack: FeatureMapStack = src_obj['stack']
dst_stack: FeatureMapStack = dst_obj['stack']
src_shape: RenderShape = src_obj['rshape']
dst_shape: RenderShape = dst_obj['rshape']
# Use destination op params (receptive field / pooling window).
k = getattr(dst_obj['layer'], 'kernel_size', None)
p = getattr(dst_obj['layer'], 'pool_size', None)
kernel = _as_tuple2(k if k is not None else p)
# Style from destination layer (more intuitive).
dst_style = dst_obj['style']
conn_fill = dst_style.get('connector_fill', connector_fill)
conn_w = int(dst_style.get('connector_width', connector_width))
pfill_base = dst_style.get('patch_fill', patch_fill)
pout = dst_style.get('patch_outline', patch_outline)
pscale = float(dst_style.get('patch_scale', patch_scale))
# Patch box opacity: by default make patches semi-transparent on layers with face images.
base_patch_rgba = get_rgba_tuple(pfill_base)
src_alpha = _effective_patch_alpha_for_layer(
src_obj.get('style', {}),
has_face_image=bool(src_obj.get('has_face_image', False)),
base_alpha=base_patch_rgba[3],
default_on_image=int(patch_alpha_on_image),
)
dst_alpha = _effective_patch_alpha_for_layer(
dst_style,
has_face_image=bool(dst_obj.get('has_face_image', False)),
base_alpha=base_patch_rgba[3],
default_on_image=int(patch_alpha_on_image),
)
src_patch_fill = _with_alpha(base_patch_rgba, src_alpha)
dst_patch_fill = _with_alpha(base_patch_rgba, dst_alpha)
if src_shape.kind == 'spatial' and dst_shape.kind == 'spatial' and (k is not None or p is not None):
sx1, sy1, sx2, sy2 = src_stack.front_rect()
dx1, dy1, dx2, dy2 = dst_stack.front_rect()
# src patch size based on kernel/pool ratio
src_w_dim = max(1, int(src_shape.w_dim))
src_h_dim = max(1, int(src_shape.h_dim))
kh, kw = kernel
patch_w = max(4, int((src_stack.width * (kw / src_w_dim)) * pscale))
patch_h = max(4, int((src_stack.height * (kh / src_h_dim)) * pscale))
patch_w = min(patch_w, int(src_stack.width * 0.6))
patch_h = min(patch_h, int(src_stack.height * 0.6))
# Place patch randomly on the source front face (outgoing patch is biased to the right).
rx, ry = patch_ratios_out.get(i, (0.80, 0.50))
sp = _place_patch_by_ratio((sx1, sy1, sx2, sy2), patch_w, patch_h, rx, ry, margin=1.0)
# Destination activation patch: small, randomly placed (incoming patch is biased to the left).
dsz = max(3, int(min(dst_stack.width, dst_stack.height) * 0.22))
rx2, ry2 = patch_ratios_in.get(i + 1, (0.20, 0.50))
dp = _place_patch_by_ratio((dx1, dy1, dx2, dy2), float(dsz), float(dsz), rx2, ry2, margin=1.0)
edge_defs.append(
dict(
type='pyramid',
src=src_stack,
dst=dst_stack,
src_patch=sp,
dst_patch=dp,
connector_fill=conn_fill,
connector_width=conn_w,
src_patch_fill=src_patch_fill,
dst_patch_fill=dst_patch_fill,
patch_outline=pout,
draw_patches=draw_patches,
)
)
elif src_shape.kind == 'spatial' and dst_shape.kind == 'vector':
edge_defs.append(
dict(
type='funnel',
src=src_stack,
dst=dst_stack,
connector_fill=conn_fill,
connector_width=conn_w,
)
)
else:
edge_defs.append(
dict(
type='full',
src=src_stack,
dst=dst_stack,
connector_fill=conn_fill,
connector_width=conn_w,
)
)
# Second pass: if a layer has both an incoming and outgoing patch, enforce ordering + separation.
# Incoming patch lives on layer j as dst_patch of edge (j-1). Outgoing patch lives on layer j as src_patch of edge j.
for j in range(1, len(stacks) - 1):
left_edge = edge_defs[j - 1]
right_edge = edge_defs[j]
if left_edge.get('type') == 'pyramid' and right_edge.get('type') == 'pyramid':
layer_obj = stacks[j]
lname = getattr(layer_obj.get('layer'), 'name', f'layer_{j}')
rng_sep = random.Random(_stable_seed(seed, 'sep', j, lname))
face = layer_obj['stack'].front_rect()
inc = left_edge['dst_patch']
out = right_edge['src_patch']
inc2, out2 = _enforce_in_out_patch_separation(
face,
inc,
out,
rng=rng_sep,
x_eps=1.0,
v_gap=2.0,
margin=1.0,
)
left_edge['dst_patch'] = inc2
right_edge['src_patch'] = out2
# Final: instantiate connection objects.
for ed in edge_defs:
if ed['type'] == 'pyramid':
connections.append(
PyramidConnection(
ed['src'],
ed['dst'],
src_patch=ed['src_patch'],
dst_patch=ed['dst_patch'],
connector_fill=ed['connector_fill'],
connector_width=ed['connector_width'],
src_patch_fill=ed['src_patch_fill'],
dst_patch_fill=ed['dst_patch_fill'],
patch_outline=ed['patch_outline'],
draw_patches=ed.get('draw_patches', True),
)
)
elif ed['type'] == 'funnel':
connections.append(
FunnelConnection(
ed['src'],
ed['dst'],
connector_fill=ed['connector_fill'],
connector_width=ed['connector_width'],
)
)
else:
connections.append(
FullConnection(
ed['src'],
ed['dst'],
connector_fill=ed['connector_fill'],
connector_width=ed['connector_width'],
)
)
# Create canvas and draw
img = Image.new("RGBA", (img_w, img_h), get_rgba_tuple(background_fill))
# Pass 1: draw stacks (geometry + outlines)
draw = aggdraw.Draw(img)
for obj in stacks:
obj["stack"].draw(draw)
draw.flush()
# Pass 1.5: paste optional face images onto each stack's front face
_apply_face_images(img, stacks)
# Pass 2: draw connections (polygons + patch boxes) over the faces
draw = aggdraw.Draw(img)
for conn in connections:
conn.draw(draw)
draw.flush()
# Text pass
if font is None:
try:
font = ImageFont.load_default()
except Exception: # noqa: BLE001
font = None
if font is not None and (top_label or bottom_label):
dtext = ImageDraw.Draw(img)
for obj in stacks:
layer = obj["layer"]
rshape = obj["rshape"]
stack: FeatureMapStack = obj["stack"]
x1, y1, x2, y2 = stack.front_rect()
b_left, _b_top, b_right, _b_bottom = stack.bounds()
cx = (b_left + b_right) / 2.0
style = obj.get("style", {}) or {}
top_pad = int(style.get("top_label_padding", top_label_padding))
bottom_pad = int(style.get("bottom_label_padding", bottom_label_padding))
if bottom_label:
text = bottom_label_callable(layer, rshape)
if text:
tw, th = _get_multiline_text_size(dtext, str(text), font)
dtext.multiline_text((cx - tw / 2.0, y2 + bottom_pad), str(text), font=font, fill=font_color, spacing=2, align='center')
if top_label:
text = top_label_callable(layer, rshape)
if text:
# Account for stack offset: top of visible stack is y1 - offset
vis_top = y1 - stack.offset
tw, th = _get_multiline_text_size(dtext, str(text), font)
dtext.multiline_text((cx - tw / 2.0, vis_top - th - top_pad), str(text), font=font, fill=font_color, spacing=2, align='center')
if to_file:
img.save(to_file)
return img