import { createColumnHelper } from "@tanstack/react-table";
import _, { groupBy } from "lodash";
import { useEffect, useMemo, useState } from "react";
import { Alert, Button } from "react-bootstrap";
import type {
  AccountHoldingResponse,
  AccountWithMetrics,
} from "../../../api/src/accounts/accounts.service";
import type { Transaction } from "../../../api/src/transactions/lib";
import Loading from "../Loading";
import { capitalize, displayAccountName } from "../lib/display";
import { formatCurrencyComponent, formatPercent } from "../lib/numbers";
import { CASH_SYMBOLS } from "../lib/security";
import SecurityLink from "./SecurityLink";
import { Table, useTable } from "./Table/Table";

type RiskAllocationCategory = "growth" | "defensive" | "cash";

type AccountHoldingTotal = Required<
  Pick<
    AccountHoldingResponse,
    "value" | "realizedGainLoss" | "unrealizedGainLoss"
  >
> & {
  weight: number;
  modelWeight: number;
  modelWeightDifference: number;
  income: number;
};

type AccountHoldingRow = {
  securityId: number;
  riskAllocation: RiskAllocationCategory;
  security: AccountHoldingResponse["security"];
  totals: AccountHoldingTotal;
  accounts: Record<number, AccountHoldingTotal>;
};

export type HoldingsFieldMode =
  | "value"
  | "weight"
  | "modelWeight"
  | "modelWeightDifference"
  | "realizedGainLoss"
  | "unrealizedGainLoss"
  | "income";

function getTableValue(fieldMode: HoldingsFieldMode, value: number) {
  return value === 0
    ? "–"
    : fieldMode === "weight" ||
        fieldMode === "modelWeight" ||
        fieldMode === "modelWeightDifference"
      ? formatPercent(value, 2)
      : formatCurrencyComponent(value, 2);
}

