"""
adapted from upstream test: https://github.com/fastai/fastai/blob/master/nbs/examples/distrib_pytorch.py
"""

from fastai.vision.all import *
from fastai.distributed import *
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class Net(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Conv2d(1, 32, 3, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1),
            nn.MaxPool2d(2),
            nn.Dropout2d(0.25),
            Flatten(),
            nn.Linear(9216, 128),
            nn.ReLU(),
            nn.Dropout2d(0.5),
            nn.Linear(128, 10),
            nn.LogSoftmax(dim=1),
        )


batch_size, test_batch_size = 256, 512
epochs, lr = 5, 1e-2

kwargs = {"num_workers": 1, "pin_memory": True}
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_loader = DataLoader(
    datasets.MNIST("../data", train=True, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True,
    **kwargs
)
test_loader = DataLoader(
    datasets.MNIST("../data", train=False, transform=transform),
    batch_size=test_batch_size,
    shuffle=True,
    **kwargs
)

if __name__ == "__main__":
    data = DataLoaders(train_loader, test_loader)
    learn = Learner(
        data,
        Net(),
        loss_func=F.nll_loss,
        opt_func=Adam,
        metrics=accuracy,
        path="/opt/ml",
        model_dir="model",
    )
    with learn.distrib_ctx():
        learn.fit_one_cycle(epochs, lr)
    learn.save("model")