# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: MIT-0
"""An ipycanvas-based interactive widget for drawing PIL-compatible doodles in JupyterLab
"""
# Python Built-Ins:
from math import floor
from typing import Tuple, Union
# External Dependenices:
import numpy as np
from ipycanvas import Canvas, hold_canvas
from IPython.display import display
from ipywidgets import HTML, Button, Layout, Output, VBox
from matplotlib.colors import to_hex, to_rgb
from PIL import Image, ImageDraw
class ValidatedColor:
"""Canvas expects different color repr from PIL/image, so this class stores both"""
hexa: str
np_8bit: np.ndarray
def __init__(self, color: Union[Tuple[float], np.ndarray]):
self.set_color(color)
def set_color(self, color: Union[Tuple[float], np.ndarray]):
"""Use this method to update all stored representations at once"""
self.hexa = to_hex(color)
self.np_8bit = (255 * np.array(to_rgb(color))).astype(int)
class PixelDrawCanvas:
"""JupyterLab widget to interactively draw on a canvas and export the pixel data to Python
This widget maintains a buffer of pixel values and draws individual pixel rects to canvas (in
batches, at least) to canvas on each mouse event... More toy/demo than an optimized design!
Usage
-----
After creating the PixelDrawCanvas you can either call `.display()` to directly display it in
the notebook, or access the `.widget` property if you want to embed the UI it in another
ipywidgets widget.
Draw on the canvas by clicking and dragging, or press the "Clear" button to start again.
You can read the 0-255, 3-channel (height, width, 3) pixel data numpy array from `.data`.
`matplotlib.pyplot.imshow(data)` should confirm that what you see in the widget matches this.
You can also programmatically `.clear()` the drawing from Python if you like.
"""
def __init__(
self,
width: int = 28,
height: int = 28,
color_bg: Tuple[float, float, float] = (0, 0, 0),
color_fg: Tuple[float, float, float] = (1.0, 1.0, 1.0),
pen_size: int = 3,
title_html: str = "
Draw a digit!
",
):
"""Create a PixelDrawCanvas"""
self.col_bg = ValidatedColor(color_bg)
self.col_fg = ValidatedColor(color_fg)
# -- Create individual widget components:
self.canvas = Canvas(width=width, height=height, image_smoothing_enabled=False)
# (Without explicit canvas.layout width, VBox/HBox fills full available width)
self.canvas.layout.height = f"{max(200, min(1000, height))}px"
self.canvas.layout.width = f"{max(200, min(1000, width))}px"
self.canvas.image_smoothing_enabled = False
self._clear_button = Button(
description="Clear",
icon="eraser",
tooltip="Clear the drawing to a blank image",
)
self._console = Output(
layout=Layout(
max_height="140px",
overflow_y="auto",
)
)
self._title = HTML(title_html)
# -- Initialize state:
self.is_drawing = False
# (Temporary data __init__ to be overridden by clear() shortly:)
self.data = np.zeros((height, width, 3))
self.set_pen(pen_size=pen_size)
# -- Set up listeners:
# Wrap widget event listener member functions so they have access to this `self` instance
# when called and are also able to `print()` to the console output if needed.
@self._console.capture()
def on_mouse_down(*args, **kwargs):
return self._on_mouse_down(*args, **kwargs)
@self._console.capture()
def on_mouse_move(*args, **kwargs):
return self._on_mouse_move(*args, **kwargs)
@self._console.capture()
def on_mouse_out(*args, **kwargs):
return self._on_mouse_out(*args, **kwargs)
@self._console.capture()
def on_mouse_up(*args, **kwargs):
return self._on_mouse_up(*args, **kwargs)
@self._console.capture()
def on_clear_click(*args, **kwargs):
return self.clear()
self.canvas.on_mouse_down(on_mouse_down)
self.canvas.on_mouse_move(on_mouse_move)
self.canvas.on_mouse_out(on_mouse_out)
self.canvas.on_mouse_up(on_mouse_up)
self._clear_button.on_click(on_clear_click)
# Set up composite view with the different widget components:
self.widget = VBox(
[self._title, self._clear_button, self.canvas, self._console],
width=f"{width}px",
)
# Finally initialize to clear state ready to use:
with self._console:
self.clear()
def clear(self):
"""Clear the drawing"""
height = self.canvas.height
width = self.canvas.width
with hold_canvas(self.canvas):
self.canvas.clear()
self.canvas.fill_style = self.col_bg.hexa
self.canvas.fill_rect(0, 0, width, height)
self.canvas.fill_style = self.col_fg.hexa
self.data = np.tile(self.col_bg.np_8bit, (height, width, 1))
print("Cleared drawing")
def draw_from_buffer(self):
"""Draw the contents of the .data buffer to the canvas
This reproduces steps from clear() instead of calling it internally, to avoid flicker. Only
pixels of the current col_fg in the buffer will be drawn (doesn't support changing col_fg
dynamically or drawing multiple colors).
"""
height = self.canvas.height
width = self.canvas.width
fg_mask = (self.data == np.expand_dims(self.col_fg.np_8bit, (0, 1))).all(-1)
with hold_canvas(self.canvas):
self.canvas.clear()
self.canvas.fill_style = self.col_bg.hexa
self.canvas.fill_rect(0, 0, width, height)
self.canvas.fill_style = self.col_fg.hexa
fg_coords = np.argwhere(fg_mask) # N entries of (x, y) pairs
self.canvas.fill_rects(fg_coords[:, 1], fg_coords[:, 0], 1, 1)
def display(self):
"""Display the widget (in a Jupyter/Lab notebook)"""
display(self.widget)
def _on_mouse_down(self, x, y):
self.is_drawing = True
self.paint(x, y)
def _on_mouse_move(self, x, y):
if self.is_drawing:
self.paint(x, y)
def _on_mouse_out(self, x, y):
"""Re-draw from data buffer on each mouse-out in case anything weird happened"""
self.is_drawing = False
self.draw_from_buffer()
def _on_mouse_up(self, x, y):
self.is_drawing = False
def set_pen(self, pen_size: int = 15) -> np.ndarray:
"""Set up the pen/brush (define pen_mask matrix)
We pre-calculate and store a boolean `.pen_mask` matrix for the requested brush size (and
assumed circular shape). If you wanted, you could set other whacky shapes by replacing your
own boolean matrix (True where the pen marks, False where it doesn't).
Returns
-------
pen_mask :
The same boolean 2D matrix this function saves to `self.pen_mask`.
"""
# No sense re-inventing the "pixellated circle" wheel, so use PIL:
mask_img = Image.new("1", (pen_size, pen_size))
draw = ImageDraw.Draw(mask_img)
draw.ellipse((0, 0, pen_size - 1, pen_size - 1), fill="white")
self.pen_mask = np.array(mask_img) # (pen_size, pen_size) boolean array
return self.pen_mask
def paint(self, x, y):
"""Mark the given location with the current pen"""
# Truncate the current pen mask if required (if location is close to edge of image):
x_floor = floor(x)
y_floor = floor(y)
pen_mask = self.pen_mask
x_maskstart = floor(x - (pen_mask.shape[1] / 2))
if x_maskstart < 0:
pen_mask = pen_mask[:, -x_maskstart:] # Truncate left of pen
x_maskstart = 0
x_pixelsafter = self.data.shape[1] - (x_maskstart + pen_mask.shape[1])
if x_pixelsafter < 0:
pen_mask = pen_mask[:, :x_pixelsafter] # Truncate right of pen
x_pixelsafter = 0
y_maskstart = floor(y - (pen_mask.shape[0] / 2))
if y_maskstart < 0:
pen_mask = pen_mask[-y_maskstart:, :] # Truncate top of pen
y_maskstart = 0
y_pixelsafter = self.data.shape[0] - (y_maskstart + pen_mask.shape[0])
if y_pixelsafter < 0:
pen_mask = pen_mask[:y_pixelsafter, :] # Truncate bottom of pen
y_pixelsafter = 0
x_maskend = x_maskstart + pen_mask.shape[1]
y_maskend = y_maskstart + pen_mask.shape[0]
# Check which pixels will be actually updated to avoid drawing unnecessary canvas rects:
new_fg_pixels_offset = np.argwhere(
pen_mask
& (
self.data[
y_maskstart:(y_maskstart + pen_mask.shape[0]),
x_maskstart:(x_maskstart + pen_mask.shape[1]),
:,
]
!= np.expand_dims(self.col_fg.np_8bit, (0, 1))
).all(-1)
)
# Update the data buffer:
full_mask = np.zeros_like(self.data)
full_mask[y_maskstart:y_maskend, x_maskstart:x_maskend, :] = np.expand_dims(pen_mask, -1)
self.data = np.where(full_mask, self.col_fg.np_8bit, self.data)
# Draw the canvas updates:
with hold_canvas(self.canvas):
self.canvas.fill_style = self.col_fg.hexa
self.canvas.fill_rects(
new_fg_pixels_offset[:, 1] + x_maskstart,
new_fg_pixels_offset[:, 0] + y_maskstart,
1,
1,
)
self.canvas.fill_rect(x_floor, y_floor, 1, 1)