// Functions implementation for the main agent, see src/services/goapi/internal/goapi/chat/functions.json

import { DesignOptimizationRequest } from 'app/api/custom_types/optimizations';
import {
  GptFunction,
  ToolCompleteCallback,
} from 'app/chat/responseStreamParser';
import { useAppSelector } from 'app/hooks';
import { ModelState } from 'app/modelState/ModelState';
import { OutputLogLevel } from 'app/slices/simResultsSlice';
import { Agent } from 'app/third_party_types/chat-types';
import React from 'react';
import { scanNodesAndSignals } from 'ui/modelEditor/optimizations/optimizerModalUtils';
import { useSimworkerJobRunner } from '../../../../app/SimworkerJobRunner';
import { useChatContext } from '../ChatContextProvider';
import { usePythonExecutor, usePythonModelCallback } from '../PythonHooks';
import { useImageUpload } from '../useImageUpload';
import { EditorMode } from '../useModelEditorInfo';
import { usePythonToJsonConverter } from '../usePythonToJsonConverter';
import { useSearchBlocks } from './SearchBlocks';

// code assumes `df` contains sim results and must set `results` variable
const executePythonCode = (code: string) => `
import pandas as pd
import numpy as np
import control as ct

${code}

results = str(results)
results = results.split('\\n')
if len(results) > 100:
    results = '\\n'.join(results[:100]) + '\\n... (truncated)'
else:
    results = '\\n'.join(results)
`;

// code assumes `df` contains sim results and must set `results` variable
const plotCode = (code: string) => `
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
from io import BytesIO
import base64
import control as ct

${code}

buf = BytesIO()
plt.savefig(buf, format='png')
plt.close('all')
buf.seek(0)
results = base64.b64encode(buf.read()).decode()
`;

// For single sims we create a df of the signals
// For ensembles we create an `outputs` (dict) of dfs
const withSimResults = (code: string) => `
import io

import numpy as np
import pandas as pd

if "continuous_results.csv" in inputs["last_sim_results"]:
    csv_file = inputs["last_sim_results"]["continuous_results.csv"]
    df = pd.read_csv(io.BytesIO(csv_file))
else:
    df = pd.DataFrame()
    outputs = {}
    for key, value in inputs["last_sim_results"].items():
        np_file = np.load(io.BytesIO(value))
        if isinstance(np_file, np.ndarray):
            df[key] = pd.Series(np_file)
        else:
            outputs[key] = {k: pd.Series(v) for k, v in np_file.items()}

if inputs["last_optim_results"]:
    optim_df = pd.DataFrame({name: value for name, value in inputs["last_optim_results"].items()})
    df = pd.concat([df, optim_df], axis=1)

${code}
`;

const makeCostPythonScriptBlock = (
  name: string,
  code: string,
  integrate: boolean,
  modelState: ModelState,
) => {
  const { signalOptions } = scanNodesAndSignals(modelState);
  const signals = signalOptions.filter(
    (signal) => signal.locationData.parentPath.length === 0,
  );
  // port names can't have "." in them
  const inputPorts = signals.map((signal) =>
    signal.signalPathName.replace(/\./g, '_'),
  );
  const inputPortsStr = inputPorts.join('", "');
  const inputPortNamesToVar = inputPorts
    .map((port, i) => `"${port}": "${signals[i].signalPathName}"`)
    .join(', ');
  const unpackPorts = `
import numpy as np
import jax.numpy as jnp
class _BlockWrapper:
  def __init__(self, port_name, port_value):
      self.__dict__[port_name] = port_value
port_names_to_var = {${inputPortNamesToVar}}
for real_port_name, wrapper_port in port_names_to_var.items():
    block_name, port_name = wrapper_port.split(".")
    b = _BlockWrapper(port_name, locals()[real_port_name])
    exec(f"{block_name} = b")
`;

  const pythonScriptCode = JSON.stringify(`${unpackPorts}\n${code}`);

  const createPythonScriptBlock = `
${name} = core.PythonScript(
user_statements=${pythonScriptCode},
finalize_script="",
input_names=["${inputPortsStr}"],
output_names=["cost"],
time_mode="agnostic",
accelerate_with_jax=True,
)
`;

  const addLinks = signals
    .map(
      (signal) =>
        `add_link(${
          signal.signalPathName
        }, ${name}.${signal.signalPathName.replace(/\./g, '_')})`,
    )
    .join('\n');

  if (!integrate) {
    return `${createPythonScriptBlock}\n${addLinks}`;
  }

  const integratorBlock = `${name}_integrator = core.Integrator(); add_link(${name}.cost, ${name}_integrator.in_0)`;
  return `${createPythonScriptBlock}\n${addLinks}\n${integratorBlock}`;
};

interface executePythonWithSimResultsArgs {
  code: string;
}

interface executePythonWithSimResultsResult {
  result?: string;
  error?: string;
}

export const useExecutePythonWithSimResults = () => {
  const { lastSimResults, lastOptimResults } = useSimworkerJobRunner();
  const executePython = usePythonExecutor();
  const { optimizationEnabled } = useAppSelector(
    (state) => state.userOptions.options,
  );

  // This function expects the code to set `results` variable
  const executePythonWithSimResults = React.useCallback(
    async ({
      code,
    }: executePythonWithSimResultsArgs): Promise<executePythonWithSimResultsResult> => {
      if (
        (!lastSimResults?.current || lastSimResults?.current?.length <= 0) &&
        (!lastOptimResults?.current ||
          lastOptimResults?.current?.metrics.length <= 0)
      ) {
        if (optimizationEnabled) {
          return {
            error:
              'No results from simulation or optimization. Please run a simulation or optimization first.',
          };
        }
        return {
          error: 'No results from simulation. Please run a simulation first.',
        };
      }

      const signalsMap = lastSimResults?.current?.reduce(
        (map, signal) => map.set(signal.name, signal.rawFile),
        new Map<string, ArrayBuffer>(),
      );

      type MetricValue = number[] | number;
      const metricsMap = lastOptimResults?.current?.metrics.reduce(
        (map, metric) => map.set(metric.name, metric.value),
        new Map<string, MetricValue>(),
      );

      return executePython({
        code: withSimResults(code),
        inputs: {
          last_sim_results: signalsMap,
          last_optim_results: metricsMap || new Map<string, ArrayBuffer>(),
        },
        returnVariableNames: ['results'],
      }).then(({ results, error }) => {
        if (error) {
          return { error };
        }
        const resultsMap = results as Map<string, unknown>;
        return { result: resultsMap.get('results') as string };
      });
    },
    [executePython, lastOptimResults, lastSimResults, optimizationEnabled],
  );

  return executePythonWithSimResults;
};

