import { FetchBaseQueryError } from '@reduxjs/toolkit/dist/query';
import {
  generatedApi,
  usePostJobCreateMutation,
  usePostSimulationCreateMutation,
} from 'app/apiGenerated/generatedApi';
import {
  DownsamplingAlgorithm,
  JobSummary,
  SimulationCreateRequest,
  SimulationResultsS3Url,
} from 'app/apiGenerated/generatedApiTypes';
import { useAppDispatch, useAppSelector } from 'app/hooks';
import React, { useRef } from 'react';
import { parseSimulationLogs } from 'ui/modelEditor/utils';
import { useAppParams } from 'util/useAppParams';
import { v4 as uuid } from 'uuid';
import { OptimizationRequest } from './api/custom_types/optimizations';
import { SIGNAL_TYPE_DATA_FILE_NAME } from './api/useSignalTypeData';
import {
  OptimalParameterJson,
  OptimizationMetricJson,
  OptimizationResultsJson,
  SignalTypes,
} from './generated_types/collimator/dashboard/serialization/ui_types.gen';
import { projectActions } from './slices/projectSlice';
import { ModelLogLine } from './slices/simResultsSlice';

export interface SimworkerJobRunnerContext {
  runSimulation: (
    signals: string[],
    onComplete: OnRunSimulationComplete,
    abortSignal?: AbortSignal,
  ) => void;
  runEnsembleSim: (
    sweepStrategy: 'all_combinations' | 'monte_carlo',
    parameters: any,
    signals: string[],
    onComplete: OnRunSimulationComplete,
    abortSignal?: AbortSignal,
    saveSignalsToNpz?: boolean,
    numRuns?: number,
  ) => void;
  runModelCheck: (
    onComplete: OnRunCheckComplete,
    abortSignal?: AbortSignal,
  ) => void;
  runOptimization: (
    request: OptimizationRequest,
    onComplete: OnRunOptimizationComplete,
    abortSignal?: AbortSignal,
    onJobUpdated?: OnJobUpdated,
  ) => void;
  runGenerateModelPythonCode: (onComplete: OnRunGenerateCodeComplete) => void;
  lastSimResults?: React.MutableRefObject<SignalFile[]>;
  lastOptimResults?: React.MutableRefObject<
    OptimizationResultsJson | undefined
  >;
  isRunning: boolean;
}

type OnRunGenerateCodeComplete = (input: {
  code?: string;
  error?: string;
}) => void;

type OnRunCheckComplete = (input: { logs?: string; error?: string }) => void;

type OnRunSimulationComplete = (input: {
  results?: SignalFile[];
  logs?: ModelLogLine[];
  error?: string;
}) => void;

export type OnRunOptimizationCompleteArgs = {
  optimalParameters?: OptimalParameterJson[];
  metrics?: OptimizationMetricJson[];
  error?: string;
  logs?: ModelLogLine[];
};

type OnRunOptimizationComplete = (input: OnRunOptimizationCompleteArgs) => void;
type OnJobUpdated = (summary: JobSummary) => void;

type SignalFile = {
  name: string;
  rawFile: ArrayBuffer;
};

const singletonContext = React.createContext<SimworkerJobRunnerContext>({
  runSimulation: () => {},
  runEnsembleSim: () => {},
  runOptimization: () => {},
  runModelCheck: () => {},
  runGenerateModelPythonCode: () => {},
  isRunning: false,
});

type FetchFilesArgs = {
  files?: SimulationResultsS3Url[];
};

type FetchFilesResponse = {
  raw: SignalFile[];
};

const getFiles = async (args: FetchFilesArgs) => {
  const { files } = args;
  if (!files) return { data: { raw: [] } };
  try {
    let promises = files.map((f) => {
      if (f.url) {
        return fetch(f.url).then((data) => ({ name: f.name, data }));
      }
      return Promise.reject(
        new Error(`File URL not found for file ${f.name}: ${f.url}`),
      );
    });
    const responses = await Promise.all(promises);
    const texts = await Promise.all(
      responses.map(async (resp: { name: string; data: Response }) => {
        const rawFile = await resp.data.arrayBuffer();
        return { name: resp.name || '', rawFile };
      }),
    );

    return {
      data: {
        raw: texts,
      },
    };
  } catch (error: any) {
    return { error: error as FetchBaseQueryError };
  }
};

const injectedApi = generatedApi.injectEndpoints({
  endpoints: (build) => ({
    getFiles: build.query<FetchFilesResponse, FetchFilesArgs>({
      queryFn: getFiles,
    }),
  }),
});

const { useLazyGetFilesQuery } = injectedApi;

function poll<T>(
  fn: () => Promise<T>,
  validate: (result: T) => boolean,
  abortSignal?: AbortSignal,
  interval = 1000,
  maxAttempts = 900,
): Promise<T> {
  let attempts = 0;

  const executePoll = (
    resolve: (value: T | PromiseLike<T>) => void,
    reject: (reason?: any) => void,
  ) => {
    fn()
      .then((result) => {
        attempts++;
        if (validate(result)) {
          resolve(result);
        } else if (attempts >= maxAttempts) {
          reject(new Error('Exceeded max attempts'));
        } else if (abortSignal?.aborted) {
          reject(new Error('Aborted'));
        } else {
          setTimeout(executePoll, interval, resolve, reject);
        }
      })
      .catch((err) => reject(err));
  };

  return new Promise(executePoll);
}

const useSimworkerJobRunnerContext = () => {
  const lastSimResults = useRef<SignalFile[]>([]);
  const lastOptimResults = useRef<OptimizationResultsJson>();
  const [isRunning, setIsRunning] = React.useState(false);

  const dispatch = useAppDispatch();
  const { modelId } = useAppParams();

  const worker_type = useAppSelector(
    (state) => state.model.present.configuration.worker_type,
  );
  const modelUuid = useAppSelector(
    (state) => state.modelMetadata.loadedModelId,
  );
  const [callPostSimulationCreateApi] = usePostSimulationCreateMutation();
  const [getJobSummary] = generatedApi.endpoints.getJobSummary.useLazyQuery();
  const [getSignalsS3Urls] =
    generatedApi.endpoints.getSimulationProcessResultsReadByUuid.useLazyQuery();
  const [triggerGetFiles] = useLazyGetFilesQuery();
  const [getSimulationLogs] =
    generatedApi.endpoints.getSimulationLogsReadByUuid.useLazyQuery();
  const [getSimulationLogFile] =
    generatedApi.endpoints.getSimulationLogFileReadByUuid.useLazyQuery();

  const pollSummary = React.useCallback(
    async (
      jobUuid: string,
      abortSignal?: AbortSignal,
      onJobUpdated?: OnJobUpdated,
    ) => {
      const finalSummary = await poll(
        () => getJobSummary({ jobUuid }),
        ({ data: summary }) => {
          if (!summary) return false;
          if (onJobUpdated) onJobUpdated(summary as JobSummary);
          if (
            summary.status === 'completed' ||
            summary.status === 'failed' ||
            summary.status === 'cancelled'
          ) {
            return true;
          }
          return false;
        },
        abortSignal,
      );
      return finalSummary.data as JobSummary;
    },
    [getJobSummary],
  );

  const downloadSignals = React.useCallback(
    async (
      signalNames: string,
      summary: JobSummary,
      downsamplingAlgorithm: DownsamplingAlgorithm,
    ) => {
      const s3Urls = await getSignalsS3Urls({
        modelUuid,
        downsamplingAlgorithm,
        simulationUuid: summary.uuid,
        signalNames,
      }).unwrap();
      if (s3Urls.error) {
        return { error: s3Urls.error };
      }

      const files = await triggerGetFiles({
        files: s3Urls.s3_urls,
      }).unwrap();
      return { files };
    },
    [getSignalsS3Urls, modelUuid, triggerGetFiles],
  );

  const getLogs = React.useCallback(
    async (modelUuid: string, simulationUuid: string) => {
      const logsRaw: string = await getSimulationLogs({
        modelUuid,
        simulationUuid,
      })
        .unwrap()
        .catch((error) => {
          if (process.env.NODE_ENV === 'development') {
            console.error('Failed to get simulation logs:', error);
          }
          return '';
        });
      return parseSimulationLogs(logsRaw ?? '');
    },
    [getSimulationLogs],
  );

  const getSignalTypes = React.useCallback(
    async (modelUuid: string, simulationUuid: string) => {
      const blob: unknown = await getSimulationLogFile({
        modelUuid,
        simulationUuid,
        simulationJsonLogFile: SIGNAL_TYPE_DATA_FILE_NAME,
      })
        .unwrap()
        .catch((error) => {
          if (process.env.NODE_ENV === 'development') {
            console.error('Failed to get signal types:', error);
          }
          return null;
        });
      if (!blob) {
        return { nodes: [] } as SignalTypes;
      }
      return blob as SignalTypes;
    },
    [getSimulationLogFile],
  );

  const runSimulation = React.useCallback(
    async (
      simulationCreateRequest: SimulationCreateRequest,
      signals: string[],
      downsamplingAlgorithm: DownsamplingAlgorithm,
      onComplete: OnRunSimulationComplete,
      abortSignal?: AbortSignal,
    ) => {
      try {
        // Clear up simulationSummary for app-wide tracking. This will make
        // signals color overlays and visualizer refresh properly after completion.
        const correlationId = uuid();
        dispatch(projectActions.runModelRequestSent({ correlationId }));

        // Do not switch to Visualizer tab after done here. We can't restore it
        // at the end either because it tends to switch when setting the flag.
        dispatch(projectActions.setAutoSwitchBottomTabOnSimDone(false));

        const summary = await callPostSimulationCreateApi({
          modelUuid,
          simulationCreateRequest,
          'X-Correlation-ID': correlationId,
        }).unwrap();

        const finalSummary = await pollSummary(
          summary.uuid,
          abortSignal,
          (summary) =>
            dispatch(projectActions.simulationSummaryUpdated(summary)),
        );
        const [logs, signalTypes] = await Promise.all([
          getLogs(modelUuid, finalSummary.uuid),
          getSignalTypes(modelUuid, finalSummary.uuid),
        ]);

        if (finalSummary.status === 'failed') {
          onComplete({ error: finalSummary.fail_reason, logs });
          return;
        }

        const { files, error } = await downloadSignals(
          signals.join(','),
          finalSummary,
          downsamplingAlgorithm,
        );
        if (error) {
          onComplete({ error: `Failed to download signals: ${error}` });
          return;
        }
        lastSimResults.current = files?.raw || [];
        onComplete({
          results: lastSimResults.current,
          logs,
        });
      } catch (error) {
        if (process.env.NODE_ENV === 'development') {
          console.error('Simulation failed:', error);
        }
        onComplete({ error: `Failed to run simulation.` });
      }
    },
    [
      callPostSimulationCreateApi,
      modelUuid,
      pollSummary,
      getLogs,
      getSignalTypes,
      downloadSignals,
      dispatch,
    ],
  );

  const runSingleSimulation = React.useCallback(
    async (
      signals: string[],
      onComplete: OnRunSimulationComplete,
      abortSignal?: AbortSignal,
    ) =>
      runSimulation(
        {
          compile_only: false,
          ignore_cache: true,
          worker_type,
          target: 'visualizer',
        },
        signals,
        'none', // FIXME: chat does not support downsampled signals yet
        onComplete,
        abortSignal,
      ),
    [runSimulation, worker_type],
  );

  const runEnsembleSim = React.useCallback(
    async (
      sweepStrategy: 'all_combinations' | 'monte_carlo',
      parameters: any, // FIXME typing
      signals: string[],
      onComplete: OnRunSimulationComplete,
      abortSignal?: AbortSignal,
      saveSignalsToNpz = false,
      numRuns?: number,
    ) =>
      runSimulation(
        {
          compile_only: false,
          ignore_cache: true,
          model_overrides: {
            recorded_signals: {
              signal_ids: signals,
            },
            ensemble_config: {
              sweep_strategy: sweepStrategy,
              model_parameter_sweeps: parameters.map((p: any) => ({
                parameter_name: p.name,
                sweep_expression: p.values,
              })),
              num_sims: numRuns,
            },
          },
          target: 'ensemble',
          save_npz: saveSignalsToNpz,
        },
        signals.concat(['time']),
        'none', // TODO: implement downsampling for npz files
        onComplete,
        abortSignal,
      ),
    [runSimulation],
  );

  const [callPostJobCreateApi] = usePostJobCreateMutation();
  const runOptimization = React.useCallback(
    async (
      request: OptimizationRequest,
      onComplete: OnRunOptimizationComplete,
      abortSignal?: AbortSignal,
      onJobUpdated?: (summary: JobSummary) => void,
    ) => {
      try {
        setIsRunning(true);
        const summary = await callPostJobCreateApi({
          jobCreateRequest: {
            model_id: modelUuid,
            request,
            kind: 'Optimization',
          },
        }).unwrap();
        if (onJobUpdated) onJobUpdated(summary);
        const finalSummary = await pollSummary(
          summary.uuid,
          abortSignal,
          onJobUpdated,
        );
        const logs = await getLogs(modelUuid, finalSummary.uuid);
        const results = finalSummary.results as OptimizationResultsJson;
        lastOptimResults.current = results;
        onComplete({
          optimalParameters: results?.optimal_parameters,
          metrics: results?.metrics,
          error: finalSummary.fail_reason,
          logs,
        });
      } catch (error) {
        if (process.env.NODE_ENV === 'development') {
          console.error('Optimization failed:', error);
        }
        onComplete({
          error: `Failed to run optimization.`,
        });
      } finally {
        setIsRunning(false);
      }
    },
    [callPostJobCreateApi, getLogs, modelUuid, pollSummary, setIsRunning],
  );

  const runModelCheck = React.useCallback(
    async (onComplete: OnRunCheckComplete, abortSignal?: AbortSignal) => {
      try {
        const summary = await callPostSimulationCreateApi({
          modelUuid,
          simulationCreateRequest: {
            compile_only: true,
            ignore_cache: true,
            worker_type,
            target: 'visualizer',
          },
        }).unwrap();

        const finalSummary = await pollSummary(summary.uuid, abortSignal);
        const { data: simulationLogsRaw, error } = await dispatch(
          generatedApi.endpoints.getSimulationLogsReadByUuid.initiate({
            modelUuid: modelId || '',
            simulationUuid: finalSummary.uuid,
          }),
        );
        if (error) {
          onComplete({
            error: `Failed to retrieve simulation logs: ${
              (error as FetchBaseQueryError).data
            }`,
          });
        } else {
          onComplete({ logs: simulationLogsRaw || '' });
        }
      } catch (error) {
        onComplete({
          error: `Failed to run model check.`,
        });
      }
    },
    [
      callPostSimulationCreateApi,
      dispatch,
      modelId,
      modelUuid,
      pollSummary,
      worker_type,
    ],
  );

  const runGenerateModelPythonCode = React.useCallback(
    async (onComplete: OnRunGenerateCodeComplete) => {
      try {
        const summary = await callPostJobCreateApi({
          jobCreateRequest: {
            model_id: modelUuid,
            kind: 'GenerateModelPythonCode',
          },
        }).unwrap();
        const finalSummary = await pollSummary(summary.uuid);
        const results = finalSummary.results as { code: string };
        onComplete({ code: results?.code, error: finalSummary.fail_reason });
      } catch (error) {
        onComplete({ error: `Failed to generate model python code.` });
      }
    },
    [callPostJobCreateApi, modelUuid, pollSummary],
  );

  return {
    runSimulation: runSingleSimulation,
    runEnsembleSim,
    runOptimization,
    runModelCheck,
    runGenerateModelPythonCode,
    lastSimResults,
    lastOptimResults,
    isRunning,
  };
};

export const SimworkerJobRunnerProvider = ({
  children,
}: {
  children: React.ReactNode;
}) => {
  const context = useSimworkerJobRunnerContext();
  return (
    <singletonContext.Provider value={context}>
      {children}
    </singletonContext.Provider>
  );
};

export const useSimworkerJobRunner = () => {
  const context = React.useContext(singletonContext);
  if (!context) {
    throw new Error(
      'useSimworkerJobRunner must be used within a SimworkerJobRunnerProvider',
    );
  }
  return context;
};
