Source code for visualkeras.graph

from typing import Any, Dict, Mapping, Optional, Union, List, Sequence, Tuple
import aggdraw
from PIL import Image, ImageDraw, ImageFont
from math import ceil
import warnings
from .utils import *
from .layer_utils import *
from .options import GraphOptions, GRAPH_PRESETS

class _DummyLayer:

    def __init__(self, name, units=None):
        if units:
            self.units = units
        self.name = name


[docs] def graph_view(model, to_file: str = None, color_map: dict = None, node_size: int = 50, background_fill: Any = 'white', padding: int = 10, layer_spacing: int = 250, node_spacing: int = 10, connector_fill: Any = 'gray', connector_width: int = 1, ellipsize_after: int = 10, inout_as_tensor: bool = True, show_neurons: bool = True, styles: Optional[Mapping[Union[str, type], Dict[str, Any]]] = None, image_fit: str = 'contain', circular_crop: bool = True, layered_groups: Optional[Sequence[Dict[str, Any]]] = None, *, options: Union[GraphOptions, Mapping[str, Any], None] = None, preset: Union[str, None] = None) -> Image: """Render a Keras model as a graph-style architecture diagram. This renderer emphasizes connectivity rather than tensor volume. It is a good fit for models with branching, merges, skip connections, or other topologies where the main question is how layers connect rather than how tensor shapes evolve from left to right. Parameters ---------- model : Any Keras model instance to visualize. Graph view works across sequential, functional, and many subclassed models. It is usually the clearest choice when the architecture has multiple paths or when you need a compact topology-first diagram. to_file : str, optional Path to save the rendered image. The image format is inferred from the file extension. The rendered ``PIL.Image`` is returned whether or not this value is supplied. Use this when you want to save the figure and continue working with the in-memory image in the same call. color_map : dict, optional Mapping from layer class to broad style values such as ``fill`` and ``outline``. This is the simplest way to assign consistent colors by layer type. It works well for high-level styling rules, while ``styles`` is better when you need per-layer overrides or image-based customization. node_size : int, default=50 Default node diameter or box size in pixels. Increase this for presentation-sized figures or when layer labels need more room. Smaller values produce denser diagrams and are often useful for large networks. background_fill : Any, default='white' Background color for the final image. This accepts any Pillow-compatible color value. Choose a background that keeps nodes, connectors, and annotation text easy to distinguish. padding : int, default=10 Outer padding around the full diagram in pixels. This affects the margin between the rendered graph and the image edges. Increase it when group boxes, larger nodes, or custom overlays feel too close to the boundary. layer_spacing : int, default=250 Horizontal spacing between layer columns. This is the main control for how spread out the graph feels from input to output. Larger values improve readability in complex graphs, while smaller values keep the figure compact. node_spacing : int, default=10 Vertical spacing between nodes within the same column. Use this to separate dense stacks of nodes or to compress layers when a model has many nodes per rank. It works together with ``node_size`` to define the overall vertical density of the diagram. connector_fill : Any, default='gray' Color used for connector lines between nodes. Neutral connector colors are usually easier to read because the nodes themselves already carry most of the semantic styling. connector_width : int, default=1 Line width used for connectors. Thicker connectors can improve readability in exported figures or when the graph contains long cross-column edges. ellipsize_after : int, default=10 Maximum number of neuron markers to draw before collapsing the remainder into ellipsis markers. This keeps large dense layers from dominating the diagram. Lower values trade detail for compactness, while higher values preserve more of the true node count at the cost of visual density. inout_as_tensor : bool, default=True If ``True``, represent each input or output tensor as a single node. If ``False``, flatten tensor shapes into scalar-like node counts when possible. Tensor mode is usually better for topology diagrams. Expanded mode is more literal, but it can grow quickly for layers with larger shapes. show_neurons : bool, default=True If ``True``, draw individual neuron markers for supported layers. If ``False``, represent each layer as a single node or box. Turning this off produces a more abstract graph and is often helpful for large models or diagrams intended for documentation rather than low-level inspection. styles : mapping, optional Fine-grained style overrides keyed by layer name or layer class. Use this when ``color_map`` is too coarse. Graph styles can override values such as ``fill``, ``outline``, ``node_size``, ``connector_fill``, ``connector_width``, ``box_scale``, embedded images, and related node presentation details. image_fit : {'contain', 'cover', 'fill', 'match_aspect'}, default='contain' Default fit mode for images attached through ``styles``. This controls how per-layer images are resized inside graph nodes. Use ``contain`` when preserving the full image is important and ``cover`` when full node coverage is more important than edge cropping. circular_crop : bool, default=True If ``True``, crop node images to a circle where supported. Circular crops often produce cleaner node icons in graph view. Disable this when you need the full rectangular image or when square logos read better in your figure. layered_groups : sequence of dict, optional Group definitions used to draw labeled background regions behind sets of nodes. Groups are useful for highlighting architectural stages or conceptual blocks such as encoder, bottleneck, and decoder sections. They provide structure without altering the graph layout itself. options : GraphOptions or mapping, optional Configuration bundle applied after ``preset`` and before explicit keyword arguments. Use this when you want to reuse a graph style across multiple models or keep application code more readable than a long keyword-argument call. preset : str, optional Name of a preset from ``visualkeras.GRAPH_PRESETS``. Graph mode currently provides ``default``, ``compact``, and ``presentation``. Presets are intended as starting points. They can be combined with ``options`` and explicit overrides when you want a curated base style without giving up control. Returns ------- PIL.Image.Image Rendered graph diagram. Notes ----- Configuration precedence is ``preset`` followed by ``options`` followed by explicit keyword arguments. Full documentation: https://visualkeras.readthedocs.io/en/latest/api/graph.html """ using_presets = options is not None or preset is not None if not using_presets: defaults = GraphOptions().to_kwargs() defaults.update({ "to_file": None, "color_map": None, "node_size": 50, "background_fill": 'white', "padding": 10, "layer_spacing": 250, "node_spacing": 10, "connector_fill": 'gray', "connector_width": 1, "ellipsize_after": 10, "inout_as_tensor": True, "show_neurons": True, "styles": None, "image_fit": 'contain', "circular_crop": True, "layered_groups": None, }) current_params = { "to_file": to_file, "color_map": color_map, "node_size": node_size, "background_fill": background_fill, "padding": padding, "layer_spacing": layer_spacing, "node_spacing": node_spacing, "connector_fill": connector_fill, "connector_width": connector_width, "ellipsize_after": ellipsize_after, "inout_as_tensor": inout_as_tensor, "show_neurons": show_neurons, "styles": styles, "image_fit": image_fit, "circular_crop": circular_crop, "layered_groups": layered_groups, } custom_keys = [ key for key, value in current_params.items() if key in defaults and value != defaults[key] ] if len(custom_keys) >= 5: warnings.warn( "graph_view received many custom keyword arguments. " "Consider using visualkeras.show(..., mode='graph', preset=...) and the GraphOptions dataclass for a simpler workflow.", UserWarning, stacklevel=2, ) if preset is not None or options is not None: defaults = GraphOptions().to_kwargs() defaults["color_map"] = None defaults["styles"] = None defaults["layered_groups"] = None resolved = dict(defaults) if preset is not None: try: resolved.update(GRAPH_PRESETS[preset].to_kwargs()) except KeyError as exc: available = ", ".join(sorted(GRAPH_PRESETS.keys())) raise ValueError( f"Unknown graph preset '{preset}'. Available presets: {available}" ) from exc if options is not None: if isinstance(options, GraphOptions): option_values = options.to_kwargs() elif isinstance(options, Mapping): option_values = dict(options) else: raise TypeError( "options must be a GraphOptions instance or a mapping of keyword arguments." ) resolved.update(option_values) explicit_values = { "to_file": to_file, "color_map": color_map, "node_size": node_size, "background_fill": background_fill, "padding": padding, "layer_spacing": layer_spacing, "node_spacing": node_spacing, "connector_fill": connector_fill, "connector_width": connector_width, "ellipsize_after": ellipsize_after, "inout_as_tensor": inout_as_tensor, "show_neurons": show_neurons, "styles": styles, "image_fit": image_fit, "circular_crop": circular_crop, "layered_groups": layered_groups, } for key, value in explicit_values.items(): if key not in defaults: continue if value != defaults[key]: resolved[key] = value to_file = resolved["to_file"] color_map = resolved["color_map"] node_size = resolved["node_size"] background_fill = resolved["background_fill"] padding = resolved["padding"] layer_spacing = resolved["layer_spacing"] node_spacing = resolved["node_spacing"] connector_fill = resolved["connector_fill"] connector_width = resolved["connector_width"] ellipsize_after = resolved["ellipsize_after"] inout_as_tensor = resolved["inout_as_tensor"] show_neurons = resolved["show_neurons"] styles = resolved["styles"] image_fit = resolved["image_fit"] circular_crop = resolved["circular_crop"] layered_groups = resolved["layered_groups"] if color_map is not None and not isinstance(color_map, dict): color_map = dict(color_map) if color_map is None: color_map = dict() if styles is not None and not isinstance(styles, dict): styles = dict(styles) if styles is None: styles = {} global_defaults = { "fill": None, "outline": "black", "node_size": node_size, "node_spacing": node_spacing, "layer_spacing": layer_spacing, "connector_fill": connector_fill, "connector_width": connector_width, "ellipsize_after": ellipsize_after, "show_neurons": show_neurons, "box_scale": 3, "image_fit": image_fit, "circular_crop": circular_crop, } # Iterate over the model to compute bounds and generate boxes layers = list() layer_y = list() # Determine output names compatible with both Keras versions if hasattr(model, 'output_names'): # Older versions of Keras output_names = model.output_names else: # Newer versions of Keras output_names = [] for output in model.outputs: if hasattr(output, '_keras_history'): # Get the layer that produced the output layer = output._keras_history[0] output_names.append(layer.name) else: # Fallback # Use the tensor's name or a default name if keras_history is not available output_names.append(getattr(output, 'name', f'output_{len(output_names)}')) # Attach helper layers id_to_num_mapping, adj_matrix = model_to_adj_matrix(model) model_layers = model_to_hierarchy_lists(model, id_to_num_mapping, adj_matrix) # Add fake output layers only when needed # When inout_as_tensor=False, only add dummy layers if output-producing layers # are not in the last hierarchy level (to avoid duplication) should_add_dummy_outputs = inout_as_tensor if not inout_as_tensor: # Check if all output-producing layers are in the last hierarchy level last_level_layers = model_layers[-1] if model_layers else [] layers_producing_outputs = [] for output_tensor in model.outputs: for layer in model.layers: if hasattr(layer, 'output') and layer.output is output_tensor: layers_producing_outputs.append(layer) break # Only add dummy outputs if some output-producing layers are NOT in the last level should_add_dummy_outputs = not all(layer in last_level_layers for layer in layers_producing_outputs) if should_add_dummy_outputs: # Normalize output_shape to always be a list of tuples if isinstance(model.output_shape, tuple): # Single output model: output_shape is a tuple, convert to list of tuples output_shapes = [model.output_shape] else: # Multi-output model: output_shape is already a list of tuples output_shapes = model.output_shape model_layers.append([ _DummyLayer( output_names[i], None if inout_as_tensor else self_multiply(output_shapes[i]) ) for i in range(len(model.outputs)) ]) id_to_num_mapping, adj_matrix = augment_output_layers(model, model_layers[-1], id_to_num_mapping, adj_matrix) # Create architecture current_x = padding # + input_label_size[0] + text_padding max_right = padding id_to_node_list_map = dict() layer_counter = 0 for index, layer_list in enumerate(model_layers): current_y = 0 nodes = [] column_width = 0 column_spacing = layer_spacing last_node_spacing = node_spacing last_node_size = node_size for layer in layer_list: layer_name = getattr(layer, "name", None) if not layer_name: layer_name = f"layer_{layer_counter}" layer_counter += 1 legacy_color = color_map.get(type(layer), {}) current_defaults = global_defaults.copy() current_defaults.update(legacy_color) style = resolve_style(layer, layer_name, styles, current_defaults) # --- Image Handling --- image_path = style.get("image") image_indices = style.get("image_indices") node_image = None if image_path: try: node_image = Image.open(image_path).convert("RGBA") except Exception as e: warnings.warn(f"Could not load image {image_path} for layer {layer_name}: {e}") local_node_size = style.get("node_size", node_size) local_node_spacing = style.get("node_spacing", node_spacing) local_layer_spacing = style.get("layer_spacing", layer_spacing) local_ellipsize_after = style.get("ellipsize_after", ellipsize_after) local_show_neurons = style.get("show_neurons", show_neurons) box_scale = style.get("box_scale", 3) if local_ellipsize_after is not None: local_ellipsize_after = int(local_ellipsize_after) column_width = max(column_width, local_node_size) column_spacing = max(column_spacing, local_layer_spacing) last_node_spacing = local_node_spacing last_node_size = local_node_size is_box = True units = 1 if local_show_neurons: if hasattr(layer, 'units'): is_box = False units = layer.units elif hasattr(layer, 'filters'): is_box = False units = layer.filters elif is_internal_input(layer) and not inout_as_tensor: is_box = False # Normalize input_shape to handle both tuple and list formats input_shape = layer.input_shape if isinstance(input_shape, tuple): shape = input_shape elif isinstance(input_shape, list) and len(input_shape) == 1: shape = input_shape[0] else: raise RuntimeError(f"not supported input shape {input_shape}") units = self_multiply(shape) if local_ellipsize_after is None or local_ellipsize_after <= 0: n = units else: n = min(units, local_ellipsize_after) layer_nodes = list() for i in range(n): scale = 1 if not is_box: if local_ellipsize_after and local_ellipsize_after > 1 and i == local_ellipsize_after - 2: c = Ellipses() else: c = Circle() else: c = Box() scale = box_scale c.x1 = current_x c.y1 = current_y c.x2 = c.x1 + local_node_size c.y2 = c.y1 + local_node_size * scale current_y = c.y2 + local_node_spacing max_right = max(max_right, c.x2) c.fill = style.get('fill') if style.get('fill') is not None else 'orange' c.outline = style.get('outline') if style.get('outline') is not None else 'black' c.style = style if node_image and not isinstance(c, Ellipses): if image_indices is None or i in image_indices: c.image = node_image c.image_fit = style.get("image_fit", image_fit) c.circular_crop = style.get("circular_crop", circular_crop) layer_nodes.append(c) id_to_node_list_map[id(layer)] = layer_nodes nodes.extend(layer_nodes) current_y += 2 * local_node_size layer_y.append(current_y - last_node_spacing - 2 * last_node_size) layers.append(nodes) current_x += column_width + column_spacing # Generate image img_width = max_right + padding img_height = max(*layer_y) + 2 * padding # Calculate y offsets for centering y_offsets = [] for i, layer in enumerate(layers): y_off = (img_height - layer_y[i]) / 2 y_offsets.append(y_off) # Apply offsets to nodes for i, layer in enumerate(layers): y_off = y_offsets[i] for node in layer: node.y1 += y_off node.y2 += y_off # Calculate group bounds and expand image if needed min_x, min_y = 0, 0 max_x, max_y = img_width, img_height if layered_groups: dummy_draw = ImageDraw.Draw(Image.new("RGBA", (1, 1))) for group in layered_groups: group_nodes = _get_graph_group_nodes(id_to_node_list_map, group, model) if not group_nodes: continue g_min_x = float('inf') g_max_x = float('-inf') g_min_y = float('inf') g_max_y = float('-inf') for node in group_nodes: g_min_x = min(g_min_x, node.x1) g_max_x = max(g_max_x, node.x2) g_min_y = min(g_min_y, node.y1) g_max_y = max(g_max_y, node.y2) padding_val = group.get("padding", 10) g_min_x -= padding_val g_max_x += padding_val g_min_y -= padding_val g_max_y += padding_val min_x = min(min_x, g_min_x) max_x = max(max_x, g_max_x) min_y = min(min_y, g_min_y) max_y = max(max_y, g_max_y) # Caption caption = group.get("name", group.get("caption")) if caption: font = _get_font(group) text_w, text_h = _measure_text(dummy_draw, caption, font) center_x = (g_min_x + g_max_x) / 2 text_x1 = center_x - text_w / 2 text_x2 = center_x + text_w / 2 gap = group.get("text_spacing", 5) text_bottom = g_max_y + gap + text_h min_x = min(min_x, text_x1) max_x = max(max_x, text_x2) max_y = max(max_y, text_bottom) final_width = max_x - min_x final_height = max_y - min_y img = Image.new('RGBA', (int(ceil(final_width)), int(ceil(final_height))), background_fill) draw = aggdraw.Draw(img) # Shift nodes if needed shift_x = -min_x shift_y = -min_y if shift_x != 0 or shift_y != 0: for layer in layers: for node in layer: node.x1 += shift_x node.x2 += shift_x node.y1 += shift_y node.y2 += shift_y if layered_groups: _draw_graph_group_boxes(draw, id_to_node_list_map, layered_groups, model) for start_idx, end_idx in zip(*np.where(adj_matrix > 0)): start_id = next(get_keys_by_value(id_to_num_mapping, start_idx)) end_id = next(get_keys_by_value(id_to_num_mapping, end_idx)) start_layer_list = id_to_node_list_map[start_id] end_layer_list = id_to_node_list_map[end_id] # draw connectors for start_node_idx, start_node in enumerate(start_layer_list): for end_node in end_layer_list: if not isinstance(start_node, Ellipses) and not isinstance(end_node, Ellipses): _draw_connector(draw, start_node, end_node, color=connector_fill, width=connector_width) for i, layer in enumerate(layers): for node_index, node in enumerate(layer): if getattr(node, 'image', None): draw.flush() image = node.image fit = node.image_fit w = node.x2 - node.x1 h = node.y2 - node.y1 resized = resize_image_to_fit(image, int(w), int(h), fit) if getattr(node, 'circular_crop', False): # Supersampling for anti-aliasing super_scale = 4 mask_w = int(w) * super_scale mask_h = int(h) * super_scale mask = Image.new("L", (mask_w, mask_h), 0) mask_draw = ImageDraw.Draw(mask) mask_draw.ellipse((0, 0, mask_w, mask_h), fill=255) # Resize mask down smoothly mask = mask.resize((int(w), int(h)), Image.LANCZOS) # Apply mask to the resized image if resized.mode == 'RGBA': cropped = Image.new("RGBA", (int(w), int(h)), (0, 0, 0, 0)) cropped.paste(resized, (0, 0), mask=mask) resized = cropped else: resized.putalpha(mask) img.paste(resized, (int(node.x1), int(node.y1)), resized) draw = aggdraw.Draw(img) # Draw the node outline on top (transparent fill) # We access _fill directly to avoid the setter logic which might fail on None original_fill = node._fill node._fill = (0, 0, 0, 0) node.draw(draw) node._fill = original_fill else: node.draw(draw) draw.flush() if layered_groups: _draw_graph_group_captions(img, id_to_node_list_map, layered_groups, model) if to_file is not None: img.save(to_file) return img
def _draw_connector(draw, start_node, end_node, color, width): style = getattr(start_node, "style", {}) or {} use_color = style.get("connector_fill", color) use_width = style.get("connector_width", width) pen = aggdraw.Pen(get_rgba_tuple(use_color), use_width) x1 = start_node.x2 y1 = start_node.y1 + (start_node.y2 - start_node.y1) / 2 x2 = end_node.x1 y2 = end_node.y1 + (end_node.y2 - end_node.y1) / 2 draw.line([x1, y1, x2, y2], pen) def _get_graph_group_nodes(id_to_node_list_map: Dict[int, List[Any]], group: Dict[str, Any], model) -> List[Any]: layers_ref = group.get("layers", []) if not layers_ref: return [] group_nodes = [] # Build lookup maps name_to_layer = {} for layer in model.layers: name = getattr(layer, 'name', None) if name: name_to_layer[name] = layer for ref in layers_ref: layer = None if isinstance(ref, str): layer = name_to_layer.get(ref) else: layer = ref if layer: nodes = id_to_node_list_map.get(id(layer)) if nodes: group_nodes.extend(nodes) return group_nodes def _draw_graph_group_boxes(draw, id_to_node_list_map, groups, model): for group in groups: nodes = _get_graph_group_nodes(id_to_node_list_map, group, model) if not nodes: continue min_x = float('inf') max_x = float('-inf') min_y = float('inf') max_y = float('-inf') for node in nodes: min_x = min(min_x, node.x1) max_x = max(max_x, node.x2) min_y = min(min_y, node.y1) max_y = max(max_y, node.y2) padding = group.get("padding", 10) min_x -= padding max_x += padding min_y -= padding max_y += padding fill = group.get("fill", (200, 200, 200, 100)) outline = group.get("outline", "black") width = group.get("width", 1) pen = aggdraw.Pen(get_rgba_tuple(outline), width) brush = aggdraw.Brush(get_rgba_tuple(fill)) draw.rectangle([min_x, min_y, max_x, max_y], pen, brush) def _draw_graph_group_captions(img, id_to_node_list_map, groups, model): draw = ImageDraw.Draw(img) for group in groups: caption = group.get("name", group.get("caption")) if not caption: continue nodes = _get_graph_group_nodes(id_to_node_list_map, group, model) if not nodes: continue min_x = float('inf') max_x = float('-inf') min_y = float('inf') max_y = float('-inf') for node in nodes: min_x = min(min_x, node.x1) max_x = max(max_x, node.x2) min_y = min(min_y, node.y1) max_y = max(max_y, node.y2) padding = group.get("padding", 10) min_x -= padding max_x += padding min_y -= padding max_y += padding font = _get_font(group) color = group.get("font_color", "black") gap = group.get("text_spacing", 5) text_w, text_h = _measure_text(draw, caption, font) center_x = (min_x + max_x) / 2 text_x = center_x - text_w / 2 text_y = max_y + gap draw.text((text_x, text_y), caption, fill=color, font=font) def _get_font(group: Dict[str, Any]) -> ImageFont.ImageFont: font_src = group.get("font", None) font_size = group.get("font_size", 15) if font_src is None: try: return ImageFont.truetype("arial.ttf", font_size) except IOError: return ImageFont.load_default() elif isinstance(font_src, str): try: return ImageFont.truetype(font_src, font_size) except IOError: return ImageFont.load_default() elif isinstance(font_src, ImageFont.ImageFont): return font_src else: return ImageFont.load_default() def _measure_text(draw: ImageDraw.ImageDraw, text: str, font: Any) -> Tuple[int, int]: if hasattr(font, "getbbox"): left, top, right, bottom = font.getbbox(text) return right - left, bottom - top else: return draw.textsize(text, font=font)