In [None]:
%reload_ext autoreload
%autoreload 2

%matplotlib inline

In [None]:
import sys

sys.path.insert(0, "..")
from deployment.handler import TwinHandler

## Get Dataset

In [None]:
from fastai.vision.all import *
path = untar_data(URLs.PETS)
files = get_image_files(path/"images")

In [None]:
!pygmentize ../deployment/handler.py

In [None]:
twin_handler = TwinHandler()
image_tfm = twin_handler.image_tfm
image_tfm

In [None]:
class TwinImage(fastuple):
 @staticmethod
 def img_restore(image: torch.Tensor):
 return (image - image.min()) / (image.max() - image.min())

 def show(self, ctx=None, **kwargs):
 if len(self) > 2:
 img1, img2, same_breed = self
 else:
 img1, img2 = self
 same_breed = "Undetermined"
 if not isinstance(img1, Tensor):
 t1, t2 = image_tfm(img1), image_tfm(img2)
 else:
 t1, t2 = img1, img2
 line = t1.new_zeros(t1.shape[0], t1.shape[1], 10)
 return show_image(
 torch.cat([self.img_restore(t1), line, self.img_restore(t2)], dim=2),
 title=same_breed,
 ctx=ctx,
 )

In [None]:
img = PILImage.create(files[0])
s = TwinImage(img, img, True)
s.show();

In [None]:
img1 = PILImage.create(files[1])
s1 = TwinImage(img, img1, False)
s1.show();

In [None]:
s2 = Resize(224)(s1)
s2.show();

In [None]:
s1[0]

## Label and Dataloader

In [None]:
def label_func(fname):
 return re.match(r"^(.*)_\d+.jpg$", fname.name).groups()[0]

In [None]:
class TwinTransform(Transform):
 def __init__(self, files, label_func, splits):
 self.labels = files.map(label_func).unique()
 self.lbl2files = {
 l: L(f for f in files if label_func(f) == l) for l in self.labels
 }
 self.label_func = label_func
 self.valid = {f: self._draw(f) for f in files[splits[1]]}

 def encodes(self, f):
 f2, t = self.valid.get(f, self._draw(f))
 img1, img2 = PILImage.create(f), PILImage.create(f2)
 if (f not in self.valid) and random.random() < 0.5:
 img1, img2 = img2, img1
 img1, img2 = image_tfm(img1), image_tfm(img2)
 return TwinImage(img1, img2, t)

 def _draw(self, f):
 same = random.random() < 0.5
 cls = self.label_func(f)
 if not same:
 cls = random.choice(L(l for l in self.labels if l != cls))
 return random.choice(self.lbl2files[cls]), same

In [None]:
splits = RandomSplitter()(files)
tfm = TwinTransform(files, label_func, splits)
tfm(files[0]).show();

In [None]:
tls = TfmdLists(files, tfm, splits=splits)
show_at(tls.valid, 0);

In [None]:
dls = tls.dataloaders(
 bs=32,
 after_batch=[*aug_transforms()],
)

