import { z } from "zod";

import { AggregateZod, DavidsAggregateZod } from "../../../models/aggregate";
import {
  InfinityZod,
  DateTimeZod,
  DayOfWeekZod,
  NumberOrInfinityZod,
  PeriodUnitZod,
  PeriodZod,
  convertSnakeToCamelObject,
  convertSnakeToCamelValue,
  firstUpperCase,
} from "../../../models/primitives";
import { FrozenAlgorithmConfigZod, NaiveModelZod } from "./algorithmConfig";
import { ModelZod } from "./model";
import { RunDetailZod } from "./run";
import { ResultMetric } from "./settings";
import {
  TaskErrorClassZod,
  TaskErrorCodeZod,
  TaskNameZod,
  TaskStatusZod,
  TaskSubStatusZod,
  TaskWarningClassZod,
  TaskWarningCodeZod,
} from "./task";

export const FrozenModelZod = ModelZod.omit({
  algorithmConfigs: true,
}).extend({
  frozenModelId: z.number(),
  frozenAlgorithmConfigs: FrozenAlgorithmConfigZod.array(),
  description: z.string(),
});
export type FrozenModel = z.infer<typeof FrozenModelZod>;

export const FrozenRunZod = RunDetailZod.omit({
  runConfigId: true,
  partitions: true,
  measurements: true,
}).extend({
  frozenRunConfigId: z.number(),
});
export type FrozenRun = z.infer<typeof FrozenRunZod>;

export const ResultZod = z.object({
  runResultId: z.number(),
  started: DateTimeZod.nullable(),
  finished: DateTimeZod.nullable(),
  name: z.string(),
  taskId: z.string().nullable(),
  taskName: TaskNameZod,
  status: TaskStatusZod.nullable().catch(null),
  subStatus: TaskSubStatusZod.nullable().catch(null),
  measurementIds: z.number().array(),
  partitionIds: z.number().array(),
  timeScales: AggregateZod.array(),
  // model parameters
  frozenModels: FrozenModelZod.array(),
  modelsAndVersions: z.string(),
  modelIds: z.number().array(),
  // run parameters
  runConfigId: z.number(),
  frozenRunConfig: FrozenRunZod,
});

export type Result = z.infer<typeof ResultZod>;

export const resultMetricKeys = [
  "rsquare",
  "mae",
  "mape",
  "smape",
  "mase",
  "rmse",
] as const;
export const ResultMetricKeyZod = z.enum(resultMetricKeys);
export type ResultMetricKey = z.infer<typeof ResultMetricKeyZod>;

export function getMetricKey(metric: ResultMetric): ResultMetricKey {
  switch (metric) {
    case "RSquared":
      return "rsquare";
    case "MAE":
    case "MAPE":
    case "MASE":
    case "RMSE":
    case "SMAPE":
      return metric.toLowerCase() as ResultMetricKey;
  }
}

const SummaryMetricZod = z
  .number()
  .or(InfinityZod)
  .optional()
  .catch(undefined)
  .transform((arg) =>
    // round to two decimal places
    typeof arg === "number" ? Math.round(arg * 100) / 100 : arg
  );

export const ResultDetailZod = ResultZod.extend({
  dataSegmentResultIds: z.number().array(),
  summaryStats: z.record(
    ResultMetricKeyZod,
    z.object({
      mean: SummaryMetricZod,
      std: SummaryMetricZod,
      min: SummaryMetricZod,
      median: SummaryMetricZod,
      max: SummaryMetricZod,
    })
  ),
  description: z.string(),
});
export type ResultDetail = z.infer<typeof ResultDetailZod>;

