{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "\n", "\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.vision.all import *\n", "import matplotlib.pyplot as plt\n", "from PIL import Image" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load FastAI Model and Save Its Well Trained Weights" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def acc_camvid(*_): pass\n", "\n", "def get_y(*_): pass\n", "\n", "learn = load_learner(\"/home/ubuntu/.fastai/data/camvid_tiny/fastai_unet.pkl\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image_path = \"../sample/street_view_of_a_small_neighborhood.png\"\n", "Image.open(image_path)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "pred_fastai = learn.predict(image_path)\n", "plt.imshow(pred_fastai[0].numpy());" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Save Torch Weights" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.save(learn.model.state_dict(), \"../model_store/fasti_unet_weights.pth\")\n", "learn.model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Extract FastAI Model in PyTorch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "??unet_learner" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "??DynamicUnet" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.vision.all import *\n", "from fastai.vision.learner import _default_meta\n", "from fastai.vision.models.unet import _get_sz_change_idxs, UnetBlock, ResizeToOrig\n", "\n", "\n", "class DynamicUnetDIY(SequentialEx):\n", " \"Create a U-Net from a given architecture.\"\n", "\n", " def __init__(\n", " self,\n", " arch=resnet50,\n", " n_classes=32,\n", " img_size=(96, 128),\n", " blur=False,\n", " blur_final=True,\n", " y_range=None,\n", " last_cross=True,\n", " bottle=False,\n", " init=nn.init.kaiming_normal_,\n", " norm_type=None,\n", " self_attention=None,\n", " act_cls=defaults.activation,\n", " n_in=3,\n", " cut=None,\n", " **kwargs\n", " ):\n", " meta = model_meta.get(arch, _default_meta)\n", " encoder = create_body(\n", " arch, n_in, pretrained=False, cut=ifnone(cut, meta[\"cut\"])\n", " )\n", " imsize = img_size\n", "\n", " sizes = model_sizes(encoder, size=imsize)\n", " sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))\n", " self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)\n", " x = dummy_eval(encoder, imsize).detach()\n", "\n", " ni = sizes[-1][1]\n", " middle_conv = nn.Sequential(\n", " ConvLayer(ni, ni * 2, act_cls=act_cls, norm_type=norm_type, **kwargs),\n", " ConvLayer(ni * 2, ni, act_cls=act_cls, norm_type=norm_type, **kwargs),\n", " ).eval()\n", " x = middle_conv(x)\n", " layers = [encoder, BatchNorm(ni), nn.ReLU(), middle_conv]\n", "\n", " for i, idx in enumerate(sz_chg_idxs):\n", " not_final = i != len(sz_chg_idxs) - 1\n", " up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])\n", " do_blur = blur and (not_final or blur_final)\n", " sa = self_attention and (i == len(sz_chg_idxs) - 3)\n", " unet_block = UnetBlock(\n", " up_in_c,\n", " x_in_c,\n", " self.sfs[i],\n", " final_div=not_final,\n", " blur=do_blur,\n", " self_attention=sa,\n", " act_cls=act_cls,\n", " init=init,\n", " norm_type=norm_type,\n", " **kwargs\n", " ).eval()\n", " layers.append(unet_block)\n", " x = unet_block(x)\n", "\n", " ni = x.shape[1]\n", " if imsize != sizes[0][-2:]:\n", " layers.append(PixelShuffle_ICNR(ni, act_cls=act_cls, norm_type=norm_type))\n", " layers.append(ResizeToOrig())\n", " if last_cross:\n", " layers.append(MergeLayer(dense=True))\n", " ni += in_channels(encoder)\n", " layers.append(\n", " ResBlock(\n", " 1,\n", " ni,\n", " ni // 2 if bottle else ni,\n", " act_cls=act_cls,\n", " norm_type=norm_type,\n", " **kwargs\n", " )\n", " )\n", " layers += [\n", " ConvLayer(ni, n_classes, ks=1, act_cls=None, norm_type=norm_type, **kwargs)\n", " ]\n", " apply_init(nn.Sequential(layers[3], layers[-2]), init)\n", " # apply_init(nn.Sequential(layers[2]), init)\n", " if y_range is not None:\n", " layers.append(SigmoidRange(*y_range))\n", " super().__init__(*layers)\n", "\n", " def __del__(self):\n", " if hasattr(self, \"sfs\"):\n", " self.sfs.remove()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_torch_rep = DynamicUnetDIY()\n", "model_torch_rep" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "state = torch.load(\"../model_store/fasti_unet_weights.pth\")\n", "model_torch_rep.load_state_dict(state)\n", "model_torch_rep.eval();" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Testing" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image = Image.open(image_path).convert(\"RGB\")\n", "image" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from torchvision import transforms" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "image_tfm = transforms.Compose(\n", " [\n", " # must be consistent with model training\n", " transforms.Resize((96, 128)),\n", " transforms.ToTensor(),\n", " # default statistics from imagenet\n", " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n", " ]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = image_tfm(image).unsqueeze_(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%time\n", "# inference on CPU\n", "raw_out = model_torch_rep(x)\n", "raw_out.shape" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pred_res = raw_out[0].argmax(dim=0).numpy().astype(np.uint8)\n", "pred_res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import base64\n", "import numpy as np\n", "\n", "pred_encoded = base64.b64encode(pred_res).decode(\"utf-8\")\n", "pred_decoded_byte = base64.decodebytes(bytes(pred_encoded, encoding=\"utf-8\"))\n", "pred_decoded = np.reshape(\n", " np.frombuffer(pred_decoded_byte, dtype=np.uint8), pred_res.shape\n", ")\n", "\n", "assert np.allclose(pred_decoded, pred_res)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plt.imshow(pred_decoded);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "np.all(pred_fastai[0].numpy() == pred_res)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }