import { ChatMessageEdit } from 'ui/appBottomBar/assistant/ChatContextProvider';
import {
  PythonModel,
  PythonRenderConfig,
} from 'ui/appBottomBar/assistant/PythonHooks';
import { WebSocketContext } from 'ui/common/WebSocketProvider';
import {
  Agent,
  ChatCompleteRequestPayload,
  ChatCompleteStartStreamPayload,
  ChatMessage,
  FinishReason,
} from '../third_party_types/chat-types';
import { WebSocketMessageType } from '../third_party_types/websocket/websocket-message-type';
import { OpenAiModels } from './models';
import { GptFunction, ResponseStreamParser } from './responseStreamParser';

const CHAT_RESPONSE_TIMEOUT_MS = 5 * 60 * 1000;

interface callCompletionParams {
  messages: ChatMessage[];
  websocket: WebSocketContext;
  temperature?: number;
  abortSignal: AbortSignal;
  modelId?: string;
  functions?: { [k: string]: GptFunction };
  systemPrompt?: string;
  tools?: string;
  seed?: number;
  onNewMessage: (message: ChatMessage) => void;
  onChunkReceived: (message: ChatMessageEdit, index?: number) => void;
  agentRef?: React.MutableRefObject<Agent>;
  getCurrentModelInPython: (
    pyRenderConfig: PythonRenderConfig,
  ) => Promise<PythonModel>;
}

const cleanupMessages = (messages: ChatMessage[]) => {
  const newMessages = messages.map((m, i) => {
    const content = m.content.filter((c) => c.text || c.image_url || c.error);
    return {
      ...m,
      content,
    };
  });
  return newMessages;
};

const subscribeToStream = (
  websocket: WebSocketContext,
  responseParser: ResponseStreamParser,
) => {
  websocket.subscribe(
    WebSocketMessageType.INTERNAL_ERROR,
    `response_parser_error`,
    responseParser.onInternalError,
    true,
  );
  websocket.subscribe(
    WebSocketMessageType.CHAT_COMPLETE_START_STREAM,
    'response_parser',
    (data) => {
      const payload = data.payload as ChatCompleteStartStreamPayload;
      const streamUuid = payload.streamUuid;
      responseParser.start(payload.streamUuid);
      websocket.subscribe(
        WebSocketMessageType.CHAT_COMPLETE_RESPONSE,
        `response_parser_${streamUuid}`,
        responseParser.onChatResponseReceived,
      );

      websocket.subscribe(
        WebSocketMessageType.CHAT_COMPLETE_ERROR,
        `response_parser_${streamUuid}`,
        responseParser.onError,
      );

      websocket.unsubscribe(
        WebSocketMessageType.CHAT_COMPLETE_START_STREAM,
        'response_parser',
      );
    },
  );
};

const unsubscribeFromStream = (
  websocket: WebSocketContext,
  responseParser: ResponseStreamParser,
) => {
  websocket.unsubscribe(
    WebSocketMessageType.CHAT_COMPLETE_RESPONSE,
    `response_parser_${responseParser.streamUuid}`,
  );
  websocket.unsubscribe(
    WebSocketMessageType.CHAT_COMPLETE_ERROR,
    `response_parser_${responseParser.streamUuid}`,
  );
  websocket.unsubscribe(
    WebSocketMessageType.INTERNAL_ERROR,
    `response_parser_${responseParser.streamUuid}`,
  );
};

type CallCompletionResult = {
  error?: any;
};

export const callCompletion = async ({
  messages,
  websocket,
  abortSignal,
  temperature = 0.1,
  modelId,
  functions,
  systemPrompt,
  tools,
  seed,
  onNewMessage,
  onChunkReceived,
  agentRef,
  getCurrentModelInPython,
}: callCompletionParams): Promise<CallCompletionResult> => {
  let error = '';
  let responseParser;
  try {
    const aiModel = OpenAiModels.find((m) => m.id === modelId);
    if (!aiModel) throw new Error(`Unknown model ID: ${modelId}`);

    if (modelId === 'none') {
      return {};
    }

    const conversation = [...messages];

    responseParser = new ResponseStreamParser({
      onNewMessage: (message) => {
        message.agentId = agentRef?.current;
        onNewMessage(message);
        conversation.push(message);
      },
      onChunkReceived: (edit, index) => {
        onChunkReceived(edit, index);
        const msgIdx = index ?? conversation.length - 1;
        conversation[msgIdx] = {
          ...conversation[msgIdx],
          ...edit,
        };
      },
      functions,
      abortSignal,
    });

    subscribeToStream(websocket, responseParser);

    const cleanConversation = cleanupMessages(conversation);

    const pythonModel = await getCurrentModelInPython({
      output_groups: false,
      output_submodels: false,
    }).catch((e) => {
      console.error(e.message);
    });
    const diagram = pythonModel
      ? `${pythonModel.pythonStr}\n\n${pythonModel.stdout}\n\n${pythonModel.stderr}`
      : '';
    const payload: ChatCompleteRequestPayload = {
      messages: cleanConversation,
      temperature,
      seed,
      genAiModel: aiModel.id,
      systemPrompt,
      tools,
      agentId: agentRef?.current,
      diagram,
    };
    websocket.publish({
      id: '', // FIXME: set the conversation ID + turn index.
      type: WebSocketMessageType.CHAT_COMPLETE_REQUEST,
      payload,
    });

    const startTime = Date.now();
    while (!responseParser.isDone()) {
      // eslint-disable-next-line no-await-in-loop
      await new Promise((resolve) => {
        setTimeout(resolve, 1000);
      });

      if (abortSignal && abortSignal.aborted) {
        responseParser.abort();
      }

      if (Date.now() - startTime > CHAT_RESPONSE_TIMEOUT_MS) {
        responseParser.abort();
        error = FinishReason.Timeout;
      }

      // tool calls should be processed sequentially because gpt often tries to
      // call run_simulation and plot at the same time.
      while (
        responseParser.toolCallQueue.length > 0 &&
        !responseParser.isDone()
      ) {
        if (abortSignal && abortSignal.aborted) {
          responseParser.abort();
        }
        const item = responseParser.toolCallQueue.shift();
        if (!item) {
          break;
        }
        // eslint-disable-next-line no-await-in-loop
        await item();
      }
    }

    unsubscribeFromStream(websocket, responseParser);

    if (responseParser.finishReason === FinishReason.ToolCalls) {
      return await callCompletion({
        messages: conversation,
        websocket,
        abortSignal,
        temperature,
        modelId,
        systemPrompt,
        functions,
        tools,
        seed,
        onNewMessage,
        onChunkReceived,
        agentRef,
        getCurrentModelInPython,
      });
    }
    error = responseParser.error || '';
  } catch (e: any) {
    error = e.message;
  } finally {
    if (responseParser) {
      unsubscribeFromStream(websocket, responseParser);
    }
  }
  return { error };
};