export const ResultSegmentZod = z.object({
  dataSegmentResultId: z.number(),
  partitionId: z.number(),
  partitionName: z.string(),
  partitionFullName: z.string(),
  partitionIsAggregate: z.boolean(),
  measurementId: z.number(),
  measurementName: z.string(),
  runResultId: z.number(),
  createdAt: DateTimeZod,
  status: TaskStatusZod.nullable(),
  subStatus: TaskSubStatusZod.nullable(),
  problem: z.string().nullable(),
  modelId: z.number(),
  modelName: z.string(),
  modelVersion: z.string(),
  // metrics
  rsquare: z.number().nullable(),
  mae: z.number().nullable(),
  mase: z.number().nullable(),
  mape: z.number().nullable(),
  smape: z.number().nullable(),
  rmse: z.number().nullable(),
});
export type ResultSegment = z.infer<typeof ResultSegmentZod>;

export const ResultSegmentSparklineZod = z.object({
  dataSegmentResultId: z.number(),
  forecastMean: z.number().nullable().optional(),
  data: z
    .object({
      x: DateTimeZod,
      yHistory: z.number().nullable().optional(),
      yForecast: z.number().nullable().optional(),
    })
    .array()
    .catch([]),
});
export type ResultSegmentSparkline = z.infer<typeof ResultSegmentSparklineZod>;

export const ResultDataEntryZod = z.object({
  ts: DateTimeZod,
  value: z.number().nullable().default(null),
  interpolatedValue: z.number().nullable().optional(),
  lowerValue: z.number().nullable().optional(),
  upperValue: z.number().nullable().optional(),
  residual: z.number().nullable().optional(),
});
export type ResultDataEntry = z.infer<typeof ResultDataEntryZod>;

const segmentFactorTypes = [
  "BusinessHours",
  "InfluencingFactor",
  "MeasurementCovariate",
] as const;
const SegmentFactorTypeZod = z.enum(segmentFactorTypes);
export type SegmentFactorType = z.infer<typeof SegmentFactorTypeZod>;

export const SegmentInfluencingFactorZod = z.object({
  influencingFactorId: z.number().nullable(),
  influencingFactorResultId: z.number(),
  name: z.string(),
  pipelinePosition: z.number(),
  factorType: SegmentFactorTypeZod,
  measurementCovariateId: z.number().nullable().optional(),
  data: z
    .object({
      ts: DateTimeZod,
      value: z.number(),
    })
    .array(),
});
export type SegmentInfluencingFactor = z.infer<
  typeof SegmentInfluencingFactorZod
>;

export const SegmentResultObjectZod = z.discriminatedUnion("resultType", [
  z.object({
    resultType: z.literal("Chart"),
    name: z.string(),
    pipelinePosition: z.number(),
    payload: z.any(),
  }),
  z.object({
    resultType: z.literal("Table"),
    pipelinePosition: z.number(),
    name: z.string(),
    payload: z
      .record(
        z.string(),
        z.union([
          NumberOrInfinityZod,
          z.string(),
          DateTimeZod,
          z.null(),
          z.undefined(),
        ])
      )
      .array(),
  }),
]);
export type SegmentResultObject = z.infer<typeof SegmentResultObjectZod>;

export const ResultDataChartZod = z.object({
  name: z.string(),
  payload: z.any(),
});

export type ResultDataChart = z.infer<typeof ResultDataChartZod>;

export const ResultDataTableZod = z.object({
  name: z.string(),
  payload: z
    .record(
      z.string(),
      z.union([
        NumberOrInfinityZod,
        z.string(),
        DateTimeZod,
        z.null(),
        z.undefined(),
      ])
    )
    .array(),
});

export type ResultDataTable = z.infer<typeof ResultDataTableZod>;

export const trendTypes = [
  "Trend",
  "QuadraticTrend",
  "LogarithmicTrend",
] as const;
export const TrendTypeZod = convertSnakeToCamelValue(
  firstUpperCase(z.enum(trendTypes))
);
export type TrendType = z.infer<typeof TrendTypeZod>;

export const intersectTypes = ["Level", "Intersect"] as const;
export const IntersectTypeZod = convertSnakeToCamelValue(
  firstUpperCase(z.enum(intersectTypes))
);
export type IntersectType = z.infer<typeof IntersectTypeZod>;

