import StickyBox from "@/components/utils/sticky-box";
import {
  CustomModelEditorTab,
  CustomModelSetPromptEditorStateEventHandler,
  customModelTrainingBackendToDisplayName,
  CustomModelTrainingBackendType,
  CustomModelTrainingEditorStatus,
  CustomModelTrainingInputGenericHardwareSpecific,
  CustomModelTrainingItem,
  CustomModelTrainingStatus,
  EditorAssetContentType,
  FrontendDisplayTemplateType,
  getCustomModelTypeFromFrontendDisplayTemplateType,
  getCustomModelWorkflowFromCustomModelInfo,
  getTrainingDisplayName,
  isCustomModelTrainingStatusActive,
  isHighQualityTrainingBackendType,
  resetCustomModelEditorState,
  UiDisplayMessageEventHandler,
} from "@/core/common/types";
import { Assets } from "@/core/controllers/assets";
import { classNames } from "@/core/utils/classname-utils";
import { debugError, debugLog } from "@/core/utils/print-utilts";
import { capitalizeFirstLetter } from "@/core/utils/string-utils";
import { getTimeDifferenceInSeconds, sortByTimeModified } from "@/core/utils/time-utils";
import * as Dialog from "@radix-ui/react-dialog";
import { ChevronLeftIcon, ChevronRightIcon } from "@radix-ui/react-icons";
import * as Tabs from "@radix-ui/react-tabs";
import {
  DropdownClassName,
  PrimaryButtonClassName,
  PrimaryButtonClassNameDisabled,
  SecondaryButtonClassNameInactive,
} from "components/constants/class-names";
import { SimpleSpinner } from "components/icons/simple-spinner";
import { ManageSubscriptionDialogProvider } from "components/popup/message-dialog/manage-subscription-dialog";
import { ScrollAreaContainer } from "components/scroll-area/scroll-area";
import { ProgressHandler } from "components/utils/progress-handler";
import { editorContextStore } from "contexts/editor-context";
import { formatDistanceToNow } from "date-fns";
import { useCustomModelWorkflowSliderConfigUpdateEffect } from "hooks/use-custom-model-workflow-slider-config-update";
import {
  useCanUserEditCustomModel,
  useCustomModelsEffect,
  usePublicCustomModelsEffect,
} from "hooks/use-custom-models-effect";
import { DragAndDropOverlay } from "hooks/use-file-drop";
import { clamp, noop } from "lodash";
import React from "react";
import { useParams, useSearchParams } from "react-router-dom";
import { Chatbot } from "../chatbot";
import { CustomModelEditorHeaderZIndex, FloatTagZIndex } from "../constants/zIndex";
import { CustomModelEditorProvider, useCustomModelEditor } from "./custom-model-editor-context";
import { getCustomModelPlaygroundPromptEditorStateFromSerializedEditorState } from "./custom-model-mention-plugin";
import {
  CustomModelPlaygroundProvider,
  getCustomModelPlaygroundEditorStateFromPredictionInput,
  getCustomModelPlaygroundPromptEditorStateFromTrainingId,
  useCustomModelPlayground,
} from "./custom-model-playground-context";
import { GenerateOutput, GenerateSettings } from "./custom-model-playground-editor";
import { CustomModelStartTrainEditor } from "./custom-model-start-train-editor";
import {
  CustomModelTrainingInputProvider,
  getDurationSecondsFromTrainingSteps,
} from "./custom-model-training-context";
import styles from "./custom-model.module.css";
import { CustomModelDatasetGridItem, DeleteDataAlert } from "./dataset-grid-item";
import { CreateCustomModelNavbar } from "./navbar";
import { UploadButton } from "./upload-button";
import { handleUploadImage, UploadedImageData, uploadModelDatasetImage } from "./upload-image";

import { CircleOff, Pause, PlayIcon } from "lucide-react";
import {
  OnboardingDialog,
  OnboardingDialogProvider,
} from "../popup/message-dialog/onboarding-dialog";
import { InfoBox } from "../utils/info-box";
import { Tooltip } from "../utils/tooltip";
import { customModelWorkflowData } from "./custom-model-workflows-data";

const triggerClassName = classNames(
  "w-fit flex flex-row items-center justify-center gap-2",
  styles.TabsTrigger,
  "group px-6 py-2 md:min-w-[4rem] xl:min-w-[8rem] transition-colors text-center truncate",
);

/* @tw */
const triggerIndexClassName =
  "w-[22px] h-[22px] text-xs flex flex-row items-center justify-center text-center rounded bg-zinc-800 text-zinc-500 group-hover:text-lime-500 group-hover:bg-zinc-800/50 transition-colors";

/* @tw */
const editorHeadrButtonClassName =
  "md:min-w-[128px] min-h-[40px] flex flex-row items-center justify-center rounded-full px-4 py-2 font-semibold gap-2 cursor-pointer transition-colors";

/* @tw */
const trainingButtonClassName =
  "flex flex-row items-center gap-2 max-w-full w-fit p-2 text-sm rounded-lg truncate border transition-colors";

const datasetGridClassName = "w-full grid grid-cols-2 xl:grid-cols-3 2xl:grid-cols-4 gap-2";

function DatasetGrid() {
  const workflow = editorContextStore((state) => state.customModelWorkflow);
  const modelId = editorContextStore((state) => state.customModelId);
  const dataset = editorContextStore((state) => state.customModelDataset);
  const dataIdToDeleteRef = React.useRef<string | undefined>();
  const [isDeleteAlertOpen, setDeleteAlertOpen] = React.useState(false);
  const canUserEdit = useCanUserEditCustomModel();

  if (!modelId || !dataset) {
    return null;
  }

  const deleteDataEntry = (dataId: string) => {
    const backend = editorContextStore.getState()?.backend;

    if (!dataId || !backend) {
      return Promise.resolve();
    }

    return backend
      .deleteCustomModelDataItem({
        modelId,
        dataId,
      })
      .then(() => {
        editorContextStore.getState().setCustomModelDataset((dataset) => {
          if (!dataset?.[dataId]) {
            return dataset;
          }
          const newDataset = { ...dataset };
          delete newDataset[dataId];
          return newDataset;
        });
        dataIdToDeleteRef.current = undefined;
      });
  };

  return (
    <div className="w-full flex flex-col">
      <div className={datasetGridClassName}>
        {canUserEdit && (
          <UploadButton
            modelId={modelId}
            workflow={customModelWorkflowData[workflow]}
            onDelete={(dataId) => {
              deleteDataEntry(dataId);
            }}
          />
        )}
        {Object.entries(dataset).map(([key, item]) => (
          <CustomModelDatasetGridItem
            id={key}
            key={key}
            item={item}
            modelId={modelId}
            onDelete={() => {
              setDeleteAlertOpen(true);
              dataIdToDeleteRef.current = key;
            }}
          />
        ))}

        <DeleteDataAlert
          open={isDeleteAlertOpen}
          onOpenChange={setDeleteAlertOpen}
          onDelete={() => {
            const dataId = dataIdToDeleteRef.current;
            if (dataId) {
              deleteDataEntry(dataId);
            }
          }}
        />
      </div>
    </div>
  );
}

const TrainingItemToClassNames = {
  [CustomModelTrainingStatus.Starting]: "bg-lime-900 border-lime-700 text-lime-500",
  [CustomModelTrainingStatus.Processing]: "bg-lime-900 border-lime-700 text-lime-500",
  [CustomModelTrainingStatus.Succeeded]: "bg-lime-900/20 border-lime-900 text-lime-500",
  [CustomModelTrainingStatus.Canceled]: "bg-zinc-800/50 border-zinc-800 text-zinc-500",
  [CustomModelTrainingStatus.Failed]: "bg-red-900/10 border-red-900/60 text-red-500",
};

function TrainingItemStatus({ training }: { training: CustomModelTrainingItem }) {
  return (
    <div className="relative flex flex-row justify-end">
      <div
        className={classNames(
          trainingButtonClassName,
          TrainingItemToClassNames[training.status] ||
            TrainingItemToClassNames[CustomModelTrainingStatus.Canceled],
        )}
      >
        {isCustomModelTrainingStatusActive(training.status) && (
          <SimpleSpinner width={18} height={18} pathClassName="fill-lime-500" />
        )}
        <span className="flex-1 truncate max-w-28">{capitalizeFirstLetter(training.status)}</span>
      </div>
    </div>
  );
}

function StopTrainingButton({
  setStatus,
  training,
  modelId,
}: {
  setStatus: (value: CustomModelTrainingEditorStatus) => void;
  training: CustomModelTrainingItem;
  modelId: string;
}) {
  const backend = editorContextStore((state) => state.backend);
  const [isLoading, setIsLoading] = React.useState(false);

  return (
    <div className="w-full flex flex-col gap-4">
      <div className="flex flex-row items-center gap-2">
        <button
          className={classNames(
            trainingButtonClassName,
            isLoading
              ? "bg-red-900/10 border-red-900/60 text-red-900 cursor-wait"
              : "bg-red-900/20 hover:bg-red-900/50 border-red-900/60 text-red-500 cursor-pointer",
            "flex-1 gap-1",
          )}
          onClick={() => {
            if (isLoading) {
              return;
            }

            const trainingId = training?.id;

            if (!backend || !trainingId) {
              return;
            }

            const { eventEmitter } = editorContextStore.getState();

            setIsLoading(true);

            backend
              .stopCustomModelTraining({
                trainingId,
                modelId,
              })
              .then((response) => {
                if (!response.ok) {
                  eventEmitter.emit<UiDisplayMessageEventHandler>(
                    "ui:display-message",
                    "error",
                    `Error stopping training: ${response.message}`,
                  );
                }
              })
              .finally(() => {
                setIsLoading(false);
              });
          }}
        >
          {isLoading ? (
            <>
              <SimpleSpinner width={18} height={18} pathClassName="fill-red-900" />
              <span>Stopping ...</span>
            </>
          ) : (
            <>
              <Pause size={18} />
              <span>Stop</span>
            </>
          )}
        </button>
      </div>
    </div>
  );
}

function getTrainingMessage({ training }: { training: CustomModelTrainingItem }) {
  try {
    const time = training.timeModified || training.timeCreated;
    const input = training.input;

    const totalTrainingSteps = isHighQualityTrainingBackendType(input.backendType)
      ? (input as CustomModelTrainingInputGenericHardwareSpecific).steps
      : 100;

    if (!time) {
      return "";
    }

    if (isCustomModelTrainingStatusActive(training.status) && totalTrainingSteps) {
      const expectedDuration = getDurationSecondsFromTrainingSteps(totalTrainingSteps) * 1000; // convert to milliseconds
      const currentDuration = Date.now() - training.timeCreated.toDate().getTime();
      const timeLeftMilliseconds = Math.max(expectedDuration - currentDuration, 0);
      const timeLeftMinutes = formatDistanceToNow(new Date(Date.now() + timeLeftMilliseconds));
      return `${timeLeftMinutes} left`;
    } else {
      const timeAgo = formatDistanceToNow(time.toDate(), {
        addSuffix: true,
      });
      return `${timeAgo}`;
    }
  } catch (error) {
    debugError("Error getting training message: ", error);

    return "";
  }
}

function getExpectedDurationSecondsFromTraining(training: CustomModelTrainingItem) {
  const input = training.input;

  if (input.backendType === CustomModelTrainingBackendType.Fal) {
    return getDurationSecondsFromTrainingSteps((input.iter_multiplier ?? 1) * 100);
  } else if (isHighQualityTrainingBackendType(input.backendType)) {
    return getDurationSecondsFromTrainingSteps(input.steps);
  }

  return getDurationSecondsFromTrainingSteps(1000);
}

function getProgressFromTraining(training: CustomModelTrainingItem) {
  if (typeof training?.progress === "number" && training.progress > 0) {
    return Math.max(training.progress * 0.01, 0.01);
  }

  const timestamp = training.timeModified || training.timeCreated;

  const elapsedSeconds = getTimeDifferenceInSeconds({
    startTime: timestamp.toDate(),
  });

  const expectedDurationSeconds = getExpectedDurationSecondsFromTraining(training);

  const progress = clamp(elapsedSeconds / expectedDurationSeconds, 0.01, 0.99);

  return progress;
}

function TestTrainingButton({
  training,
  modelId,
}: {
  modelId: string;
  training: CustomModelTrainingItem;
}) {
  const { setTab } = useCustomModelEditor();

  const { setApiState } = useCustomModelPlayground();

  return (
    <Tooltip
      triggerProps={{
        asChild: true,
      }}
      triggerChildren={
        <button
          className={classNames(
            training.status === CustomModelTrainingStatus.Succeeded
              ? PrimaryButtonClassName
              : "opacity-0",
            "flex flex-row items-center gap-1 max-w-full w-fit px-2 py-1 text-sm rounded truncate border transition-colors",
            trainingButtonClassName,
            "gap-1",
          )}
          onClick={(e) => {
            if (training.status !== CustomModelTrainingStatus.Succeeded) {
              return;
            }

            e.preventDefault();

            const { customModelInfo, customModelWorkflow } = editorContextStore.getState();

            const promptEditorState = getCustomModelPlaygroundPromptEditorStateFromTrainingId({
              customModelType:
                customModelInfo?.customModelType ??
                getCustomModelTypeFromFrontendDisplayTemplateType(customModelWorkflow),
              trainingId: training.id,
              trainingDisplayName: training.displayName ?? "",
              caption: training.caption ?? "",
              modelId,
              modelDisplayName: customModelInfo?.displayName ?? "",
            });

            setApiState((apiState) => ({
              ...apiState,
              promptEditorState,
            }));

            setTab("generate");
          }}
        >
          <PlayIcon size={18} />
          <span>Test</span>
        </button>
      }
      contentChildren={
        <span className="text-xs">Test this training in Generate prompt editor.</span>
      }
    />
  );
}

