#!/usr/bin/env python3

import torch
import argparse
from tqdm.auto import tqdm
import logging
import torchvision
import json
import os
import PIL
import PIL.Image
import torch
import multiprocessing



logger = logging.getLogger(__name__)
torch.backends.cuda.matmul.allow_tf32 = True


class DatasetDecomposed(torch.utils.data.Dataset):
    def __init__(self, root, classes_original, transform = None):
        self.class_to_idx = []
        self.class_to_idx_original = {}
        self.classes = []
        self.classes_original = classes_original
        self.config = {}
        self.file_paths = []
        self.labels = []
        self.labels_original = []
        self.root = root
        self.transform = transform

        if not os.path.isdir(root):
            logger.error("Invalid dataset directory.")
            return

        file_config_dataset = open(os.path.join(root, "dataset.json"), "r")
        self.config = json.load(file_config_dataset)
        file_config_dataset.close()

        for dataset in self.config["datasets"]:
            dataset_labels = dataset["labels"]
            dataset_labels.remove("")
            self.classes.append(dataset_labels)
            self.labels.append([])

        for classes_dataset in self.classes:
            class_to_idx_dataset = {}

            for (class_index, class_name) in enumerate(classes_dataset):
                class_to_idx_dataset[class_name] = class_index

            self.class_to_idx.append(class_to_idx_dataset)

        for (class_index, class_name) in enumerate(self.classes_original):
            self.class_to_idx_original[class_name] = class_index

        for file_name in os.listdir(root):
            if ".json" in file_name:
                continue

            file_name_split = file_name.split("---")

            # A valid data should contain decomposed labels, an original label, and a file key
            if len(file_name_split) < len(self.labels) + 2:
                logger.warning("Invalid data \"" + file_name + "\".")
                continue

            class_original = file_name_split[-2]
            label_original = self.class_to_idx_original[class_original]

            self.file_paths.append(os.path.abspath(os.path.join(root, file_name)))
            self.labels_original.append(label_original)

            for i in range(0, len(self.labels)):
                label = self.class_to_idx[i][file_name_split[i]]
                self.labels[i].append(label)

        return

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, index):
        image = PIL.Image.open(self.file_paths[index]).convert("RGB")
        labels = []

        if self.transform is not None:
            image = self.transform(image)

        for label in self.labels:
            labels.append(int(label[index]))

        return (image, torch.LongTensor(labels), self.labels_original[index])



class ResNet152MTL(torch.nn.Module):
    def __init__(self, config_dataset):
        super().__init__()

        self.net = torchvision.models.resnet152()
        net_fc_in_features = self.net.fc.in_features
        self.net.fc = torch.nn.Identity()

        layer_list = []

        for dataset_entry in config_dataset["datasets"]:
            dataset_name = dataset_entry["name"]
            dataset_labels = dataset_entry["labels"]

            layer = torch.nn.Linear(net_fc_in_features, len(dataset_labels))
            layer_dict = torch.nn.ModuleDict({dataset_name: layer})

            layer_list.append(layer_dict)

        self.net.heads_mtl = torch.nn.ModuleList(layer_list)

        return

    def forward(self, input):
        outputs_head = []
        output_neck = self.net(input)

        for head in self.net.heads_mtl:
            head_layer = list(head.values())[0]
            outputs_head.append(head_layer(output_neck))

        return outputs_head

class ViTB32MTL(torch.nn.Module):
    def __init__(self, config_dataset):
        super().__init__()

        self.net = torchvision.models.vit_b_32()
        net_heads_head_in_features = self.net.heads.head.in_features
        self.net.heads.head = torch.nn.Identity()

        layer_list = []

        for dataset_entry in config_dataset["datasets"]:
            dataset_name = dataset_entry["name"]
            dataset_labels = dataset_entry["labels"]

            layer = torch.nn.Linear(net_heads_head_in_features, len(dataset_labels))
            layer_dict = torch.nn.ModuleDict({dataset_name: layer})

            layer_list.append(layer_dict)

        self.net.heads_mtl = torch.nn.ModuleList(layer_list)

        return

    def forward(self, input):
        outputs_head = []
        output_neck = self.net(input)

        for head in self.net.heads_mtl:
            head_layer = list(head.values())[0]
            outputs_head.append(head_layer(output_neck))

        return outputs_head