export const seasonalityTypes = [
  "Daily",
  "Weekly",
  "Monthly",
  "Yearly",
  "Seasonality",
  "Fourier",
] as const;
export const SeasonalityTypeZod = firstUpperCase(z.enum(seasonalityTypes));
export type SeasonalityType = z.infer<typeof SeasonalityTypeZod>;

export const arimaTypes = ["AR", "I", "MA"] as const;
export const ArimaTypeZod = firstUpperCase(z.enum(arimaTypes));
export type ArimaType = z.infer<typeof ArimaTypeZod>;

export const categoricalTypes = [
  "DayInWeek",
  "Quarters",
  "DayInMonth",
  "MonthInYear",
  "Year",
  "DayInYear",
] as const;
export const CategoricalTypeZod = convertSnakeToCamelValue(
  firstUpperCase(z.enum(categoricalTypes))
);
export type CategoricalType = z.infer<typeof CategoricalTypeZod>;

export const postprocessingTypes = [
  "RemoveNegative",
  "Overwrite",
  "ReplaceTrend",
  "ReplaceIntersect",
  "Absolute",
  "Relative",
] as const;
export const PostprocessingTypeZod = z.enum(postprocessingTypes);
export type PostprocessingType = z.infer<typeof PostprocessingTypeZod>;

const ContributionBetaZod = z.object({
  explanationResultId: z.number(),
  pipelinePosition: z.number(),
  data: z
    .object({
      ts: DateTimeZod,
      value: z.number().nullable(),
    })
    .array(),
});

export const ResultContributionEntryZod = z.discriminatedUnion("factorType", [
  ContributionBetaZod.extend({
    factorType: z.literal("Trend"),
    subtype: TrendTypeZod.catch("Trend"),
    properties: z.any(),
  }),
  ContributionBetaZod.extend({
    factorType: z.literal("Intersect"),
    subtype: IntersectTypeZod.catch("Intersect"),
    properties: z.any(),
  }),
  ContributionBetaZod.extend({
    factorType: z.literal("Seasonality"),
    subtype: SeasonalityTypeZod.catch("Seasonality"),
    properties: z
      .object({
        function: z.enum(["sin", "cos"]),
        value: z.number(),
      })
      .partial()
      .catch({}),
  }),
  ContributionBetaZod.extend({
    factorType: z.literal("PreviousAlgorithm"),
    subtype: z.null(),
    properties: z.any(),
  }),
  ContributionBetaZod.extend({
    factorType: z.literal("ARIMA"),
    subtype: ArimaTypeZod,
    properties: z.any(),
  }),
  ContributionBetaZod.extend({
    factorType: z.literal("NaiveModel"),
    subtype: NaiveModelZod,
    properties: z
      .object({
        // this might be actually properties: PeriodZod
        timescale: PeriodZod,
      })
      .partial()
      .catch({}),
  }),
  ContributionBetaZod.extend({
    factorType: z.literal("Categorical"),
    subtype: CategoricalTypeZod,
    properties: z.object({ category: z.string() }).partial().catch({}),
  }),
  ContributionBetaZod.extend({
    factorType: z.literal("Autoregressive"),
    subtype: z.null(),
    properties: PeriodZod,
  }),
  ContributionBetaZod.extend({
    factorType: z.literal("InfluencingFactor"),
    subtype: z.null(),
    properties: z.object({ id: z.number(), name: z.string() }),
  }),
  ContributionBetaZod.extend({
    factorType: z.literal("BusinessHours"),
    subtype: z.null(),
    properties: z.any(),
  }),
  ContributionBetaZod.extend({
    factorType: z.literal("MeasurementCovariate"),
    subtype: z.any(),
    properties: z.object({ id: z.number(), name: z.string() }),
  }),
  ContributionBetaZod.extend({
    factorType: z.literal("AggregateCorrection"),
    subtype: z.null(),
    properties: z.any(),
  }),
  ContributionBetaZod.extend({
    factorType: z.literal("Postprocessing"),
    subtype: z.preprocess(
      (arg) => (arg === "remove_negative" ? "RemoveNegative" : arg),
      PostprocessingTypeZod
    ),
    properties: convertSnakeToCamelObject(
      z
        .object({
          value: z.number().nullable(),
          perUnit: PeriodUnitZod.nullable(),
        })
        .partial()
        .catch({})
    ),
  }),
  ContributionBetaZod.extend({
    factorType: z.literal("Covariate"),
    subtype: z.any(),
    properties: z.any(),
  }),
]);
export type ResultContributionEntry = z.infer<
  typeof ResultContributionEntryZod
