from typing import Any, Callable, Mapping, Optional, Union, List, Dict, Sequence, Tuple
import aggdraw
from PIL import ImageFont
from math import ceil
from .utils import *
from .layer_utils import *
from .options import LayeredOptions, LAYERED_PRESETS, LAYERED_TEXT_CALLABLES
import warnings
try:
from tensorflow.keras import layers
except:
try:
from tensorflow.python.keras import layers
except:
try:
from keras import layers
except:
class _LayerNamespace:
class Layer:
pass
layers = _LayerNamespace()
_BUILT_IN_TEXT_CALLABLES = tuple(LAYERED_TEXT_CALLABLES.values())
def _resolve_layer_output_shape(layer) -> Any:
"""
Attempt to retrieve a layer's output shape across keras/tensorflow versions.
Prefers an explicit ``output_shape`` attribute, falls back to the tensor's
shape, and finally tries ``compute_output_shape`` when available.
"""
shape = getattr(layer, "output_shape", None)
if shape is not None:
return _shape_to_tuple(shape)
output = getattr(layer, "output", None)
tensor_shape = getattr(output, "shape", None)
if tensor_shape is not None:
return _shape_to_tuple(tensor_shape)
compute_output_shape = getattr(layer, "compute_output_shape", None)
if callable(compute_output_shape):
input_shape = getattr(layer, "input_shape", None)
if input_shape is not None:
try:
return _shape_to_tuple(compute_output_shape(input_shape))
except Exception: # noqa: BLE001
pass
return None
def _shape_to_tuple(shape: Any) -> Any:
if shape is None:
return None
if isinstance(shape, tuple):
return shape
if hasattr(shape, "as_list"):
try:
return tuple(shape.as_list())
except Exception: # noqa: BLE001
return tuple(shape)
if isinstance(shape, list):
return tuple(shape)
return shape
def _get_group_boxes(boxes: List[Box], group: Dict[str, Any]) -> List[Box]:
layers_ref = group.get("layers", [])
if not layers_ref:
return []
group_boxes = []
for box in boxes:
if not hasattr(box, 'layer'):
continue
layer = box.layer
# Check if node matches any layer in the group
for layer_ref in layers_ref:
if layer is layer_ref:
group_boxes.append(box)
break
# Check name match
layer_name = getattr(layer, 'name', '')
if isinstance(layer_ref, str) and (layer_name == layer_ref):
group_boxes.append(box)
break
return group_boxes
def _get_logo_boxes(boxes: List[Box], group: Dict[str, Any]) -> List[Box]:
layers_ref = group.get("layers", [])
if not layers_ref:
return []
target_boxes = []
# Build lookup maps
name_to_boxes = {}
type_to_boxes = {}
for box in boxes:
if not hasattr(box, 'layer'): continue
layer_name = getattr(box.layer, 'name', None)
if layer_name:
if layer_name not in name_to_boxes:
name_to_boxes[layer_name] = []
name_to_boxes[layer_name].append(box)
layer_type = type(box.layer)
if layer_type not in type_to_boxes:
type_to_boxes[layer_type] = []
type_to_boxes[layer_type].append(box)
for ref in layers_ref:
if isinstance(ref, str):
if ref in name_to_boxes:
target_boxes.extend(name_to_boxes[ref])
elif isinstance(ref, type):
if ref in type_to_boxes:
target_boxes.extend(type_to_boxes[ref])
return target_boxes
def _draw_layered_group_boxes(draw, boxes, groups, draw_reversed):
for group in groups:
group_boxes = _get_group_boxes(boxes, group)
if not group_boxes: continue
min_x = float('inf')
max_x = float('-inf')
min_y = float('inf')
max_y = float('-inf')
for box in group_boxes:
if draw_reversed:
min_x = min(min_x, box.x1 - box.de)
max_x = max(max_x, box.x2)
min_y = min(min_y, box.y1 - box.de)
max_y = max(max_y, box.y2)
else:
min_x = min(min_x, box.x1)
max_x = max(max_x, box.x2 + box.de)
min_y = min(min_y, box.y1 - box.de)
max_y = max(max_y, box.y2)
padding = group.get("padding", 10)
min_x -= padding
max_x += padding
min_y -= padding
max_y += padding
fill = group.get("fill", (200, 200, 200, 100))
outline = group.get("outline", "black")
width = group.get("width", 1)
pen = aggdraw.Pen(get_rgba_tuple(outline), width)
brush = aggdraw.Brush(get_rgba_tuple(fill))
draw.rectangle([min_x, min_y, max_x, max_y], pen, brush)
def _draw_layered_group_captions(img, boxes, groups, draw_reversed):
draw = ImageDraw.Draw(img)
for group in groups:
caption = group.get("name", group.get("caption"))
if not caption: continue
group_boxes = _get_group_boxes(boxes, group)
if not group_boxes: continue
min_x = float('inf')
max_x = float('-inf')
min_y = float('inf')
max_y = float('-inf')
for box in group_boxes:
if draw_reversed:
min_x = min(min_x, box.x1 - box.de)
max_x = max(max_x, box.x2)
min_y = min(min_y, box.y1 - box.de)
max_y = max(max_y, box.y2)
else:
min_x = min(min_x, box.x1)
max_x = max(max_x, box.x2 + box.de)
min_y = min(min_y, box.y1 - box.de)
max_y = max(max_y, box.y2)
padding = group.get("padding", 10)
min_x -= padding
max_x += padding
min_y -= padding
max_y += padding
font = _get_font(group)
color = group.get("font_color", "black")
gap = group.get("text_spacing", 5)
text_w, text_h = _measure_text(draw, caption, font)
center_x = (min_x + max_x) / 2
text_x = center_x - text_w / 2
text_y = max_y + gap
draw.text((text_x, text_y), caption, fill=color, font=font)
def _get_font(group: Dict[str, Any]) -> ImageFont.ImageFont:
font_src = group.get("font", None)
font_size = group.get("font_size", 15)
if font_src is None:
try:
return ImageFont.truetype("arial.ttf", font_size)
except IOError:
return ImageFont.load_default()
elif isinstance(font_src, str):
try:
return ImageFont.truetype(font_src, font_size)
except IOError:
return ImageFont.load_default()
elif isinstance(font_src, ImageFont.ImageFont):
return font_src
else:
return ImageFont.load_default()
def _measure_text(draw: ImageDraw.ImageDraw, text: str, font: Any) -> Tuple[int, int]:
if hasattr(font, "getbbox"):
left, top, right, bottom = font.getbbox(text)
return right - left, bottom - top
else:
return draw.textsize(text, font=font)
[docs]
def layered_view(model,
to_file: str = None,
min_z: int = 20,
min_xy: int = 20,
max_z: int = 400,
max_xy: int = 2000,
scale_z: float = 1.5,
scale_xy: float = 4,
type_ignore: list = None,
index_ignore: list = None,
color_map: dict = None,
one_dim_orientation: str = 'z',
index_2D: list = [],
background_fill: Any = 'white',
draw_volume: bool = True,
draw_reversed: bool = False,
padding: int = 10,
text_callable: Callable[[int, layers.Layer], tuple] = None,
text_vspacing: int = 4,
spacing: int = 10,
draw_funnel: bool = True,
shade_step=10,
legend: bool = False,
legend_text_spacing_offset = 15,
font: ImageFont = None,
font_color: Any = 'black',
show_dimension=False,
sizing_mode: str = 'accurate',
dimension_caps: dict = None,
relative_base_size: int = 20,
connector_fill: Any = 'gray',
connector_width: int = 1,
image_fit: str = "fill",
image_axis: str = "z",
layered_groups: Optional[Sequence[Dict[str, Any]]] = None,
logo_groups: Optional[Sequence[Dict[str, Any]]] = None,
logos_legend: Union[bool, Dict[str, Any]] = False,
styles: Optional[Mapping[Union[str, type], Dict[str, Any]]] = None,
*,
options: Union[LayeredOptions, Mapping[str, Any], None] = None,
preset: Union[str, None] = None) -> Image:
"""Render a Keras model as a layered architecture diagram.
This renderer is best suited to sequential or effectively linear models
where layer order and tensor shape progression are the main story.
Parameters
----------
model : Any
Keras model instance to visualize.
Layered view works best when the model can be understood as a left to
right sequence of transformations. It is usually the clearest choice
for CNN style architectures and other models where tensor size changes
are part of the explanation.
to_file : str, optional
Path to save the rendered image. The image format is inferred from the
file extension.
The rendered ``PIL.Image`` is returned whether or not this value is
provided. Use this when you want the renderer to both write an image to
disk and keep the in-memory result for further processing.
min_z : int, default=20
Minimum rendered depth in pixels for a layer box.
This lower bound is applied after scaling. It prevents layers with very
small channel counts from collapsing into thin slivers that are hard to
see or compare.
min_xy : int, default=20
Minimum rendered width and height in pixels for a layer box.
This is especially useful when a model mixes small tensors with much
larger ones. A reasonable minimum keeps every layer visible without
letting a few large layers define the entire layout.
max_z : int, default=400
Maximum rendered depth in pixels for a layer box.
Use this to keep channel-heavy layers from dominating the visual depth
of the diagram. It is most useful for deep convolutional stacks where
late layers would otherwise become excessively thick.
max_xy : int, default=2000
Maximum rendered width and height in pixels for a layer box.
This cap protects the layout from becoming impractically large when the
model contains very large spatial dimensions or long sequences. It acts
as a safety rail after scaling has been applied.
scale_z : float, default=1.5
Multiplier applied to the depth dimension before clamping.
Increase this value when channel depth should read more strongly in the
diagram. Reduce it when depth cues feel exaggerated or when channel rich
layers overshadow the rest of the architecture.
scale_xy : float, default=4
Multiplier applied to the width and height dimensions before clamping.
This is one of the main controls for the overall apparent size of the
rendered layers. Lower values usually make crowded diagrams easier to
fit, while higher values make individual layers easier to inspect.
type_ignore : list, optional
Sequence of layer classes to exclude from rendering.
This is the simplest way to hide utility or low-information layers such
as dropout, padding, or normalization layers without modifying the
model itself. Every instance of a matching class is skipped.
index_ignore : list, optional
Sequence of layer indices to exclude from rendering.
Use this when you want precise control over individual layers rather
than entire layer types. Indices refer to positions in the model's
layer list before rendering-time filtering is applied.
color_map : dict, optional
Mapping from layer class to style values such as ``fill`` and
``outline``.
This provides broad styling by layer type and is the quickest way to
create a consistent color language across the diagram. It is best suited
to coarse styling rules, while ``styles`` is better for fine-grained
per-layer overrides.
one_dim_orientation : {'x', 'y', 'z'}, default='z'
Axis used when rendering one dimensional layers such as dense or
flattened outputs.
Dense and flattened layers do not naturally have both width and height,
so this setting controls how they are represented visually. Changing it
can make mixed CNN and MLP models much easier to read.
index_2D : list, optional
Layer indices that should be forced into flat 2D rendering even when
``draw_volume`` is enabled.
This is useful when most of the model benefits from 3D boxes but a few
layers read better as flat blocks. Common cases include classifier heads
and summary style terminal layers.
background_fill : Any, default='white'
Background color for the final image.
This value accepts any Pillow-compatible color form, including named
colors and RGBA tuples. Darker backgrounds often pair well with bright
fills, while neutral backgrounds keep the focus on shape and layout.
draw_volume : bool, default=True
If ``True``, render boxes with 3D depth cues. If ``False``, render flat
2D rectangles.
The volumetric mode is usually the signature layered-view look. Flat
mode is simpler and often preferable for documentation, compact figures,
or models where depth would add noise rather than clarity.
draw_reversed : bool, default=False
Reverse the 3D viewing direction when ``draw_volume`` is enabled.
This changes which faces of a 3D layer box are visible. It can be
helpful for decoder style networks or for diagrams where the default
perspective makes the flow feel visually backward.
padding : int, default=10
Outer padding around the full diagram in pixels.
Increase this when legends, labels, or grouped overlays are too close
to the image boundary. Padding affects the whole canvas rather than
spacing between individual layers.
text_callable : callable, optional
Callable receiving ``(layer_index, layer)`` and returning
``(text, above)`` to annotate a layer. Built-in helpers are available in
``visualkeras.options.LAYERED_TEXT_CALLABLES``.
This is the main hook for custom per-layer text. Use it when you want
labels such as layer names, tensor shapes, block roles, or any other
model-specific notes placed above or below each rendered box.
text_vspacing : int, default=4
Vertical spacing between lines produced by ``text_callable``.
Increase this for multiline labels that feel cramped. Smaller values
help conserve vertical space when you are already using generous
padding or wide layer spacing.
spacing : int, default=10
Horizontal spacing between consecutive rendered layers.
This is the main control for how tightly packed the diagram feels.
Increasing it can make grouped stages and labels easier to follow, while
smaller values produce more compact figures.
draw_funnel : bool, default=True
If ``True``, draw tapered transitions between consecutive layers.
Funnels emphasize size changes between adjacent layers. They can be
visually helpful in CNN diagrams, but turning them off produces a
cleaner and more schematic look.
shade_step : int, default=10
Amount of shading variation used for 3D faces.
Larger values create stronger contrast between faces and make the depth
effect more pronounced. Smaller values produce a flatter and more subtle
appearance.
legend : bool, default=False
If ``True``, add a legend describing rendered layer types and colors.
A legend is useful when the diagram uses custom colors or contains many
repeated layer classes. For small internal diagrams it may be unnecessary,
but for external readers it often improves readability.
legend_text_spacing_offset : int, default=15
Extra width reserved for legend labels.
Increase this when legend text is clipped or when long layer names need
more room. This setting affects legend layout only.
font : PIL.ImageFont.ImageFont, optional
Font used for legend and annotation text. If omitted, the default PIL
font is used.
Use a custom font when you need the figure to match a publication or
presentation style. The font choice can have a noticeable effect on the
final layout, especially when legends or custom labels are enabled.
font_color : Any, default='black'
Text color used for legends and annotations.
This should contrast clearly with ``background_fill`` and any other
styling applied to the figure. In practice, it is often adjusted together
with ``font`` and ``legend`` settings.
show_dimension : bool, default=False
If ``True`` and ``legend`` is enabled, include output dimensions in the
legend entries.
This adds shape information to the legend without requiring custom text
on every layer. It is useful when you want to preserve a clean diagram
while still exposing the underlying tensor sizes.
sizing_mode : {'accurate', 'balanced', 'capped', 'logarithmic', 'relative'}, default='accurate'
Strategy used to convert tensor dimensions into rendered sizes.
``accurate`` stays closest to the underlying tensor dimensions after
scaling and clamping. ``balanced`` reduces extreme differences so the
whole figure remains readable. ``capped`` respects ``dimension_caps`` to
constrain large dimensions. ``logarithmic`` compresses very large ranges.
``relative`` scales layers directly from their dimensions using
``relative_base_size``.
dimension_caps : dict, optional
Custom caps used by ``capped`` mode. Supported keys are ``channels``,
``sequence``, and ``general``.
This is useful when a small number of very large layers distort the
overall layout. By capping specific dimension groups, you can keep the
diagram readable while still preserving meaningful differences.
relative_base_size : int, default=20
Base pixel unit used by ``relative`` sizing mode.
In ``relative`` mode, visual size is driven directly by the actual
tensor dimensions. This value defines the pixel size associated with a
dimension of one and therefore controls the overall scale of the figure.
connector_fill : Any, default='gray'
Color used for connector and transition elements.
This should usually complement rather than compete with the layer fills.
Neutral colors tend to work best when the boxes themselves already carry
most of the semantic styling.
connector_width : int, default=1
Line width used for connector and transition elements.
Increase this for presentation sized figures or when connectors are hard
to distinguish at your chosen output resolution.
image_fit : str, default='fill'
Default fit mode for images injected through ``styles``. Individual
style entries can override this.
This controls how images are resized within the layer face they occupy.
Choose a mode based on whether you prefer full coverage, preserved aspect
ratio, or exact fill behavior.
image_axis : {'x', 'y', 'z'}, default='z'
Default axis used when rendering per-layer images in 3D mode.
This matters when images are applied to volumetric boxes. It determines
which face or orientation should be treated as the primary image plane
unless a per-layer override is supplied through ``styles``.
layered_groups : sequence of dict, optional
Group definitions used to draw labeled background regions behind sets of
layers.
Groups are useful for separating architectural stages such as feature
extraction, bottlenecks, and classifier heads. They add visual structure
without changing the rendered layers themselves.
logo_groups : sequence of dict, optional
Logo placement definitions used to add icons or other overlays to
selected layers.
This is mainly intended for annotated presentation graphics and other
highly styled diagrams. It is less common than ``layered_groups`` but can
be useful when you want compact visual markers on specific layers.
logos_legend : bool or dict, default=False
If truthy, render a legend describing entries supplied through
``logo_groups``.
A simple boolean enables the default legend behavior. A mapping allows
more control over how the legend is rendered when logo overlays are part
of the figure.
styles : mapping, optional
Fine-grained per-layer style overrides keyed by layer name or layer
class.
This is the most flexible styling mechanism in layered mode. Use it for
per-layer images, detailed text and outline overrides, or any case where
``color_map`` is too coarse. See the layered API reference for supported
style keys and examples.
options : LayeredOptions or mapping, optional
Configuration bundle applied after ``preset`` and before explicit keyword
arguments.
This is the preferred way to reuse a consistent layered style across
multiple models. It also keeps larger examples and application code much
easier to read than passing many keyword arguments inline.
preset : str, optional
Name of a preset from ``visualkeras.LAYERED_PRESETS``. Layered mode
currently provides ``default``, ``compact``, and ``presentation``.
Presets are useful starting points rather than strict modes. They can be
combined with ``options`` and explicit overrides when you want the
convenience of a predefined style without giving up control.
Returns
-------
PIL.Image.Image
Rendered layered diagram.
Notes
-----
Configuration precedence is ``preset`` followed by ``options`` followed by
explicit keyword arguments.
Full documentation:
https://visualkeras.readthedocs.io/en/latest/api/layered.html
"""
using_presets = options is not None or preset is not None
if not using_presets:
defaults = LayeredOptions().to_kwargs()
defaults.update({
"to_file": None,
"type_ignore": None,
"index_ignore": None,
"color_map": None,
"one_dim_orientation": 'z',
"index_2D": [],
"background_fill": 'white',
"draw_volume": True,
"draw_reversed": False,
"padding": 10,
"text_callable": None,
"text_vspacing": 4,
"spacing": 10,
"draw_funnel": True,
"shade_step": 10,
"legend": False,
"legend_text_spacing_offset": 15,
"font": None,
"font_color": 'black',
"show_dimension": False,
"sizing_mode": 'accurate',
"dimension_caps": None,
"relative_base_size": 20,
"connector_fill": "gray",
"connector_width": 1,
"image_fit": "fill",
"image_axis": "z",
"layered_groups": None,
"styles": None,
})
current_params = {
"to_file": to_file,
"min_z": min_z,
"min_xy": min_xy,
"max_z": max_z,
"max_xy": max_xy,
"scale_z": scale_z,
"scale_xy": scale_xy,
"type_ignore": type_ignore,
"index_ignore": index_ignore,
"color_map": color_map,
"one_dim_orientation": one_dim_orientation,
"index_2D": index_2D,
"background_fill": background_fill,
"draw_volume": draw_volume,
"draw_reversed": draw_reversed,
"padding": padding,
"text_callable": text_callable,
"text_vspacing": text_vspacing,
"spacing": spacing,
"draw_funnel": draw_funnel,
"shade_step": shade_step,
"legend": legend,
"legend_text_spacing_offset": legend_text_spacing_offset,
"font": font,
"font_color": font_color,
"show_dimension": show_dimension,
"sizing_mode": sizing_mode,
"dimension_caps": dimension_caps,
"relative_base_size": relative_base_size,
"connector_fill": connector_fill,
"connector_width": connector_width,
"image_fit": image_fit,
"image_axis": image_axis,
"layered_groups": layered_groups,
"styles": styles,
}
custom_keys = [
key for key, value in current_params.items()
if key in defaults and value != defaults[key]
]
if len(custom_keys) >= 5:
warnings.warn(
"layered_view received many custom keyword arguments. "
"Consider using visualkeras.show(..., preset=...) for a simpler workflow.",
UserWarning,
stacklevel=2,
)
if preset is not None or options is not None:
defaults = LayeredOptions().to_kwargs()
defaults["type_ignore"] = None
defaults["index_ignore"] = None
defaults["color_map"] = None
defaults["text_callable"] = None
defaults["dimension_caps"] = None
defaults["font"] = None
defaults["index_2D"] = []
defaults["layered_groups"] = None
defaults["styles"] = None
resolved = dict(defaults)
if preset is not None:
try:
resolved.update(LAYERED_PRESETS[preset].to_kwargs())
except KeyError as exc:
available = ", ".join(sorted(LAYERED_PRESETS.keys()))
raise ValueError(
f"Unknown layered preset '{preset}'. Available presets: {available}"
) from exc
if options is not None:
if isinstance(options, LayeredOptions):
option_values = options.to_kwargs()
elif isinstance(options, Mapping):
option_values = dict(options)
else:
raise TypeError(
"options must be a LayeredOptions instance or a mapping of keyword arguments."
)
resolved.update(option_values)
explicit_values = {
"to_file": to_file,
"min_z": min_z,
"min_xy": min_xy,
"max_z": max_z,
"max_xy": max_xy,
"scale_z": scale_z,
"scale_xy": scale_xy,
"type_ignore": type_ignore,
"index_ignore": index_ignore,
"color_map": color_map,
"one_dim_orientation": one_dim_orientation,
"index_2D": index_2D,
"background_fill": background_fill,
"draw_volume": draw_volume,
"draw_reversed": draw_reversed,
"padding": padding,
"text_callable": text_callable,
"text_vspacing": text_vspacing,
"spacing": spacing,
"draw_funnel": draw_funnel,
"shade_step": shade_step,
"legend": legend,
"legend_text_spacing_offset": legend_text_spacing_offset,
"font": font,
"font_color": font_color,
"show_dimension": show_dimension,
"sizing_mode": sizing_mode,
"dimension_caps": dimension_caps,
"relative_base_size": relative_base_size,
"connector_fill": connector_fill,
"connector_width": connector_width,
"layered_groups": layered_groups,
"styles": styles,
}
for key, value in explicit_values.items():
if key not in defaults:
continue
if value != defaults[key]:
resolved[key] = value
to_file = resolved["to_file"]
min_z = resolved["min_z"]
min_xy = resolved["min_xy"]
max_z = resolved["max_z"]
max_xy = resolved["max_xy"]
scale_z = resolved["scale_z"]
scale_xy = resolved["scale_xy"]
type_ignore = resolved["type_ignore"]
index_ignore = resolved["index_ignore"]
color_map = resolved["color_map"]
one_dim_orientation = resolved["one_dim_orientation"]
index_2D = resolved["index_2D"]
background_fill = resolved["background_fill"]
draw_volume = resolved["draw_volume"]
draw_reversed = resolved["draw_reversed"]
padding = resolved["padding"]
text_callable = resolved["text_callable"]
text_vspacing = resolved["text_vspacing"]
spacing = resolved["spacing"]
draw_funnel = resolved["draw_funnel"]
shade_step = resolved["shade_step"]
legend = resolved["legend"]
legend_text_spacing_offset = resolved["legend_text_spacing_offset"]
font = resolved["font"]
font_color = resolved["font_color"]
show_dimension = resolved["show_dimension"]
sizing_mode = resolved["sizing_mode"]
dimension_caps = resolved["dimension_caps"]
relative_base_size = resolved["relative_base_size"]
connector_fill = resolved["connector_fill"]
connector_width = resolved["connector_width"]
image_fit = resolved["image_fit"]
image_axis = resolved["image_axis"]
layered_groups = resolved["layered_groups"]
styles = resolved["styles"]
if styles is not None and not isinstance(styles, dict):
styles = dict(styles)
if styles is None:
styles = {}
global_defaults = {
'fill': None,
'outline': 'black',
'padding': padding,
'spacing': spacing,
'scale_z': scale_z,
'scale_xy': scale_xy,
'min_z': min_z,
'max_z': max_z,
'min_xy': min_xy,
'max_xy': max_xy,
'shade_step': shade_step,
'font_color': font_color,
'image_fit': image_fit,
'image_axis': image_axis
}
if type_ignore is not None and not isinstance(type_ignore, list):
type_ignore = list(type_ignore)
if index_ignore is not None and not isinstance(index_ignore, list):
index_ignore = list(index_ignore)
if index_2D is None:
index_2D = []
elif not isinstance(index_2D, list):
index_2D = list(index_2D)
if color_map is not None and not isinstance(color_map, dict):
color_map = dict(color_map)
if dimension_caps is not None and not isinstance(dimension_caps, dict):
dimension_caps = dict(dimension_caps)
if isinstance(text_callable, str):
try:
text_callable = LAYERED_TEXT_CALLABLES[text_callable]
except KeyError as exc:
available = ", ".join(sorted(LAYERED_TEXT_CALLABLES))
raise ValueError(
f"Unknown text_callable preset '{text_callable}'. "
f"Available presets: {available}"
) from exc
if callable(text_callable) and text_callable not in _BUILT_IN_TEXT_CALLABLES:
warnings.warn(
"Custom text_callable detected. Built-in caption templates are available "
"via visualkeras.show(..., text_callable='name').",
UserWarning,
stacklevel=2,
)
# Deprecation warning for legend_text_spacing_offset
if legend_text_spacing_offset != 0:
warnings.warn("The legend_text_spacing_offset parameter is deprecated and will be removed in a future release.")
boxes = list()
layer_y = list()
color_wheel = ColorWheel()
current_z = padding
x_off = -1
layer_types = list()
dimension_list = []
img_height = 0
max_right = 0
if type_ignore is None:
type_ignore = list()
if index_ignore is None:
index_ignore = list()
if color_map is None:
color_map = dict()
# Pre-process groups to map layers to their groups
layer_to_groups = {}
if layered_groups:
# Create a map of layer name to index for faster lookup
name_to_index = {}
for i, layer in enumerate(model.layers):
name = getattr(layer, 'name', None)
if name:
name_to_index[name] = i
for group in layered_groups:
for ref in group.get("layers", []):
idx = -1
if isinstance(ref, str):
idx = name_to_index.get(ref, -1)
else:
# Assume it's a layer object
try:
idx = model.layers.index(ref)
except ValueError:
pass
if idx != -1:
layer_to_groups.setdefault(idx, []).append(group)
last_rendered_index = -1
for index, layer in enumerate(model.layers):
# Ignore layers that the use has opted out to
if type(layer) in type_ignore or index in index_ignore:
continue
# Do not render the SpacingDummyLayer, just increase the pointer
if type(layer) == SpacingDummyLayer:
current_z += layer.spacing
continue
# Adjust spacing to prevent group overlap
if last_rendered_index != -1 and layered_groups:
prev_groups = layer_to_groups.get(last_rendered_index, [])
curr_groups = layer_to_groups.get(index, [])
exiting = [g for g in prev_groups if g not in curr_groups]
entering = [g for g in curr_groups if g not in prev_groups]
clearance = max([g.get('padding', 10) for g in exiting] + [0]) + \
max([g.get('padding', 10) for g in entering] + [0])
current_z += clearance
layer_type = type(layer)
if legend and show_dimension:
layer_types.append(layer_type)
elif layer_type not in layer_types:
layer_types.append(layer_type)
# Resolve Layer Name
try:
layer_name = getattr(layer, 'name', None) or f'{layer.__class__.__name__}_{index}'
except AttributeError:
layer_name = f'unknown_{index}'
# Resolve Styles
# Merge legacy color_map into the defaults for backward compatibility.
legacy_color = color_map.get(type(layer), {})
current_defaults = global_defaults.copy()
current_defaults.update(legacy_color)
style = resolve_style(layer, layer_name, styles, current_defaults)
# Get the primary shape of the layer's output
raw_shape = _resolve_layer_output_shape(layer)
shape = extract_primary_shape(raw_shape, layer_name)
# Use Styles for Dimensions
# We pass the specific constraints and scalers from the style instead of the global args.
x, y, z = calculate_layer_dimensions(
shape,
style['scale_z'],
style['scale_xy'],
style['max_z'],
style['max_xy'],
style['min_z'],
style['min_xy'],
one_dim_orientation, sizing_mode,
dimension_caps, relative_base_size
)
# --- Image Handling ---
image_path = style.get("image")
node_image = None
if image_path:
try:
node_image = Image.open(image_path).convert("RGBA")
fit_mode = style.get("image_fit", image_fit)
axis = style.get("image_axis", image_axis)
if fit_mode == "match_aspect":
img_w, img_h = node_image.size
img_ratio = img_w / img_h
if axis == 'z': # Front (Width x Height) -> (z x y)
surf_ratio = z / y if y > 0 else 1
if img_ratio > surf_ratio:
z = int(y * img_ratio)
else:
y = int(z / img_ratio)
elif axis == 'y': # Top (Width x Depth) -> (z x de)
# de = x / 3. We adjust x to achieve target de.
# Ratio = Width / Depth = z / de
# de = z / Ratio
if img_ratio > 0:
de_target = int(z / img_ratio)
x = de_target * 3
elif axis == 'x': # Side (Depth x Height) -> (de x y)
# Ratio = Depth / Height = de / y
# de = y * Ratio
de_target = int(y * img_ratio)
x = de_target * 3
# Apply scale_image
scale_factor = style.get("scale_image")
if scale_factor is not None:
try:
scale_factor = float(scale_factor)
if scale_factor < 0: scale_factor = 0.0
except (ValueError, TypeError):
scale_factor = 1.0
if axis == 'z':
z = int(z * scale_factor)
y = int(y * scale_factor)
elif axis == 'y':
z = int(z * scale_factor)
x = int(x * scale_factor)
elif axis == 'x':
x = int(x * scale_factor)
y = int(y * scale_factor)
except Exception as e:
warnings.warn(f"Failed to load image for layer '{layer_name}': {e}")
image_path = None
node_image = None
if legend and show_dimension:
dimension_string = str(shape)
dimension_string = dimension_string[1:len(dimension_string)-1].split(", ")
dimension = []
for i in range(0, len(dimension_string)):
if dimension_string[i].isnumeric():
dimension.append(dimension_string[i])
dimension_list.append(dimension)
box = Box()
box.layer = layer # Store layer for grouping
box.style = style # Store style for later use
box.image = node_image
if node_image:
box.image_fit = style.get("image_fit", image_fit)
box.image_axis = style.get("image_axis", image_axis)
# Use styles for visual properties
# If fill is None (default), fallback to the color wheel
if style.get('fill') is None:
box.fill = color_wheel.get_color(layer_type)
else:
box.fill = style.get('fill')
box.outline = style.get('outline', 'black')
box.shade = style.get('shade_step', shade_step)
# Update the color_map so the legend reflects this layer's appearance
color_map[layer_type] = {'fill': box.fill, 'outline': box.outline}
box.de = 0
if draw_volume and index not in index_2D:
box.de = x / 3
if x_off == -1:
x_off = box.de / 2
# top left coordinate
box.x1 = current_z - box.de / 2
box.y1 = box.de
# bottom right coordinate
box.x2 = box.x1 + z
box.y2 = box.y1 + y
boxes.append(box)
layer_y.append(box.y2 - (box.y1 - box.de))
# Update image bounds
hh = box.y2 - (box.y1 - box.de)
if hh > img_height:
img_height = hh
if box.x2 + box.de > max_right:
max_right = box.x2 + box.de
# Use style-based spacing
layer_spacing = style.get('spacing', spacing)
current_z += z + layer_spacing
last_rendered_index = index
# Generate image
min_scene_x = float('inf')
max_scene_x = float('-inf')
max_top_extent = 0
max_bottom_extent = 0
for i, box in enumerate(boxes):
h = layer_y[i]
half_h = h / 2
max_top_extent = max(max_top_extent, half_h)
max_bottom_extent = max(max_bottom_extent, half_h)
visual_x1 = box.x1 + x_off
visual_x2 = box.x2 + x_off
if draw_reversed:
visual_x1 += box.de
visual_x2 += box.de
min_scene_x = min(min_scene_x, visual_x1)
max_scene_x = max(max_scene_x, visual_x2)
if text_callable is not None:
if font is None:
font = ImageFont.load_default()
box_idx = -1
for index, layer in enumerate(model.layers):
if type(layer) in type_ignore or type(layer) == SpacingDummyLayer or index in index_ignore:
continue
box_idx += 1
box = boxes[box_idx]
local_font = box.style.get('font', font)
local_vspacing = box.style.get('text_vspacing', text_vspacing)
text, above = text_callable(box_idx, layer)
text_w = 0
text_h = 0
lines = text.split('\n')
for line in lines:
if hasattr(local_font, 'getsize'):
line_w, line_h = local_font.getsize(line)
else:
bbox = local_font.getbbox(line)
line_w = bbox[2]
line_h = bbox[3]
text_w = max(text_w, line_w)
text_h += line_h
text_h += (len(lines) - 1) * local_vspacing
width = box.x2 - box.x1
base_x = box.x1 + x_off
if draw_reversed:
base_x += box.de
if above:
center_x = base_x + box.de + width / 2
max_top_extent = max(max_top_extent, (layer_y[box_idx] / 2) + text_h)
else:
center_x = base_x + width / 2
max_bottom_extent = max(max_bottom_extent, (layer_y[box_idx] / 2) + text_h)
t_x1 = center_x - text_w / 2
t_x2 = center_x + text_w / 2
min_scene_x = min(min_scene_x, t_x1)
max_scene_x = max(max_scene_x, t_x2)
if layered_groups:
dummy_img = Image.new("RGBA", (1, 1))
dummy_draw = ImageDraw.Draw(dummy_img)
for group in layered_groups:
group_boxes = _get_group_boxes(boxes, group)
if not group_boxes: continue
g_min_x = float('inf')
g_max_x = float('-inf')
g_min_y = float('inf') # relative to center (negative is up)
g_max_y = float('-inf')
for box in group_boxes:
# Reconstruct visual bounds in scene space
idx = boxes.index(box)
h = layer_y[idx]
# Y bounds (relative to center)
# Top is -h/2, Bottom is h/2
g_min_y = min(g_min_y, -h/2)
g_max_y = max(g_max_y, h/2)
# X bounds
visual_x1 = box.x1 + x_off
visual_x2 = box.x2 + x_off
if draw_reversed:
visual_x1 += box.de
visual_x2 += box.de
# Back face extends to left/up
g_min_x = min(g_min_x, visual_x1 - box.de)
g_max_x = max(g_max_x, visual_x2)
else:
# Normal mode
# Back face extends to right/up
g_min_x = min(g_min_x, visual_x1)
g_max_x = max(g_max_x, visual_x2 + box.de)
# Apply padding
padding_val = group.get("padding", 10)
g_min_x -= padding_val
g_max_x += padding_val
g_min_y -= padding_val
g_max_y += padding_val
# Update scene extents
min_scene_x = min(min_scene_x, g_min_x)
max_scene_x = max(max_scene_x, g_max_x)
max_top_extent = max(max_top_extent, -g_min_y)
max_bottom_extent = max(max_bottom_extent, g_max_y)
# Caption
caption = group.get("name", group.get("caption"))
if caption:
font = _get_font(group)
text_w, text_h = _measure_text(dummy_draw, caption, font)
center_x = (g_min_x + g_max_x) / 2
text_x1 = center_x - text_w / 2
text_x2 = center_x + text_w / 2
gap = group.get("text_spacing", 5)
text_bottom = g_max_y + gap + text_h
min_scene_x = min(min_scene_x, text_x1)
max_scene_x = max(max_scene_x, text_x2)
max_bottom_extent = max(max_bottom_extent, text_bottom)
total_content_height = max_top_extent + max_bottom_extent
img_height = total_content_height
center_y_pos = max_top_extent
x_shift = padding - min_scene_x
img_width = max_scene_x + x_shift + padding
img = Image.new('RGBA', (int(ceil(img_width)), int(ceil(img_height))), background_fill)
for i, node in enumerate(boxes):
h = layer_y[i]
node_top = center_y_pos - h / 2
node.y1 = node_top + node.de
node.y2 = node_top + h
node.x1 += x_shift + x_off
node.x2 += x_shift + x_off
draw = aggdraw.Draw(img)
# Prepare logos
box_logos = {}
if logo_groups:
for group in logo_groups:
path = group.get("file")
if not path: continue
try:
logo_img = Image.open(path)
except:
continue
target_boxes = _get_logo_boxes(boxes, group)
for box in target_boxes:
if id(box) not in box_logos:
box_logos[id(box)] = []
box_logos[id(box)].append((group, logo_img))
# Correct x positions of reversed boxes
if draw_reversed:
for box in boxes:
offset = box.de
# offset = 0
box.x1 = box.x1 + offset
box.x2 = box.x2 + offset
if layered_groups:
_draw_layered_group_boxes(draw, boxes, layered_groups, draw_reversed)
# Draw created boxes
last_box = None
if draw_reversed:
for box in boxes:
pen = aggdraw.Pen(get_rgba_tuple(box.outline))
if last_box is not None and draw_funnel:
# Top connection back
draw.line([last_box.x2 - last_box.de, last_box.y1 - last_box.de,
box.x1 - box.de, box.y1 - box.de], pen)
# Bottom connection back
draw.line([last_box.x2 - last_box.de, last_box.y2 - last_box.de,
box.x1 - box.de, box.y2 - box.de], pen)
last_box = box
last_box = None
for box in reversed(boxes):
pen = aggdraw.Pen(get_rgba_tuple(box.outline))
if last_box is not None and draw_funnel:
# Top connection front
draw.line([last_box.x1, last_box.y1,
box.x2, box.y1], pen)
# Bottom connection front
draw.line([last_box.x1, last_box.y2,
box.x2, box.y2], pen)
box.draw(draw, draw_reversed=True)
if id(box) in box_logos:
draw.flush()
for group, logo_img in box_logos[id(box)]:
draw_node_logo(img, box, logo_img, group, draw_volume, draw_reversed=True)
draw = aggdraw.Draw(img)
if getattr(box, 'image', None):
draw.flush()
image = box.image
fit = box.image_fit
axis = box.image_axis
x1, y1, x2, y2 = box.x1, box.y1, box.x2, box.y2
de = box.de
if axis == 'z': # Front
w = x2 - x1
h = y2 - y1
resized = resize_image_to_fit(image, int(w), int(h), fit)
img.paste(resized, (int(x1), int(y1)), resized)
elif axis == 'y': # Top
# Reversed Top Face: TL(x1-de, y1-de), TR(x2-de, y1-de), BR(x2, y1), BL(x1, y1)
p1 = (x1 - de, y1 - de)
p2 = (x2 - de, y1 - de)
p3 = (x2, y1)
p4 = (x1, y1)
apply_affine_transform(img, image, [p1, p2, p3, p4], fit)
elif axis == 'x': # Side
# Reversed Side Face (Left): TL(x1-de, y1-de), TR(x1, y1), BR(x1, y2), BL(x1-de, y2-de)
p1 = (x1 - de, y1 - de)
p2 = (x1, y1)
p3 = (x1, y2)
p4 = (x1 - de, y2 - de)
apply_affine_transform(img, image, [p1, p2, p3, p4], fit)
draw = aggdraw.Draw(img)
last_box = box
else:
for box in boxes:
pen = aggdraw.Pen(get_rgba_tuple(box.outline))
if last_box is not None and draw_funnel:
draw.line([last_box.x2 + last_box.de, last_box.y1 - last_box.de,
box.x1 + box.de, box.y1 - box.de], pen)
draw.line([last_box.x2 + last_box.de, last_box.y2 - last_box.de,
box.x1 + box.de, box.y2 - box.de], pen)
draw.line([last_box.x2, last_box.y2,
box.x1, box.y2], pen)
draw.line([last_box.x2, last_box.y1,
box.x1, box.y1], pen)
box.draw(draw, draw_reversed=False)
if id(box) in box_logos:
draw.flush()
for group, logo_img in box_logos[id(box)]:
draw_node_logo(img, box, logo_img, group, draw_volume, draw_reversed=False)
draw = aggdraw.Draw(img)
if getattr(box, 'image', None):
draw.flush()
image = box.image
fit = box.image_fit
axis = box.image_axis
x1, y1, x2, y2 = box.x1, box.y1, box.x2, box.y2
de = box.de
if axis == 'z': # Front
w = x2 - x1
h = y2 - y1
resized = resize_image_to_fit(image, int(w), int(h), fit)
img.paste(resized, (int(x1), int(y1)), resized)
elif axis == 'y': # Top
# Normal Top Face: TL(x1, y1), TR(x2, y1), BR(x2+de, y1-de), BL(x1+de, y1-de)
# Wait, Box.draw normal top:
# draw.polygon([self.x1, self.y1,
# self.x1 + self.de, self.y1 - self.de,
# self.x2 + self.de, self.y1 - self.de,
# self.x2, self.y1
# ], pen, brush_s1)
# Order: BL, TL, TR, BR (relative to face?)
# Let's map to TL, TR, BR, BL.
# TL: (x1+de, y1-de)
# TR: (x2+de, y1-de)
# BR: (x2, y1)
# BL: (x1, y1)
p1 = (x1 + de, y1 - de)
p2 = (x2 + de, y1 - de)
p3 = (x2, y1)
p4 = (x1, y1)
apply_affine_transform(img, image, [p1, p2, p3, p4], fit)
elif axis == 'x': # Side
# Normal Side Face (Right): TL(x2, y1), TR(x2+de, y1-de), BR(x2+de, y2-de), BL(x2, y2)
p1 = (x2, y1)
p2 = (x2 + de, y1 - de)
p3 = (x2 + de, y2 - de)
p4 = (x2, y2)
apply_affine_transform(img, image, [p1, p2, p3, p4], fit)
draw = aggdraw.Draw(img)
last_box = box
draw.flush()
if text_callable is not None:
draw_text = ImageDraw.Draw(img)
i = -1
for index, layer in enumerate(model.layers):
if type(layer) in type_ignore or type(layer) == SpacingDummyLayer or index in index_ignore:
continue
i += 1
# Retrieve Styles
box = boxes[i]
local_font = box.style.get('font', font)
local_font_color = box.style.get('font_color', font_color)
local_vspacing = box.style.get('text_vspacing', text_vspacing)
text, above = text_callable(i, layer)
text_height = 0
text_x_adjust = []
for line in text.split('\n'):
# Use local_font for measurements
if hasattr(local_font, 'getsize'):
line_height = local_font.getsize(line)[1]
text_x_adjust.append(local_font.getsize(line)[0])
else:
line_height = local_font.getbbox(line)[3]
text_x_adjust.append(local_font.getbbox(line)[2])
text_height += line_height
# Use local_vspacing
text_height += (len(text.split('\n')) - 1) * local_vspacing
text_x = box.x1 + (box.x2 - box.x1) / 2
text_y = box.y2
if above:
text_x = box.x1 + box.de + (box.x2 - box.x1) / 2
text_y = box.y1 - box.de - text_height
# Use max width of the specific font
text_x -= max(text_x_adjust or [0]) / 2
anchor = 'la'
if above:
anchor = 'la'
draw_text.multiline_text(
(text_x, text_y),
text,
font=local_font,
fill=local_font_color,
anchor=anchor,
align='center',
spacing=local_vspacing
)
# Create layer color legend
if legend:
if font is None:
font = ImageFont.load_default()
if hasattr(font, 'getsize'):
text_height = font.getsize("Ag")[1]
else:
text_height = font.getbbox("Ag")[3]
cube_size = text_height
de = 0
if draw_volume:
de = cube_size // 2
patches = list()
if show_dimension:
counter = 0
for layer_type in layer_types:
if show_dimension:
label = layer_type.__name__ + "(" + str(dimension_list[counter]) + ")"
counter += 1
else:
label = layer_type.__name__
if hasattr(font, 'getsize'):
text_size = font.getsize(label)
else:
# Get last two values of the bounding box
# getbbox returns 4 dimensions in total, where the first two are always zero,
# So we fetch the last two dimensions to match the behavior of getsize
text_size = font.getbbox(label)[2:]
label_patch_size = (2 * cube_size + de + spacing + text_size[0], cube_size + de)
# this only works if cube_size is bigger than text height
img_box = Image.new('RGBA', label_patch_size, background_fill)
img_text = Image.new('RGBA', label_patch_size, (0, 0, 0, 0))
draw_box = aggdraw.Draw(img_box)
draw_text = ImageDraw.Draw(img_text)
box = Box()
box.x1 = cube_size
box.x2 = box.x1 + cube_size
box.y1 = de
box.y2 = box.y1 + cube_size
box.de = de
box.shade = shade_step
box.fill = color_map.get(layer_type, {}).get('fill', "#000000")
box.outline = color_map.get(layer_type, {}).get('outline', "#000000")
box.draw(draw_box, draw_reversed)
text_x = box.x2 + box.de + spacing
text_y = (label_patch_size[1] - text_height) / 2 # 2D center; use text_height and not the current label!
draw_text.text((text_x, text_y), label, font=font, fill=font_color)
draw_box.flush()
img_box.paste(img_text, mask=img_text)
patches.append(img_box)
legend_image = linear_layout(patches, max_width=img.width, max_height=img.height, padding=padding,
spacing=spacing,
background_fill=background_fill, horizontal=True)
img = vertical_image_concat(img, legend_image, background_fill=background_fill)
if layered_groups:
_draw_layered_group_captions(img, boxes, layered_groups, draw_reversed)
if logos_legend:
if font is None:
font = ImageFont.load_default()
img = draw_logos_legend(img, logo_groups, logos_legend, background_fill, font, font_color)
if to_file is not None:
img.save(to_file)
return img