function TrainingItem({
  modelId,
  training,
}: {
  modelId: string;
  training: CustomModelTrainingItem;
}) {
  const backend = editorContextStore((state) => state.backend);
  const [status, setStatus] = React.useState<CustomModelTrainingEditorStatus>(
    CustomModelTrainingEditorStatus.Default,
  );

  const trainingMessage = React.useMemo(
    () =>
      getTrainingMessage({
        training,
      }),
    [training],
  );

  const trainingId = React.useMemo(() => training.id, [training.id]);

  const trainingInputType = training.input.backendType;

  const trainingInputModelType =
    training.input?.trainingStrengthPercent >= 0.5 ? "Product" : "Style";
  const trainingInputModelQuality =
    training.input?.backendType === CustomModelTrainingBackendType.GenericHardware
      ? customModelTrainingBackendToDisplayName[CustomModelTrainingBackendType.GenericHardware]
      : customModelTrainingBackendToDisplayName[CustomModelTrainingBackendType.Fal];

  const [progressInternal, setProgressInternal] = React.useState(getProgressFromTraining(training));

  const progressHandlerRef = React.useRef<ProgressHandler>(
    new ProgressHandler({
      speed: trainingInputType === CustomModelTrainingBackendType.Fal ? 5e-4 : 1e-5,
      setProgress: setProgressInternal,
    }),
  );

  const setProgressFromTraining = React.useCallback((training: CustomModelTrainingItem) => {
    const progress = getProgressFromTraining(training);

    progressHandlerRef.current.setProgress(progress);
  }, []);

  React.useEffect(() => {
    debugLog(`Training ${training.id} updated`);
    setProgressFromTraining(training);
  }, [training, setProgressFromTraining]);

  React.useEffect(() => {
    if (!backend || !trainingId) {
      return;
    }

    return backend.onCustomModelTrainingUpdate({
      modelId,
      trainingId,
      callback: (newTraining) => {
        const { setCustomModelTrainings } = editorContextStore.getState();

        setCustomModelTrainings((prevTrainings) => {
          return {
            ...prevTrainings,
            [trainingId]: newTraining,
          };
        });

        debugLog(`Training ${newTraining.id} snapshot`);

        setProgressFromTraining(newTraining);
      },
    });
  }, [backend, modelId, trainingId, setProgressFromTraining]);

  const isActive = React.useMemo(
    () => isCustomModelTrainingStatusActive(training.status),
    [training.status],
  );

  return (
    <div
      className={classNames(
        SecondaryButtonClassNameInactive,
        "relative max-w-full flex flex-row items-center justify-start gap-2 p-2 rounded-lg overflow-hidden text-zinc-300 active:border-lime-700",
      )}
    >
      <div className="flex-1 flex flex-col gap-2">
        <div className="text-sm font-semibold">
          <div className="flex flex-col lg:flex-row items-start lg:items-center gap-2 text-nowrap">
            <span>{getTrainingDisplayName(training)}</span>
            <div className="min-w-0 max-w-20 xl:max-w-28 text-sm font-normal text-zinc-500 truncate">
              {trainingMessage}
            </div>
          </div>
        </div>
        <div className={classNames("flex flex-row gap-1")}>
          <div className="min-w-0 w-fit text-zinc-500 truncate border border-zinc-700 rounded-full px-2 py-0.5 text-xs">
            {trainingInputModelType}
          </div>
          <div className="min-w-0 w-fit text-zinc-500 truncate border border-zinc-700 rounded-full px-2 py-0.5 text-xs">
            {trainingInputModelQuality}
          </div>
        </div>
      </div>
      <div className="flex flex-col pb-2 xl:pb-0 xl:flex-row gap-2 justify-end">
        {isActive ? (
          <StopTrainingButton setStatus={setStatus} training={training} modelId={modelId} />
        ) : (
          <TestTrainingButton modelId={modelId} training={training} />
        )}
        <TrainingItemStatus training={training} />
      </div>
      <div
        className={classNames(
          "absolute left-0 bottom-0 w-full h-px rounded-b-full bg-zinc-800",
          isActive ? "" : "hidden",
        )}
      >
        <div
          className={`${styles.TransitionWidth} rounded-full h-full bg-lime-500 animate-[pulse_3s_ease-in-out_infinite] shadow-[0_0_8px] shadow-lime-500/50`}
          style={{
            width: `${progressInternal * 100}%`,
          }}
        />
      </div>
    </div>
  );
}

function TrainingHistoryEmpty() {
  return (
    <div className="px-3 py-2.5 flex flex-col border border-zinc-800 rounded-lg text-sm text-zinc-700 gap-2">
      <div className="flex flex-row items-center gap-2">
        <CircleOff size={16} />
        <span>No training yet.</span>
      </div>
      <div>Click "Start Training" button above to create a new training.</div>
    </div>
  );
}

function TrainingHistoryDialog({
  open,
  onOpenChange,
  modelId,
  trainings,
}: {
  open: boolean;
  onOpenChange: (open: boolean) => void;
  modelId: string;
  trainings: Record<string, CustomModelTrainingItem>;
}) {
  return (
    <Dialog.Root open={open} onOpenChange={onOpenChange}>
      <Dialog.Portal>
        <Dialog.Overlay
          className={styles.DialogOverlay}
          style={{
            zIndex: FloatTagZIndex,
          }}
        />
        <Dialog.Content
          className={classNames(
            styles.DialogContent,
            DropdownClassName,
            "py-2 w-[90vw] md:w-[50vw] lg:w-[600px] rounded-xl text-sm",
          )}
          style={{
            zIndex: FloatTagZIndex,
          }}
        >
          <div className="pb-4">
            <Dialog.Title className="text-lg font-semibold text-zinc-300">
              Training History
            </Dialog.Title>
            <Dialog.Close className="absolute top-4 right-4 text-zinc-400 hover:text-zinc-200">
              ✕
            </Dialog.Close>
          </div>
          <div className="flex flex-col gap-2 max-h-[calc(100vh-10rem)] overflow-y-scroll">
            {Object.values(trainings)
              .sort(sortByTimeModified)
              .map((training) => (
                <TrainingItem key={training.id} modelId={modelId} training={training} />
              ))}
          </div>
        </Dialog.Content>
      </Dialog.Portal>
    </Dialog.Root>
  );
}