export function createRowsFromHoldings(
  holdings: AccountHoldingResponse[],
  transactions: Transaction[],
  modelSummaries: {
    accountId: number;
    model: Record<string, number>;
  }[],
) {
  const growthHoldings = holdings.filter(
    (holding) =>
      !CASH_SYMBOLS.includes(holding.security?.symbol ?? "") &&
      (holding.security?.riskAllocation.growth ?? 0.5) > 0,
  );
  const defensiveHoldings = holdings.filter(
    (holding) =>
      !CASH_SYMBOLS.includes(holding.security?.symbol ?? "") &&
      (holding.security?.riskAllocation.defensive ?? 0.5) > 0,
  );
  const cashHoldings = holdings.filter((holding) =>
    CASH_SYMBOLS.includes(holding.security?.symbol ?? ""),
  );

  const allHoldings: Record<RiskAllocationCategory, AccountHoldingResponse[]> =
    {
      growth: growthHoldings,
      defensive: defensiveHoldings,
      cash: cashHoldings,
    };

  const transactionsBySecurity = groupBy(transactions, "securityId");
  const modelSummariesMap = _.keyBy(modelSummaries, "accountId");

  const totalValue = holdings.reduce((sum, holding) => sum + holding.value, 0);

  return Object.entries(allHoldings).flatMap(
    ([riskAllocationCategory, riskAllocationHoldings]) => {
      const holdingsBySecurity = groupBy(riskAllocationHoldings, "securityId");

      return Object.values(holdingsBySecurity).map(
        (securityHoldings): AccountHoldingRow => {
          const security = securityHoldings[0].security;
          const riskAllocationRatio =
            riskAllocationCategory === "cash"
              ? 1
              : riskAllocationCategory === "growth"
                ? (security?.riskAllocation.growth ?? 0.5)
                : (security?.riskAllocation.defensive ?? 0.5);
          const securityTransactions =
            transactionsBySecurity[securityHoldings[0].securityId] ?? [];

          const holdingsByAccount = groupBy(securityHoldings, "accountId");
          const transactionsByAccount = groupBy(
            securityTransactions,
            "accountId",
          );

          const holdingsByAccountTotals = Object.entries(
            holdingsByAccount,
          ).reduce<Record<number, AccountHoldingTotal>>(
            (result, [accountId, holdings]) => ({
              ...result,
              [parseInt(accountId)]: holdings.reduce(
                (sum, holding): AccountHoldingTotal => {
                  const weight =
                    sum.weight +
                    (totalValue === 0 ? 0 : holding.value / totalValue) *
                      riskAllocationRatio;
                  const modelWeight =
                    sum.modelWeight +
                    (typeof security?.symbol === "undefined"
                      ? 0
                      : (modelSummariesMap[accountId]?.model[security.symbol] ??
                        0)) *
                      riskAllocationRatio;

                  return {
                    value:
                      sum.value + (holding.value ?? 0) * riskAllocationRatio,
                    unrealizedGainLoss:
                      sum.unrealizedGainLoss +
                      holding.unrealizedGainLoss * riskAllocationRatio,
                    realizedGainLoss:
                      sum.realizedGainLoss +
                      (holding.realizedGainLoss ?? 0) * riskAllocationRatio,
                    weight,
                    modelWeight,
                    modelWeightDifference: weight - modelWeight,
                    income:
                      sum.income +
                      (transactionsByAccount[accountId] ?? []).reduce(
                        (sum, transaction) => sum + (transaction.amount ?? 0),
                        0,
                      ) *
                        riskAllocationRatio,
                  };
                },
                {
                  value: 0,
                  unrealizedGainLoss: 0,
                  realizedGainLoss: 0,
                  weight: 0,
                  modelWeight: 0,
                  modelWeightDifference: 0,
                  income: 0,
                },
              ),
            }),
            {},
          );

          return {
            securityId: securityHoldings[0].securityId,
            riskAllocation: riskAllocationCategory as RiskAllocationCategory,
            security,
            totals: Object.values(
              holdingsByAccountTotals,
            ).reduce<AccountHoldingTotal>(
              (sum, accountTotals) => {
                const weight =
                  sum.weight +
                  (totalValue === 0 ? 0 : accountTotals.value / totalValue);
                const modelWeight = sum.modelWeight + accountTotals.modelWeight;

                return {
                  value: sum.value + accountTotals.value,
                  unrealizedGainLoss:
                    sum.unrealizedGainLoss + accountTotals.unrealizedGainLoss,
                  realizedGainLoss:
                    sum.realizedGainLoss + accountTotals.realizedGainLoss,
                  weight,
                  modelWeight,
                  modelWeightDifference: weight - modelWeight,
                  income: sum.income + accountTotals.income,
                };
              },
              {
                value: 0,
                unrealizedGainLoss: 0,
                realizedGainLoss: 0,
                weight: 0,
                modelWeight: 0,
                modelWeightDifference: 0,
                income: 0,
              },
            ),
            accounts: holdingsByAccountTotals,
          };
        },
      );
    },
  );
}

