import {
  DistributionType,
  OptimizationStochasticParam,
} from 'app/api/custom_types/optimizations';
import { useAppSelector } from 'app/hooks';
import React from 'react';
import { Remove } from 'ui/common/Icons/Standard';
import Input from 'ui/common/Input/Input';
import SectionHeading from 'ui/common/Inputs/SectionHeading';
import SelectInput from 'ui/common/SelectInput';
import { BATCH_SIZE, NUM_BATCHES } from './OptimizerAlgoList';
import {
  OptimizerInputLabel,
  ParamContainer,
  ParamInputGroup,
  ParamSelectRow,
  StochasticParamsConfig,
} from './optimizerModalUtils';

const DISTRIBUTION_ARGS = {
  normal: ['mean', 'std_dev'],
  uniform: ['min', 'max'],
  lognormal: ['mean', 'std_dev'],
};

const DISTRIBUTION_LABELS: Record<string, string> = {
  mean: 'Mean',
  std_dev: 'Std. deviation',
  min: 'Minimum',
  max: 'Maximum',
};

const DISTRIBUTION_DEFAULTS: Record<string, string> = {
  mean: '0',
  std_dev: '1',
  min: '0',
  max: '1',
};

const StochasticParamFields = ({
  config,
  onUpdate,
}: {
  config: StochasticParamsConfig;
  onUpdate: (config: StochasticParamsConfig) => void;
}) => {
  const params = config.params;
  const numBatches = config.numBatches ?? '';
  const batchSize = config.batchSize ?? '';

  const distributionOptions = [
    { label: 'Normal', value: 'normal' },
    { label: 'Uniform', value: 'uniform' },
    { label: 'LogNormal', value: 'lognormal' },
  ];

  const modelParameters = useAppSelector(
    (state) => state.model.present.parameters,
  );

  const paramOptions = React.useMemo(
    () =>
      modelParameters.map((param) => ({
        label: param.name,
        value: param.name,
      })),
    [modelParameters],
  );

  const updateParams = (newParams: OptimizationStochasticParam[]) =>
    onUpdate({ ...config, params: newParams });

  const addStochasticParam = () =>
    updateParams([...params, { distribution: 'normal' }]);

  const removeStochasticParam = (removeIndex: number) =>
    updateParams(params.filter((_, i) => i !== removeIndex));

  const setDistribution = (index: number, newVal: DistributionType) =>
    updateParams(
      params.map((param, i) =>
        i === index ? { ...param, distribution: newVal } : param,
      ),
    );

  const selectParamName = (index: number, newVal: string) =>
    updateParams(
      params.map((param, i) =>
        i === index
          ? {
              ...param,
              param_name: newVal,
              mean: modelParameters.find((p) => p.name === newVal)?.value,
            }
          : param,
      ),
    );

  const setDistributionArg = (index: number, arg: string) => (newVal: string) =>
    updateParams(
      params.map((param, i) =>
        i === index ? { ...param, [arg]: newVal } : param,
      ),
    );

  const setNumBatches = (newVal: string) =>
    onUpdate({
      ...config,
      numBatches: newVal,
    });

  const setBatchSize = (newVal: string) =>
    onUpdate({
      ...config,
      batchSize: newVal,
    });

  return (
    <>
      <SectionHeading noBorder onButtonClick={addStochasticParam}>
        Stochastic parameters
      </SectionHeading>
      {params.map((param, i) => (
        <ParamContainer key={`param_${i}`}>
          <ParamSelectRow>
            <SelectInput
              options={paramOptions}
              currentValue={param.param_name}
              onSelectValue={(newVal) => selectParamName(i, newVal)}
              isOptionDisabled={(option) =>
                // TODO: include design parameters in disabled list
                option.value !== param.param_name &&
                config.params.map((p) => p.param_name).includes(option.value)
              }
            />
            <Remove onClick={() => removeStochasticParam(i)} />
          </ParamSelectRow>
          <ParamInputGroup>
            <ParamInputGroup>
              <OptimizerInputLabel>Distribution</OptimizerInputLabel>
              <SelectInput
                options={distributionOptions}
                currentValue={param.distribution}
                onSelectValue={(newVal) =>
                  setDistribution(i, newVal as DistributionType)
                }
              />
            </ParamInputGroup>
            {DISTRIBUTION_ARGS[param.distribution].map((arg) => (
              <ParamInputGroup key={arg}>
                <OptimizerInputLabel>
                  {DISTRIBUTION_LABELS[arg]}
                </OptimizerInputLabel>
                <Input
                  value={(param as Record<string, any>)[arg]}
                  onChangeText={setDistributionArg(i, arg)}
                  placeholder={DISTRIBUTION_DEFAULTS[arg]}
                  defaultValue=""
                  rightIconIsResetButton
                  onClickRightIcon={() =>
                    setDistributionArg(i, arg)(DISTRIBUTION_DEFAULTS[arg])
                  }
                  hasBorder
                />
              </ParamInputGroup>
            ))}
          </ParamInputGroup>
        </ParamContainer>
      ))}
      {params.length > 0 && (
        <ParamInputGroup key="num_batches">
          <OptimizerInputLabel>Number of batches</OptimizerInputLabel>
          <Input
            value={numBatches}
            onChangeText={setNumBatches}
            rightIconIsResetButton
            hasBorder
            placeholder={NUM_BATCHES.default}
          />
          <OptimizerInputLabel>Batch size</OptimizerInputLabel>
          <Input
            value={batchSize}
            onChangeText={setBatchSize}
            rightIconIsResetButton
            hasBorder
            placeholder={BATCH_SIZE.default}
          />
        </ParamInputGroup>
      )}
    </>
  );
};

export default StochasticParamFields;