>;

export const IntradayProfileZod = z.object({
  dayOfWeek: DayOfWeekZod,
  timeScale: DavidsAggregateZod,
  data: z
    .object({
      x: z.string(),
      y: z.number().nullable(),
    })
    .array(),
});
export type IntradayProfile = z.infer<typeof IntradayProfileZod>;

export const IntradayProfileSetZod = z.object({
  intradayProfileSetId: z.number(),
  startDate: DateTimeZod.nullable(),
  endDate: DateTimeZod.nullable(),
  profiles: IntradayProfileZod.array(),
});
export type IntradayProfileSet = z.infer<typeof IntradayProfileSetZod>;

const verticalLineTypes = ["Changepoints", "Fitdates", "Calcdates"] as const;
export const VerticalLineTypeZod = z.enum(verticalLineTypes);
export type VerticalLineType = z.infer<typeof VerticalLineTypeZod>;

export const PrecisionZod = z.number().or(z.literal("max"));
export type Precision = z.infer<typeof PrecisionZod>;

export const ContributionChartZod = z.enum([
  "AllContributions",
  "Intersect",
  "Trend",
  "Seasonality",
  "PreviousAlgorithm",
  "ARIMA",
  "NaiveModel",
  "Categorical",
  "Autoregressive",
  "InfluencingFactor",
  "BusinessHours",
  "MeasurementCovariate",
  "AggregateCorrection",
  "Postprocessing",
  "Covariate",
]);
export type ContributionChart = z.infer<typeof ContributionChartZod>;

export const ComparedResultSegmentZod = z.object({
  partitionId: z.number(),
  partitionName: z.string(),
  partitionFullName: z.string(),
  partitionIsAggregate: z.boolean(),
  measurementId: z.number(),
  measurementName: z.string(),
  segments: z
    .object({
      runResultId: z.number(),
      status: TaskStatusZod.nullable(),
      subStatus: TaskSubStatusZod.nullable(),
      // metrics
      rsquare: z.number().nullable(),
      mae: z.number().nullable(),
      mase: z.number().nullable(),
      mape: z.number().nullable(),
      smape: z.number().nullable(),
      rmse: z.number().nullable(),
    })
    .array(),
});
export type ComparedResultSegment = z.infer<typeof ComparedResultSegmentZod>;

const ResultErrorInfoZod = z.object({
  code: TaskErrorCodeZod,
  name: TaskErrorClassZod,
  count: z.number(),
});
export type ResultErrorInfo = z.infer<typeof ResultErrorInfoZod>;

const ResultWarningInfoZod = z.object({
  code: TaskWarningCodeZod,
  name: TaskWarningClassZod,
  count: z.number(),
});
export type ResultWarningInfo = z.infer<typeof ResultWarningInfoZod>;

export const ResultProblemsZod = z.object({
  numErrorSegments: z.number(),
  numWarningSegments: z.number(),
  numCanceledSegments: z.number(),
  numSuccessSegments: z.number(),
  numUnsetSegments: z.number(),
  errors: ResultErrorInfoZod.array(),
  warnings: ResultWarningInfoZod.array(),
});
export type ResultProblems = z.infer<typeof ResultProblemsZod>;
