#!/usr/bin/env python3

import torch
import argparse
from tqdm.auto import tqdm
import logging
import torchvision
import os
import multiprocessing

logger = logging.getLogger(__name__)
torch.backends.cuda.matmul.allow_tf32 = True

class ResNet152(torch.nn.Module):
    def __init__(self, class_count):
        super().__init__()
        self.net = torchvision.models.resnet152(pretrained=False)
        net_fc_in_features = self.net.fc.in_features
        self.net.fc = torch.nn.Linear(net_fc_in_features, class_count)

    def forward(self, input):
        return self.net(input)


def pgd_attack(model, x, y, num_steps, step_size, step_norm, eps, eps_norm,
               clamp=(0, 1), y_target=None):
    
    x_adv = x.clone().detach().requires_grad_(True)
    targeted = y_target is not None
    num_channels = x.shape[1]
    batch_size = x.shape[0]

    loss_fn = torch.nn.CrossEntropyLoss() 
    model = model.train()

    for i in range(num_steps):

        prediction = model(x_adv)

        loss = loss_fn(prediction, y_target if targeted else y)
        #x_adv.retain_grad()
        loss.backward()

        with torch.no_grad():
            if step_norm == 'inf':
                gradients = x_adv.grad.sign() * step_size
            else:
                gradients = x_adv.grad * step_size / x_adv.grad.view(batch_size, -1).norm(step_norm, dim=-1).view(-1, num_channels, 1, 1)

            if targeted:
                x_adv -= gradients
            else:
                x_adv += gradients

        if eps_norm == 'inf':
            x_adv = torch.max(torch.min(x_adv, x + eps), x - eps)
        else:
            delta = x_adv - x
            mask = delta.view(batch_size, -1).norm(eps_norm, dim=1) <= eps
            scaling_factor = delta.view(batch_size, -1).norm(eps_norm, dim=1)
            scaling_factor[mask] = eps
            delta *= eps / scaling_factor.view(-1, 1, 1, 1)
            x_adv = x + delta

        x_adv = x_adv.clamp(*clamp)
        x_adv = x_adv.detach().requires_grad_(True)

    return [(x * 255).to(torch.uint8) for x in list(x_adv.detach().cpu())]




def main(args):

    dataset_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(
            (224, 224)),
        torchvision.transforms.ToTensor(),
    ])
    dataset_test = torchvision.datasets.ImageFolder(args.dataset, dataset_transforms)
    class_count = len(dataset_test.classes)

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=128,
                                                   shuffle=False,
                                                   num_workers=multiprocessing.cpu_count()
                                                   )

    device = torch.device("cuda")
    model = ResNet152(class_count)
    model = torch.nn.DataParallel(model)
    checkpoint = torch.load("../../mapillary-dataset-network/checkpoints/" + args.model)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.train()

    attacked_images = []
    for input, labels in tqdm(data_loader_test):
        input = input.to(device)
        labels = labels.to(device)
        attacked_images += pgd_attack(model,
                                    input,
                                    labels,
                                    args.num_steps,
                                    args.step_size,
                                    args.step_norm,
                                    args.epsilon,
                                    args.eps_norm)

    os.makedirs("../../mapillary-dataset/images/split/pgd/original", exist_ok = True)
    torch.save(attacked_images, "../../mapillary-dataset/images/split/pgd/original/" + args.model + ".pt")




if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    print(torch.cuda.is_available())
    parser.add_argument('--attack', type=str, default='pgd', help='attack')
    parser.add_argument('--epsilon', type=float, default=0.2, help='epsilon')
    parser.add_argument('--num_steps', type=int, default=20, help='num_steps')
    parser.add_argument('--step_size', type=float, default=1e-2, help='step_size')
    parser.add_argument('--step_norm', type=str, default='inf', help='step_norm')
    parser.add_argument('--eps_norm', type=str, default='inf', help='eps_norm')
    parser.add_argument('--model', type=str, default='', help='name of model')
    parser.add_argument('--dataset', type=str, default='../../mapillary-dataset/images/split/original/test', help='loc of data')

    args = parser.parse_args()

    main(args)
