{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "sys.path.insert(0, \"..\")\n", "from deployment.handler import TwinHandler" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get Dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from fastai.vision.all import *\n", "path = untar_data(URLs.PETS)\n", "files = get_image_files(path/\"images\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pygmentize ../deployment/handler.py" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "twin_handler = TwinHandler()\n", "image_tfm = twin_handler.image_tfm\n", "image_tfm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TwinImage(fastuple):\n", " @staticmethod\n", " def img_restore(image: torch.Tensor):\n", " return (image - image.min()) / (image.max() - image.min())\n", "\n", " def show(self, ctx=None, **kwargs):\n", " if len(self) > 2:\n", " img1, img2, same_breed = self\n", " else:\n", " img1, img2 = self\n", " same_breed = \"Undetermined\"\n", " if not isinstance(img1, Tensor):\n", " t1, t2 = image_tfm(img1), image_tfm(img2)\n", " else:\n", " t1, t2 = img1, img2\n", " line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)\n", " return show_image(\n", " torch.cat([self.img_restore(t1), line, self.img_restore(t2)], dim=2),\n", " title=same_breed,\n", " ctx=ctx,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img = PILImage.create(files[0])\n", "s = TwinImage(img, img, True)\n", "s.show();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "img1 = PILImage.create(files[1])\n", "s1 = TwinImage(img, img1, False)\n", "s1.show();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s2 = Resize(224)(s1)\n", "s2.show();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "s1[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Label and Dataloader" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def label_func(fname):\n", " return re.match(r\"^(.*)_\\d+.jpg$\", fname.name).groups()[0]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TwinTransform(Transform):\n", " def __init__(self, files, label_func, splits):\n", " self.labels = files.map(label_func).unique()\n", " self.lbl2files = {\n", " l: L(f for f in files if label_func(f) == l) for l in self.labels\n", " }\n", " self.label_func = label_func\n", " self.valid = {f: self._draw(f) for f in files[splits[1]]}\n", "\n", " def encodes(self, f):\n", " f2, t = self.valid.get(f, self._draw(f))\n", " img1, img2 = PILImage.create(f), PILImage.create(f2)\n", " if (f not in self.valid) and random.random() < 0.5:\n", " img1, img2 = img2, img1\n", " img1, img2 = image_tfm(img1), image_tfm(img2)\n", " return TwinImage(img1, img2, t)\n", "\n", " def _draw(self, f):\n", " same = random.random() < 0.5\n", " cls = self.label_func(f)\n", " if not same:\n", " cls = random.choice(L(l for l in self.labels if l != cls))\n", " return random.choice(self.lbl2files[cls]), same" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "splits = RandomSplitter()(files)\n", "tfm = TwinTransform(files, label_func, splits)\n", "tfm(files[0]).show();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tls = TfmdLists(files, tfm, splits=splits)\n", "show_at(tls.valid, 0);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls = tls.dataloaders(\n", " bs=32,\n", " after_batch=[*aug_transforms()],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@typedispatch\n", "def show_batch(\n", " x: TwinImage,\n", " y,\n", " samples,\n", " ctxs=None,\n", " max_n=6,\n", " nrows=None,\n", " ncols=2,\n", " figsize=None,\n", " **kwargs\n", "):\n", " if figsize is None:\n", " figsize = (ncols * 6, max_n // ncols * 3)\n", " if ctxs is None:\n", " ctxs = get_grid(\n", " min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize\n", " )\n", " for i, ctx in enumerate(ctxs):\n", " TwinImage(x[0][i], x[1][i], [\"Not similar\", \"Similar\"][x[2][i].item()]).show(\n", " ctx=ctx\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dls.show_batch()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model and Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class TwinModel(Module):\n", " def __init__(self, encoder, head):\n", " self.encoder, self.head = encoder, head\n", "\n", " def forward(self, x1, x2):\n", " ftrs = torch.cat([self.encoder(x1), self.encoder(x2)], dim=1)\n", " return self.head(ftrs)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "head = twin_handler.head_reload\n", "encoder, _ = twin_handler.get_encoder(pre_train=True)\n", "model = TwinModel(encoder, head)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def loss_func(out, targ):\n", " return nn.CrossEntropyLoss()(out, targ.long())\n", "\n", "\n", "def twin_splitter(model):\n", " return [params(model.encoder), params(model.head)]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn = Learner(\n", " dls, model, loss_func=loss_func, splitter=twin_splitter, metrics=accuracy\n", ")\n", "learn.freeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib\n", "import matplotlib.pyplot as plt\n", "\n", "# Customize matplotlib\n", "matplotlib.rcParams.update(\n", " {\n", " 'text.usetex': False,\n", " 'font.family': 'stixgeneral',\n", " 'mathtext.fontset': 'stix',\n", " }\n", ")\n", "\n", "learn.lr_find()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.fit_one_cycle(4, 1e-3)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.export(\"../model/resnet50_0.959.pkl\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.unfreeze()\n", "learn.fit_one_cycle(4, slice(1e-6,1e-4))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.export(\"../model/resnet50_0.962.pkl\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Diagnose" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@typedispatch\n", "def show_results(\n", " x: TwinImage,\n", " y,\n", " samples,\n", " outs,\n", " ctxs=None,\n", " max_n=6,\n", " nrows=None,\n", " ncols=2,\n", " figsize=None,\n", " **kwargs,\n", "):\n", " if figsize is None:\n", " figsize = (ncols * 6, max_n // ncols * 4)\n", " if ctxs is None:\n", " ctxs = get_grid(\n", " min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize\n", " )\n", " for i, ctx in enumerate(ctxs):\n", " title = f'Actual: {[\"Not similar\",\"Similar\"][x[2][i].item()]} \\n \\\n", " Prediction: {[\"Not similar\",\"Similar\"][y[2][i].argmax().item()]}'\n", " TwinImage(x[0][i], x[1][i], title).show(ctx=ctx)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learn.show_results()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner_reload = load_learner(\"../model/resnet50_0.962.pkl\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@patch\n", "def twinpredict(\n", " self: Learner,\n", " item,\n", " rm_type_tfms=None,\n", " with_input=False,\n", "):\n", " res = self.predict(item, rm_type_tfms=None, with_input=False)\n", " if res[0].argmax().item() == 0:\n", " label = \"Prediction: Not similar\"\n", " else:\n", " label = \"Prediction: Similar\"\n", " TwinImage(item[0], item[1], label).show()\n", " return res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "imgtest = image_tfm(PILImage.create(files[0]))\n", "imgval = image_tfm(PILImage.create(files[1]))\n", "twintest = TwinImage(imgtest, imgval)\n", "twintest.show();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "res = learner_reload.twinpredict(twintest)\n", "res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "imgtest = image_tfm(PILImage.create(files[1]))\n", "imgval = image_tfm(PILImage.create(files[1]))\n", "twintest = TwinImage(imgtest, imgval)\n", "twintest.show();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "res = learner_reload.twinpredict(twintest)\n", "res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "imgtest = image_tfm(PILImage.create(\"../sample/c1.jpg\"))\n", "imgval = image_tfm(PILImage.create(\"../sample/c2.jpg\"))\n", "\n", "twintest = TwinImage(imgtest, imgval)\n", "twintest.show();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "res = learner_reload.twinpredict(twintest)\n", "res" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "imgtest = image_tfm(PILImage.create(\"../sample/c3.jpg\"))\n", "imgval = image_tfm(PILImage.create(\"../sample/c2.jpg\"))\n", "twintest = TwinImage(imgtest, imgval)\n", "twintest.show();" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "res = learner_reload.twinpredict(twintest)\n", "res" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Export to PyTorch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "learner_reload = load_learner(\"../model/resnet50_0.962.pkl\")\n", "torch.save(learner_reload.encoder.state_dict(), \"../model/resnet50_0.962_encoder.pth\")\n", "torch.save(learner_reload.head.state_dict(), \"../model/resnet50_0.962_head.pth\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "latex_envs": { "LaTeX_envs_menu_present": true, "autoclose": false, "autocomplete": true, "bibliofile": "biblio.bib", "cite_by": "apalike", "current_citInitial": 1, "eqLabelWithNumbers": true, "eqNumInitial": 1, "hotkeys": { "equation": "Ctrl-E", "itemize": "Ctrl-I" }, "labels_anchors": false, "latex_user_defs": false, "report_style_numbering": false, "user_envs_cfg": false }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }