# %%


from matplotlib import pyplot as plt
import pandas as pd
from sklearn.preprocessing import Normalizer
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error as mse

"""
    This is an example script to show the basic princples of the data.
    It trains a gradient boosting regression model and compares the model to a baseline on the test data.
    It also contains basic plots of the data.


    The script assumes that the extracted "processed" photovoltaic and wind data is within the same folder.
    Otherwise adapt the paths below.
    Through the IS_WIND flag you change between the wind and photovoltaic data.

    Author: Jens Schreiber, University of Kassel
    License: Creative Commons Attribution 4.0
"""

# change according to datatype
IS_WIND = False

path_to_file = "./DAF_ICON_Synthetic_PV_Power_processed/00164.h5"
baseline_name = "PVLibBaseline"
most_important_column = "ASWDIRS_SFC_0_M_INSTANT"

if IS_WIND:
    most_important_column = "WindSpeed60m"
    baseline_name = "EnerconBaseline"
    path_to_file = "./DAF_ICON_Synthetic_Wind_Power_processed/00011.h5"
# %%
# contains weather + historical power data
df = pd.read_hdf(path_to_file, key="powerdata")
df.TestFlag = df.TestFlag.apply(bool)
# contains such as the location
df_meta = pd.read_hdf(path_to_file, key="metadata")
# contains baseline models
df_baseline = pd.read_hdf(path_to_file, key="baseline")

df_train = df[~df.TestFlag]
df_test = df[df.TestFlag]

# %%
y_column = "PowerGeneration"
x_columns = [c for c in df_train.columns if c not in ["TestFlag", y_column]]

# %%
scaler = Normalizer()
X_train = scaler.fit_transform(df_train[x_columns])
X_test = scaler.transform(df_test[x_columns])

model = GradientBoostingRegressor()
model = model.fit(X_train, df_train[y_column].ravel())

preds = model.predict(X_test)
# %%
# calc nRMSE
error = mse(df_test[y_column].ravel(), preds) ** 0.5
print(f"GBRT has an nRMSE of {error:0.04f}")

test_mask_baseline = df_baseline.index.isin(df_test.index)
df_baseline_test = df_baseline[test_mask_baseline]
preds_baseline = df_baseline_test[baseline_name].ravel()
error = mse(df_test[y_column].ravel(), preds_baseline) ** 0.5
print(f"{baseline_name} has an nRMSE of {error:0.04f}")

# %%
plt.figure(figsize=(16, 9))
plt.scatter(
    df_test[most_important_column], df_test.PowerGeneration, label="PowerGeneration"
)
plt.scatter(df_test[most_important_column], preds, label="Preds GBRT", alpha=0.5)
plt.scatter(
    df_test[most_important_column],
    preds_baseline,
    label="Preds Baseline",
    alpha=0.5,
)
plt.xlabel(most_important_column)
plt.ylabel("Target/Target Prediction")
plt.legend()
plt.show()
# %%
plt.figure(figsize=(16, 9))
plt.plot(df_test["PowerGeneration"].values, label="PowerGeneration")
plt.plot(preds, label="Preds GBRT")
plt.plot(preds_baseline, label="Preds Baseline")
plt.legend()
# just show the first few samples
plt.xlim((0, 500))
plt.xlabel("Timestep")
plt.ylabel("Target/Target Prediction")
plt.show()
# %%
