In [None]:
%reload_ext autoreload
%autoreload 2

%matplotlib inline

In [None]:
import sys

sys.path.insert(0, "..")
from deployment.handler import TwinHandler, HookCAM, HookCAMBwd

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import torch

## Recreate Model and Load Weights

In [None]:
twin_handler = TwinHandler()
image_tfm = twin_handler.image_tfm

head_reload = twin_handler.head_reload
encoder_reload, _ = twin_handler.get_encoder(pre_train=False)
type(encoder_reload), type(head_reload)

In [None]:
state_head = torch.load("../model/resnet50_0.962_head.pth")
head_reload.load_state_dict(state_head)
head_reload.eval()

In [None]:
state_encoder = torch.load("../model/resnet50_0.962_encoder.pth")
encoder_reload.load_state_dict(state_encoder)
encoder_reload.eval()

## PyTorch Inference and CAM

In [None]:
imgtest, imgval = (
    Image.open("../sample/c3.jpg").convert("RGB"),
    Image.open("../sample/c4.jpg").convert("RGB"),
)

_, axes = plt.subplots(1, 2, figsize=(10, 5))
axes[0].imshow(imgtest)
axes[0].axis("off")
axes[1].imshow(imgval)
axes[1].axis("off")
plt.show()

imgtest_ts, imgval_ts = image_tfm(imgtest)[None, ...], image_tfm(imgval)[None, ...]

In [None]:
def img_restore(image: torch.Tensor):
        return (image - image.min()) / (image.max() - image.min())

In [None]:
%%time
with HookCAMBwd(encoder_reload) as hookg:
    with HookCAM(encoder_reload) as hook:
        
        encoder_reload.to(torch.device("cuda"))
        l_emb = encoder_reload(imgtest_ts.cuda())
        r_emb = encoder_reload(imgval_ts.cuda())

        ftrs = torch.cat([l_emb, r_emb], dim=1)
        head_reload.to(torch.device("cuda"))
        res = head_reload(ftrs)[0]
        act = hook.stored

    pred_cls = res.argmax().item()
    res[pred_cls].backward()
    grad = hookg.stored
    
encoder_reload.zero_grad(), head_reload.zero_grad()
pred_cls, ["Not Similar", "Similar"][pred_cls]

In [None]:
weight_left = grad[0][0].mean(dim=[1, 2], keepdim=True)
cam_map_left = (weight_left * act[0][0]).sum(0)

weight_right = grad[1][0].mean(dim=[1, 2], keepdim=True)
cam_map_right = (weight_right * act[1][0]).sum(0)
print(cam_map_left.shape, cam_map_right.shape)

In [None]:
_, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(img_restore(imgtest_ts[0].permute(1, 2, 0)))
ax[0].imshow(
    cam_map_left.detach().cpu(),
    alpha=0.3,
    extent=(0, 224, 224, 0),
    interpolation="bilinear",
    cmap="jet",
)
ax[0].axis("off")

ax[1].imshow(img_restore(imgval_ts[0].permute(1, 2, 0)))
ax[1].imshow(
    cam_map_right.detach().cpu(),
    alpha=0.3,
    extent=(0, 224, 224, 0),
    interpolation="bilinear",
    cmap="jet",
)
ax[1].axis("off")

plt.show()

## TorchServe

### handle.py

In [None]:
!pygmentize ../deployment/handler.py

### Archive Model: Encoder and Head

```bash
cd ../

cp model/resnet50_0.962_head.pth model/head_weight.pth
cp model/resnet50_0.962_encoder.pth model/encoder_weight.pth

torch-model-archiver --model-name twin --version 1.0 --serialized-file ./model/encoder_weight.pth --export-path model_store --handler ./deployment/handler.py -f --extra-files ./model/head_weight.pth
```

### Serve the Model

```bash
torchserve --start --ncs --model-store model_store --models twin.mar
```

### Call API

```bash
time http --form POST http://127.0.0.1:8080/predictions/twin left@sample/c1.jpg right@sample/c2.jpg cam=False
```

### Sample Response

```bash
HTTP/1.1 200
Cache-Control: no-cache; no-store, must-revalidate, private
Expires: Thu, 01 Jan 1970 00:00:00 UTC
Pragma: no-cache
connection: keep-alive
content-length: 46
x-request-id: e51e2f15-a9b6-4522-b700-c0f5e35008a7

[
  -0.938737690448761,
  0.7865392565727234
]


real    0m1.765s
user    0m0.334s
sys     0m0.036s
```

### Stop TorchServe

```sh
torchserve --stop
```