def pgd_attack_mtl(model, x, y, task, num_steps, step_size, step_norm, eps, eps_norm,
               clamp=(0, 1), y_target=None):

    x_adv = x.clone().detach().requires_grad_(True)
    targeted = y_target is not None
    num_channels = x.shape[1]
    batch_size = x.shape[0]

    loss_fn = torch.nn.CrossEntropyLoss()
    model = model.train()

    for i in range(num_steps):

        prediction = model(x_adv)[task]


        loss = loss_fn(prediction, y_target if targeted else y[:,task])
        loss.backward()

        with torch.no_grad():
            if step_norm == 'inf':
                gradients = x_adv.grad.sign() * step_size
            else:
                gradients = x_adv.grad * step_size / x_adv.grad.view(batch_size, -1).norm(step_norm, dim=-1).view(-1, num_channels, 1, 1)

            if targeted:
                x_adv -= gradients
            else:
                x_adv += gradients

        if eps_norm == 'inf':
            x_adv = torch.max(torch.min(x_adv, x + eps), x - eps)
        else:
            delta = x_adv - x
            mask = delta.view(batch_size, -1).norm(eps_norm, dim=1) <= eps
            scaling_factor = delta.view(batch_size, -1).norm(eps_norm, dim=1)
            scaling_factor[mask] = eps
            delta *= eps / scaling_factor.view(-1, 1, 1, 1)
            x_adv = x + delta

        x_adv = x_adv.clamp(*clamp)
        x_adv = x_adv.detach().requires_grad_(True)

    return [(x * 255).to(torch.uint8) for x in list(x_adv.detach().cpu())]



def main(args):

    dataset_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(
            (224, 224)),
        torchvision.transforms.ToTensor(),
    ])


    model_path = "../../mapillary-dataset-network/checkpoints/"

    dataset_test_o = torchvision.datasets.ImageFolder(args.dataset, dataset_transforms)

    dataset_test = DatasetDecomposed("../../mapillary-dataset/images/split/generated/test",
                                    dataset_test_o.classes,
                                    dataset_transforms)
    config_dataset = dataset_test.config

    data_loader_test = torch.utils.data.DataLoader(dataset_test,
                                                   batch_size=128,
                                                   shuffle=False,
                                                   num_workers=multiprocessing.cpu_count()
                                                   )



    device = torch.device("cuda")

    model_type = args.model.split(".")[1].split("_")[0]

    if model_type == 'vit':
        model = ViTB32MTL(config_dataset)
        checkpoint = torch.load(model_path + args.model)
    else:
        model = ResNet152MTL(config_dataset)
        checkpoint = torch.load(model_path + args.model)
    model = torch.nn.DataParallel(model)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.train()

    attacked_images = []
    for input, labels, _ in tqdm(data_loader_test):
        input = input.to(device)
        labels = labels.to(device)
        attacked_images += pgd_attack_mtl(model,
                                    input,
                                    labels,
                                    args.task,
                                    args.num_steps,
                                    args.step_size,
                                    args.step_norm,
                                    args.epsilon,
                                    args.eps_norm)

    os.makedirs("../../mapillary-dataset/images/split/pgd/generated", exist_ok = True)
    torch.save(attacked_images, "../../mapillary-dataset/images/split/pgd/generated/" + args.model + "_task_" + str(args.task) + ".pt")




if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    print(torch.cuda.is_available())
    parser.add_argument('--attack', type=str, default='pgd', help='attack')
    parser.add_argument('--epsilon', type=float, default=0.2, help='epsilon')
    parser.add_argument('--num_steps', type=int, default=20, help='num_steps')
    parser.add_argument('--step_size', type=float, default=1e-2, help='step_size')
    parser.add_argument('--task', type=int, default=0, help='Taks id from 0-4')
    parser.add_argument('--step_norm', type=str, default='inf', help='step_norm')
    parser.add_argument('--eps_norm', type=str, default='inf', help='eps_norm')
    parser.add_argument('--model', type=str, default='', help='name of model')
    parser.add_argument('--dataset', type=str, default='../../mapillary-dataset/images/split/original/test', help='loc of data')

    args = parser.parse_args()

    main(args)