In [None]:
@typedispatch
def show_batch(
 x: TwinImage,
 y,
 samples,
 ctxs=None,
 max_n=6,
 nrows=None,
 ncols=2,
 figsize=None,
 **kwargs
):
 if figsize is None:
 figsize = (ncols * 6, max_n // ncols * 3)
 if ctxs is None:
 ctxs = get_grid(
 min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize
 )
 for i, ctx in enumerate(ctxs):
 TwinImage(x[0][i], x[1][i], ["Not similar", "Similar"][x[2][i].item()]).show(
 ctx=ctx
 )

In [None]:
dls.show_batch()

## Model and Training

In [None]:
class TwinModel(Module):
 def __init__(self, encoder, head):
 self.encoder, self.head = encoder, head

 def forward(self, x1, x2):
 ftrs = torch.cat([self.encoder(x1), self.encoder(x2)], dim=1)
 return self.head(ftrs)

In [None]:
head = twin_handler.head_reload
encoder, _ = twin_handler.get_encoder(pre_train=True)
model = TwinModel(encoder, head)

In [None]:
def loss_func(out, targ):
 return nn.CrossEntropyLoss()(out, targ.long())


def twin_splitter(model):
 return [params(model.encoder), params(model.head)]

In [None]:
learn = Learner(
 dls, model, loss_func=loss_func, splitter=twin_splitter, metrics=accuracy
)
learn.freeze()

In [None]:
import matplotlib
import matplotlib.pyplot as plt

# Customize matplotlib
matplotlib.rcParams.update(
 {
 'text.usetex': False,
 'font.family': 'stixgeneral',
 'mathtext.fontset': 'stix',
 }
)

learn.lr_find()

In [None]:
learn.fit_one_cycle(4, 1e-3)

In [None]:
learn.export("../model/resnet50_0.959.pkl")

In [None]:
learn.unfreeze()
learn.fit_one_cycle(4, slice(1e-6,1e-4))

In [None]:
learn.export("../model/resnet50_0.962.pkl")

## Diagnose

In [None]:
@typedispatch
def show_results(
 x: TwinImage,
 y,
 samples,
 outs,
 ctxs=None,
 max_n=6,
 nrows=None,
 ncols=2,
 figsize=None,
 **kwargs,
):
 if figsize is None:
 figsize = (ncols * 6, max_n // ncols * 4)
 if ctxs is None:
 ctxs = get_grid(
 min(x[0].shape[0], max_n), nrows=None, ncols=ncols, figsize=figsize
 )
 for i, ctx in enumerate(ctxs):
 title = f'Actual: {["Not similar","Similar"][x[2][i].item()]} \n \
 Prediction: {["Not similar","Similar"][y[2][i].argmax().item()]}'
 TwinImage(x[0][i], x[1][i], title).show(ctx=ctx)

In [None]:
learn.show_results()

## Inference

In [None]:
learner_reload = load_learner("../model/resnet50_0.962.pkl")

In [None]:
@patch
def twinpredict(
 self: Learner,
 item,
 rm_type_tfms=None,
 with_input=False,
):
 res = self.predict(item, rm_type_tfms=None, with_input=False)
 if res[0].argmax().item() == 0:
 label = "Prediction: Not similar"
 else:
 label = "Prediction: Similar"
 TwinImage(item[0], item[1], label).show()
 return res

In [None]:
imgtest = image_tfm(PILImage.create(files[0]))
imgval = image_tfm(PILImage.create(files[1]))
twintest = TwinImage(imgtest, imgval)
twintest.show();

In [None]:
res = learner_reload.twinpredict(twintest)
res

In [None]:
imgtest = image_tfm(PILImage.create(files[1]))
imgval = image_tfm(PILImage.create(files[1]))
twintest = TwinImage(imgtest, imgval)
twintest.show();

In [None]:
res = learner_reload.twinpredict(twintest)
res

In [None]:
imgtest = image_tfm(PILImage.create("../sample/c1.jpg"))
imgval = image_tfm(PILImage.create("../sample/c2.jpg"))

twintest = TwinImage(imgtest, imgval)
twintest.show();

In [None]:
res = learner_reload.twinpredict(twintest)
res

In [None]:
imgtest = image_tfm(PILImage.create("../sample/c3.jpg"))
imgval = image_tfm(PILImage.create("../sample/c2.jpg"))
twintest = TwinImage(imgtest, imgval)
twintest.show();

In [None]:
res = learner_reload.twinpredict(twintest)
res

## Export to PyTorch

In [None]:
learner_reload = load_learner("../model/resnet50_0.962.pkl")
torch.save(learner_reload.encoder.state_dict(), "../model/resnet50_0.962_encoder.pth")
torch.save(learner_reload.head.state_dict(), "../model/resnet50_0.962_head.pth")