import { FirebaseBackend } from "@/backend/firebase/firebase-backend";
import { displayUiMessage } from "@/components/utils/display-message";
import { editorContextStore } from "@/contexts/editor-context";
import {
  RealTimeRenderStartColorCorrectionEventHandler,
  RealTimeRenderStartRenderEventHandler,
} from "@/core/common/types";
import {
  RealTimeRenderMode,
  RealTimeRenderStatus,
  RealTimeServerConfig,
  RealTimeServerId,
  RealTimeServerTier,
  RealTimeUserId,
  isRealTimeColorCorrectionMessage,
  isRealTimeServerConfig,
} from "@/core/common/types/realtime-render";
import { debugError, debugLog } from "@/core/utils/print-utilts";
import { TimeoutCallback } from "@/core/utils/time-utils";
import { decode, encode } from "@msgpack/msgpack";

// configured in cloudflare. Cloudflare A records require https/ wss
// to test a raw ip address i.e. something that DOESN'T require https/wss, make sure to update RealTimeServerUrl, WebsocketProtocol, and HttpProtocol

export const WebsocketProtocol = "wss";
// export const WebsocketProtocol = "ws";

export const HttpProtocol = "https";
// export const HttpProtocol = "http";

const ColorCorrectionServerUrl = import.meta.env.VITE_REALTIME_COLOR_CORRECTION_API_URL;

export const RealTimeServerTimeOutSeconds = import.meta.env.VITE_REALTIME_SERVER_TIMEOUT_SECONDS; // needs to be synced with the StreamDiffusion server's timeout

type TierToServerIds = { [key in RealTimeServerTier]?: RealTimeServerId[] };

type GetStreamDataCallback = () => Promise<any[]>;

function closeWebSocket(ws: WebSocket): Promise<void> {
  return new Promise((resolve, reject) => {
    debugError("Close the websocket");

    // Check if the WebSocket is already closed or closing
    if (ws.readyState === WebSocket.CLOSED || ws.readyState === WebSocket.CLOSING) {
      resolve();
      return;
    }

    // Listen for the 'close' event
    ws.addEventListener(
      "close",
      () => {
        resolve();
      },
      { once: true },
    );

    // Handle errors (optional)
    ws.addEventListener(
      "error",
      (err) => {
        reject(err);
      },
      { once: true },
    );

    // Attempt to close the WebSocket
    ws.close();
  });
}

function encodeBytesChunked(
  input: Uint8Array,
  chunkSizeKB: number,
  endToken: Uint8Array,
): Uint8Array[] {
  const chunkSize = chunkSizeKB * 1024;
  const chunks: Uint8Array[] = [];
  const endTokenLength = endToken.length;

  for (let i = 0; i < input.length; i += chunkSize) {
    let chunk: Uint8Array;
    if (i + chunkSize >= input.length) {
      chunk = new Uint8Array(input.length - i + endTokenLength);
      chunk.set(input.subarray(i), 0);
      chunk.set(endToken, input.length - i);
    } else {
      chunk = input.subarray(i, i + chunkSize);
    }
    chunks.push(chunk);
  }

  return chunks;
}

function decodeBytesChunked(chunks: Uint8Array[], endToken: Uint8Array): Uint8Array {
  const lastChunkIndex = chunks.length - 1;
  const lastChunk = chunks[lastChunkIndex];
  const footer = endToken;
  const footerLength = footer.length;

  if (lastChunk.slice(-footerLength).every((value, index) => value === footer[index])) {
    chunks[lastChunkIndex] = lastChunk.subarray(0, lastChunk.length - footerLength);
  }

  let totalLength = 0;
  chunks.forEach((chunk) => (totalLength += chunk.length));

  const result = new Uint8Array(totalLength);
  let offset = 0;
  chunks.forEach((chunk) => {
    result.set(chunk, offset);
    offset += chunk.length;
  });

  return result;
}

class StreamingWebSocket<T = object> {
  static EndToken = new Uint8Array([0x7c, 0x3d, 0x46, 0x49, 0x4e, 0x49, 0x53, 0x48, 0x3d, 0x7c]);

  private websocket: WebSocket;

  private readQueue: Uint8Array[] = [];

  private messageChunkSizeKB: number;

  private onMessage: (message: T) => void;

  private unregisterEventListeners: () => void;

