Source code for visualkeras.functional

"""Functional (graph-aware) layered renderer for Keras/TensorFlow models.

This renderer targets functional graphs with branches, merges, and multi-input/
output structures. The pipeline is:
1) Graph extraction
2) Rank assignment (longest path)
3) Edge normalization (virtual nodes for long edges)
4) Crossing reduction (barycentric ordering)
5) Component layout + rendering
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union
from collections import deque
import warnings
import re

import aggdraw
import numpy as np
from PIL import Image, ImageDraw, ImageFont

from .layer_utils import (
    calculate_layer_dimensions,
    extract_primary_shape,
    find_output_layers,
    get_incoming_layers,
    get_layers,
)
from .options import FunctionalOptions, FUNCTIONAL_PRESETS, LAYERED_TEXT_CALLABLES
from .utils import (
    Box, 
    ColorWheel, 
    fade_color, 
    get_rgba_tuple, 
    resize_image_to_fit, 
    apply_affine_transform,
    draw_node_logo,
    draw_logos_legend
)



@dataclass
class FunctionalNode:
    """Internal node record used by the functional renderer.

    Instances of this dataclass represent the normalized nodes that move
    through the functional layout pipeline. They are exposed in the API
    reference because the module is documented with ``automodule``, but most
    users interact with them indirectly through :func:`functional_view`.
    """
    layer: Any
    node_id: int
    name: str
    layer_type: type
    shape: Optional[Tuple[Any, ...]]
    dims: Tuple[int, int, int]
    width: int
    height: int
    order: int
    rank: int = 0
    rank_order: int = 0
    x: int = 0
    y: int = 0
    kind: str = "layer"  # layer, input, output, virtual, collapsed
    component: int = 0
    style: Dict[str, Any] = field(default_factory=dict)
    de: int = 0
    shade: int = 0
    image: Optional[Image.Image] = None


@dataclass(frozen=True)
class FunctionalEdge:
    """Directed edge connecting two :class:`FunctionalNode` objects."""
    src: int
    dst: int


@dataclass
class FunctionalGraph:
    """Normalized graph container used during functional layout and rendering."""
    nodes: Dict[int, FunctionalNode]
    edges: List[FunctionalEdge]
    inputs: List[int]
    outputs: List[int]


class _SyntheticLayer:
    """Lightweight placeholder for synthetic input/output/virtual anchors."""

    def __init__(self, name: str, output_shape: Optional[Tuple[Any, ...]] = None) -> None:
        self.name = name
        self.output_shape = output_shape


def _normalize_collapse_selector(kind: str, selector: Any, rule_index: int) -> Union[str, type, Tuple[Union[str, type], ...]]:
    if kind == "layer":
        if isinstance(selector, (str, type)):
            return selector
        raise TypeError(
            f"collapse_rules[{rule_index}]['selector'] must be a layer name (str) or layer type for kind='layer'."
        )

    if kind == "block":
        if not isinstance(selector, Sequence) or isinstance(selector, (str, bytes)):
            raise TypeError(
                f"collapse_rules[{rule_index}]['selector'] must be a sequence of layer names/types for kind='block'."
            )
        normalized: List[Union[str, type]] = []
        for item_index, item in enumerate(selector):
            if isinstance(item, (str, type)):
                normalized.append(item)
                continue
            raise TypeError(
                f"collapse_rules[{rule_index}]['selector'][{item_index}] must be a layer name (str) or layer type."
            )
        if len(normalized) < 2:
            raise ValueError(
                f"collapse_rules[{rule_index}]['selector'] for kind='block' must contain at least 2 entries."
            )
        return tuple(normalized)

    raise ValueError(f"Unsupported collapse rule kind '{kind}'.")


def _validate_and_normalize_collapse_rules(
    collapse_rules: Optional[Sequence[Mapping[str, Any]]]
) -> List[Dict[str, Any]]:
    """Validate and normalize explicit collapse rule definitions.

    Rules must be mappings with:
    - ``kind``: ``"layer"`` or ``"block"``
    - ``selector``: a layer name/type (for ``layer``) or sequence of names/types (for ``block``)
    - ``repeat_count``: integer >= 2

    Optional fields:
    - ``label``: string, defaults to ``"{repeat_count}x"``
    - ``annotation_position``: ``"above"`` or ``"below"``, defaults to ``"above"``

    Args:
        collapse_rules: User-provided collapse rules from ``functional_view`` options/kwargs.

    Returns:
        A normalized list of plain dict rules with canonical keys and validated values.

    Raises:
        TypeError: If rule container/items or typed fields have invalid types.
        ValueError: If required keys are missing or values are outside accepted ranges.
    """
    if collapse_rules is None:
        return []
    if not isinstance(collapse_rules, Sequence) or isinstance(collapse_rules, (str, bytes, Mapping)):
        raise TypeError("collapse_rules must be a sequence of mapping rules.")

    normalized_rules: List[Dict[str, Any]] = []
    for rule_index, raw_rule in enumerate(collapse_rules):
        if not isinstance(raw_rule, Mapping):
            raise TypeError(f"collapse_rules[{rule_index}] must be a mapping.")

        kind = str(raw_rule.get("kind", "")).strip().lower()
        if kind not in {"layer", "block"}:
            raise ValueError(
                f"collapse_rules[{rule_index}]['kind'] must be one of: 'layer', 'block'."
            )

        if "selector" not in raw_rule:
            raise ValueError(f"collapse_rules[{rule_index}] is missing required key 'selector'.")
        selector = _normalize_collapse_selector(kind, raw_rule["selector"], rule_index)

        repeat_count = raw_rule.get("repeat_count")
        if not isinstance(repeat_count, int) or repeat_count < 2:
            raise ValueError(
                f"collapse_rules[{rule_index}]['repeat_count'] must be an integer >= 2."
            )

        annotation_position = str(raw_rule.get("annotation_position", "above")).strip().lower()
        if annotation_position not in {"above", "below"}:
            raise ValueError(
                f"collapse_rules[{rule_index}]['annotation_position'] must be 'above' or 'below'."
            )

        label = raw_rule.get("label")
        if label is None:
            label = f"{repeat_count}x"
        elif not isinstance(label, str):
            raise TypeError(f"collapse_rules[{rule_index}]['label'] must be a string when provided.")

        normalized_rules.append(
            {
                "kind": kind,
                "selector": selector,
                "repeat_count": repeat_count,
                "label": label,
                "annotation_position": annotation_position,
            }
        )

    return normalized_rules


[docs] def functional_view( model, to_file: Optional[str] = None, color_map: Optional[Mapping[type, Mapping[str, Any]]] = None, background_fill: Any = "white", padding: int = 20, column_spacing: int = 80, row_spacing: int = 40, component_spacing: int = 80, connector_fill: Any = "gray", connector_width: int = 2, connector_arrow: bool = False, connector_padding: int = 5, min_z: int = 20, min_xy: int = 20, max_z: int = 400, max_xy: int = 2000, scale_z: float = 1.5, scale_xy: float = 4.0, one_dim_orientation: str = "z", sizing_mode: str = "balanced", dimension_caps: Optional[Mapping[str, int]] = None, relative_base_size: int = 20, text_callable: Optional[Callable[[int, Any], Tuple[str, bool]]] = None, text_vspacing: int = 4, font: Optional[ImageFont.ImageFont] = None, font_color: Any = "black", add_output_nodes: bool = False, layout_iterations: int = 4, virtual_node_size: int = 12, render_virtual_nodes: bool = False, draw_volume: bool = False, orientation_rotation: Optional[float] = None, shade_step: int = 10, image_fit: str = "fill", image_axis: str = "z", layered_groups: Optional[Sequence[Dict[str, Any]]] = None, logo_groups: Optional[Sequence[Dict[str, Any]]] = None, logos_legend: Union[bool, Dict[str, Any]] = False, styles: Optional[Mapping[Union[str, type], Dict[str, Any]]] = None, *, simple_text_visualization: bool = False, simple_text_label_mode: str = "below", collapse_enabled: bool = False, collapse_rules: Optional[Sequence[Mapping[str, Any]]] = None, collapse_annotations: bool = True, options: Union[FunctionalOptions, Mapping[str, Any], None] = None, preset: Union[str, None] = None, ) -> Image.Image: """Render a functional model using a graph-aware layered layout. This renderer sits between layered view and graph view. It preserves more layer-level geometry than a pure topology diagram while still handling branches, merges, multi-input paths, and other functional-model structure. Parameters ---------- model : Any Keras model instance to visualize. Functional view is most useful when the model has a meaningful graph structure that would be lost in a strictly sequential rendering, but you still want the output to feel like an architecture diagram rather than a generic node-link graph. to_file : str, optional Path to save the rendered image. The image format is inferred from the file extension. The rendered ``PIL.Image`` is returned whether or not this value is provided. Use this when you want to save the output and keep working with the in-memory image in the same call. color_map : mapping, optional Mapping from layer class to broad style values such as ``fill`` and ``outline``. This is the quickest way to create a consistent color language by layer type. It is best for coarse styling rules, while ``styles`` is better for per-layer control. background_fill : Any, default='white' Background color for the final image. This accepts any Pillow-compatible color value. Choose a background that keeps layer boxes, connectors, and annotation text easy to read. padding : int, default=20 Outer padding around the full diagram in pixels. Increase this when labels, groups, or legends feel too close to the canvas edge. Padding affects the whole composition rather than the gaps between nodes inside the layout. column_spacing : int, default=80 Horizontal spacing between layout ranks. This is one of the main controls for how open the diagram feels from left to right. Larger values improve readability in wide branching graphs, while smaller values make the figure more compact. row_spacing : int, default=40 Vertical spacing between nodes within the same rank. Use this to manage crowded parallel branches. It works together with node size, text labels, and connector routing. component_spacing : int, default=80 Spacing between disconnected graph components. This matters when the renderer splits a model into separate connected subgraphs or when synthetic nodes create distinct visual blocks. connector_fill : Any, default='gray' Color used for connector paths between nodes. Neutral connector colors tend to work well because the boxes themselves already communicate most of the layer semantics. connector_width : int, default=2 Line width used for connectors. Increase this for exported figures, large canvases, or diagrams where the connector paths need more visual weight. connector_arrow : bool, default=False If ``True``, draw directional arrowheads on connectors. Arrowheads can be helpful for teaching material or graphs where the direction of flow is not already obvious from the layout. connector_padding : int, default=5 Padding reserved around nodes when routing connectors. This helps prevent connectors from appearing glued to the box edges and gives routed paths a cleaner look. min_z : int, default=20 Minimum rendered depth in pixels for a layer box when volumetric rendering is used. This prevents channel-light layers from collapsing into thin slivers. It matters most when ``draw_volume`` is enabled. min_xy : int, default=20 Minimum rendered width and height in pixels for a layer box. A reasonable minimum keeps small layers visible even when the model also contains very large tensors. max_z : int, default=400 Maximum rendered depth in pixels for a layer box. Use this to keep channel-heavy layers from dominating the visual depth of the diagram. max_xy : int, default=2000 Maximum rendered width and height in pixels for a layer box. This cap protects the layout from becoming impractically large when a model contains very large spatial dimensions or long sequences. scale_z : float, default=1.5 Multiplier applied to the depth dimension before clamping. Increase this when channel depth should read more strongly in the figure. Reduce it when depth cues feel exaggerated. scale_xy : float, default=4.0 Multiplier applied to width and height dimensions before clamping. This is a main control for the apparent size of rendered layer boxes. Lower values keep dense graphs manageable, while higher values make individual nodes easier to inspect. one_dim_orientation : {'x', 'y', 'z'}, default='z' Axis used when rendering one-dimensional layers. This affects how dense or flattened outputs are represented visually in mixed architectures that combine convolutional and vector-like stages. sizing_mode : {'accurate', 'balanced', 'capped', 'logarithmic', 'relative'}, default='balanced' Strategy used to convert tensor dimensions into rendered sizes. ``balanced`` is the default because it usually produces readable functional diagrams without letting a few extreme layers dominate the canvas. The other modes trade realism against compactness in different ways. dimension_caps : mapping, optional Custom caps used by sizing modes that support them. Supported keys are ``channels``, ``sequence``, and ``general``. This is useful when a small number of large layers would otherwise make the rest of the graph difficult to compare. relative_base_size : int, default=20 Base pixel unit used by ``relative`` sizing mode. In relative mode, a dimension of one maps directly to this many pixels, subject to minimum and maximum bounds. text_callable : callable, optional Callable receiving ``(layer_index, layer)`` and returning ``(text, above)`` for per-layer labels. This is the main hook for custom annotations such as layer names, block roles, or tensor shapes. Built-in helpers are available in ``visualkeras.options.LAYERED_TEXT_CALLABLES``. text_vspacing : int, default=4 Vertical spacing between lines produced by ``text_callable``. Increase this for multiline labels that feel cramped. Smaller values help conserve vertical space in dense graphs. font : PIL.ImageFont.ImageFont, optional Font used for labels, annotations, and legends where applicable. A custom font is useful when the figure needs to match an existing visual style for documentation or publication. font_color : Any, default='black' Text color used for labels and related annotations. This should contrast clearly with the background and any group overlays. add_output_nodes : bool, default=False If ``True``, add explicit output nodes even when the graph can end on a real layer. This can make complex multi-output diagrams easier to read because the termination points become visually explicit. layout_iterations : int, default=4 Number of refinement passes used by parts of the layout pipeline. More iterations can improve ordering or collision handling in difficult graphs, but they also increase render time. virtual_node_size : int, default=12 Size used for virtual nodes inserted during long-edge normalization. This matters only when virtual nodes are rendered or when their size influences layout spacing. render_virtual_nodes : bool, default=False If ``True``, draw virtual routing nodes that are otherwise only used internally by the layout algorithm. This is mainly useful for debugging or for highly explicit topology diagrams where routing helpers should remain visible. draw_volume : bool, default=False If ``True``, render layer boxes with 3D depth cues. If ``False``, use flat 2D rectangles. Flat mode is usually easier to read in complex functional graphs. Volumetric mode can be effective for presentation graphics or models where tensor depth is an important part of the explanation. orientation_rotation : float, optional Optional rotation applied to volumetric boxes. This lets you change the apparent viewing angle of 3D nodes when the default perspective does not suit the figure. shade_step : int, default=10 Amount of shading variation used for 3D faces and related effects. Larger values create stronger contrast between faces. Smaller values produce a flatter and subtler look. image_fit : {'fill', 'contain', 'cover', 'match_aspect'}, default='fill' Default fit mode for images injected through ``styles``. Choose a mode based on whether you prefer full coverage, preserved aspect ratio, or exact fill behavior. image_axis : {'x', 'y', 'z'}, default='z' Default axis used when rendering per-layer images in volumetric mode. This determines which face of a 3D node should receive an embedded image unless a per-layer override is supplied through ``styles``. layered_groups : sequence of dict, optional Group definitions used to draw labeled background regions behind sets of nodes. Groups are useful for separating architectural stages or conceptual blocks without changing the graph structure itself. logo_groups : sequence of dict, optional Logo placement definitions used to add icons or other overlays to selected nodes. This is mainly intended for presentation graphics and other highly styled figures. logos_legend : bool or dict, default=False If truthy, render a legend describing entries supplied through ``logo_groups``. A simple boolean enables the default legend behavior. A mapping allows more control over legend layout. styles : mapping, optional Fine-grained style overrides keyed by layer name or layer class. Use this for per-layer images, local color overrides, box text styling, volumetric settings, and other adjustments that are too specific for ``color_map``. simple_text_visualization : bool, default=False If ``True``, render nodes primarily as text blocks instead of sized boxes. This mode is useful when a compact textual diagram communicates the architecture better than geometric layer boxes. simple_text_label_mode : {'below', 'inside'}, default='below' Placement mode for labels in simple-text visualization. ``below`` keeps text outside the box area, while ``inside`` produces a tighter and more schematic layout. collapse_enabled : bool, default=False If ``True``, enable collapsing of repeated layers or repeated blocks according to ``collapse_rules``. Collapsing is useful for models with repeated motifs such as residual stacks, repeated convolution blocks, or Transformer-style repetition. collapse_rules : sequence of mapping, optional Explicit rules describing which repeated patterns may be collapsed. Each rule declares a ``kind``, a ``selector``, and a ``repeat_count``. Optional label and annotation fields let you control how the collapsed block is described in the rendered figure. collapse_annotations : bool, default=True If ``True``, draw annotations that describe collapsed regions. Disable this when you want the cleaner layout benefits of collapsing without additional explanatory text. options : FunctionalOptions or mapping, optional Configuration bundle applied after ``preset`` and before explicit keyword arguments. This is the preferred way to reuse a functional style across multiple models or notebooks. preset : str, optional Name of a preset from ``visualkeras.FUNCTIONAL_PRESETS``. Functional mode currently provides ``default``, ``compact``, and ``presentation``. Presets are intended as convenient starting points. They can be refined further with ``options`` and explicit overrides. Returns ------- PIL.Image.Image Rendered functional diagram. Notes ----- Configuration precedence is ``preset`` followed by ``options`` followed by explicit keyword arguments. Full documentation: https://visualkeras.readthedocs.io/en/latest/api/functional.html """ using_presets = options is not None or preset is not None if not using_presets: defaults = FunctionalOptions().to_kwargs() current_params = { "to_file": to_file, "color_map": color_map, "background_fill": background_fill, "padding": padding, "column_spacing": column_spacing, "row_spacing": row_spacing, "component_spacing": component_spacing, "connector_fill": connector_fill, "connector_width": connector_width, "connector_arrow": connector_arrow, "connector_padding": connector_padding, "min_z": min_z, "min_xy": min_xy, "max_z": max_z, "max_xy": max_xy, "scale_z": scale_z, "scale_xy": scale_xy, "one_dim_orientation": one_dim_orientation, "sizing_mode": sizing_mode, "dimension_caps": dimension_caps, "relative_base_size": relative_base_size, "text_callable": text_callable, "text_vspacing": text_vspacing, "font": font, "font_color": font_color, "add_output_nodes": add_output_nodes, "layout_iterations": layout_iterations, "virtual_node_size": virtual_node_size, "render_virtual_nodes": render_virtual_nodes, "draw_volume": draw_volume, "orientation_rotation": orientation_rotation, "shade_step": shade_step, "image_fit": image_fit, "image_axis": image_axis, "layered_groups": layered_groups, "logo_groups": logo_groups, "logos_legend": logos_legend, "simple_text_visualization": simple_text_visualization, "simple_text_label_mode": simple_text_label_mode, "collapse_enabled": collapse_enabled, "collapse_rules": collapse_rules, "collapse_annotations": collapse_annotations, "styles": styles, } custom_keys = [ key for key, value in current_params.items() if key in defaults and value != defaults[key] ] if len(custom_keys) >= 5: warnings.warn( "functional_view received many custom keyword arguments. " "Consider using visualkeras.show(..., mode='functional', preset=...) " "and the FunctionalOptions dataclass for a simpler workflow.", UserWarning, stacklevel=2, ) if preset is not None or options is not None: defaults = FunctionalOptions().to_kwargs() defaults["color_map"] = None defaults["dimension_caps"] = None defaults["font"] = None defaults["styles"] = None defaults["orientation_rotation"] = None defaults["image_fit"] = "fill" defaults["image_axis"] = "z" defaults["layered_groups"] = None defaults["logo_groups"] = None defaults["logos_legend"] = False defaults["simple_text_visualization"] = False resolved = dict(defaults) if preset is not None: try: resolved.update(FUNCTIONAL_PRESETS[preset].to_kwargs()) except KeyError as exc: available = ", ".join(sorted(FUNCTIONAL_PRESETS.keys())) raise ValueError( f"Unknown functional preset '{preset}'. Available presets: {available}" ) from exc if options is not None: if isinstance(options, FunctionalOptions): option_values = options.to_kwargs() elif isinstance(options, Mapping): option_values = dict(options) else: raise TypeError( "options must be a FunctionalOptions instance or a mapping of keyword arguments." ) resolved.update(option_values) explicit_values = { "to_file": to_file, "color_map": color_map, "background_fill": background_fill, "padding": padding, "column_spacing": column_spacing, "row_spacing": row_spacing, "component_spacing": component_spacing, "connector_fill": connector_fill, "connector_width": connector_width, "connector_arrow": connector_arrow, "connector_padding": connector_padding, "min_z": min_z, "min_xy": min_xy, "max_z": max_z, "max_xy": max_xy, "scale_z": scale_z, "scale_xy": scale_xy, "one_dim_orientation": one_dim_orientation, "sizing_mode": sizing_mode, "dimension_caps": dimension_caps, "relative_base_size": relative_base_size, "text_callable": text_callable, "text_vspacing": text_vspacing, "font": font, "font_color": font_color, "add_output_nodes": add_output_nodes, "layout_iterations": layout_iterations, "virtual_node_size": virtual_node_size, "render_virtual_nodes": render_virtual_nodes, "draw_volume": draw_volume, "orientation_rotation": orientation_rotation, "shade_step": shade_step, "image_fit": image_fit, "image_axis": image_axis, "layered_groups": layered_groups, "logo_groups": logo_groups, "logos_legend": logos_legend, "simple_text_visualization": simple_text_visualization, "simple_text_label_mode": simple_text_label_mode, "collapse_enabled": collapse_enabled, "collapse_rules": collapse_rules, "collapse_annotations": collapse_annotations, "styles": styles, } for key, value in explicit_values.items(): if key not in defaults: continue if value != defaults[key]: resolved[key] = value to_file = resolved["to_file"] color_map = resolved["color_map"] background_fill = resolved["background_fill"] padding = resolved["padding"] column_spacing = resolved["column_spacing"] row_spacing = resolved["row_spacing"] component_spacing = resolved["component_spacing"] connector_fill = resolved["connector_fill"] connector_width = resolved["connector_width"] connector_arrow = resolved["connector_arrow"] connector_padding = resolved["connector_padding"] min_z = resolved["min_z"] min_xy = resolved["min_xy"] max_z = resolved["max_z"] max_xy = resolved["max_xy"] scale_z = resolved["scale_z"] scale_xy = resolved["scale_xy"] one_dim_orientation = resolved["one_dim_orientation"] sizing_mode = resolved["sizing_mode"] dimension_caps = resolved["dimension_caps"] relative_base_size = resolved["relative_base_size"] text_callable = resolved["text_callable"] text_vspacing = resolved["text_vspacing"] font = resolved["font"] font_color = resolved["font_color"] add_output_nodes = resolved["add_output_nodes"] layout_iterations = resolved["layout_iterations"] virtual_node_size = resolved["virtual_node_size"] render_virtual_nodes = resolved["render_virtual_nodes"] draw_volume = resolved["draw_volume"] orientation_rotation = resolved.get("orientation_rotation", orientation_rotation) shade_step = resolved.get("shade_step", shade_step) image_fit = resolved.get("image_fit", image_fit) image_axis = resolved.get("image_axis", image_axis) layered_groups = resolved.get("layered_groups", layered_groups) logo_groups = resolved.get("logo_groups", logo_groups) logos_legend = resolved.get("logos_legend", logos_legend) simple_text_visualization = resolved.get("simple_text_visualization", simple_text_visualization) simple_text_label_mode = resolved.get("simple_text_label_mode", simple_text_label_mode) collapse_enabled = resolved.get("collapse_enabled", collapse_enabled) collapse_rules = resolved.get("collapse_rules", collapse_rules) collapse_annotations = resolved.get("collapse_annotations", collapse_annotations) styles = resolved["styles"] if simple_text_visualization: draw_volume = False orientation_rotation = None if color_map is not None and not isinstance(color_map, dict): color_map = dict(color_map) if dimension_caps is not None and not isinstance(dimension_caps, dict): dimension_caps = dict(dimension_caps) simple_text_label_mode = str(simple_text_label_mode or "below").strip().lower() if simple_text_label_mode not in {"inside", "below"}: raise ValueError( "simple_text_label_mode must be one of: 'inside', 'below'." ) if isinstance(text_callable, str): try: text_callable = LAYERED_TEXT_CALLABLES[text_callable] except KeyError as exc: available = ", ".join(sorted(LAYERED_TEXT_CALLABLES)) raise ValueError( f"Unknown text callable preset '{text_callable}'. " f"Available presets: {available}" ) from exc if color_map is None: color_map = {} if styles is not None and not isinstance(styles, dict): styles = dict(styles) if styles is None: styles = {} if simple_text_visualization and simple_text_label_mode == "below": if text_callable is None: text_callable = LAYERED_TEXT_CALLABLES["name_shape"] root_style = styles.get(object) if root_style is None: styles[object] = {"box_text_enabled": False} elif isinstance(root_style, Mapping): root_style_copy = dict(root_style) root_style_copy.setdefault("box_text_enabled", False) styles[object] = root_style_copy normalized_collapse_rules = _validate_and_normalize_collapse_rules(collapse_rules) global_defaults = { "connector_fill": connector_fill, "connector_width": connector_width, "connector_arrow": connector_arrow, "connector_padding": connector_padding, "draw_volume": draw_volume, "orientation_rotation": orientation_rotation, "shade_step": shade_step, "image_fit": image_fit, "image_axis": image_axis, "padding": 0, # separate from global image padding "box_orientation": "vertical", "box_text_rotation": None, "box_text_color": font_color, "box_text_font": None, "box_text_font_size": 14, "box_text_padding": 8, "box_text_wrap": "words", "box_text_autoshrink": True, "box_text_min_font_size": 8, "box_text_align": "center", "box_text_valign": "middle", "box_outline_width": 2, "box_fill": None, "box_outline": None, "box_text_enabled": True, "collapse_badge_enabled": True, "collapse_badge_fill": "white", "collapse_badge_outline": "black", "collapse_badge_text_color": font_color, "collapse_badge_font": None, "collapse_badge_font_size": 12, "collapse_badge_padding": (4, 2), "collapse_annotation_color": connector_fill, "collapse_annotation_font": None, "collapse_annotation_font_size": 12, "collapse_annotation_offset": 10, "collapse_annotation_width": 2, "collapse_annotation_head_size": 6, } graph = _build_graph( model, styles=styles, global_defaults=global_defaults, min_z=min_z, min_xy=min_xy, max_z=max_z, max_xy=max_xy, scale_z=scale_z, scale_xy=scale_xy, one_dim_orientation=one_dim_orientation, sizing_mode=sizing_mode, dimension_caps=dimension_caps, relative_base_size=relative_base_size, add_output_nodes=add_output_nodes, virtual_node_size=virtual_node_size, draw_volume=draw_volume, shade_step=shade_step, image_fit=image_fit, image_axis=image_axis, simple_text_visualization=simple_text_visualization, ) if collapse_enabled and normalized_collapse_rules: graph, rule_apply_counts = _collapse_graph_with_rules( graph, normalized_collapse_rules, collapse_annotations=collapse_annotations, ) for rule_index, apply_count in enumerate(rule_apply_counts): if apply_count == 0: warnings.warn( f"collapse_rules[{rule_index}] did not match any collapsible linear chain.", UserWarning, stacklevel=2, ) ranks = _assign_ranks(graph.nodes, graph.edges) graph, ranks = _expand_long_edges(graph, ranks, virtual_node_size) if graph.nodes: _mark_inputs_outputs(graph) text_top_padding: Dict[int, int] = {} text_bottom_padding: Dict[int, int] = {} if simple_text_visualization and simple_text_label_mode == "below" and text_callable is not None: text_top_padding, text_bottom_padding = _compute_external_text_padding( graph, text_callable=text_callable, text_vspacing=text_vspacing, font=font, ) components = _split_components(graph.nodes, graph.edges) components.sort(key=lambda comp: _component_sort_key(graph, comp)) column_widths = _column_widths(graph.nodes, ranks) x_positions = _column_positions(column_widths, padding, column_spacing) y_offset = padding for component_index, node_ids in enumerate(components): rank_nodes = _order_by_barycenter( graph, node_ids, ranks, iterations=layout_iterations, ) for node_id in node_ids: graph.nodes[node_id].component = component_index component_height = _assign_component_positions( graph, node_ids, rank_nodes, x_positions, column_widths, y_offset, row_spacing, node_top_padding=text_top_padding, node_bottom_padding=text_bottom_padding, ) if component_height <= 0: continue y_offset += component_height + component_spacing _straighten_layout( graph, ranks, row_spacing, node_top_padding=text_top_padding, node_bottom_padding=text_bottom_padding, ) img = _render_graph( graph, color_map=color_map, background_fill=background_fill, padding=padding, connector_fill=connector_fill, connector_width=connector_width, connector_arrow=connector_arrow, connector_padding=connector_padding, text_callable=text_callable, text_vspacing=text_vspacing, font=font, font_color=font_color, render_virtual_nodes=render_virtual_nodes, draw_volume=draw_volume, orientation_rotation=orientation_rotation, layered_groups=layered_groups, logo_groups=logo_groups, logos_legend=logos_legend, simple_text_visualization=simple_text_visualization, external_text_bottom_padding=text_bottom_padding, ) if to_file is not None: img.save(to_file) return img
def _build_graph( model, *, styles: Mapping[Union[str, type], Dict[str, Any]], global_defaults: Dict[str, Any], min_z: int, min_xy: int, max_z: int, max_xy: int, scale_z: float, scale_xy: float, one_dim_orientation: str, sizing_mode: str, dimension_caps: Optional[Mapping[str, int]], relative_base_size: int, add_output_nodes: bool, virtual_node_size: int, draw_volume: bool, shade_step: int, image_fit: str, image_axis: str, simple_text_visualization: bool = False, ) -> FunctionalGraph: """Build the intermediate functional graph from a Keras model. This stage creates ``FunctionalNode`` objects for each model layer, applies style/default resolution, computes visual dimensions, optionally loads image textures, and collects directed graph edges from inbound layer relations. When requested, synthetic output nodes are appended. Args: model: Keras/TensorFlow model. styles: Style overrides keyed by layer type or layer name. global_defaults: Base style values merged into each node style. min_z: Minimum depth dimension used in dimension scaling. min_xy: Minimum width/height dimension used in scaling. max_z: Maximum depth dimension used in scaling. max_xy: Maximum width/height dimension used in scaling. scale_z: Depth scaling factor. scale_xy: Width/height scaling factor. one_dim_orientation: Orientation hint for one-dimensional shapes. sizing_mode: Dimension scaling strategy. dimension_caps: Optional dimension caps for capped/balanced sizing modes. relative_base_size: Base pixel size used by relative sizing mode. add_output_nodes: Whether to append synthetic output marker nodes. virtual_node_size: Size used for synthetic helper nodes. draw_volume: Global default for volumetric rendering depth. shade_step: Global shade delta for box rendering. image_fit: Default image fit mode for textured nodes. image_axis: Default box face axis for image projection. simple_text_visualization: If true, force flat 2D boxes. Returns: A ``FunctionalGraph`` containing nodes, edges, and inferred input/output ids. """ def resolve_style(layer, name) -> Dict[str, Any]: final_style = global_defaults.copy() for cls in type(layer).__mro__: if cls in styles: final_style.update(styles[cls]) if name in styles: final_style.update(styles[name]) return final_style layers = list(get_layers(model)) order_map = {id(layer): index for index, layer in enumerate(layers)} nodes: Dict[int, FunctionalNode] = {} for layer in layers: node_id = id(layer) name = getattr(layer, "name", None) or f"layer_{order_map[node_id]}" node_style = resolve_style(layer, name) if simple_text_visualization: node_style = dict(node_style) node_style["draw_volume"] = False image_path = node_style.get("image") node_image = None shape = extract_primary_shape(_resolve_layer_output_shape(layer), name) dims = calculate_layer_dimensions( shape, scale_z, scale_xy, max_z, max_xy, min_z, min_xy, one_dim_orientation=one_dim_orientation, sizing_mode=sizing_mode, dimension_caps=dimension_caps, relative_base_size=relative_base_size, ) width = max(min_xy, int(dims[2])) height = max(min_xy, int(dims[1])) if simple_text_visualization: box_size = node_style.get("box_size") if isinstance(box_size, (tuple, list)) and len(box_size) == 2: try: width = int(box_size[0]) height = int(box_size[1]) except Exception: pass box_min = node_style.get("box_min_size") if isinstance(box_min, (tuple, list)) and len(box_min) == 2: try: width = max(width, int(box_min[0])) height = max(height, int(box_min[1])) except Exception: pass use_volume = node_style.get('draw_volume', draw_volume) if simple_text_visualization: use_volume = False de = 0 if use_volume: de = int(width / 3) if image_path: try: node_image = Image.open(image_path).convert("RGBA") fit_mode = node_style.get("image_fit", image_fit) axis = ("z" if simple_text_visualization else node_style.get("image_axis", image_axis)) if fit_mode == "match_aspect": img_w, img_h = node_image.size img_ratio = img_w / img_h if axis == 'z': surf_ratio = width / height if img_ratio > surf_ratio: width = int(height * img_ratio) else: height = int(width / img_ratio) elif axis == 'y': if img_ratio > 0: de = int(width / img_ratio) elif axis == 'x': de = int(height * img_ratio) scale_factor = node_style.get("scale_image") if scale_factor is not None: try: scale_factor = float(scale_factor) if scale_factor < 0: scale_factor = 0.0 except (ValueError, TypeError): scale_factor = 1.0 if axis == 'z': # Front (Width x Height) width = int(width * scale_factor) height = int(height * scale_factor) elif axis == 'y': # Top (Width x Depth) width = int(width * scale_factor) de = int(de * scale_factor) elif axis == 'x': # Side (Depth x Height) de = int(de * scale_factor) height = int(height * scale_factor) except Exception as e: warnings.warn(f"Failed to load image for layer '{name}': {e}. Reverting to default visualization.") image_path = None # Fallback to standard logic below total_width = width + de total_height = height + de shade = node_style.get('shade_step', shade_step) nodes[node_id] = FunctionalNode( layer=layer, node_id=node_id, name=name, layer_type=type(layer), shape=extract_primary_shape(_resolve_layer_output_shape(layer), name) if not image_path else None, dims=(int(dims[0]), int(dims[1]), int(dims[2])), width=total_width, height=total_height, order=order_map[node_id], style=node_style, de=de, shade=shade, image=node_image ) edges = _collect_edges(nodes) if add_output_nodes: nodes, edges = _attach_output_nodes(model, nodes, edges, virtual_node_size) inputs, outputs = _find_inputs_outputs(nodes, edges) return FunctionalGraph(nodes=nodes, edges=edges, inputs=inputs, outputs=outputs) def _collect_edges(nodes: Dict[int, FunctionalNode]) -> List[FunctionalEdge]: edges: List[FunctionalEdge] = [] seen = set() for node in nodes.values(): for inbound in get_incoming_layers(node.layer): inbound_id = id(inbound) if inbound_id not in nodes: continue edge = (inbound_id, node.node_id) if edge in seen: continue edges.append(FunctionalEdge(*edge)) seen.add(edge) return edges def _attach_output_nodes( model, nodes: Dict[int, FunctionalNode], edges: List[FunctionalEdge], virtual_node_size: int, ) -> Tuple[Dict[int, FunctionalNode], List[FunctionalEdge]]: order_base = max((node.order for node in nodes.values()), default=0) + 1 new_nodes = dict(nodes) new_edges = list(edges) outputs = list(find_output_layers(model)) output_size = max(virtual_node_size * 2, 16) for index, layer in enumerate(outputs): name = f"output_{index}" synthetic = _SyntheticLayer(name=name, output_shape=getattr(layer, "output_shape", None)) node_id = id(synthetic) new_nodes[node_id] = FunctionalNode( layer=synthetic, node_id=node_id, name=name, layer_type=type(synthetic), shape=extract_primary_shape(getattr(synthetic, "output_shape", None), name), dims=(output_size, output_size, output_size), width=output_size, height=output_size, order=order_base + index, kind="output", ) if id(layer) in nodes: new_edges.append(FunctionalEdge(id(layer), node_id)) return new_nodes, new_edges def _find_inputs_outputs( nodes: Dict[int, FunctionalNode], edges: Sequence[FunctionalEdge], ) -> Tuple[List[int], List[int]]: incoming = {node_id: 0 for node_id in nodes} outgoing = {node_id: 0 for node_id in nodes} for edge in edges: incoming[edge.dst] += 1 outgoing[edge.src] += 1 inputs = [node_id for node_id, count in incoming.items() if count == 0] outputs = [node_id for node_id, count in outgoing.items() if count == 0] return inputs, outputs def _build_edge_index( nodes: Mapping[int, FunctionalNode], edges: Sequence[FunctionalEdge], ) -> Tuple[Dict[int, List[int]], Dict[int, List[int]]]: outgoing: Dict[int, List[int]] = {node_id: [] for node_id in nodes} incoming: Dict[int, List[int]] = {node_id: [] for node_id in nodes} for edge in edges: if edge.src not in nodes or edge.dst not in nodes: continue outgoing[edge.src].append(edge.dst) incoming[edge.dst].append(edge.src) return outgoing, incoming def _node_matches_collapse_selector(node: FunctionalNode, selector: Union[str, type]) -> bool: if isinstance(selector, str): return node.name == selector try: return issubclass(node.layer_type, selector) except TypeError: return False def _find_first_collapse_sequence( graph: FunctionalGraph, rule: Mapping[str, Any], ) -> Optional[List[int]]: """Find the first collapsible linear node sequence matching one rule. Matching is strict and linear: each internal hop must traverse a node with exactly one outgoing edge and a successor with exactly one incoming edge from that predecessor. This avoids collapsing ambiguous branch/merge paths. Args: graph: Current graph state. rule: One normalized collapse rule. Returns: A list of node ids for the first valid match in node order, or ``None`` if no sequence satisfies the rule. """ kind = str(rule["kind"]) repeat_count = int(rule["repeat_count"]) if kind == "layer": selector_pattern: List[Union[str, type]] = [rule["selector"]] * repeat_count else: block = list(rule["selector"]) selector_pattern = block * repeat_count if len(selector_pattern) < 2: return None outgoing, incoming = _build_edge_index(graph.nodes, graph.edges) candidate_ids = sorted(graph.nodes, key=lambda node_id: graph.nodes[node_id].order) for start_id in candidate_ids: start_node = graph.nodes[start_id] if start_node.kind != "layer": continue if not _node_matches_collapse_selector(start_node, selector_pattern[0]): continue sequence = [start_id] cursor = start_id valid = True for expected_selector in selector_pattern[1:]: out_nodes = outgoing.get(cursor, []) if len(out_nodes) != 1: valid = False break next_id = out_nodes[0] if next_id in sequence: valid = False break next_node = graph.nodes.get(next_id) if next_node is None or next_node.kind != "layer": valid = False break in_nodes = incoming.get(next_id, []) if len(in_nodes) != 1 or in_nodes[0] != cursor: valid = False break if not _node_matches_collapse_selector(next_node, expected_selector): valid = False break sequence.append(next_id) cursor = next_id if valid: return sequence return None def _collapse_node_sequence( graph: FunctionalGraph, sequence: Sequence[int], *, rule: Mapping[str, Any], collapse_annotations: bool, ) -> FunctionalGraph: """Collapse one matched node sequence into a synthetic collapsed node. The function removes all sequence members, inserts one synthetic node with merged metadata/style markers, and rewires boundary edges so incoming edges target the collapsed node and outgoing edges originate from it. Args: graph: Source graph. sequence: Ordered node ids to collapse. rule: Normalized rule that produced the match. collapse_annotations: Whether block-level annotation rendering is enabled. Returns: A new ``FunctionalGraph`` with collapsed topology and recomputed input/output sets. """ if not sequence: return graph seq_ids = list(sequence) seq_set = set(seq_ids) old_nodes = graph.nodes member_nodes = [old_nodes[node_id] for node_id in seq_ids] first_node = member_nodes[0] last_node = member_nodes[-1] collapse_label = str(rule.get("label", f"{rule['repeat_count']}x")) collapse_kind = str(rule["kind"]) synthetic_name = f"collapsed_{first_node.name}_{collapse_label}" synthetic_layer = _SyntheticLayer( name=synthetic_name, output_shape=getattr(last_node.layer, "output_shape", None), ) collapsed_node_id = id(synthetic_layer) while collapsed_node_id in old_nodes: synthetic_layer = _SyntheticLayer( name=f"{synthetic_name}_{collapsed_node_id}", output_shape=getattr(last_node.layer, "output_shape", None), ) collapsed_node_id = id(synthetic_layer) collapsed_style = dict(first_node.style or {}) collapsed_style["collapsed"] = True collapsed_style["collapse_kind"] = collapse_kind collapsed_style["collapse_repeat_count"] = int(rule["repeat_count"]) collapsed_style["collapse_label"] = collapse_label collapsed_style["collapse_annotation_position"] = rule.get("annotation_position", "above") collapsed_style["collapse_annotation_enabled"] = bool(collapse_annotations) collapsed_style["collapse_members"] = tuple(node.name for node in member_nodes) if collapse_kind == "block": selector = rule.get("selector", ()) collapsed_style["collapse_block_size"] = len(selector) if isinstance(selector, tuple) else 0 collapsed_node = FunctionalNode( layer=synthetic_layer, node_id=collapsed_node_id, name=synthetic_name, layer_type=first_node.layer_type, shape=last_node.shape, dims=( max(node.dims[0] for node in member_nodes), max(node.dims[1] for node in member_nodes), max(node.dims[2] for node in member_nodes), ), width=max(node.width for node in member_nodes), height=max(node.height for node in member_nodes), order=first_node.order, rank=first_node.rank, rank_order=first_node.rank_order, kind="collapsed", component=first_node.component, style=collapsed_style, de=max(node.de for node in member_nodes), shade=first_node.shade, image=first_node.image if collapse_kind == "layer" else None, ) new_nodes = {node_id: node for node_id, node in old_nodes.items() if node_id not in seq_set} new_nodes[collapsed_node_id] = collapsed_node first_id = seq_ids[0] last_id = seq_ids[-1] seen_edges = set() new_edges: List[FunctionalEdge] = [] for edge in graph.edges: src = edge.src dst = edge.dst if src in seq_set and dst in seq_set: continue if src in seq_set: if src != last_id: continue src = collapsed_node_id if dst in seq_set: if dst != first_id: continue dst = collapsed_node_id if src == dst: continue edge_key = (src, dst) if edge_key in seen_edges: continue seen_edges.add(edge_key) new_edges.append(FunctionalEdge(src, dst)) inputs, outputs = _find_inputs_outputs(new_nodes, new_edges) return FunctionalGraph(nodes=new_nodes, edges=new_edges, inputs=inputs, outputs=outputs) def _collapse_graph_with_rules( graph: FunctionalGraph, rules: Sequence[Mapping[str, Any]], *, collapse_annotations: bool, ) -> Tuple[FunctionalGraph, List[int]]: """Apply collapse rules repeatedly until no more matches remain. Rules are processed in order. For each rule, the first valid match is collapsed repeatedly until that rule has no additional matches in the updated graph. Args: graph: Source graph before collapse. rules: Normalized collapse rules. collapse_annotations: Whether collapsed block annotations are enabled. Returns: Tuple of ``(collapsed_graph, applied_counts)`` where ``applied_counts`` tracks how many collapses each rule produced. """ if not rules: return graph, [] collapsed_graph = graph applied_counts = [0 for _ in rules] for rule_index, rule in enumerate(rules): while True: sequence = _find_first_collapse_sequence(collapsed_graph, rule) if not sequence: break collapsed_graph = _collapse_node_sequence( collapsed_graph, sequence, rule=rule, collapse_annotations=collapse_annotations, ) applied_counts[rule_index] += 1 return collapsed_graph, applied_counts def _assign_ranks( nodes: Dict[int, FunctionalNode], edges: Sequence[FunctionalEdge], ) -> Dict[int, int]: outgoing: Dict[int, List[int]] = {node_id: [] for node_id in nodes} incoming_count = {node_id: 0 for node_id in nodes} for edge in edges: outgoing[edge.src].append(edge.dst) incoming_count[edge.dst] += 1 queue = deque(sorted( (node_id for node_id, count in incoming_count.items() if count == 0), key=lambda node_id: nodes[node_id].order, )) ranks = {node_id: 0 for node_id in queue} while queue: node_id = queue.popleft() for child_id in outgoing[node_id]: ranks[child_id] = max(ranks.get(child_id, 0), ranks[node_id] + 1) incoming_count[child_id] -= 1 if incoming_count[child_id] == 0: queue.append(child_id) if len(ranks) != len(nodes): missing = [node_id for node_id in nodes if node_id not in ranks] for node_id in missing: ranks[node_id] = 0 warnings.warn( "Functional graph contains cycles or disconnected nodes. " "Assigning rank 0 to unprocessed nodes.", UserWarning, stacklevel=2, ) for node_id, rank in ranks.items(): nodes[node_id].rank = rank return ranks def _expand_long_edges( graph: FunctionalGraph, ranks: Dict[int, int], virtual_node_size: int, ) -> Tuple[FunctionalGraph, Dict[int, int]]: if not graph.edges: return graph, ranks new_nodes = dict(graph.nodes) new_edges: List[FunctionalEdge] = [] order_base = max((node.order for node in new_nodes.values()), default=0) + 1 for edge in graph.edges: src_rank = ranks.get(edge.src, 0) dst_rank = ranks.get(edge.dst, src_rank + 1) rank_delta = dst_rank - src_rank if rank_delta <= 1: new_edges.append(edge) continue prev_id = edge.src for step in range(1, rank_delta): rank = src_rank + step name = f"virtual_{order_base}" synthetic = _SyntheticLayer(name=name) node_id = id(synthetic) new_nodes[node_id] = FunctionalNode( layer=synthetic, node_id=node_id, name=name, layer_type=type(synthetic), shape=None, dims=(virtual_node_size, virtual_node_size, virtual_node_size), width=virtual_node_size, height=virtual_node_size, order=order_base, rank=rank, kind="virtual", de=0, # Virtual nodes are always flat shade=0 ) ranks[node_id] = rank new_edges.append(FunctionalEdge(prev_id, node_id)) prev_id = node_id order_base += 1 new_edges.append(FunctionalEdge(prev_id, edge.dst)) inputs, outputs = _find_inputs_outputs(new_nodes, new_edges) return FunctionalGraph(nodes=new_nodes, edges=new_edges, inputs=inputs, outputs=outputs), ranks def _mark_inputs_outputs(graph: FunctionalGraph) -> None: for node_id in graph.inputs: node = graph.nodes.get(node_id) if node and node.kind == "layer": node.kind = "input" for node_id in graph.outputs: node = graph.nodes.get(node_id) if node and node.kind == "layer": node.kind = "output" def _split_components( nodes: Mapping[int, FunctionalNode], edges: Sequence[FunctionalEdge], ) -> List[List[int]]: adjacency: Dict[int, List[int]] = {node_id: [] for node_id in nodes} for edge in edges: adjacency[edge.src].append(edge.dst) adjacency[edge.dst].append(edge.src) components: List[List[int]] = [] visited = set() for node_id in nodes: if node_id in visited: continue queue = deque([node_id]) visited.add(node_id) component = [] while queue: current = queue.popleft() component.append(current) for neighbor in adjacency.get(current, []): if neighbor not in visited: visited.add(neighbor) queue.append(neighbor) components.append(component) return components def _component_sort_key(graph: FunctionalGraph, node_ids: Sequence[int]) -> Tuple[int, int]: orders = [graph.nodes[node_id].order for node_id in node_ids] ranks = [graph.nodes[node_id].rank for node_id in node_ids] return (min(ranks or [0]), min(orders or [0])) def _order_by_barycenter( graph: FunctionalGraph, node_ids: Sequence[int], ranks: Dict[int, int], *, iterations: int, ) -> Dict[int, List[int]]: node_set = set(node_ids) sub_edges = [ edge for edge in graph.edges if edge.src in node_set and edge.dst in node_set and abs(ranks.get(edge.dst, 0) - ranks.get(edge.src, 0)) == 1 ] max_rank = max((ranks.get(node_id, 0) for node_id in node_ids), default=0) rank_nodes: Dict[int, List[int]] = {rank: [] for rank in range(max_rank + 1)} for node_id in node_ids: rank_nodes[ranks.get(node_id, 0)].append(node_id) for node_list in rank_nodes.values(): node_list.sort(key=lambda node_id: graph.nodes[node_id].order) positions = _positions_from_rank_nodes(rank_nodes) incoming = _incoming_map(sub_edges) outgoing = _outgoing_map(sub_edges) if iterations <= 0: for rank, node_list in rank_nodes.items(): for index, node_id in enumerate(node_list): graph.nodes[node_id].rank_order = index return rank_nodes for _ in range(iterations): for rank in range(1, max_rank + 1): rank_nodes[rank].sort( key=lambda node_id: _barycenter_key(node_id, incoming, positions, graph) ) positions = _positions_from_rank_nodes(rank_nodes) for rank in range(max_rank - 1, -1, -1): rank_nodes[rank].sort( key=lambda node_id: _barycenter_key(node_id, outgoing, positions, graph) ) positions = _positions_from_rank_nodes(rank_nodes) for rank, node_list in rank_nodes.items(): for index, node_id in enumerate(node_list): graph.nodes[node_id].rank_order = index return rank_nodes def _barycenter_key( node_id: int, neighbor_map: Mapping[int, List[int]], positions: Mapping[int, int], graph: FunctionalGraph, ) -> Tuple[float, int]: neighbors = neighbor_map.get(node_id, []) values = [positions[n] for n in neighbors if n in positions] if values: barycenter = sum(values) / len(values) else: barycenter = float(positions.get(node_id, 0)) return (barycenter, graph.nodes[node_id].order) def _compute_external_text_padding( graph: FunctionalGraph, *, text_callable: Callable[[int, Any], Tuple[str, bool]], text_vspacing: int, font: Optional[ImageFont.ImageFont], ) -> Tuple[Dict[int, int], Dict[int, int]]: """Estimate vertical label extents for external text labels. The returned maps are keyed by node id and contain pixel extents that should be reserved above or below node rectangles to avoid overlap with neighboring rows when labels are drawn outside boxes. """ text_top_padding: Dict[int, int] = {} text_bottom_padding: Dict[int, int] = {} active_font = font or ImageFont.load_default() visible_nodes = [ node for node in sorted(graph.nodes.values(), key=lambda n: n.order) if node.kind != "virtual" ] for index, node in enumerate(visible_nodes): text, above = text_callable(index, node.layer) text_value = "" if text is None else str(text) if not text_value: continue text_height = 0 for line in text_value.split("\n"): if hasattr(active_font, "getsize"): text_height += active_font.getsize(line)[1] else: text_height += active_font.getbbox(line)[3] text_height += (len(text_value.split("\n")) - 1) * text_vspacing extent = max(0, int(text_height) + 4) if above: text_top_padding[node.node_id] = extent else: text_bottom_padding[node.node_id] = extent return text_top_padding, text_bottom_padding def _resolve_external_label_x_collisions( labels: List[Dict[str, Any]], *, image_width: int, edge_padding: int, min_gap: int = 8, y_tolerance: int = 2, ) -> None: """Shift external labels horizontally to reduce same-row overlap. Labels are grouped by overlapping y-ranges (with a small tolerance). Within each group, labels are ordered by preferred x and nudged to satisfy a minimum horizontal gap where possible. """ if len(labels) < 2: return sorted_indices = sorted(range(len(labels)), key=lambda idx: labels[idx]["y"]) groups: List[List[int]] = [] current_group = [sorted_indices[0]] current_bottom = labels[sorted_indices[0]]["y"] + labels[sorted_indices[0]]["h"] for idx in sorted_indices[1:]: label = labels[idx] y1 = label["y"] y2 = y1 + label["h"] if y1 <= current_bottom + y_tolerance: current_group.append(idx) current_bottom = max(current_bottom, y2) else: groups.append(current_group) current_group = [idx] current_bottom = y2 groups.append(current_group) left_bound = float(edge_padding) right_bound = float(max(edge_padding, image_width - edge_padding)) available = max(1.0, right_bound - left_bound) for group in groups: if len(group) < 2: continue ordered = sorted(group, key=lambda idx: labels[idx]["x_pref"]) widths = [float(max(1, labels[idx]["w"])) for idx in ordered] required = sum(widths) + float(min_gap * (len(ordered) - 1)) xs = [float(labels[idx]["x_pref"]) for idx in ordered] if required > available: x_cursor = left_bound for i in range(len(xs)): xs[i] = x_cursor x_cursor += widths[i] + min_gap else: xs[0] = max(xs[0], left_bound) for i in range(1, len(xs)): xs[i] = max(xs[i], xs[i - 1] + widths[i - 1] + min_gap) overflow = (xs[-1] + widths[-1]) - right_bound if overflow > 0: xs = [x - overflow for x in xs] if xs[0] < left_bound: x_cursor = left_bound for i in range(len(xs)): xs[i] = x_cursor x_cursor += widths[i] + min_gap for i, idx in enumerate(ordered): labels[idx]["x"] = int(round(xs[i])) def _positions_from_rank_nodes(rank_nodes: Mapping[int, List[int]]) -> Dict[int, int]: positions: Dict[int, int] = {} for node_list in rank_nodes.values(): for index, node_id in enumerate(node_list): positions[node_id] = index return positions def _assign_component_positions( graph: FunctionalGraph, node_ids: Sequence[int], rank_nodes: Mapping[int, List[int]], x_positions: Mapping[int, int], column_widths: Mapping[int, int], base_y: int, row_spacing: int, *, node_top_padding: Optional[Mapping[int, int]] = None, node_bottom_padding: Optional[Mapping[int, int]] = None, ) -> int: node_top_padding = node_top_padding or {} node_bottom_padding = node_bottom_padding or {} node_set = set(node_ids) max_height = 0 column_heights: Dict[int, int] = {} for rank, ordered_ids in rank_nodes.items(): filtered = [node_id for node_id in ordered_ids if node_id in node_set] if not filtered: continue column_width = column_widths.get(rank, 0) if column_width <= 0: column_width = max(graph.nodes[node_id].width for node_id in filtered) column_height = 0 for node_id in filtered: node = graph.nodes[node_id] column_height += ( node.height + int(node_top_padding.get(node_id, 0)) + int(node_bottom_padding.get(node_id, 0)) ) if len(filtered) > 1: column_height += row_spacing * (len(filtered) - 1) column_heights[rank] = column_height max_height = max(max_height, column_height) for rank, ordered_ids in rank_nodes.items(): filtered = [node_id for node_id in ordered_ids if node_id in node_set] if not filtered: continue column_width = column_widths.get(rank, 0) if column_width <= 0: column_width = max(graph.nodes[node_id].width for node_id in filtered) column_height = column_heights.get(rank, 0) y_cursor = base_y + int((max_height - column_height) / 2) for node_id in filtered: node = graph.nodes[node_id] top_pad = int(node_top_padding.get(node_id, 0)) bottom_pad = int(node_bottom_padding.get(node_id, 0)) node.x = x_positions.get(rank, 0) + int((column_width - node.width) / 2) node.y = y_cursor + top_pad y_cursor = node.y + node.height + bottom_pad + row_spacing if column_height: max_height = max(max_height, y_cursor - base_y - row_spacing) return max_height def _column_widths(nodes: Mapping[int, FunctionalNode], ranks: Mapping[int, int]) -> Dict[int, int]: widths: Dict[int, int] = {} for node_id, node in nodes.items(): rank = ranks.get(node_id, 0) widths[rank] = max(widths.get(rank, 0), node.width) return widths def _column_positions(column_widths: Mapping[int, int], padding: int, column_spacing: int) -> Dict[int, int]: if not column_widths: return {} max_rank = max(column_widths) positions: Dict[int, int] = {} x_cursor = padding for rank in range(max_rank + 1): positions[rank] = x_cursor x_cursor += column_widths.get(rank, 0) + column_spacing return positions def _get_font(group: Dict[str, Any]) -> ImageFont.ImageFont: font_src = group.get("font", None) font_size = group.get("font_size", 15) if font_src is None: try: return ImageFont.truetype("arial.ttf", font_size) except IOError: return ImageFont.load_default() elif isinstance(font_src, str): try: return ImageFont.truetype(font_src, font_size) except IOError: return ImageFont.load_default() elif isinstance(font_src, ImageFont.ImageFont): return font_src else: return ImageFont.load_default() def _measure_text(draw: ImageDraw.ImageDraw, text: str, font: Any) -> Tuple[int, int]: if hasattr(font, "getbbox"): left, top, right, bottom = font.getbbox(text) return right - left, bottom - top else: return draw.textsize(text, font=font) def _prettify_layer_name(name: str) -> str: if not name: return "" name = name.replace("_", " ").strip() # CamelCase -> spaced name = re.sub(r"(?<!^)(?=[A-Z])", " ", name) name = re.sub(r"\s+", " ", name).strip() return name def _resolve_box_label(node: FunctionalNode) -> str: style = node.style or {} if style.get("box_text") is not None: return str(style.get("box_text")) cb = style.get("box_text_callable") if callable(cb): try: out = cb(node.layer) if out is not None: return str(out) except Exception: pass return _prettify_layer_name(getattr(node.layer_type, "__name__", node.name)) def _try_load_font(path: str, size: int) -> Optional[ImageFont.ImageFont]: try: return ImageFont.truetype(path, size) except Exception: return None def _is_font_like(value: Any) -> bool: if value is None: return False return any(hasattr(value, attr) for attr in ("getbbox", "getsize", "getmask")) def _resolve_box_font(style: Mapping[str, Any], fallback: Optional[Any]) -> Tuple[Any, Optional[str], int]: size = int(style.get("box_text_font_size", 14) or 14) src = style.get("box_text_font") if _is_font_like(src): return src, None, size if isinstance(src, str): f = _try_load_font(src, size) if f is not None: return f, src, size if _is_font_like(fallback): return fallback, None, size for cand in ("DejaVuSans.ttf", "arial.ttf"): f = _try_load_font(cand, size) if f is not None: return f, cand, size return ImageFont.load_default(), None, size def _multiline_bbox(draw: ImageDraw.ImageDraw, text: str, font: ImageFont.ImageFont, spacing: int) -> Tuple[int, int, int, int]: # Prefer PIL's bbox helper (handles glyph bearings correctly) if hasattr(draw, "multiline_textbbox"): return draw.multiline_textbbox((0, 0), text, font=font, spacing=spacing, align="left") # Fallback: union per-line bbox lefts, tops, rights, bottoms = [], [], [], [] y = 0 for line in text.split("\n"): if hasattr(draw, "textbbox"): l, t, r, b = draw.textbbox((0, y), line, font=font) else: w, h = draw.textsize(line, font=font) l, t, r, b = 0, y, w, y + h lefts.append(l); tops.append(t); rights.append(r); bottoms.append(b) y += (b - t) + spacing if not lefts: return (0, 0, 0, 0) return (min(lefts), min(tops), max(rights), max(bottoms)) def _render_text_image_bbox(text: str, font: ImageFont.ImageFont, color: Any, spacing: int, margin: int = 2) -> Image.Image: # IMPORTANT: account for negative bearings by offsetting by -left/-top dummy = Image.new("RGBA", (1, 1), (0, 0, 0, 0)) d = ImageDraw.Draw(dummy) l, t, r, b = _multiline_bbox(d, text, font, spacing) w = max(1, int(r - l)) h = max(1, int(b - t)) img = Image.new("RGBA", (w + 2 * margin, h + 2 * margin), (0, 0, 0, 0)) d2 = ImageDraw.Draw(img) d2.multiline_text((margin - l, margin - t), text, font=font, fill=color, spacing=spacing, align="center") return img def _wrap_text_to_width(draw: ImageDraw.ImageDraw, text: str, font: ImageFont.ImageFont, max_width: int, mode: str, max_lines: Optional[int]) -> str: mode = (mode or "words").lower() if mode == "none" or max_width <= 0: return text if mode not in {"words", "chars"}: mode = "words" tokens = text.split() if mode == "words" else list(text) lines = [] cur = "" for tok in tokens: cand = tok if not cur else (cur + (" " if mode == "words" else "") + tok) l, t, r, b = _multiline_bbox(draw, cand, font, 0) if (r - l) <= max_width or not cur: cur = cand else: lines.append(cur) cur = tok if max_lines is not None and len(lines) >= max_lines: break if cur and (max_lines is None or len(lines) < max_lines): lines.append(cur) if max_lines is not None: lines = lines[:max_lines] return "\n".join(lines) def _draw_box_text_in_rect( base: Image.Image, rect: Tuple[int, int, int, int], text: str, *, style: Mapping[str, Any], fallback_font: Optional[ImageFont.ImageFont], fallback_color: Any, fallback_spacing: int, ) -> None: """Render style-driven text inside a rectangular node region. The text is optionally wrapped, autoshrunk, rotated, and aligned according to ``box_text_*`` style keys. Rendering is performed via an intermediate RGBA text image to avoid clipping from glyph bearings. Args: base: Destination image. rect: Target rectangle ``(x1, y1, x2, y2)``. text: Raw label text. style: Node style mapping with text layout keys. fallback_font: Global fallback font when per-node font is absent. fallback_color: Global fallback text color. fallback_spacing: Global fallback line spacing. """ if not text: return x1, y1, x2, y2 = rect w = max(1, x2 - x1) h = max(1, y2 - y1) pad = style.get("box_text_padding", 8) if isinstance(pad, (tuple, list)) and len(pad) == 2: pad_x, pad_y = int(pad[0]), int(pad[1]) else: pad_x = pad_y = int(pad) avail_w = max(1, w - 2 * pad_x) avail_h = max(1, h - 2 * pad_y) orientation = str(style.get("box_orientation", "vertical") or "vertical").lower() rot = style.get("box_text_rotation") if rot is None: rot = 90 if orientation == "vertical" else 0 try: rot = int(rot) % 360 except Exception: rot = 0 if rot not in (0, 90, 180, 270): rot = 0 # For 90/270, swap fit constraints pre-rotation pre_w, pre_h = (avail_h, avail_w) if rot in (90, 270) else (avail_w, avail_h) spacing = int(style.get("box_text_line_spacing", fallback_spacing) or fallback_spacing) color = style.get("box_text_color", fallback_color) wrap_mode = str(style.get("box_text_wrap", "words") or "words").lower() max_lines = style.get("box_text_max_lines") try: max_lines = int(max_lines) if max_lines is not None else None except Exception: max_lines = None autoshrink = bool(style.get("box_text_autoshrink", True)) min_size = int(style.get("box_text_min_font_size", 8) or 8) font, font_path, size0 = _resolve_box_font(style, fallback_font) dummy = Image.new("RGBA", (1, 1), (0, 0, 0, 0)) d = ImageDraw.Draw(dummy) def measure_fit(fnt: ImageFont.ImageFont) -> Tuple[str, int, int, bool]: wrapped = _wrap_text_to_width(d, text, fnt, pre_w, wrap_mode, max_lines) l, t, r, b = _multiline_bbox(d, wrapped, fnt, spacing) tw = int(r - l) th = int(b - t) return wrapped, tw, th, (tw <= pre_w and th <= pre_h) wrapped, tw, th, ok = measure_fit(font) if autoshrink and not ok and font_path: cur = size0 while cur > min_size: cur -= 1 f2 = _try_load_font(font_path, cur) if f2 is None: break wrapped2, tw2, th2, ok2 = measure_fit(f2) if ok2: font = f2 wrapped, tw, th, ok = wrapped2, tw2, th2, ok2 break txt_img = _render_text_image_bbox(wrapped, font, color, spacing, margin=2) if rot: txt_img = txt_img.rotate(rot, expand=True) tw, th = txt_img.size align = str(style.get("box_text_align", "center") or "center").lower() valign = str(style.get("box_text_valign", "middle") or "middle").lower() if align == "left": px = x1 + pad_x elif align == "right": px = x2 - pad_x - tw else: px = x1 + (w - tw) // 2 if valign == "top": py = y1 + pad_y elif valign == "bottom": py = y2 - pad_y - th else: py = y1 + (h - th) // 2 base.alpha_composite(txt_img, (int(px), int(py))) def _resolve_annotation_font( style_font: Any, fallback_font: Optional[Any], size: int, ) -> Any: if _is_font_like(style_font): return style_font if isinstance(style_font, str): resolved = _try_load_font(style_font, size) if resolved is not None: return resolved if _is_font_like(fallback_font): return fallback_font for candidate in ("DejaVuSans.ttf", "arial.ttf"): resolved = _try_load_font(candidate, size) if resolved is not None: return resolved return ImageFont.load_default() def _draw_collapse_badge( draw: ImageDraw.ImageDraw, *, rect: Tuple[int, int, int, int], label: str, font: Any, fill: Any, outline: Any, text_color: Any, padding: Union[int, Tuple[int, int], List[int]], ) -> None: """Draw a compact collapse-count badge (for example ``"4x"``) on a node. Args: draw: PIL drawing context. rect: Node bounds ``(x1, y1, x2, y2)``. label: Badge text. font: Font used for label text. fill: Badge background color. outline: Badge border color. text_color: Badge text color. padding: Horizontal/vertical text padding in the badge. """ if not label: return x1, y1, x2, _ = rect text_w, text_h = _measure_text(draw, label, font) if isinstance(padding, (tuple, list)) and len(padding) == 2: pad_x, pad_y = int(padding[0]), int(padding[1]) else: pad_x = pad_y = int(padding) badge_w = max(1, text_w + 2 * pad_x) badge_h = max(1, text_h + 2 * pad_y) badge_x2 = x2 - 4 badge_x1 = max(x1 + 2, badge_x2 - badge_w) badge_y1 = y1 + 4 badge_y2 = badge_y1 + badge_h if hasattr(draw, "rounded_rectangle"): try: draw.rounded_rectangle( (badge_x1, badge_y1, badge_x2, badge_y2), radius=4, fill=fill, outline=outline, width=1, ) except TypeError: draw.rounded_rectangle( (badge_x1, badge_y1, badge_x2, badge_y2), radius=4, fill=fill, outline=outline, ) else: draw.rectangle((badge_x1, badge_y1, badge_x2, badge_y2), fill=fill, outline=outline) draw.text((badge_x1 + pad_x, badge_y1 + pad_y), label, font=font, fill=text_color) def _draw_collapse_block_annotation( draw: ImageDraw.ImageDraw, *, rect: Tuple[int, int, int, int], label: str, position: str, color: Any, line_width: int, head_size: int, offset: int, font: Any, image_size: Tuple[int, int], ) -> None: """Draw a block-collapse double-headed arrow with an optional count label. Args: draw: PIL drawing context. rect: Collapsed node bounds ``(x1, y1, x2, y2)``. label: Annotation text shown near the arrow. position: ``"above"`` or ``"below"`` relative to the node. color: Arrow and text color. line_width: Arrow line thickness. head_size: Arrowhead size. offset: Pixel distance from node edge to arrow baseline. font: Font used for annotation text. image_size: Destination image size used for clipping guards. """ x1, y1, x2, y2 = rect if x2 <= x1: return left = int(x1 + 8) right = int(x2 - 8) if right - left < 12: left = x1 + 2 right = x2 - 2 if right - left < 8: return if position == "below": line_y = y2 + offset else: line_y = y1 - offset line_y = max(2, min(image_size[1] - 3, int(line_y))) draw.line((left, line_y, right, line_y), fill=color, width=max(1, line_width)) head = max(3, head_size) draw.polygon( [(left, line_y), (left + head, line_y - head // 2), (left + head, line_y + head // 2)], fill=color, ) draw.polygon( [(right, line_y), (right - head, line_y - head // 2), (right - head, line_y + head // 2)], fill=color, ) if not label: return text_w, text_h = _measure_text(draw, label, font) text_x = int((left + right - text_w) / 2) if position == "below": text_y = line_y + 4 else: text_y = line_y - text_h - 4 text_y = max(0, min(image_size[1] - max(1, text_h), text_y)) bg = (255, 255, 255, 220) draw.rectangle( (text_x - 2, text_y - 1, text_x + text_w + 2, text_y + text_h + 1), fill=bg, ) draw.text((text_x, text_y), label, font=font, fill=color) def _draw_collapsed_annotations( img: Image.Image, collapsed_nodes: Sequence[Tuple[FunctionalNode, Tuple[int, int, int, int]]], *, fallback_font: Optional[Any], fallback_font_color: Any, default_annotation_color: Any, ) -> None: """Render all collapsed-node overlays after the main node pass. For each collapsed node this can draw: - an ``Nx`` badge - an optional double-headed block annotation line with label Args: img: Destination RGBA image. collapsed_nodes: ``(node, rect)`` entries gathered during rendering. fallback_font: Global fallback font. fallback_font_color: Global fallback text color. default_annotation_color: Global fallback arrow/annotation color. """ if not collapsed_nodes: return draw = ImageDraw.Draw(img) for node, rect in collapsed_nodes: style = node.style or {} if not bool(style.get("collapsed", False)): continue label = str(style.get("collapse_label", "")) if bool(style.get("collapse_badge_enabled", True)): badge_font_size = int(style.get("collapse_badge_font_size", 12) or 12) badge_font = _resolve_annotation_font( style.get("collapse_badge_font"), fallback_font, badge_font_size, ) _draw_collapse_badge( draw, rect=rect, label=label, font=badge_font, fill=style.get("collapse_badge_fill", "white"), outline=style.get("collapse_badge_outline", "black"), text_color=style.get("collapse_badge_text_color", fallback_font_color), padding=style.get("collapse_badge_padding", (4, 2)), ) if ( style.get("collapse_kind") == "block" and bool(style.get("collapse_annotation_enabled", True)) ): annotation_font_size = int(style.get("collapse_annotation_font_size", 12) or 12) annotation_font = _resolve_annotation_font( style.get("collapse_annotation_font"), fallback_font, annotation_font_size, ) position = str(style.get("collapse_annotation_position", "above") or "above").lower() if position not in {"above", "below"}: position = "above" _draw_collapse_block_annotation( draw, rect=rect, label=label, position=position, color=style.get("collapse_annotation_color", default_annotation_color), line_width=int(style.get("collapse_annotation_width", 2) or 2), head_size=int(style.get("collapse_annotation_head_size", 6) or 6), offset=int(style.get("collapse_annotation_offset", 10) or 10), font=annotation_font, image_size=img.size, ) def _render_graph( graph: FunctionalGraph, *, color_map: Mapping[type, Mapping[str, Any]], background_fill: Any, padding: int, connector_fill: Any, connector_width: int, connector_arrow: bool, connector_padding: int, text_callable: Optional[Callable[[int, Any], Tuple[str, bool]]], text_vspacing: int, font: Optional[ImageFont.ImageFont], font_color: Any, render_virtual_nodes: bool, draw_volume: bool, orientation_rotation: Optional[float] = None, layered_groups: Optional[Sequence[Dict[str, Any]]] = None, logo_groups: Optional[Sequence[Dict[str, Any]]] = None, logos_legend: Union[bool, Dict[str, Any]] = False, simple_text_visualization: bool = False, external_text_bottom_padding: Optional[Mapping[int, int]] = None, ) -> Image.Image: """Render a positioned ``FunctionalGraph`` to a PIL image. Rendering order: 1. optional group backgrounds 2. connectors 3. nodes (flat or volumetric), node images, and logos 4. in-node text and external annotations (collapsed markers) 5. optional group captions and logo legend Args: graph: Graph with already-computed node positions. color_map: Optional per-layer fill/outline overrides. background_fill: Canvas background color. padding: Outer image padding. connector_fill: Connector color. connector_width: Default connector width. connector_arrow: Whether connectors include arrowheads by default. connector_padding: Anchor offset from node faces. text_callable: Optional callable for above/below external labels. text_vspacing: Vertical spacing for multiline text labels. font: Optional global font for text labels. font_color: Global text color. render_virtual_nodes: Whether virtual routing nodes are visible. draw_volume: Whether default node rendering uses volumetric boxes. orientation_rotation: Optional rotation applied to volumetric boxes. layered_groups: Optional background highlight groups. logo_groups: Optional node-logo overlay groups. logos_legend: Legend toggle/config for logo groups. simple_text_visualization: Switch to flat 2D box rendering mode. Returns: A rendered ``PIL.Image.Image``. """ max_right = padding max_bottom = padding for node in graph.nodes.values(): max_right = max(max_right, node.x + node.width + padding) max_bottom = max(max_bottom, node.y + node.height + padding) if ( node.kind == "collapsed" and node.style.get("collapse_kind") == "block" and bool(node.style.get("collapse_annotation_enabled", True)) and str(node.style.get("collapse_annotation_position", "above")).lower() == "below" ): annotation_offset = int(node.style.get("collapse_annotation_offset", 10) or 10) max_bottom = max(max_bottom, node.y + node.height + annotation_offset + padding + 24) if external_text_bottom_padding: for node_id, extra in external_text_bottom_padding.items(): node = graph.nodes.get(node_id) if node is None: continue max_bottom = max(max_bottom, node.y + node.height + int(extra) + padding) if layered_groups: dummy_img = Image.new("RGBA", (1, 1)) dummy_draw = ImageDraw.Draw(dummy_img) for group in layered_groups: group_nodes = _get_group_nodes(graph, group) if not group_nodes: continue g_min_x = min(n.x for n in group_nodes) g_max_x = max(n.x + n.width for n in group_nodes) g_min_y = min(n.y for n in group_nodes) g_max_y = max(n.y + n.height for n in group_nodes) g_padding = group.get("padding", 10) g_min_x -= g_padding g_max_x += g_padding g_min_y -= g_padding g_max_y += g_padding max_right = max(max_right, g_max_x + padding) max_bottom = max(max_bottom, g_max_y + padding) caption = group.get("name", group.get("caption")) if caption: font = _get_font(group) text_w, text_h = _measure_text(dummy_draw, caption, font) center_x = (g_min_x + g_max_x) / 2 text_x = center_x - text_w / 2 gap = group.get("text_spacing", 5) text_y = g_max_y + gap max_right = max(max_right, text_x + text_w + padding) max_bottom = max(max_bottom, text_y + text_h + padding) node_logos = {} # node_id -> list of (group, image) if logo_groups: for group in logo_groups: path = group.get("file") if not path: continue try: logo_img = Image.open(path) except: continue target_nodes = _get_logo_nodes(graph, group) for node in target_nodes: if node.node_id not in node_logos: node_logos[node.node_id] = [] node_logos[node.node_id].append((group, logo_img)) img = Image.new("RGBA", (int(max_right), int(max_bottom)), background_fill) draw = aggdraw.Draw(img) color_wheel = ColorWheel() if layered_groups: _draw_group_boxes(draw, graph, layered_groups) _draw_connectors( draw, graph.edges, graph.nodes, render_virtual_nodes, connector_fill, connector_width, connector_arrow, connector_padding, ) draw.flush() pending_simple_text: List[Tuple[FunctionalNode, Tuple[int, int, int, int]]] = [] pending_collapsed_annotations: List[Tuple[FunctionalNode, Tuple[int, int, int, int]]] = [] for node in graph.nodes.values(): if node.kind == "virtual" and not render_virtual_nodes: continue if simple_text_visualization: x1 = int(node.x) y1 = int(node.y) x2 = int(node.x + node.width) y2 = int(node.y + node.height) fill = node.style.get("box_fill") if fill is None: fill = node.style.get("fill") if fill is None: fill = color_map.get(node.layer_type, {}).get("fill") if fill is None: fill = color_wheel.get_color(node.layer_type) outline = node.style.get("box_outline") if outline is None: outline = node.style.get("outline") if outline is None: outline = color_map.get(node.layer_type, {}).get("outline") if outline is None: outline = "black" outline_w = int(node.style.get("box_outline_width", 2) or 2) pen = aggdraw.Pen(get_rgba_tuple(outline), outline_w) brush = aggdraw.Brush(get_rgba_tuple(fill)) draw.rectangle((x1, y1, x2, y2), pen, brush) if node.image is not None: draw.flush() fit_mode = node.style.get("image_fit", "fill") quad = [(x1, y1), (x2, y1), (x2, y2), (x1, y2)] try: apply_affine_transform(img, node.image, quad, fit_mode) except Exception: pass draw = aggdraw.Draw(img) if node.node_id in node_logos: draw.flush() box_tmp = Box() box_tmp.de = 0 box_tmp.shade = 0 box_tmp.rotation = None box_tmp.x1 = x1 box_tmp.y1 = y1 box_tmp.x2 = x2 box_tmp.y2 = y2 for group, logo_img in node_logos[node.node_id]: draw_node_logo(img, box_tmp, logo_img, group, draw_volume=False) draw = aggdraw.Draw(img) if node.kind != "virtual" and bool(node.style.get("box_text_enabled", True)): pending_simple_text.append((node, (x1, y1, x2, y2))) if node.kind == "collapsed": pending_collapsed_annotations.append((node, (x1, y1, x2, y2))) continue box = Box() box.de = getattr(node, 'de', 0) box.shade = getattr(node, 'shade', 0) box.rotation = orientation_rotation box.x1 = node.x box.y1 = node.y + box.de real_width = node.width - box.de real_height = node.height - box.de box.x2 = box.x1 + real_width box.y2 = box.y1 + real_height fill = color_map.get(node.layer_type, {}).get("fill") outline = color_map.get(node.layer_type, {}).get("outline") if node.kind == "virtual": bg = get_rgba_tuple(background_fill) fill = fade_color(bg, 10) outline = fade_color(get_rgba_tuple(connector_fill), 10) box.fill = fill if fill is not None else color_wheel.get_color(node.layer_type) box.outline = outline if outline is not None else "black" if node.style.get('fill'): box.fill = node.style.get('fill') if node.style.get('outline'): box.outline = node.style.get('outline') box.draw(draw, draw_reversed=False) if node.image is not None: draw.flush() fit_mode = node.style.get("image_fit", "fill") axis = node.style.get("image_axis", "z") target_face_idx = 0 # Front if axis == 'y': target_face_idx = 4 # Top elif axis == 'x': target_face_idx = 2 # Right / Side quad = box.get_face_quad(target_face_idx) if quad: apply_affine_transform(img, node.image, quad, fit_mode) draw = aggdraw.Draw(img) if node.node_id in node_logos: draw.flush() for group, logo_img in node_logos[node.node_id]: draw_node_logo(img, box, logo_img, group, draw_volume) draw = aggdraw.Draw(img) if node.kind == "collapsed": pending_collapsed_annotations.append( (node, (int(node.x), int(node.y), int(node.x + node.width), int(node.y + node.height))) ) draw.flush() if simple_text_visualization and pending_simple_text: for node, rect in pending_simple_text: label = _resolve_box_label(node) _draw_box_text_in_rect( img, rect, label, style=node.style or {}, fallback_font=font, fallback_color=font_color, fallback_spacing=text_vspacing, ) if text_callable is not None: if font is None: font = ImageFont.load_default() draw_text = ImageDraw.Draw(img) external_labels: List[Dict[str, Any]] = [] visible_nodes = [ node for node in sorted(graph.nodes.values(), key=lambda n: n.order) if node.kind != "virtual" ] for index, node in enumerate(visible_nodes): text, above = text_callable(index, node.layer) text_value = "" if text is None else str(text) if not text_value: continue text_height = 0 text_widths = [] for line in text_value.split("\n"): if hasattr(font, "getsize"): text_widths.append(font.getsize(line)[0]) text_height += font.getsize(line)[1] else: bbox = font.getbbox(line) text_widths.append(bbox[2]) text_height += bbox[3] text_height += (len(text_value.split("\n")) - 1) * text_vspacing de = getattr(node, 'de', 0) real_width = node.width - de real_height = node.height - de face_x = node.x face_y = node.y + de text_x = face_x + real_width / 2 - max(text_widths or [0]) / 2 if above: text_y = face_y - text_height - 4 else: text_y = face_y + real_height + 4 text_width = max(text_widths or [0]) external_labels.append( { "x": float(text_x), "x_pref": float(text_x), "y": int(text_y), "w": int(text_width), "h": int(text_height), "text": text_value, } ) _resolve_external_label_x_collisions( external_labels, image_width=img.size[0], edge_padding=padding, min_gap=8, ) for label in external_labels: draw_text.multiline_text( (label["x"], label["y"]), label["text"], font=font, fill=font_color, spacing=text_vspacing, ) if pending_collapsed_annotations: _draw_collapsed_annotations( img, pending_collapsed_annotations, fallback_font=font, fallback_font_color=font_color, default_annotation_color=connector_fill, ) if layered_groups: _draw_group_captions(img, graph, layered_groups) if logos_legend: if font is None: font = ImageFont.load_default() img = draw_logos_legend(img, logo_groups, logos_legend, background_fill, font, font_color) return img def _straighten_layout( graph: FunctionalGraph, ranks: Dict[int, int], row_spacing: int, *, node_top_padding: Optional[Mapping[int, int]] = None, node_bottom_padding: Optional[Mapping[int, int]] = None, ) -> None: """ Adjusts y positions to align linear connections straight. Accounts for 3D depth (de) to ensure visual centers align. """ nodes_by_rank = {} node_top_padding = node_top_padding or {} node_bottom_padding = node_bottom_padding or {} for node in graph.nodes.values(): r = ranks.get(node.node_id, 0) nodes_by_rank.setdefault(r, []).append(node) sorted_ranks = sorted(nodes_by_rank.keys()) outgoing = {n: [] for n in graph.nodes} incoming = {n: [] for n in graph.nodes} for edge in graph.edges: outgoing[edge.src].append(edge.dst) incoming[edge.dst].append(edge.src) def get_visual_center(node: FunctionalNode) -> float: de = getattr(node, 'de', 0) return node.y + (node.height + de) / 2.0 def set_visual_center(node: FunctionalNode, center_y: float): de = getattr(node, 'de', 0) new_y = center_y - (node.height + de) / 2.0 node.y = int(new_y) def resolve_collisions(col_nodes: List[FunctionalNode]): col_nodes.sort(key=lambda n: n.y - int(node_top_padding.get(n.node_id, 0))) if not col_nodes: return current_y_map = {n.node_id: n.y for n in col_nodes} for i in range(1, len(col_nodes)): prev_node = col_nodes[i-1] curr_node = col_nodes[i] prev_bottom = int(node_bottom_padding.get(prev_node.node_id, 0)) curr_top = int(node_top_padding.get(curr_node.node_id, 0)) min_y = ( current_y_map[prev_node.node_id] + prev_node.height + prev_bottom + row_spacing + curr_top ) if current_y_map[curr_node.node_id] < min_y: current_y_map[curr_node.node_id] = min_y for node in col_nodes: node.y = int(current_y_map[node.node_id]) for rank in sorted_ranks: col_nodes = nodes_by_rank[rank] for node in col_nodes: parents = incoming[node.node_id] if len(parents) == 1: parent_id = parents[0] if len(outgoing[parent_id]) == 1: parent = graph.nodes[parent_id] target_center = get_visual_center(parent) set_visual_center(node, target_center) resolve_collisions(col_nodes) for rank in reversed(sorted_ranks): col_nodes = nodes_by_rank[rank] for node in col_nodes: children = outgoing[node.node_id] if len(children) == 1: child_id = children[0] if len(incoming[child_id]) == 1: child = graph.nodes[child_id] target_center = get_visual_center(child) set_visual_center(node, target_center) resolve_collisions(col_nodes) def _get_group_nodes(graph: FunctionalGraph, group: Dict[str, Any]) -> List[FunctionalNode]: """Resolve all graph nodes referenced by one layered-group configuration. Group ``layers`` entries may reference layer objects directly or layer names. Name matching checks both the rendered node name and the underlying Keras layer ``name`` attribute. """ layers = group.get("layers", []) if not layers: return [] group_nodes = [] for node in graph.nodes.values(): for layer_ref in layers: if node.layer is layer_ref: group_nodes.append(node) break node_layer_name = getattr(node.layer, 'name', '') if isinstance(layer_ref, str) and (node.name == layer_ref or node_layer_name == layer_ref): group_nodes.append(node) break return group_nodes def _draw_group_boxes( draw: aggdraw.Draw, graph: FunctionalGraph, groups: Sequence[Dict[str, Any]], ) -> None: """Draw configured group highlight rectangles behind matching nodes. Args: draw: Aggdraw context for the destination image. graph: Positioned graph. groups: Group style dictionaries containing layer selectors and colors. """ for group in groups: group_nodes = _get_group_nodes(graph, group) if not group_nodes: continue min_x = min(n.x for n in group_nodes) max_x = max(n.x + n.width for n in group_nodes) min_y = min(n.y for n in group_nodes) max_y = max(n.y + n.height for n in group_nodes) padding = group.get("padding", 10) min_x -= padding max_x += padding min_y -= padding max_y += padding fill = group.get("fill", (200, 200, 200, 100)) outline = group.get("outline", "black") width = group.get("width", 1) fill_rgba = get_rgba_tuple(fill) outline_rgba = get_rgba_tuple(outline) pen = aggdraw.Pen(outline_rgba, width) brush = aggdraw.Brush(fill_rgba) draw.rectangle([min_x, min_y, max_x, max_y], pen, brush) def _draw_group_captions( img: Image.Image, graph: FunctionalGraph, groups: Sequence[Dict[str, Any]], ) -> None: """Draw text captions for configured layer groups. Captions are centered horizontally below each group's padded bounding box. Args: img: Destination image. graph: Positioned graph. groups: Group configuration dictionaries. """ draw = ImageDraw.Draw(img) for group in groups: caption = group.get("name", group.get("caption")) if not caption: continue group_nodes = _get_group_nodes(graph, group) if not group_nodes: continue min_x = min(n.x for n in group_nodes) max_x = max(n.x + n.width for n in group_nodes) min_y = min(n.y for n in group_nodes) max_y = max(n.y + n.height for n in group_nodes) padding = group.get("padding", 10) box_min_x = min_x - padding box_max_x = max_x + padding box_max_y = max_y + padding font = _get_font(group) color = group.get("font_color", "black") gap = group.get("text_spacing", 5) text_w, text_h = _measure_text(draw, caption, font) center_x = (box_min_x + box_max_x) / 2 text_x = center_x - text_w / 2 text_y = box_max_y + gap draw.text((text_x, text_y), caption, fill=color, font=font) def _draw_connectors( draw: aggdraw.Draw, edges: Iterable[FunctionalEdge], nodes: Mapping[int, FunctionalNode], render_virtual_nodes: bool, connector_fill: Any, connector_width: int, connector_arrow: bool, connector_padding: int, ) -> None: """Draw orthogonal connector polylines between graph nodes. The routing pass consolidates virtual-node chains, computes shared elbows for branch/merge readability, and supports per-node connector overrides for width, arrowheads, and padding. Args: draw: Aggdraw context. edges: Directed graph edges. nodes: Node lookup by id. render_virtual_nodes: Whether virtual nodes are visually rendered. connector_fill: Default connector color. connector_width: Default connector line width. connector_arrow: Default arrowhead toggle. connector_padding: Default anchor offset from node faces. """ pen = aggdraw.Pen(connector_fill, connector_width) brush = aggdraw.Brush(connector_fill) outgoing = _outgoing_map(edges) incoming = _incoming_map(edges) visited: set[Tuple[int, int]] = set() paths_by_src: Dict[int, List[List[int]]] = {} def anchor(node: FunctionalNode, role: str) -> Tuple[int, int]: padding = node.style.get("connector_padding", connector_padding) de = getattr(node, 'de', 0) real_width = node.width - de real_height = node.height - de face_x = node.x face_y = node.y + de if node.kind == "virtual" and not render_virtual_nodes: x = face_x + real_width / 2 elif role == "start": x = face_x + real_width + padding elif role == "end": x = face_x - padding else: x = face_x + real_width / 2 y = face_y + real_height / 2 return int(round(x)), int(round(y)) shared_merge_x: Dict[int, int] = {} for dst_node_id, src_node_ids in incoming.items(): if len(src_node_ids) > 1: max_start_x = -1 valid_merge = True for src_id in src_node_ids: if src_id not in nodes: valid_merge = False break s_node = nodes[src_id] sx, _ = anchor(s_node, "start") if sx > max_start_x: max_start_x = sx if valid_merge and dst_node_id in nodes: d_node = nodes[dst_node_id] dx, _ = anchor(d_node, "end") shared_merge_x[dst_node_id] = int(round(max_start_x + (dx - max_start_x) / 2)) def append_path(start_id: int, next_id: int) -> None: path = [start_id] prev_id = start_id current_id = next_id while True: path.append(current_id) visited.add((prev_id, current_id)) node = nodes.get(current_id) if node is None: break if node.kind != "virtual" or render_virtual_nodes: break if len(outgoing.get(current_id, [])) != 1 or len(incoming.get(current_id, [])) != 1: break next_ids = outgoing.get(current_id, []) if not next_ids: break prev_id = current_id current_id = next_ids[0] if (prev_id, current_id) in visited: break paths_by_src.setdefault(start_id, []).append(path) for edge in edges: if edge.src not in nodes or edge.dst not in nodes: continue if (edge.src, edge.dst) in visited: continue src_node = nodes.get(edge.src) if src_node is None: continue if src_node.kind == "virtual" and not render_virtual_nodes and len(incoming.get(edge.src, [])) == 1: continue append_path(edge.src, edge.dst) def add_point(points: List[Tuple[int, int]], x: int, y: int) -> None: if not points or points[-1] != (x, y): points.append((x, y)) for src_id, paths in paths_by_src.items(): start_node = nodes.get(src_id) if start_node is None: continue paths.sort(key=lambda path: nodes[path[-1]].y + nodes[path[-1]].height / 2) count = len(paths) shared_branch_x: Optional[int] = None if count > 1: x_start, _ = anchor(start_node, "start") next_xs: List[int] = [] for path in paths: if len(path) < 2: continue next_node = nodes.get(path[1]) if next_node: next_role = "end" if len(path) == 2 else "mid" x_next, _ = anchor(next_node, next_role) next_xs.append(x_next) if next_xs: min_x_next = min(next_xs) shared_branch_x = int(round(x_start + (min_x_next - x_start) / 2)) min_mid = x_start + 2 max_mid = min_x_next - 2 if min_mid < max_mid: shared_branch_x = max(min_mid, min(max_mid, shared_branch_x)) else: shared_branch_x = int(round((x_start + min_x_next) / 2)) if shared_branch_x <= x_start + 1: shared_branch_x = None for index, path in enumerate(paths): points: List[Tuple[int, int]] = [] for idx, node_id in enumerate(path): node = nodes.get(node_id) if node is None: continue role = "mid" if idx == 0: role = "start" elif idx == len(path) - 1: role = "end" x, y = anchor(node, role) if not points: add_point(points, x, y) continue x1, y1 = points[-1] x2, y2 = x, y if x2 <= x1 + 1: add_point(points, x2, y2) continue mid_x = 0 if idx == len(path) - 1 and node_id in shared_merge_x: mid_x = shared_merge_x[node_id] elif idx == 1 and shared_branch_x is not None: mid_x = shared_branch_x else: mid_x = int(round(x1 + (x2 - x1) / 2)) min_mid = x1 + 2 max_mid = x2 - 2 if min_mid < max_mid: mid_x = max(min_mid, min(max_mid, mid_x)) else: mid_x = int(round((x1 + x2) / 2)) add_point(points, mid_x, y1) add_point(points, mid_x, y2) add_point(points, x2, y2) src_node = nodes[path[0]] use_arrow = src_node.style.get("connector_arrow", connector_arrow) use_width = src_node.style.get("connector_width", connector_width) current_pen = aggdraw.Pen(connector_fill, use_width) if len(points) >= 2: draw.line([coord for point in points for coord in point], current_pen) if use_arrow: end_x, end_y = points[-1] prev_x, prev_y = points[-2] arrow_size = max(8, use_width * 3) if end_x > prev_x: # Right p1 = (end_x, end_y) p2 = (end_x - arrow_size, end_y - arrow_size // 2) p3 = (end_x - arrow_size, end_y + arrow_size // 2) elif end_x < prev_x: # Left p1 = (end_x, end_y) p2 = (end_x + arrow_size, end_y - arrow_size // 2) p3 = (end_x + arrow_size, end_y + arrow_size // 2) elif end_y > prev_y: # Down p1 = (end_x, end_y) p2 = (end_x - arrow_size // 2, end_y - arrow_size) p3 = (end_x + arrow_size // 2, end_y - arrow_size) else: # Up p1 = (end_x, end_y) p2 = (end_x - arrow_size // 2, end_y + arrow_size) p3 = (end_x + arrow_size // 2, end_y + arrow_size) draw.polygon([p1[0], p1[1], p2[0], p2[1], p3[0], p3[1]], current_pen, brush) def _incoming_map(edges: Sequence[FunctionalEdge]) -> Dict[int, List[int]]: incoming: Dict[int, List[int]] = {} for edge in edges: incoming.setdefault(edge.dst, []).append(edge.src) return incoming def _outgoing_map(edges: Sequence[FunctionalEdge]) -> Dict[int, List[int]]: outgoing: Dict[int, List[int]] = {} for edge in edges: outgoing.setdefault(edge.src, []).append(edge.dst) return outgoing def _resolve_layer_output_shape(layer: Any) -> Optional[Any]: shape = getattr(layer, "output_shape", None) if shape is not None: return _shape_to_tuple(shape) output = getattr(layer, "output", None) tensor_shape = getattr(output, "shape", None) if tensor_shape is not None: return _shape_to_tuple(tensor_shape) input_shape = getattr(layer, "input_shape", None) if input_shape is not None: return _shape_to_tuple(input_shape) compute_output_shape = getattr(layer, "compute_output_shape", None) if callable(compute_output_shape): if input_shape is not None: try: return _shape_to_tuple(compute_output_shape(input_shape)) except Exception: # noqa: BLE001 pass return None def _shape_to_tuple(shape: Any) -> Any: if shape is None: return None if isinstance(shape, tuple): return shape if hasattr(shape, "as_list"): try: return tuple(shape.as_list()) except Exception: # noqa: BLE001 return tuple(shape) if isinstance(shape, list): return tuple(shape) return shape def _get_logo_nodes(graph: FunctionalGraph, group: Dict[str, Any]) -> List[FunctionalNode]: """Resolve nodes targeted by one logo-group configuration. String entries in ``layers`` target layer names. Type entries target all nodes whose layer is an instance of that type. """ layers_ref = group.get("layers", []) if not layers_ref: return [] target_nodes = [] name_to_nodes = {} type_to_nodes = {} for node in graph.nodes.values(): if node.kind == "virtual": continue layer_name = getattr(node.layer, 'name', None) if layer_name: if layer_name not in name_to_nodes: name_to_nodes[layer_name] = [] name_to_nodes[layer_name].append(node) layer_type = type(node.layer) if layer_type not in type_to_nodes: type_to_nodes[layer_type] = [] type_to_nodes[layer_type].append(node) for ref in layers_ref: if isinstance(ref, str): if ref in name_to_nodes: target_nodes.extend(name_to_nodes[ref]) elif isinstance(ref, type): if ref in type_to_nodes: target_nodes.extend(type_to_nodes[ref]) return target_nodes