function TrainingConfig() {
  const modelId = editorContextStore((state) => state.customModelId);
  const trainings = editorContextStore((state) => state.customModelTrainings);
  const [isHistoryDialogOpen, setHistoryDialogOpen] = React.useState(false);

  const hasActiveTrainings = React.useMemo(() => {
    return (
      Object.values(trainings).filter(
        (training) =>
          training.status === CustomModelTrainingStatus.Starting ||
          training.status === CustomModelTrainingStatus.Processing,
      ).length > 0
    );
  }, [trainings]);

  if (!modelId) {
    return null;
  }

  return (
    <div className="relative flex flex-col gap-8">
      <CustomModelStartTrainEditor
        disabledStart={hasActiveTrainings}
        isDialog={false}
        modelId={modelId}
        onExit={noop}
      />
      <div className="flex flex-col gap-2">
        <div className="flex flex-row justify-between items-center">
          <div className="text-sm font-semibold">Training History</div>
        </div>
        {Object.values(trainings).length > 0 ? (
          <div className="flex flex-col gap-2">
            {Object.values(trainings)
              .sort(sortByTimeModified)
              .map((training) => (
                <TrainingItem key={training.id} modelId={modelId} training={training} />
              ))}
          </div>
        ) : (
          <TrainingHistoryEmpty />
        )}
      </div>
      <TrainingHistoryDialog
        open={isHistoryDialogOpen}
        onOpenChange={setHistoryDialogOpen}
        modelId={modelId}
        trainings={trainings}
      />
    </div>
  );
}

function Data() {
  const canUserEdit = useCanUserEditCustomModel();
  const modelId = editorContextStore((state) => state.customModelId);
  const dataset = editorContextStore((state) => state.customModelDataset);
  const workflow = editorContextStore(
    (state) => state.customModelWorkflow || FrontendDisplayTemplateType.Custom,
  );
  const uploadFile = React.useCallback(
    ({ data, contentType }: { data: File | Blob; contentType: EditorAssetContentType }) => {
      if (!modelId || !dataset) {
        return Promise.resolve(undefined);
      }
      return uploadModelDatasetImage({
        data,
        modelId,
        contentType,
      });
    },
    [modelId, dataset],
  );

  debugLog("Worfklow: ", workflow);

  return (
    <div className="flex flex-col gap-4">
      <InfoBox title="Upload Dataset Images" className="w-full" defaultValue="info">
        <div className="flex flex-col md:flex-row gap-4">
          <ol className="max-w-[400px] list-decimal list-outside pl-5 flex flex-col gap-2 text-lime-200 text-sm leading-relaxed">
            {customModelWorkflowData[workflow]?.moodboardInfoboxText ? (
              Array.isArray(customModelWorkflowData[workflow]?.moodboardInfoboxText) ? (
                (customModelWorkflowData[workflow]?.moodboardInfoboxText as string[]).map(
                  (textLine, index) => (
                    <li key={index} className="">
                      {textLine}
                    </li>
                  ),
                )
              ) : (
                <li className="">{customModelWorkflowData[workflow]?.moodboardInfoboxText}</li>
              )
            ) : (
              <li className="">Upload at least 3 photos to help your model get better results</li>
            )}
            <li className="">Image size must be at least 384×384</li>
          </ol>
          <div>
            {customModelWorkflowData[workflow]?.moodboardExampleUrls && (
              <div className="flex flex-1 flex-row gap-2">
                {customModelWorkflowData[workflow]?.moodboardExampleUrls?.map(
                  (exampleUrl, index) => {
                    return (
                      <img
                        key={index}
                        src={exampleUrl}
                        alt="example"
                        className="h-[12vh] object-cover rounded-md"
                      />
                    );
                  },
                )}
              </div>
            )}
          </div>
        </div>
      </InfoBox>
      <DragAndDropOverlay
        overlayClassName="absolute top-0 left-0 w-full h-full"
        containerClassName="relative"
        handleDropFiles={(files, e) => {
          const filesData: UploadedImageData[] = [];
          for (let i = 0; i < files.length; ++i) {
            const file = files[i];
            if (!file) {
              continue;
            }
            const data = handleUploadImage({
              image: file,
              onError: () => {},
              uploadFile,
            });
            if (!data) {
              return;
            }
            filesData.push(data);
          }
        }}
        disabled={!canUserEdit || !modelId || !dataset}
      >
        <DatasetGrid />
        <div className="w-full h-[30vh]" />
      </DragAndDropOverlay>
    </div>
  );
}