  constructor({
    websocketURL,
    uid,
    messageChunkSizeKB = 64,
    onMessage,
    onOpen,
    onError,
    onClose,
  }: {
    // uid is only needed as a constructor arg to send in websocket protocol, so the cloudflare worker can identify the user and run auth.
    websocketURL: string;
    uid: string;
    messageChunkSizeKB?: number;
    onMessage: (message: T) => void;
    onOpen: (event: Event) => void;
    onError: (event: Event) => void;
    onClose: (event: CloseEvent) => void;
  }) {
    this.websocket = new WebSocket(websocketURL, uid); // second arg for javascript websocket is the Sec-WebSocket-Protocol. either a string or list of strings.
    this.websocket.addEventListener("close", () => {
      debugLog(`WebSocket closed at: ${new Date().toISOString()}`);
    });
    this.websocket.addEventListener("error", () => {
      debugError(`WebSocket error at: ${new Date().toISOString()}`);
    });
    this.websocket.binaryType = "arraybuffer";
    this.onMessage = onMessage;
    this.messageChunkSizeKB = messageChunkSizeKB;

    this.websocket.addEventListener("message", this.handleWebsocketMessageEvent);
    this.websocket.addEventListener("open", onOpen);
    this.websocket.addEventListener("error", onError);
    this.websocket.addEventListener("close", onClose);

    this.unregisterEventListeners = () => {
      this.websocket.removeEventListener("message", this.handleWebsocketMessageEvent);
      this.websocket.removeEventListener("open", onOpen);
      this.websocket.removeEventListener("error", onError);
      this.websocket.removeEventListener("close", onClose);
    };
  }

  async destroy() {
    this.readQueue.length = 0;

    await closeWebSocket(this.websocket);

    this.unregisterEventListeners();
  }

  get readyState() {
    return this.websocket.readyState;
  }

  private encodeMessageChunk(input: Uint8Array): Uint8Array[] {
    return encodeBytesChunked(input, this.messageChunkSizeKB, StreamingWebSocket.EndToken);
  }

  private decodeMessageChunk(chunks: Uint8Array[]): Uint8Array {
    return decodeBytesChunked(chunks, StreamingWebSocket.EndToken);
  }

  private decodeReadQueue() {
    try {
      const decodedMessageBytes = this.decodeMessageChunk(this.readQueue);

      this.readQueue.length = 0;

      const decodedMessage = decode(decodedMessageBytes) as T;

      return decodedMessage;
    } catch (error) {
      console.error(error);
    }

    return;
  }

  private isLastMessage(message: Uint8Array) {
    const footer = StreamingWebSocket.EndToken;
    return message.slice(-footer.length).every((value, index) => value === footer[index]);
  }

  private handleWebsocketMessageBytes = (message: Uint8Array) => {
    this.readQueue.push(message);

    if (this.isLastMessage(message)) {
      const messageData = this.decodeReadQueue();

      if (messageData) {
        this.onMessage(messageData);
      }

      return;
    }
  };

  private handleWebsocketMessageEvent = (event: MessageEvent<any>) => {
    try {
      const message = event.data;

      if (message instanceof ArrayBuffer) {
        this.handleWebsocketMessageBytes(new Uint8Array(message));
      }
    } catch (error) {
      console.error(error);
    }
  };

  send<Message = any>(message: Message) {
    if (!message) {
      return;
    }

    const messageBytes = encode(message);

    const messageChunks = this.encodeMessageChunk(messageBytes);

    if (messageChunks.length <= 0) {
      return;
    }

    for (let i = 0; i < messageChunks.length; ++i) {
      const chunk = messageChunks[i];
      this.websocket.send(chunk);
    }
  }
}

enum RealTimeMessageStatus {
  SetServerId = "set_server_id",
  Connected = "connected",
  SendFrame = "send_frame",
  ColorCorrect = "color_correct",
  Wait = "wait",
  Timeout = "timeout",
  Error = "error",
}

type RealTimeMessageSetServerId = {
  status: RealTimeMessageStatus.SetServerId;
  serverId: RealTimeServerId;
};

type RealTimeMessageConnected = {
  status: RealTimeMessageStatus.Connected;
  userId: string;
};

type RealTimeMessageSendFrame = {
  status: RealTimeMessageStatus.SendFrame;
};

type RealTimeMessageColorCorrect = {
  status: RealTimeMessageStatus.ColorCorrect;
};

type RealTimeMessageWait = {
  status: RealTimeMessageStatus.Wait;
};

type RealTimeMessageTimeout = {
  status: RealTimeMessageStatus.Timeout;
};

type RealTimeMessageError = {
  status: RealTimeMessageStatus.Error;
  message: string;
};

type RealTimeMessage =
  | RealTimeMessageSetServerId
  | RealTimeMessageConnected
  | RealTimeMessageSendFrame
  | RealTimeMessageColorCorrect
  | RealTimeMessageWait
  | RealTimeMessageTimeout
  | RealTimeMessageError;

