from _typeshed import Incomplete
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import Any, TypeVar, overload
from typing_extensions import Final, Literal, Self, TypeAlias, TypeGuard

from tensorflow import Tensor, _TensorCompatible
from tensorflow._aliases import KerasSerializable
from tensorflow.keras.metrics import (
    binary_crossentropy as binary_crossentropy,
    categorical_crossentropy as categorical_crossentropy,
)

class Loss(ABC):
    reduction: _ReductionValues
    name: str | None
    def __init__(self, reduction: _ReductionValues = "auto", name: str | None = None) -> None: ...
    @abstractmethod
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...
    @classmethod
    def from_config(cls, config: dict[str, Any]) -> Self: ...
    def get_config(self) -> dict[str, Any]: ...
    def __call__(
        self, y_true: _TensorCompatible, y_pred: _TensorCompatible, sample_weight: _TensorCompatible | None = None
    ) -> Tensor: ...

class BinaryCrossentropy(Loss):
    def __init__(
        self,
        from_logits: bool = False,
        label_smoothing: float = 0.0,
        axis: int = -1,
        reduction: _ReductionValues = ...,
        name: str | None = "binary_crossentropy",
    ) -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class BinaryFocalCrossentropy(Loss):
    def __init__(
        self,
        apply_class_balancing: bool = False,
        alpha: float = 0.25,
        gamma: float = 2.0,
        from_logits: bool = False,
        label_smoothing: float = 0.0,
        axis: int = -1,
        reduction: _ReductionValues = ...,
        name: str | None = "binary_focal_crossentropy",
    ) -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class CategoricalCrossentropy(Loss):
    def __init__(
        self,
        from_logits: bool = False,
        label_smoothing: float = 0.0,
        axis: int = -1,
        reduction: _ReductionValues = ...,
        name: str | None = "categorical_crossentropy",
    ) -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class CategoricalHinge(Loss):
    def __init__(self, reduction: _ReductionValues = ..., name: str | None = "categorical_hinge") -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class CosineSimilarity(Loss):
    def __init__(self, axis: int = -1, reduction: _ReductionValues = ..., name: str | None = "cosine_similarity") -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class Hinge(Loss):
    def __init__(self, reduction: _ReductionValues = ..., name: str | None = "hinge") -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class Huber(Loss):
    def __init__(self, delta: float = 1.0, reduction: _ReductionValues = ..., name: str | None = "huber_loss") -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class KLDivergence(Loss):
    def __init__(self, reduction: _ReductionValues = ..., name: str | None = "kl_divergence") -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class LogCosh(Loss):
    def __init__(self, reduction: _ReductionValues = ..., name: str | None = "log_cosh") -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class MeanAbsoluteError(Loss):
    def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_absolute_error") -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class MeanAbsolutePercentageError(Loss):
    def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_absolute_percentage_error") -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class MeanSquaredError(Loss):
    def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_squared_error") -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class MeanSquaredLogarithmicError(Loss):
    def __init__(self, reduction: _ReductionValues = ..., name: str | None = "mean_squared_logarithmic_error") -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class Poisson(Loss):
    def __init__(self, reduction: _ReductionValues = ..., name: str | None = "poisson") -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class SparseCategoricalCrossentropy(Loss):
    def __init__(
        self,
        from_logits: bool = False,
        ignore_class: int | None = None,
        reduction: _ReductionValues = ...,
        name: str = "sparse_categorical_crossentropy",
    ) -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class SquaredHinge(Loss):
    def __init__(self, reduction: _ReductionValues = ..., name: str | None = "squared_hinge") -> None: ...
    def call(self, y_true: Tensor, y_pred: Tensor) -> Tensor: ...

class Reduction:
    AUTO: Final = "auto"
    NONE: Final = "none"
    SUM: Final = "sum"
    SUM_OVER_BATCH_SIZE: Final = "sum_over_batch_size"
    @classmethod
    def all(cls) -> tuple[_ReductionValues, ...]: ...
    @classmethod
    def validate(cls, key: object) -> TypeGuard[_ReductionValues]: ...

_ReductionValues: TypeAlias = Literal["auto", "none", "sum", "sum_over_batch_size"]

def categorical_hinge(y_true: _TensorCompatible, y_pred: _TensorCompatible) -> Tensor: ...
def huber(y_true: _TensorCompatible, y_pred: _TensorCompatible, delta: float = 1.0) -> Tensor: ...
def log_cosh(y_true: _TensorCompatible, y_pred: _TensorCompatible) -> Tensor: ...
def deserialize(
    name: str | dict[str, Any], custom_objects: dict[str, Any] | None = None, use_legacy_format: bool = False
) -> Loss: ...
def serialize(loss: KerasSerializable, use_legacy_format: bool = False) -> dict[str, Any]: ...

_FuncT = TypeVar("_FuncT", bound=Callable[..., Any])

@overload
def get(identifier: None) -> None: ...
@overload
def get(identifier: str | dict[str, Any]) -> Loss: ...
@overload
def get(identifier: _FuncT) -> _FuncT: ...

# This is complete with respect to methods documented defined here,
# but many methods get re-exported here from tf.keras.metrics that aren't
# covered yet.
def __getattr__(name: str) -> Incomplete: ...