function TwoPanelEditorLeftPanel({
  className,
  children,
  ...props
}: React.PropsWithChildren & React.HTMLAttributes<HTMLDivElement>) {
  return <div className="flex min-h-0 w-full md:w-[40vw] xl:w-[512px]">{children}</div>;
}

function TwoPanelEditorContainer({
  className,
  children,
  ...props
}: React.PropsWithChildren & React.HTMLAttributes<HTMLDivElement>) {
  return (
    <div {...props} className={classNames("min-h-0 flex flex-col md:flex-row", className)}>
      {children}
    </div>
  );
}

function TwoPanelEditor({
  className,
  leftPanelChildren,
  leftPanelProps = {},
  rightPanelChildren,
  rightPanelProps = {},
  smallScreenLeftPanelChildrenContainerClassName = "",
  smallScreenRightPanelChildrenContainerClassName = "",
  ...props
}: React.HTMLAttributes<HTMLDivElement> & {
  leftPanelChildren: React.ReactNode;
  leftPanelProps?: React.HTMLAttributes<HTMLDivElement>;
  rightPanelChildren: React.ReactNode;
  rightPanelProps?: React.HTMLAttributes<HTMLDivElement>;
  smallScreenLeftPanelChildrenContainerClassName?: string;
  smallScreenRightPanelChildrenContainerClassName?: string;
}) {
  return (
    <div {...props} className={classNames("min-h-0 flex flex-col md:flex-row", className)}>
      <div
        {...leftPanelProps}
        className={classNames(
          "flex min-h-0 w-full md:w-[40vw] xl:w-[512px]",
          leftPanelProps.className ?? "",
        )}
      >
        <ScrollAreaContainer
          className="w-full"
          viewportProps={{
            className: "p-4 xl:px-8 xl:py-6 min-h-0 md:border-r md:border-zinc-800",
          }}
        >
          <div className="flex flex-col">
            <div
              className={classNames(
                smallScreenLeftPanelChildrenContainerClassName,
                "block md:hidden mb-16",
              )}
            >
              {rightPanelChildren}
            </div>
            <div className={smallScreenRightPanelChildrenContainerClassName}>
              {leftPanelChildren}
            </div>
          </div>
        </ScrollAreaContainer>
      </div>
      <div
        {...rightPanelProps}
        className={classNames(
          "hidden md:flex flex-col min-h-0 flex-1",
          rightPanelProps.className ?? "",
        )}
      >
        <ScrollAreaContainer
          viewportProps={{
            className: "p-4 xl:px-8 xl:py-6",
          }}
        >
          {rightPanelChildren}
        </ScrollAreaContainer>
      </div>
    </div>
  );
}

function Train() {
  return <TwoPanelEditor leftPanelChildren={<TrainingConfig />} rightPanelChildren={<Data />} />;
}

function Generate() {
  return (
    <TwoPanelEditor
      leftPanelChildren={<GenerateSettings />}
      smallScreenLeftPanelChildrenContainerClassName="order-2 md:order-1"
      rightPanelChildren={<GenerateOutput />}
      smallScreenRightPanelChildrenContainerClassName="order-1 md:order-2"
    />
  );
}

function isCustomModelEditorTab(value: any): value is CustomModelEditorTab {
  return value === "train" || value === "generate";
}

const CustomModelEditorContainer = React.forwardRef(
  (
    {
      className,
      children,
      ...props
    }: React.DetailedHTMLProps<React.HTMLAttributes<HTMLDivElement>, HTMLDivElement>,
    forwardedRef: React.Ref<HTMLDivElement>,
  ) => {
    return (
      <div ref={forwardedRef} className={classNames("w-full flex", className ?? "")} {...props}>
        <div className="w-full 2xl:max-w-[1920px]">{children}</div>
      </div>
    );
  },
);