export class RealTimeRenderWebSocketController {
  private _websocket?: StreamingWebSocket<RealTimeMessage>;

  private renderIndex = 0;

  private colorCorrectRenderIndex = 0;

  private defaultRenderIndex = 0;

  private colorCorrectAbortController?: AbortController;

  private realtimeServerConfigs: Record<RealTimeServerId, RealTimeServerConfig> = {};

  private tierToServerIds: TierToServerIds = {};

  private waitTimeoutCallback = new TimeoutCallback(5, () => {
    debugLog("Wait timeout reconnect");
    // this.reconnect();
  });

  private unsubscribeToRealTimeRenderStatusChange = () => {};

  constructor() {
    FirebaseBackend.getRealTimeRenderConfigs().then((configs) => {
      this.realtimeServerConfigs = configs.reduce<Record<RealTimeServerId, RealTimeServerConfig>>(
        (configs, config) => {
          if (isRealTimeServerConfig(config)) {
            configs[config.id] = config;
          }

          return configs;
        },
        {} as Record<RealTimeServerId, RealTimeServerConfig>,
      );

      this.tierToServerIds = configs.reduce<TierToServerIds>((output, config) => {
        if (isRealTimeServerConfig(config)) {
          output[config.tier] = [...(output[config.tier] ?? []), config.id];
        }

        return output;
      }, {} as TierToServerIds);
    });

    this.unsubscribeToRealTimeRenderStatusChange = editorContextStore.subscribe(
      (state) => state.realtimeRenderStatus,
      this.handleRealTimeStatusUpdate,
    );
  }

  get enabled() {
    return editorContextStore.getState().realtimeRenderMode === RealTimeRenderMode.Active;
  }

  get websocket() {
    return this._websocket;
  }

  set websocket(socket: StreamingWebSocket<RealTimeMessage> | undefined) {
    this._websocket?.destroy();
    this._websocket = socket;
  }

  get isWebsocketCreated() {
    return Boolean(this.websocket);
  }

  get websocketReadyState() {
    return this.websocket?.readyState;
  }

  get isWebsocketOpen() {
    return Boolean(this.websocketReadyState === WebSocket.OPEN);
  }

  private destroyWebSocket() {
    if (!this.websocket) {
      return;
    }

    debugLog("Destroy websocket");

    this.websocket?.destroy();
    this.websocket = undefined;
  }

  destroy() {
    this.stop();

    this.destroyWebSocket();
    this.unsubscribeToRealTimeRenderStatusChange();
  }

  handleWebSocketClose = () => {
    const { setRealtimeRenderProgress } = editorContextStore.getState();

    setRealtimeRenderProgress(1.0);

    // setRealtimeRenderMode(RealTimeRenderMode.Disabled);

    this.status = RealTimeRenderStatus.DISCONNECTED;
  };

  handleWebSocketError = (message: unknown) => {
    debugLog("WebSocket raised error");
    debugError(message);

    const { setRealtimeRenderProgress } = editorContextStore.getState();

    setRealtimeRenderProgress(0);

    // setRealtimeRenderMode(RealTimeRenderMode.Disabled);

    // this.status = RealTimeRenderStatus.DISCONNECTED;

    displayUiMessage("Reconnecting to the real-time render server ...", "info");

    this.reconnect();
  };

  private reconnectTimeId: NodeJS.Timeout | undefined = undefined;

  private cancelReconnect() {
    if (this.reconnectTimeId) {
      clearTimeout(this.reconnectTimeId);
    }
  }

  private reconnect() {
    try {
      // Only reconnect if the real-time preview mode is active

      const { realtimeRenderMode, setRealtimeRenderProgress } = editorContextStore.getState();

      if (realtimeRenderMode !== RealTimeRenderMode.Active) {
        debugLog(
          `Cannot re-connect because the real time render mode ${realtimeRenderMode} is not active.`,
        );
        return;
      }

      this.stop();

      setRealtimeRenderProgress(0.33);

      // Sample a random number from 1 to 5 seconds and reconnect
      const waitSeconds = 1;

      debugLog(`Wait for ${waitSeconds} seconds to reconnect real-time`);

      return new Promise<void>((resolve, reject) => {
        this.reconnectTimeId = setTimeout(() => {
          // this.start()
          //     .then(resolve)
          //     .catch(reject);

          debugLog("Try reconnect");

          const { editor, realtimeRenderMode } = editorContextStore.getState();

          if (this.enabled) {
            debugLog("Restart realtime render");

            editor?.emit<RealTimeRenderStartRenderEventHandler>("realtime-render:start-render");
          } else {
            debugLog(
              `Cannot re-start because the real time render mode ${realtimeRenderMode} is not active.`,
            );
          }

          resolve();
        }, waitSeconds * 1000);
      });
    } catch (error) {
      debugError(error);
    }

    return Promise.resolve();
  }

