{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "sys.path.insert(0, \"..\")\n",
    "from deployment.handler import TwinHandler"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fastai.vision.all import *\n",
    "path = untar_data(URLs.PETS)\n",
    "files = get_image_files(path/\"images\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pygmentize ../deployment/handler.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "twin_handler = TwinHandler()\n",
    "image_tfm = twin_handler.image_tfm\n",
    "image_tfm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TwinImage(fastuple):\n",
    "    @staticmethod\n",
    "    def img_restore(image: torch.Tensor):\n",
    "        return (image - image.min()) / (image.max() - image.min())\n",
    "\n",
    "    def show(self, ctx=None, **kwargs):\n",
    "        if len(self) > 2:\n",
    "            img1, img2, same_breed = self\n",
    "        else:\n",
    "            img1, img2 = self\n",
    "            same_breed = \"Undetermined\"\n",
    "        if not isinstance(img1, Tensor):\n",
    "            t1, t2 = image_tfm(img1), image_tfm(img2)\n",
    "        else:\n",
    "            t1, t2 = img1, img2\n",
    "        line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)\n",
    "        return show_image(\n",
    "            torch.cat([self.img_restore(t1), line, self.img_restore(t2)], dim=2),\n",
    "            title=same_breed,\n",
    "            ctx=ctx,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img = PILImage.create(files[0])\n",
    "s = TwinImage(img, img, True)\n",
    "s.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img1 = PILImage.create(files[1])\n",
    "s1 = TwinImage(img, img1, False)\n",
    "s1.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s2 = Resize(224)(s1)\n",
    "s2.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s1[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Label and Dataloader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def label_func(fname):\n",
    "    return re.match(r\"^(.*)_\\d+.jpg$\", fname.name).groups()[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TwinTransform(Transform):\n",
    "    def __init__(self, files, label_func, splits):\n",
    "        self.labels = files.map(label_func).unique()\n",
    "        self.lbl2files = {\n",
    "            l: L(f for f in files if label_func(f) == l) for l in self.labels\n",
    "        }\n",
    "        self.label_func = label_func\n",
    "        self.valid = {f: self._draw(f) for f in files[splits[1]]}\n",
    "\n",
    "    def encodes(self, f):\n",
    "        f2, t = self.valid.get(f, self._draw(f))\n",
    "        img1, img2 = PILImage.create(f), PILImage.create(f2)\n",
    "        if (f not in self.valid) and random.random() < 0.5:\n",
    "            img1, img2 = img2, img1\n",
    "        img1, img2 = image_tfm(img1), image_tfm(img2)\n",
    "        return TwinImage(img1, img2, t)\n",
    "\n",
    "    def _draw(self, f):\n",
    "        same = random.random() < 0.5\n",
    "        cls = self.label_func(f)\n",
    "        if not same:\n",
    "            cls = random.choice(L(l for l in self.labels if l != cls))\n",
    "        return random.choice(self.lbl2files[cls]), same"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "splits = RandomSplitter()(files)\n",
    "tfm = TwinTransform(files, label_func, splits)\n",
    "tfm(files[0]).show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tls = TfmdLists(files, tfm, splits=splits)\n",
    "show_at(tls.valid, 0);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls = tls.dataloaders(\n",
    "    bs=32,\n",
    "    after_batch=[*aug_transforms()],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@typedispatch\n",
    "def show_batch(\n",
    "    x: TwinImage,\n",
    "    y,\n",
    "    samples,\n",
    "    ctxs=None,\n",
    "    max_n=6,\n",
    "    nrows=None,\n",
    "    ncols=2,\n",
    "    figsize=None,\n",
    "    **kwargs\n",
    "):\n",
    "    if figsize is None:\n",
    "        figsize = (ncols * 6, max_n // ncols * 3)\n",
    "    if ctxs is None:\n",
    "        ctxs = get_grid(\n",
    "            min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize\n",
    "        )\n",
    "    for i, ctx in enumerate(ctxs):\n",
    "        TwinImage(x[0][i], x[1][i], [\"Not similar\", \"Similar\"][x[2][i].item()]).show(\n",
    "            ctx=ctx\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dls.show_batch()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model and Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TwinModel(Module):\n",
    "    def __init__(self, encoder, head):\n",
    "        self.encoder, self.head = encoder, head\n",
    "\n",
    "    def forward(self, x1, x2):\n",
    "        ftrs = torch.cat([self.encoder(x1), self.encoder(x2)], dim=1)\n",
    "        return self.head(ftrs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "head = twin_handler.head_reload\n",
    "encoder, _ = twin_handler.get_encoder(pre_train=True)\n",
    "model = TwinModel(encoder, head)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_func(out, targ):\n",
    "    return nn.CrossEntropyLoss()(out, targ.long())\n",
    "\n",
    "\n",
    "def twin_splitter(model):\n",
    "    return [params(model.encoder), params(model.head)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn = Learner(\n",
    "    dls, model, loss_func=loss_func, splitter=twin_splitter, metrics=accuracy\n",
    ")\n",
    "learn.freeze()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Customize matplotlib\n",
    "matplotlib.rcParams.update(\n",
    "    {\n",
    "        'text.usetex': False,\n",
    "        'font.family': 'stixgeneral',\n",
    "        'mathtext.fontset': 'stix',\n",
    "    }\n",
    ")\n",
    "\n",
    "learn.lr_find()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.fit_one_cycle(4, 1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.export(\"../model/resnet50_0.959.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.unfreeze()\n",
    "learn.fit_one_cycle(4, slice(1e-6,1e-4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.export(\"../model/resnet50_0.962.pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Diagnose"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@typedispatch\n",
    "def show_results(\n",
    "    x: TwinImage,\n",
    "    y,\n",
    "    samples,\n",
    "    outs,\n",
    "    ctxs=None,\n",
    "    max_n=6,\n",
    "    nrows=None,\n",
    "    ncols=2,\n",
    "    figsize=None,\n",
    "    **kwargs,\n",
    "):\n",
    "    if figsize is None:\n",
    "        figsize = (ncols * 6, max_n // ncols * 4)\n",
    "    if ctxs is None:\n",
    "        ctxs = get_grid(\n",
    "            min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize\n",
    "        )\n",
    "    for i, ctx in enumerate(ctxs):\n",
    "        title = f'Actual: {[\"Not similar\",\"Similar\"][x[2][i].item()]} \\n \\\n",
    "        Prediction: {[\"Not similar\",\"Similar\"][y[2][i].argmax().item()]}'\n",
    "        TwinImage(x[0][i], x[1][i], title).show(ctx=ctx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learn.show_results()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner_reload = load_learner(\"../model/resnet50_0.962.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@patch\n",
    "def twinpredict(\n",
    "    self: Learner,\n",
    "    item,\n",
    "    rm_type_tfms=None,\n",
    "    with_input=False,\n",
    "):\n",
    "    res = self.predict(item, rm_type_tfms=None, with_input=False)\n",
    "    if res[0].argmax().item() == 0:\n",
    "        label = \"Prediction: Not similar\"\n",
    "    else:\n",
    "        label = \"Prediction: Similar\"\n",
    "    TwinImage(item[0], item[1], label).show()\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgtest = image_tfm(PILImage.create(files[0]))\n",
    "imgval = image_tfm(PILImage.create(files[1]))\n",
    "twintest = TwinImage(imgtest, imgval)\n",
    "twintest.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = learner_reload.twinpredict(twintest)\n",
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgtest = image_tfm(PILImage.create(files[1]))\n",
    "imgval = image_tfm(PILImage.create(files[1]))\n",
    "twintest = TwinImage(imgtest, imgval)\n",
    "twintest.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = learner_reload.twinpredict(twintest)\n",
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgtest = image_tfm(PILImage.create(\"../sample/c1.jpg\"))\n",
    "imgval = image_tfm(PILImage.create(\"../sample/c2.jpg\"))\n",
    "\n",
    "twintest = TwinImage(imgtest, imgval)\n",
    "twintest.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = learner_reload.twinpredict(twintest)\n",
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgtest = image_tfm(PILImage.create(\"../sample/c3.jpg\"))\n",
    "imgval = image_tfm(PILImage.create(\"../sample/c2.jpg\"))\n",
    "twintest = TwinImage(imgtest, imgval)\n",
    "twintest.show();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = learner_reload.twinpredict(twintest)\n",
    "res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export to PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner_reload = load_learner(\"../model/resnet50_0.962.pkl\")\n",
    "torch.save(learner_reload.encoder.state_dict(), \"../model/resnet50_0.962_encoder.pth\")\n",
    "torch.save(learner_reload.head.state_dict(), \"../model/resnet50_0.962_head.pth\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}