In [None]:
%reload_ext autoreload
%autoreload 2


%matplotlib inline

In [None]:
from fastai.vision.all import *
import matplotlib.pyplot as plt
from PIL import Image

## Load FastAI Model and Save Its Well Trained Weights

### Load Model

In [None]:
def acc_camvid(*_): pass

def get_y(*_): pass

learn = load_learner("/home/ubuntu/.fastai/data/camvid_tiny/fastai_unet.pkl")

### Load Data

In [None]:
image_path = "../sample/street_view_of_a_small_neighborhood.png"
Image.open(image_path)

### Inference

In [None]:
%%time
pred_fastai = learn.predict(image_path)
plt.imshow(pred_fastai[0].numpy());

### Save Torch Weights

In [None]:
torch.save(learn.model.state_dict(), "../model_store/fasti_unet_weights.pth")
learn.model

## Extract FastAI Model in PyTorch

In [None]:
??unet_learner

In [None]:
??DynamicUnet

In [None]:
from fastai.vision.all import *
from fastai.vision.learner import _default_meta
from fastai.vision.models.unet import _get_sz_change_idxs, UnetBlock, ResizeToOrig


class DynamicUnetDIY(SequentialEx):
    "Create a U-Net from a given architecture."

    def __init__(
        self,
        arch=resnet50,
        n_classes=32,
        img_size=(96, 128),
        blur=False,
        blur_final=True,
        y_range=None,
        last_cross=True,
        bottle=False,
        init=nn.init.kaiming_normal_,
        norm_type=None,
        self_attention=None,
        act_cls=defaults.activation,
        n_in=3,
        cut=None,
        **kwargs
    ):
        meta = model_meta.get(arch, _default_meta)
        encoder = create_body(
            arch, n_in, pretrained=False, cut=ifnone(cut, meta["cut"])
        )
        imsize = img_size

        sizes = model_sizes(encoder, size=imsize)
        sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))
        self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)
        x = dummy_eval(encoder, imsize).detach()

        ni = sizes[-1][1]
        middle_conv = nn.Sequential(
            ConvLayer(ni, ni * 2, act_cls=act_cls, norm_type=norm_type, **kwargs),
            ConvLayer(ni * 2, ni, act_cls=act_cls, norm_type=norm_type, **kwargs),
        ).eval()
        x = middle_conv(x)
        layers = [encoder, BatchNorm(ni), nn.ReLU(), middle_conv]

        for i, idx in enumerate(sz_chg_idxs):
            not_final = i != len(sz_chg_idxs) - 1
            up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])
            do_blur = blur and (not_final or blur_final)
            sa = self_attention and (i == len(sz_chg_idxs) - 3)
            unet_block = UnetBlock(
                up_in_c,
                x_in_c,
                self.sfs[i],
                final_div=not_final,
                blur=do_blur,
                self_attention=sa,
                act_cls=act_cls,
                init=init,
                norm_type=norm_type,
                **kwargs
            ).eval()
            layers.append(unet_block)
            x = unet_block(x)

        ni = x.shape[1]
        if imsize != sizes[0][-2:]:
            layers.append(PixelShuffle_ICNR(ni, act_cls=act_cls, norm_type=norm_type))
        layers.append(ResizeToOrig())
        if last_cross:
            layers.append(MergeLayer(dense=True))
            ni += in_channels(encoder)
            layers.append(
                ResBlock(
                    1,
                    ni,
                    ni // 2 if bottle else ni,
                    act_cls=act_cls,
                    norm_type=norm_type,
                    **kwargs
                )
            )
        layers += [
            ConvLayer(ni, n_classes, ks=1, act_cls=None, norm_type=norm_type, **kwargs)
        ]
        apply_init(nn.Sequential(layers[3], layers[-2]), init)
        # apply_init(nn.Sequential(layers[2]), init)
        if y_range is not None:
            layers.append(SigmoidRange(*y_range))
        super().__init__(*layers)

    def __del__(self):
        if hasattr(self, "sfs"):
            self.sfs.remove()

In [None]:
model_torch_rep = DynamicUnetDIY()
model_torch_rep

In [None]:
state = torch.load("../model_store/fasti_unet_weights.pth")
model_torch_rep.load_state_dict(state)
model_torch_rep.eval();

### Testing

In [None]:
image = Image.open(image_path).convert("RGB")
image

In [None]:
from torchvision import transforms

In [None]:
image_tfm = transforms.Compose(
    [
        # must be consistent with model training
        transforms.Resize((96, 128)),
        transforms.ToTensor(),
        # default statistics from imagenet
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [None]:
x = image_tfm(image).unsqueeze_(0)

In [None]:
%%time
# inference on CPU
raw_out = model_torch_rep(x)
raw_out.shape

In [None]:
pred_res = raw_out[0].argmax(dim=0).numpy().astype(np.uint8)
pred_res

In [None]:
import base64
import numpy as np

pred_encoded = base64.b64encode(pred_res).decode("utf-8")
pred_decoded_byte = base64.decodebytes(bytes(pred_encoded, encoding="utf-8"))
pred_decoded = np.reshape(
    np.frombuffer(pred_decoded_byte, dtype=np.uint8), pred_res.shape
)

assert np.allclose(pred_decoded, pred_res)

In [None]:
plt.imshow(pred_decoded);

In [None]:
np.all(pred_fastai[0].numpy() == pred_res)