  get userId() {
    return editorContextStore.getState().user?.uid;
  }

  get status() {
    return editorContextStore.getState().realtimeRenderStatus;
  }

  private handleRealTimeDisable = () => {
    const {
      backend,
      realtimeUserId,
      setRealtimeUserId,
      setRealtimeServerId,
      setRealtimeServerTier,
    } = editorContextStore.getState();

    // setRealtimeRenderMode(RealTimeRenderMode.Disabled);
    this.waitTimeoutCallback.stop();

    if (realtimeUserId) {
      backend?.disconnectRealTimeState(realtimeUserId);
    }

    this.destroyWebSocket();

    setRealtimeUserId(undefined);

    setRealtimeServerId(undefined);

    setRealtimeServerTier(undefined);
  };

  private handleRealTimeStatusUpdate = (value: RealTimeRenderStatus) => {
    debugLog(`Set real-time status to ${value}`);

    if (value === RealTimeRenderStatus.DISCONNECTED || value === RealTimeRenderStatus.TIMEOUT) {
      this.handleRealTimeDisable();
    }
  };

  set status(value: RealTimeRenderStatus) {
    const { setRealtimeRenderStatus } = editorContextStore.getState();

    setRealtimeRenderStatus(value);
  }

  private createWebsocketConnection(getStreamData: GetStreamDataCallback) {
    return new Promise((resolve, reject) => {
      try {
        const { userId } = this;

        if (!userId) {
          reject(new Error("User is not logged in yet."));
          return;
        }

        const realtimeUserId = crypto.randomUUID() as RealTimeUserId;

        const { setRealtimeUserId, setRealtimeServerId } = editorContextStore.getState();

        const websocketURL = `${WebsocketProtocol}://${
          import.meta.env.VITE_REALTIME_SERVER_API_URL
        }/api/ws/${realtimeUserId}`;

        this.destroyWebSocket();

        debugLog("Create websocket");

        const handleWebsocketMessageString = (data: RealTimeMessage) => {
          try {
            switch (data.status) {
              case "set_server_id":
                // eslint-disable-next-line no-case-declarations
                const serverId = data.serverId;

                setRealtimeServerId(serverId);

                debugLog(
                  `Received set_server_id data from websocket backend with server: ${serverId}`,
                );

                break;

              case "connected":
                this.status = RealTimeRenderStatus.CONNECTED;

                setRealtimeUserId(realtimeUserId);

                this.waitTimeoutCallback.stop();

                resolve({ status: "connected", userId });

                break;

              case "send_frame":
                this.waitTimeoutCallback.stop();

                this.status = RealTimeRenderStatus.SEND_FRAME;

                getStreamData?.().then((streamData) => {
                  if (!streamData) {
                    return;
                  }

                  for (const d of streamData) {
                    this.send(d);
                  }
                });

                break;

              case "color_correct":
                this.handleColorCorrect(data);

                this.waitTimeoutCallback.stop();

                break;

              case "wait":
                this.status = RealTimeRenderStatus.WAIT;

                this.waitTimeoutCallback.start();

                break;

              case "timeout":
                debugLog("timeout");

                this.status = RealTimeRenderStatus.TIMEOUT;

                this.waitTimeoutCallback.stop();

                reject(new Error("timeout"));

                break;

              case "error":
                debugLog(data.message);

                this.waitTimeoutCallback.stop();

                this.reconnect();

                reject(new Error(data.message));

                break;
            }
          } catch (error) {
            console.error(error);
          }
        };

        const websocket = new StreamingWebSocket<RealTimeMessage>({
          websocketURL,
          uid: userId,
          onMessage: handleWebsocketMessageString,
          onOpen: () => {
            websocket.send({
              uid: userId,
            });
          },
          onClose: this.handleWebSocketClose,
          onError: this.handleWebSocketError,
        });

        this.websocket = websocket;

        this.waitTimeoutCallback.start();
      } catch (err) {
        debugError(err);
        this.status = RealTimeRenderStatus.DISCONNECTED;
        reject(err);
      }
    });
  }

  private isStarting = false;