function CustomModelEditorInner({
  tab: initTab,
  trainingId,
}: {
  tab: CustomModelEditorTab;
  trainingId?: string | null;
}) {
  const { tab, setTab } = useCustomModelEditor();
  const [searchParams, setSearchParams] = useSearchParams();
  const tabsRootRef = React.useRef<HTMLDivElement | null>(null);
  const tabsListRef = React.useRef<HTMLDivElement | null>(null);

  React.useEffect(() => {
    if (!tabsListRef.current) {
      return;
    }
    setTab(initTab);
  }, [setTab, initTab]);

  const { setApiState, setPredictionId, setOutputImages } = useCustomModelPlayground();

  const isFirstLoadPredictionsRef = React.useRef(true);
  const customModelInfo = editorContextStore((state) => state.customModelInfo);
  const customModelTrainings = editorContextStore((state) => state.customModelTrainings);
  const customModelPredictions = editorContextStore((state) => state.customModelPredictions);

  React.useEffect(() => {
    const { backend, eventEmitter, customModelWorkflow } = editorContextStore.getState();

    const training =
      (trainingId ? customModelTrainings[trainingId] : undefined) ||
      Object.values(customModelTrainings)
        .sort(sortByTimeModified)
        .find((training) => training.status === CustomModelTrainingStatus.Succeeded);

    const predictionItem = Object.values(customModelPredictions)
      .filter(
        (prediction) =>
          prediction?.output?.length && prediction.status === CustomModelTrainingStatus.Succeeded,
      )
      .sort(sortByTimeModified)?.[0];

    const promptJson = predictionItem?.input?.promptJson;

    const predictionOutputs = predictionItem?.output ?? [];

    if (backend && predictionItem && promptJson && isFirstLoadPredictionsRef.current) {
      // Find the most recent prediction and use that as the default prompt
      setApiState((apiState) => ({
        ...apiState,
        ...getCustomModelPlaygroundEditorStateFromPredictionInput(predictionItem.input),
        promptEditorState: getCustomModelPlaygroundPromptEditorStateFromSerializedEditorState({
          promptEditorState: JSON.parse(promptJson),
        }),
        numImages: Math.max(predictionOutputs.length, 2),
      }));

      eventEmitter.emit<CustomModelSetPromptEditorStateEventHandler>(
        "custom-model:set-prompt-editor-state",
        {
          promptEditorStateJson: promptJson,
        },
      );

      isFirstLoadPredictionsRef.current = false;

      debugLog("Init api state from the most recent prediction:\n", predictionItem);

      Promise.all(
        predictionOutputs.map((path) =>
          Assets.loadAssetFromPath({
            path,
            backend,
          }),
        ),
      ).then((urls) => {
        setOutputImages(urls.filter(Boolean) as string[]);
      });

      setPredictionId(predictionItem.id);
    } else if (training) {
      setApiState((apiState) => {
        if (apiState.promptEditorState.text) {
          return apiState;
        }
        const promptEditorState = getCustomModelPlaygroundPromptEditorStateFromTrainingId({
          customModelType:
            customModelInfo?.customModelType ??
            getCustomModelTypeFromFrontendDisplayTemplateType(customModelWorkflow),
          trainingId: training.id,
          trainingDisplayName: training.displayName ?? "",
          caption: training.caption ?? "",
          modelId: customModelInfo?.id ?? "",
          modelDisplayName: customModelInfo?.displayName ?? "",
        });

        if (promptEditorState.json) {
          eventEmitter.emit<CustomModelSetPromptEditorStateEventHandler>(
            "custom-model:set-prompt-editor-state",
            {
              promptEditorStateJson: promptEditorState.json,
            },
          );
        }

        return {
          ...apiState,
          promptEditorState,
        };
      });

      debugLog("Init api state from the most recent training:\n", training);
    }
  }, [
    trainingId,
    setApiState,
    setOutputImages,
    setPredictionId,
    customModelInfo,
    customModelTrainings,
    customModelPredictions,
  ]);

  useCustomModelWorkflowSliderConfigUpdateEffect();

  const hasSuccessfulTraining = React.useMemo(() => {
    return Object.values(customModelTrainings).some(
      (training) => training.status === CustomModelTrainingStatus.Succeeded,
    );
  }, [customModelTrainings]);

  // Add effect to update URL when tab changes
  React.useEffect(() => {
    if (!isCustomModelEditorTab(tab)) {
      return;
    }

    // Update search params while preserving other params
    const newSearchParams = new URLSearchParams(searchParams);
    newSearchParams.set("tab", tab);
    setSearchParams(newSearchParams, { replace: true });
  }, [tab, searchParams, setSearchParams]);

  return (
    <Tabs.Root
      ref={tabsRootRef}
      value={tab}
      onValueChange={(value) => {
        if (isCustomModelEditorTab(value)) {
          setTab(value);
        }
      }}
      className="w-full min-h-0 flex-1 flex flex-col"
    >
      <StickyBox
        className="w-full flex justify-center bg-zinc-900 border-b border-zinc-800"
        style={{
          zIndex: CustomModelEditorHeaderZIndex,
        }}
      >
        <div className="w-full 2xl:max-w-[1920px]">
          <Tabs.List
            ref={tabsListRef}
            className="w-full text-md font-semibold flex flex-row items-center justify-between"
          >
            <div className="flex flex-row items-center md:pt-8 pt-2">
              <Tabs.Trigger value="train" className={triggerClassName}>
                <span className={triggerIndexClassName}>1</span>
                <span>Model Training</span>
              </Tabs.Trigger>
              <Tabs.Trigger
                value="generate"
                className={classNames(
                  triggerClassName,
                  !hasSuccessfulTraining && "opacity-50 cursor-not-allowed",
                )}
                disabled={!hasSuccessfulTraining}
              >
                <span className={triggerIndexClassName}>2</span>
                <span>Image Generation</span>
              </Tabs.Trigger>
            </div>
            <button
              onClick={() => setTab(tab === "train" ? "generate" : "train")}
              className={classNames(
                tab === "train"
                  ? hasSuccessfulTraining
                    ? PrimaryButtonClassName
                    : PrimaryButtonClassNameDisabled
                  : SecondaryButtonClassNameInactive,
                "mr-4 xl:mr-8 py-2 px-4 rounded-lg hidden md:flex items-center justify-center gap-2",
              )}
              disabled={!hasSuccessfulTraining}
            >
              {tab === "train" ? (
                <div className="flex flex-row items-center gap-2">
                  Next Step: Generate Images
                  <ChevronRightIcon />
                </div>
              ) : (
                <div className="flex flex-row items-center gap-2">
                  <ChevronLeftIcon />
                  Back to Model Training
                </div>
              )}
            </button>
          </Tabs.List>
        </div>
      </StickyBox>
      <CustomModelEditorContainer className="min-h-0 flex flex-row bg-zinc-900 justify-center">
        <div className="flex flex-col min-h-0 max-h-full">
          <Tabs.Content value="train" className="min-h-0 max-h-full flex flex-col">
            <Train />
          </Tabs.Content>
        </div>
        <div className="flex flex-col min-h-0 max-h-full">
          <Tabs.Content value="generate" className="min-h-0 max-h-full flex flex-col">
            <Generate />
          </Tabs.Content>
        </div>
      </CustomModelEditorContainer>
    </Tabs.Root>
  );
}