const HoldingsByAccountTable = ({
  holdings,
  grouped = "riskAllocation",
  accounts,
  modelSummaries,
  transactions,
  fieldMode = "value",
  isLoading = false,
}: {
  holdings: AccountHoldingResponse[];
  grouped?: "riskAllocation";
  accounts?: AccountWithMetrics[];
  modelSummaries: {
    accountId: number;
    model: Record<string, number>;
  }[];
  transactions: Transaction[];
  fieldMode?: HoldingsFieldMode;
  isLoading?: boolean;
}) => {
  const columnHelper = useMemo(
    () => createColumnHelper<AccountHoldingRow>(),
    [],
  );

  const [isExpanded, setIsExpanded] = useState(false);

  const flatColumns = useMemo(
    () => [
      columnHelper.accessor((row) => row.security?.identifier ?? "", {
        id: "symbol",
        header: () => "Symbol",
        cell: (info) => (
          <SecurityLink
            symbol={info.getValue()}
            description={info.row.original.security?.description}
          />
        ),
        minSize: 200,
        size: 200,
        enableColumnFilter: false,
        meta: {
          className: "position-sticky",
          headerClassName: "position-sticky z-1",
          headerStyle: { left: 180 },
          style: { left: 180 },
        },
      }),
      columnHelper.accessor((row) => row.totals?.[fieldMode] ?? 0, {
        id: "totalValue",
        cell: (info) => getTableValue(fieldMode, info.getValue()),
        header: () => "All Selected Accounts",
        aggregatedCell: (info) => (
          <span className="fw-bold">
            {getTableValue(fieldMode, info.getValue())}
          </span>
        ),
        minSize: 180,
        size: 180,
        enableColumnFilter: false,
        meta: {
          className: "text-end position-sticky",
          headerClassName: "text-end position-sticky z-1",
          headerStyle: { left: 380 },
          style: { left: 380 },
        },
      }),
      columnHelper.display({
        id: "expander",
        header: () => (
          <Button onClick={() => setIsExpanded(!isExpanded)}>
            {isExpanded ? "<" : ">"}
          </Button>
        ),
        meta: {
          className: `position-sticky ${!isExpanded ? "" : "border-end"}`,
          headerClassName: `position-sticky z-1 ${
            !isExpanded ? "" : "border-end"
          }`,
          headerStyle: { left: 560 },
          style: { left: 560 },
        },
      }),
      ...(!isExpanded
        ? []
        : (accounts ?? []).map((account) =>
            columnHelper.accessor(
              (row) => row.accounts[account.id]?.[fieldMode] ?? 0,
              {
                id: `account-${account.id}-${fieldMode}`,
                cell: (info) => getTableValue(fieldMode, info.getValue()),
                header: () =>
                  displayAccountName(
                    account.displayName,
                    account.displayNumber,
                  ),
                aggregatedCell: (info) => (
                  <span className="fw-bold">
                    {getTableValue(fieldMode, info.getValue())}
                  </span>
                ),
                minSize: 160,
                enableColumnFilter: false,
                meta: {
                  className: "text-end",
                  headerClassName: "text-end",
                },
              },
            ),
          )),
    ],
    [accounts, columnHelper, fieldMode, isExpanded],
  );

  const groupedColumns = useMemo(
    () => [
      columnHelper.group({
        id: "category",
        header: "Category",
        columns: [
          columnHelper.accessor((row) => row.riskAllocation, {
            id: "riskAllocation",
            header: "Risk Category",
            cell: (info) => (
              <span className="fw-bold">{capitalize(info.getValue())}</span>
            ),
            minSize: 180,
            size: 180,
            enableColumnFilter: false,
            meta: {
              className: "position-sticky",
              headerClassName: "position-sticky z-1",
              headerStyle: { left: 0 },
              style: { left: 0 },
            },
          }),
        ],
      }),
      ...flatColumns,
    ],
    [columnHelper, flatColumns],
  );

  const accountsWithHoldings = useMemo(
    () => createRowsFromHoldings(holdings, transactions, modelSummaries),
    [holdings, modelSummaries, transactions],
  );

  const filteredHoldingsForField = useMemo(
    () =>
      fieldMode === "realizedGainLoss" || fieldMode === "income"
        ? accountsWithHoldings
        : accountsWithHoldings.filter((row) => row.totals.value !== 0),
    [accountsWithHoldings, fieldMode],
  );

  const { table } = useTable({
    columns: grouped === "riskAllocation" ? groupedColumns : flatColumns,
    data: filteredHoldingsForField,
    initialState: {
      sorting: [{ id: "totalValue", desc: true }],
      grouping: typeof grouped === "undefined" ? undefined : ["riskAllocation"],
      pagination: {
        pageSize: 9999,
      },
    },
    autoResetExpanded: false,
    manualPagination: true,
    getRowId: (row) => `${row.riskAllocation}-${row.securityId}`,
  });

  useEffect(() => {
    if (accountsWithHoldings.length > 0) {
      table.toggleAllRowsExpanded(true);
    }
  }, [accountsWithHoldings, grouped, table]);

  return isLoading ? (
    <Loading message="Holdings" />
  ) : filteredHoldingsForField.length <= 0 ? (
    <Alert>No holdings found</Alert>
  ) : (
    <Table
      table={table}
      disablePagination
      isDisableGrouping={typeof grouped === "undefined"}
      isFixedGrouping
      className="w-auto"
    />
  );
};

export default HoldingsByAccountTable;