  async start(getStreamData: GetStreamDataCallback) {
    if (this.isStarting || !this.enabled) {
      return;
    }

    try {
      debugLog(`Start websocket. Is starting? ${this.isStarting}`);

      this.isStarting = true;

      const { setRealtimeRenderProgress } = editorContextStore.getState();

      setRealtimeRenderProgress(0.33);

      const response = await this.createWebsocketConnection(getStreamData);

      debugLog(`Started websocket. Is starting? ${this.isStarting}`);

      this.isStarting = false;

      return response;
    } catch (error) {
      debugError("Unknown error occured when creating the websocket");

      if (this.enabled) {
        this.reconnect();
      }
    } finally {
      this.isStarting = false;

      debugLog(`Finalize starting websocket`);
      // this.waitTimeoutCallback.stop();
    }
  }

  send(data?: Blob | { [key: string]: any }) {
    if (!data) {
      debugLog("Data is invalid");
      return;
    }

    const websocket = this.websocket;

    if (websocket && websocket.readyState === WebSocket.OPEN) {
      this.renderIndex += 1;

      if (this.defaultRenderIndex >= this.colorCorrectRenderIndex) {
        const { setRealtimeColorCorrectImageUrl } = editorContextStore.getState();

        setRealtimeColorCorrectImageUrl(undefined);
      }

      this.defaultRenderIndex = this.renderIndex;

      this.waitTimeoutCallback.stop();

      websocket.send({ status: "next_frame" });

      websocket.send(data);

      this.colorCorrectAbortController?.abort();
    } else {
      debugLog("WebSocket not connected");
    }
  }

  async stop() {
    this.status = RealTimeRenderStatus.DISCONNECTED;
    this.cancelReconnect();
  }

  private async handleColorCorrect(data: any) {
    try {
      if (!data) {
        return;
      }

      if (!isRealTimeColorCorrectionMessage(data)) {
        return;
      }

      const { user } = editorContextStore.getState();

      const uid = user?.uid;

      if (!uid) {
        return;
      }

      this.renderIndex += 1;

      this.colorCorrectRenderIndex = this.renderIndex;

      this.colorCorrectAbortController = new AbortController();

      const response = await fetch(
        // "http://34.28.12.38:8000/predict",
        `${HttpProtocol}://${ColorCorrectionServerUrl}/predict`,
        {
          method: "POST",
          headers: {
            UserId: uid,
            "Content-Type": "application/json",
            Api_Key: "flair-render-realtime-v1.0",
          },
          body: JSON.stringify({
            prompt: data.input_params.prompt,
            negative_prompt: data.input_params.negative_prompt,
            composite_image: data.input_params.composite_image,
            composite_mask_image: data.input_params.composite_mask_image,
            initial_render_image: data.output_raw_image,
            product_mask_channel: "R",
          }),
          signal: this.colorCorrectAbortController.signal,
        },
      );

      this.colorCorrectAbortController = undefined;

      const responseData = await response.json();

      if (!response.ok || !data) {
        debugError(data);
        return;
      }

      const images = responseData.images;

      if (!Array.isArray(images) || images.length <= 0) {
        return;
      }

      const image = images[0];

      if (this.colorCorrectRenderIndex > this.defaultRenderIndex) {
        // downloadJson(
        //     JSON.stringify({
        //         prompt: data.input_params.prompt,
        //         negative_prompt: data.input_params.negative_prompt,
        //         composite_image: data.input_params.composite_image,
        //         composite_mask_image: data.input_params.composite_mask_image,
        //         initial_render_image: data.output_raw_image,
        //     }),
        //     `color-correction-input-${uuid}.json`,
        // );

        // downloadImageDataUrl(image, `color-correct-output-${uuid}`);

        const { setRealtimeRenderProgress, setRealtimeColorCorrectImageUrl } =
          editorContextStore.getState();

        setRealtimeColorCorrectImageUrl(image);

        setRealtimeRenderProgress(1.0);
      }
    } catch (error) {
      debugError(error);
    }
  }

  private async getColorCorrectionArgs() {
    const websocket = this.websocket;

    if (!websocket || websocket.readyState !== WebSocket.OPEN) {
      return;
    }

    this.waitTimeoutCallback.stop();

    websocket.send({
      status: "color_correct",
    });

    editorContextStore
      .getState()
      .editor?.emit<RealTimeRenderStartColorCorrectionEventHandler>(
        "realtime-render:start-color-correction",
      );
  }

  async startColorCorrection() {
    await this.getColorCorrectionArgs();
  }

  handleSendFrame() {
    if (this.defaultRenderIndex >= this.colorCorrectRenderIndex) {
      const { setRealtimeColorCorrectImageUrl } = editorContextStore.getState();

      setRealtimeColorCorrectImageUrl(undefined);
    }
  }
}
