import os
import re
import ast
import yaml

#### FUNCTIONS ####
def get_channel_name_combi(channel_combi_num, channel_dict):
    name_of_channel_combi = ""
    for channel_number in iter(str(channel_combi_num)):
        name_of_channel_combi = "_".join([name_of_channel_combi, channel_dict[int(channel_number)]])
    return name_of_channel_combi

def get_channel_number_combi(channel_names, channel_dict):
    channel_combi = ""
    for channel_name in channel_names.split('_'):
        for key, value in channel_dict.items():
            if value == channel_name:
                channel_combi = "".join([channel_combi, str(key)])
    return channel_combi

def get_channel_name_combi_list(selected_channels, channel_dict):
    channel_names = []
    for channel_combi in selected_channels:
        channel_names.append(get_channel_name_combi(channel_combi,channel_dict))
    return channel_names

def save_config_file(config, save_dir):
    os.makedirs(save_dir, exist_ok=True)
    with open(f"{save_dir}/run_config_dump.json", "w") as f:
        json.dump(config, f)
    with open(f"{save_dir}/run_config_dump.yaml", "w") as f:
        yaml.dump(config, f)

def load_norm_per_channel(filepath_mean_and_std_of_dataset):
    with open(filepath_mean_and_std_of_dataset) as f:
        norm_per_channel_json = json.load(f)
        norm_per_channel = str([tuple(norm_per_channel_json['mean']), tuple(norm_per_channel_json['std'])])
        return norm_per_channel 


#### PARSING ####
name_of_run = config['meta']['name_of_run']
sk_save_dir = config['meta']['output_dir']
ViT_name = config['train_scDINO']['dino_vit_name']
epochs= int(config['train_scDINO']['epochs']+1)
save_dir_downstream_run = sk_save_dir+"/"+name_of_run
selected_channels = config['meta']['selected_channel_combination_per_run']
channel_dict = config['meta']['channel_dict']
saveckp_freq = int(config['train_scDINO']['saveckp_freq'])
epoch_nums = [epoch_num*saveckp_freq for epoch_num in range(0,(int(epochs/saveckp_freq)))]
if (epochs-1) not in epoch_nums: 
    epoch_nums.append(epochs-1)
print('Epochs for downstream analyses:', epoch_nums)
save_config_file(config, save_dir_downstream_run)

#scDINO-ViTs_path_for_extract_labels
scDINO_ViTs_for_path_extract_labels = f"{save_dir_downstream_run}/scDINO_ViTs/{ViT_name}_channel{get_channel_name_combi_list(selected_channels, channel_dict)[0]}/checkpoint0.pth"