export function CustomModelEditor() {
  const params = useParams();
  const modelId = params?.modelId;

  const [searchParams] = useSearchParams();

  const tab = searchParams.get("tab");

  const validatedTab = isCustomModelEditorTab(tab) ? tab : "train";

  const trainingId = searchParams.get("trainingId");

  const publicUserId = editorContextStore((state) => state.publicUserId);
  const backend = editorContextStore((state) => state.backend);
  const customModelInfo = editorContextStore((state) => state.customModelInfo);
  const setCustomModelId = editorContextStore((state) => state.setCustomModelId);
  const setCustomModelInfo = editorContextStore((state) => state.setCustomModelInfo);
  const setDataset = editorContextStore((state) => state.setCustomModelDataset);
  const setCustomModelWorkflow = editorContextStore((state) => state.setCustomModelWorkflow);
  const setCustomModelTrainings = editorContextStore((state) => state.setCustomModelTrainings);
  const setCustomModelPredictions = editorContextStore((state) => state.setCustomModelPredictions);

  React.useEffect(() => {
    setCustomModelWorkflow(getCustomModelWorkflowFromCustomModelInfo(customModelInfo));
  }, [customModelInfo, setCustomModelWorkflow]);

  React.useEffect(() => {
    setCustomModelId(modelId);

    if (modelId) {
      backend?.getCustomModelInfo(modelId).then(setCustomModelInfo);
    }

    return () => {
      resetCustomModelEditorState(editorContextStore.getState());
    };
  }, [backend, modelId, setCustomModelId, setCustomModelInfo]);

  React.useEffect(() => {
    if (!modelId) {
      return;
    }

    if (!backend) {
      return;
    }

    return backend.onCustomModelDatasetUpdate(modelId, setDataset);
  }, [backend, modelId, setDataset]);

  React.useEffect(() => {
    if (!modelId || !backend) {
      return;
    }

    const unsubscribeTrainingUpdate = backend.onCustomModelTrainingCollectionUpdate({
      modelId,
      callback: (trainings) => {
        setCustomModelTrainings(
          Object.fromEntries(trainings.map((training) => [training.id, training])),
        );
      },
    });

    return () => {
      unsubscribeTrainingUpdate();

      setCustomModelTrainings({});
    };
  }, [backend, modelId, setCustomModelTrainings]);

  React.useEffect(() => {
    setCustomModelPredictions({});

    if (!modelId || !backend || !publicUserId) {
      return;
    }

    const unsubscribePromise = backend
      .getPublicCustomModelPredictions({
        modelId,
      })
      .then((publicPredictions) => {
        return backend.onCustomModelPredictionsUpdate({
          modelId,
          publicUserId,
          callback: (predictions) => {
            setCustomModelPredictions({
              ...predictions,
              ...publicPredictions,
            });
          },
        });
      });

    return () => {
      setCustomModelPredictions({});

      unsubscribePromise.then((unsubscribe) => {
        unsubscribe();
      });
    };
  }, [backend, modelId, publicUserId, setCustomModelPredictions]);

  useCustomModelsEffect();

  usePublicCustomModelsEffect();

  return (
    <ManageSubscriptionDialogProvider>
      <CustomModelEditorProvider>
        <CustomModelTrainingInputProvider>
          <CustomModelPlaygroundProvider>
            <OnboardingDialogProvider>
              <div className="h-screen bg-zinc-900 text-zinc-100 flex flex-col">
                <CreateCustomModelNavbar />
                {modelId && <CustomModelEditorInner tab={validatedTab} trainingId={trainingId} />}
                <Chatbot />
                <OnboardingDialog />
              </div>
            </OnboardingDialogProvider>
          </CustomModelPlaygroundProvider>
        </CustomModelTrainingInputProvider>
      </CustomModelEditorProvider>
    </ManageSubscriptionDialogProvider>
  );
}
