# import
import os
import json
import pandas as pd
import numpy as np
import copy
import plotly.graph_objects as go

from openfisca_core.simulation_builder import SimulationBuilder
from leximpact_survey_scenario.leximpact_tax_and_benefit_system import leximpact_tbs
# chemin vers les cas-types
path_cas_types = os.path.join(os.getcwd(), "cas_types")
# import du cas type json
cas_types = os.path.join(path_cas_types, "irqf_01_salarie_prive_non_cadre.json")
with open(cas_types) as f:
    temp = json.load(f)
    if temp.get("sliders"):
        del temp["sliders"]
    del temp["description"], temp["linked_variables"], temp["title"]
# niveaux de revenu des donnees

# x_min_enpss = 0
# x_max_enpss = 10
# x_pas_enpss = 0.25
pss_mensuel = 3864

x_min = 0
x_max = 200000 / 12
x_pas = 1000 / 12
x_range = np.arange(x_min, x_max * 12 + 12, x_pas * 12)

# for x in x_range:
#     print (x)

x_range_length = len(x_range)
# print(x_range_length)
# donnees (annuel)

# duplication du cas type en fonction du revenus
list_cas_types = list()

# boucle revenus
for i in x_range:
    temp2 = copy.deepcopy(temp)
    temp2["individus"]["Adulte 1"]["salaire_de_base"] = {
        "2023": i,
        "2024": i,
        "2025": i,
    }
    list_cas_types.append(temp2)
data_frame = pd.DataFrame(list_cas_types)

# cas-type
cas_type = dict(
    data_frame.iloc[2][["familles", "foyers_fiscaux", "individus", "menages"]]
)

# pour chaque niveau de revenus on calcul le détail
donnees = pd.DataFrame()
for i in range(x_range_length):
    cas_type = dict(
        data_frame.iloc[i][["familles", "foyers_fiscaux", "individus", "menages"]]
    )
    simulation = SimulationBuilder()
    simulation = simulation.build_from_entities(leximpact_tbs, cas_type)
    indiv = dict()
    for variable in [
        "nb_adult",
        "nbptr",
        "revenu_assimile_salaire_apres_abattements",
        "rni",
        "ir_ss_qf",
        "taux_effectif",
        "ir_brut",
        "ir_plaf_qf",
        "decote",
        "ip_net",
    ]:
        indiv[variable] = [simulation.calculate(variable, "2025")[0]]
    donnees = pd.concat([donnees, pd.DataFrame(indiv)], axis=0)
donnees["salaire_de_base"] = x_range
col = donnees.pop("salaire_de_base")
donnees.insert(0, col.name, col)

donnees.to_excel("output.xlsx")
# traitement des donnees

# cotisations en positif
donnees = donnees.abs()

# colonne pss
donnees["pss"] = donnees.loc[:, "salaire_de_base"] / pss_mensuel

# supprimer les colonnes avec uniquement zéros
donnees = donnees.loc[:, (donnees**2).sum() != 0]
# total des cotisations sariales et employeur
# salariales
donnees["total_cotis_salarie"] = 0
for j in np.arange(0, len(donnees.columns) - 1, 1):
    if "salarie" in donnees.columns[j]:
        donnees["total_cotis_salarie"] = (
            donnees["total_cotis_salarie"] + donnees[donnees.columns[j]]
        )

# employeur
donnees["total_cotis_employeur"] = 0
for j in np.arange(0, len(donnees.columns) - 1, 1):
    if "employeur" in donnees.columns[j]:
        donnees["total_cotis_employeur"] = (
            donnees["total_cotis_employeur"] + donnees[donnees.columns[j]]
        )
# donnees["total_cotis_employeur"] = donnees["total_cotis_employeur"] + donnees["ags"]

# total
donnees["total_cotis"] = (
    donnees["total_cotis_salarie"] + donnees["total_cotis_employeur"]
)
# graph
# colonnes hors de la boucle add.trace
col_graph_loop = donnees.columns.difference(
    [
        "salaire_de_base",
        "pss",
        "total_cotis_salarie",
        "total_cotis_employeur",
        "total_cotis",
    ],
    sort=False,
)

fig = go.Figure(
    layout=go.Layout(
        template="plotly_white",
    )
)
fig.add_trace(
    go.Scatter(
        x=donnees.pss,
        y=donnees.total_cotis,
        line=dict(color="black", dash="dashdot"),
        name="total_cotis",
    )
)
fig.add_trace(
    go.Scatter(
        x=donnees.pss,
        y=donnees.total_cotis_salarie,
        line=dict(color="blue", dash="dashdot"),
        name="total_cotis_salarie",
    )
)
fig.add_trace(
    go.Scatter(
        x=donnees.pss,
        y=donnees.total_cotis_employeur,
        line=dict(color="red", dash="dashdot"),
        name="total_cotis_employeur",
    )
)

# colonne dans la boucle add.trace
print(col_graph_loop)
for col in col_graph_loop:
    # print(col)
    fig.add_trace(go.Scatter(x=donnees.pss, y=donnees[col], mode="lines", name=col))
fig.update_layout(
    showlegend=True,
    title={
        "text": "Cotisations (mensuelles) : Salarié privé non cadre",
        "y": 0.9,  # new
        "x": 0.5,
        "xanchor": "center",
        "yanchor": "top",  # new
    },
    xaxis_title="PSS",
)
fig.show()
Index(['nb_adult', 'nbptr', 'revenu_assimile_salaire_apres_abattements', 'rni',
       'ir_ss_qf', 'ir_brut', 'ir_plaf_qf', 'decote', 'ip_net'],
      dtype='object')
# export output
donnees.to_excel("output.xlsx")
# arbre de parametres
# simulation.tax_benefit_system.parameters.children.keys

# for k in simulation.tax_benefit_system.parameters.cotsoc.cotisations_salarie.children.keys():
#   print(k)
# retrouver les parametres
# simulation.tax_benefit_system.parameters.chomage.allocations_assurance_chomage.afd
# export
# donnees.to_excel("output.xlsx")