export const useMainFunctions = (): { [key: string]: GptFunction } => {
  const acausalModelingEnabled = useAppSelector(
    (state) => state.userOptions.options.acausalModelingEnabled,
  );
  const { getCurrentModelInPython } = usePythonModelCallback();

  const executePythonWithSimResults = useExecutePythonWithSimResults();

  const { search_blocks } = useSearchBlocks(acausalModelingEnabled);

  const { addPlot, setAgent } = useChatContext();

  const { uploadImageB64 } = useImageUpload();
  const plotCallback = React.useCallback(
    async (args: any, onComplete: ToolCompleteCallback) => {
      const { result, error: execError } = await executePythonWithSimResults({
        code: plotCode(args.code),
      });

      if (execError) {
        onComplete({ error: execError });
        return;
      }
      if (!result) {
        onComplete({ error: 'No result returned' });
        return;
      }
      const plotId = addPlot(result);
      let imageUrl;
      let error;
      try {
        imageUrl = await uploadImageB64(`plot-${plotId}.png`, result);
      } catch (e) {
        error = 'Failed to upload image';
      }

      // HACK: localhost image are not accessible from the cloud
      if (imageUrl?.includes('localhost')) {
        imageUrl = undefined;
        if (process.env.NODE_ENV === 'development') {
          // eslint-disable-next-line no-console
          console.debug('Localhost image detected, not uploading');
        }
      }

      onComplete({ result: `[[plot_id:${plotId}]]`, imageUrl, error });
    },
    [addPlot, executePythonWithSimResults, uploadImageB64],
  );

  const executePythonCallback = React.useCallback(
    async (args: any, onComplete: ToolCompleteCallback) => {
      const { code, has_plot } = args;
      if (has_plot) {
        plotCallback(args, onComplete);
        return;
      }
      const { result, error } = await executePythonWithSimResults({
        code: executePythonCode(code),
      });
      onComplete({ result, error });
    },
    [executePythonWithSimResults, plotCallback],
  );

  const getUserModel = React.useCallback(
    async (
      args: any,
      onComplete: ToolCompleteCallback,
      groupBlockName?: string,
    ) => {
      let pyModel;
      try {
        pyModel = await getCurrentModelInPython({
          output_groups: false,
          output_submodels: false,
          group_block_name: groupBlockName,
        });
      } catch (e) {
        if (process.env.NODE_ENV === 'development') {
          console.error(e);
        }
      }
      if (!pyModel) {
        onComplete({
          error: 'Could not retrieve current model',
        });
        return;
      }
      onComplete({
        result: `${pyModel.pythonStr}\n\n${pyModel.stdout}\n${pyModel.stderr}`,
      });
    },
    [getCurrentModelInPython],
  );

  const getUserModelCallback = React.useCallback(
    async (args: any, onComplete: ToolCompleteCallback) => {
      getUserModel(args, onComplete);
    },
    [getUserModel],
  );

  const getGroupCallback = React.useCallback(
    async (args: any, onComplete: ToolCompleteCallback) => {
      getUserModel(args, onComplete, args.block_name);
    },
    [getUserModel],
  );

  const askModelBuilderCallback = React.useCallback(
    async (
      args: any,
      onComplete: ToolCompleteCallback,
      abortSignal?: AbortSignal,
    ) => {
      const { request } = args;
      setAgent(
        acausalModelingEnabled
          ? Agent.ModelBuilderWithAcausal
          : Agent.ModelBuilder,
      );
      onComplete({ result: request });
    },
    [acausalModelingEnabled, setAgent],
  );

  const askCustomBlockBuilderCallback = React.useCallback(
    async (args: any, onComplete: ToolCompleteCallback) => {
      setAgent(Agent.CustomBlockBuilder);
      onComplete({ result: '' });
    },
    [setAgent],
  );

  const { runSimulation, runEnsembleSim, runOptimization } =
    useSimworkerJobRunner();

  const runSimulationCallback = React.useCallback(
    async (
      args: any,
      onComplete: ToolCompleteCallback,
      abortSignal?: AbortSignal,
    ) => {
      if (args.signals && !args.signals?.includes('time')) {
        args.signals.push('time');
      }
      await runSimulation(
        args.signals ?? ['continuous_results.csv'],
        ({ results, logs, error }) => {
          const logsWithoutTimestamps = logs?.map((l) => ({
            level: l.level,
            message: l.message,
          }));
          if (error) {
            const response = {
              status: 'failed',
              logs: logsWithoutTimestamps,
            };
            onComplete({
              error: JSON.stringify(response),
            });
            return;
          }
          let dfColumns = results?.map((s) => s.name);

          let continuousResults = results?.find(
            (s) => s.name === 'continuous_results.csv',
          );
          if (continuousResults) {
            const fileAsText = new TextDecoder().decode(
              continuousResults.rawFile,
            );
            dfColumns = fileAsText.split('\n')[0].split(',');
          }

          const response = {
            status: 'succeeded',
            df_columns: dfColumns,
          };
          onComplete({
            result: JSON.stringify(response),
          });
        },
        abortSignal,
      );
    },
    [runSimulation],
  );

  const runEnsembleSimCallback = React.useCallback(
    async (
      args: any,
      onComplete: ToolCompleteCallback,
      abortSignal?: AbortSignal,
    ) => {
      const sweepStrategy =
        args.sweep_strategy === 'grid' ? 'all_combinations' : 'monte_carlo';
      await runEnsembleSim(
        sweepStrategy,
        args.parameters,
        args.signals,
        ({ results, error }) => {
          if (error) {
            onComplete({ error: `Simulation failed: ${error}` });
            return;
          }
          const signals = results
            ?.map((f) => f.name)
            .filter((n) => args.signals.includes(n));
          executePythonWithSimResults({
            code: `results = list(outputs["${signals?.[0]}"].keys())`,
          }).then(({ result: runIds, error: execError }) => {
            const response = {
              signals,
              run_ids: runIds,
            };
            onComplete({
              error: execError,
              result: JSON.stringify(response),
            });
          });
        },
        abortSignal,
        true,
        args.num_runs,
      );
    },
    [executePythonWithSimResults, runEnsembleSim],
  );

  const convertPythonToJson = usePythonToJsonConverter();
  const modelState: ModelState = useAppSelector((state) => state.model.present);

  const compileCostFunction = React.useCallback(
    async (runningCost: string, terminalCost: string) => {
      const runningCostBlock = makeCostPythonScriptBlock(
        'running_cost',
        runningCost,
        true,
        modelState,
      );
      const terminalCostBlock = makeCostPythonScriptBlock(
        'terminal_cost',
        terminalCost,
        false,
        modelState,
      );
      const addCosts = `final_cost = core.Adder(operation="in_0+in_1"); add_link(running_cost_integrator.out_0, final_cost.in_0); add_link(terminal_cost.cost, final_cost.in_1);`;
      const code = `${runningCostBlock}\n${terminalCostBlock}\n${addCosts}`;
      const { jsonModel, stderr } = await convertPythonToJson(
        code,
        false,
        EditorMode.Model,
      );

      return jsonModel;
    },
    [modelState, convertPythonToJson],
  );

  const runOptimizationCallback = React.useCallback(
    async (
      args: any,
      onComplete: ToolCompleteCallback,
      abortSignal?: AbortSignal,
    ) => {
      const { algorithm, terminal_cost, running_cost, optimizable_parameters } =
        args;
      let modelJson;
      try {
        modelJson = await compileCostFunction(
          running_cost || 'cost = 0',
          terminal_cost || 'cost = 0',
        );
      } catch (e) {
        onComplete({ error: `${e}` });
        return;
      }
      if (process.env.NODE_ENV === 'development') {
        // eslint-disable-next-line no-console
        console.debug('optim model:', modelJson);
      }
      const request: DesignOptimizationRequest = {
        type: 'design',
        algorithm: algorithm?.name || 'adam',
        options: algorithm?.parameters || {},
        objective: 'final_cost.out_0',
        // constraints: filterConstraints(constraints),
        design_parameters: optimizable_parameters,
        stochastic_parameters: [],
        json_model_with_cost: modelJson,
      };
      runOptimization(
        request,
        ({ optimalParameters, metrics, error, logs }) => {
          const logsWithoutTimestamps = logs
            ?.filter(
              (l) =>
                l.level &&
                [OutputLogLevel.ERROR, OutputLogLevel.WARNING].includes(
                  l.level,
                ),
            )
            .map((l) => ({
              level: l.level,
              message: l.message,
            }));
          if (error) {
            const response = {
              status: 'failed',
              logs: logsWithoutTimestamps,
            };
            onComplete({ error: JSON.stringify(response) });
            return;
          }
          const optimalParams = JSON.stringify({ optimalParameters });
          const metricNames = metrics?.map((m) => m.name);
          const response = {
            status: 'succeeded',
            optimal_parameters: optimalParams,
            metrics: metricNames,
          };
          onComplete({ result: JSON.stringify(response) });
        },
        abortSignal,
      );
    },
    [compileCostFunction, runOptimization],
  );

  return {
    run_simulation: runSimulationCallback,
    run_ensemble_simulation: runEnsembleSimCallback,
    run_optimization: runOptimizationCallback,
    execute_python: executePythonCallback,
    get_user_model: getUserModelCallback,
    get_group: getGroupCallback,
    search_blocks,
    ask_model_builder: askModelBuilderCallback,
    ask_custom_block_builder: askCustomBlockBuilderCallback,
  };
};
