#!/usr/bin/env python3

import header
import logger
import network
import torch
import torch.nn
import torch.optim
import torchsummary
import torchvision
import test
import tqdm
import train
import type
import utility
import validate
import wandb

def main():
    resume = utility.processArgumentsTrainBaseline()

    utility.setSeed(header.config_baseline["seed"])
    torch.backends.cuda.matmul.allow_tf32 = True

    if header.run_mode == "online":
        wandb.login()

    wandb.init(project = header.project_name, name = header.run_name_baseline, config = header.config_baseline, resume = resume, mode = header.run_mode)

    utility.wAndBDefineMetrics()

    logger.log_info("Started run \"" + header.run_name_baseline + "\".")

    accuracy_validation_best = 0
    batch_step_test = 1
    batch_step_train = 1
    batch_step_validate = 1
    criterion = torch.nn.CrossEntropyLoss()
    dataset_transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize((header.config_baseline["model_input_height"], header.config_baseline["model_input_width"])),
        torchvision.transforms.ToTensor(),
    ])
    dataset_test = torchvision.datasets.ImageFolder(header.config_baseline["dir_dataset_test"], dataset_transforms)
    dataset_train = torchvision.datasets.ImageFolder(header.config_baseline["dir_dataset_train"], dataset_transforms)
    dataset_validation = torchvision.datasets.ImageFolder(header.config_baseline["dir_dataset_validation"], dataset_transforms)
    class_count = len(dataset_train.classes)
    data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size = header.config_baseline["data_loader_batch_size"], shuffle = False, num_workers = header.config_baseline["data_loader_worker_count"], pin_memory = True)
    data_loader_train = torch.utils.data.DataLoader(dataset_train, batch_size = header.config_baseline["data_loader_batch_size"], shuffle = header.config_baseline["data_loader_shuffle"], num_workers = header.config_baseline["data_loader_worker_count"], pin_memory = True)
    data_loader_validation = torch.utils.data.DataLoader(dataset_validation, batch_size = header.config_baseline["data_loader_batch_size"], shuffle = header.config_baseline["data_loader_shuffle"], num_workers = header.config_baseline["data_loader_worker_count"], pin_memory = True)
    device = torch.device("cuda")
    epoch = 1
    model = network.createModelBaseline(class_count)
    model = torch.nn.DataParallel(model)
    model = model.to(device)
    optimizer = torch.optim.SGD(model.module.getOptimizerParameters(), lr = header.config_baseline["optimizer_learning_rate"], momentum = header.config_baseline["optimizer_momentum"], weight_decay = header.config_baseline["optimizer_weight_decay"])
    learning_rate_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, header.config_baseline["learning_rate_scheduler_mode"], header.config_baseline["learning_rate_scheduler_factor"], header.config_baseline["learning_rate_scheduler_patience"], header.config_baseline["learning_rate_scheduler_threshold"], header.config_baseline["learning_rate_scheduler_threshold_mode"], header.config_baseline["learning_rate_scheduler_cooldown"], header.config_baseline["learning_rate_scheduler_min_learning_rate"], header.config_baseline["learning_rate_scheduler_min_learning_rate_decay"], header.config_baseline["learning_rate_scheduler_verbose"])
    progress_bar = None

    (accuracy_validation_best, batch_step_train, batch_step_validate, criterion, epoch) = utility.loadCheckpoint(header.config_baseline["dir_checkpoints"], header.config_baseline["file_name_checkpoint"], accuracy_validation_best, batch_step_train, batch_step_validate, criterion, epoch, learning_rate_scheduler, model, optimizer)

    if header.log_level >= type.LogLevel.trace:
        model_input_size = (header.config_baseline["model_input_channels"], header.config_baseline["model_input_height"], header.config_baseline["model_input_width"])
        torchsummary.summary(model, input_size = model_input_size)

    if epoch <= header.config_baseline["epochs"]:
        progress_bar = tqdm.tqdm(total = header.config_baseline["epochs"], position = 0)
        progress_bar.set_description_str("[INFO]: Epoch")

    while epoch <= header.config_baseline["epochs"]:
        if progress_bar is not None:
            progress_bar.n = epoch
            progress_bar.refresh()

        wandb.log({"training/epoch/step": epoch})
        wandb.log({"validation/epoch/step": epoch})

        batch_step_train = train.trainBaseline(model, data_loader_train, criterion, optimizer, device, batch_step_train)
        (accuracy_validation_epoch, loss_validation_epoch, batch_step_validate) = validate.validateBaseline(model, data_loader_validation, criterion, device, batch_step_validate)

        learning_rate_scheduler.step(loss_validation_epoch)

        if accuracy_validation_epoch > accuracy_validation_best:
            accuracy_validation_best = accuracy_validation_epoch
            wandb.log({"validation/epoch/accuracy_best": accuracy_validation_best})
            utility.saveCheckpoint(header.config_baseline["dir_checkpoints"], header.config_baseline["file_name_checkpoint_best"], accuracy_validation_best, batch_step_train, batch_step_validate, criterion, epoch, learning_rate_scheduler, model, optimizer)

        utility.saveCheckpoint(header.config_baseline["dir_checkpoints"], header.config_baseline["file_name_checkpoint"], accuracy_validation_best, batch_step_train, batch_step_validate, criterion, epoch, learning_rate_scheduler, model, optimizer)

        epoch += 1

    if progress_bar is not None:
        progress_bar.close()

    logger.log_info("Best validation accuracy: " + str(accuracy_validation_best) + ".")
    wandb.summary["validation/epoch/accuracy_best"] = accuracy_validation_best

    wandb.log({"testing/epoch/step": 1})
    batch_step_test = test.testBaseline(model, data_loader_test, device, batch_step_test)

    wandb.finish()

    return

if __name__ == "__main__":
    main()
