import string
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Mapping
from collections.abc import Sequence
from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import cast

from jsonschema._format import FormatChecker
from jsonschema.exceptions import SchemaError
from jsonschema.exceptions import ValidationError
from jsonschema.protocols import Validator
from jsonschema.validators import validator_for
from jsonschema_path.paths import SchemaPath
from openapi_schema_validator import OAS31_BASE_DIALECT_ID
from openapi_schema_validator import OAS32_BASE_DIALECT_ID
from openapi_schema_validator import oas30_format_checker
from openapi_schema_validator import oas31_format_checker
from openapi_schema_validator import oas32_format_checker
from openapi_schema_validator.validators import OAS30Validator
from openapi_schema_validator.validators import OAS31Validator
from openapi_schema_validator.validators import OAS32Validator

from openapi_spec_validator.validation.exceptions import (
    DuplicateOperationIDError,
)
from openapi_spec_validator.validation.exceptions import ExtraParametersError
from openapi_spec_validator.validation.exceptions import OpenAPIValidationError
from openapi_spec_validator.validation.exceptions import (
    ParameterDuplicateError,
)
from openapi_spec_validator.validation.exceptions import (
    UnresolvableParameterError,
)

if TYPE_CHECKING:
    from openapi_spec_validator.validation.registries import (
        KeywordValidatorRegistry,
    )


class KeywordValidator:
    def __init__(self, registry: "KeywordValidatorRegistry"):
        self.registry = registry


class ValueValidator(KeywordValidator):
    value_validator_cls: Callable[..., Validator] = NotImplemented
    value_validator_format_checker: FormatChecker = NotImplemented

    def __call__(
        self, schema: SchemaPath, value: Any
    ) -> Iterator[ValidationError]:
        with schema.resolve() as resolved:
            value_validator = self.value_validator_cls(
                resolved.contents,
                _resolver=resolved.resolver,
                format_checker=self.value_validator_format_checker,
            )
            yield from value_validator.iter_errors(value)


class OpenAPIV30ValueValidator(ValueValidator):
    value_validator_cls = OAS30Validator
    value_validator_format_checker = oas30_format_checker


class OpenAPIV31ValueValidator(ValueValidator):
    value_validator_cls = OAS31Validator
    value_validator_format_checker = oas31_format_checker


class OpenAPIV32ValueValidator(ValueValidator):
    value_validator_cls = OAS32Validator
    value_validator_format_checker = oas32_format_checker


class SchemaValidator(KeywordValidator):
    def __init__(self, registry: "KeywordValidatorRegistry"):
        super().__init__(registry)

        # recursion/visit dedupe registry
        self.visited_schema_ids: list[int] | None = []
        # meta-schema-check dedupe registry
        # to avoid validating the same schema multiple times
        self.meta_checked_schema_ids: list[int] | None = []

    @property
    def default_validator(self) -> ValueValidator:
        return cast(ValueValidator, self.registry["default"])

    def _collect_properties(self, schema: SchemaPath) -> set[str]:
        """Return *all* property names reachable from this schema."""
        props: set[str] = set()

        if "properties" in schema:
            schema_props = (schema / "properties").keys()
            props.update(cast(Sequence[str], schema_props))

        for kw in ("allOf", "anyOf", "oneOf"):
            if kw in schema:
                for sub in schema / kw:
                    props.update(self._collect_properties(sub))

        if "items" in schema:
            props.update(self._collect_properties(schema / "items"))

        if "not" in schema:
            props.update(self._collect_properties(schema / "not"))

        return props

    def _get_schema_checker(
        self, schema: SchemaPath, schema_value: Any
    ) -> Callable[[Any], None]:
        raise NotImplementedError

    def _validate_schema_meta(
        self, schema: SchemaPath, schema_value: Any
    ) -> OpenAPIValidationError | None:
        try:
            schema_checker = self._get_schema_checker(schema, schema_value)
        except ValueError as exc:
            return OpenAPIValidationError(str(exc))
        try:
            schema_checker(schema_value)
        except (SchemaError, ValidationError) as err:
            return cast(
                OpenAPIValidationError, OpenAPIValidationError.create_from(err)
            )
        return None

    def __call__(
        self,
        schema: SchemaPath,
        require_properties: bool = True,
        meta_checked: bool = False,
    ) -> Iterator[ValidationError]:
        schema_value = schema.read_value()
        if not isinstance(schema_value, (Mapping, bool)):
            yield OpenAPIValidationError(
                f"{schema_value!r} is not of type 'object', 'boolean'"
            )
            return

        schema_id = id(schema_value)
        if not meta_checked:
            assert self.meta_checked_schema_ids is not None
            if schema_id not in self.meta_checked_schema_ids:
                self.meta_checked_schema_ids.append(schema_id)
                err = self._validate_schema_meta(schema, schema_value)
                if err is not None:
                    yield err
                    return

        assert self.visited_schema_ids is not None
        if schema_id in self.visited_schema_ids:
            return
        self.visited_schema_ids.append(schema_id)

        nested_properties = []
        if "allOf" in schema:
            all_of = schema / "allOf"
            for inner_schema in all_of:
                yield from self(
                    inner_schema,
                    require_properties=False,
                    meta_checked=True,
                )
                nested_properties += list(
                    self._collect_properties(inner_schema)
                )

        if "anyOf" in schema:
            any_of = schema / "anyOf"
            for inner_schema in any_of:
                yield from self(
                    inner_schema,
                    require_properties=False,
                    meta_checked=True,
                )

        if "oneOf" in schema:
            one_of = schema / "oneOf"
            for inner_schema in one_of:
                yield from self(
                    inner_schema,
                    require_properties=False,
                    meta_checked=True,
                )

        if "not" in schema:
            not_schema = schema / "not"
            yield from self(
                not_schema,
                require_properties=False,
                meta_checked=True,
            )

        if "items" in schema:
            array_schema = schema / "items"
            yield from self(
                array_schema,
                require_properties=False,
                meta_checked=True,
            )

        if "properties" in schema:
            props = schema / "properties"
            for _, prop_schema in props.items():
                yield from self(
                    prop_schema,
                    require_properties=False,
                    meta_checked=True,
                )

        required = (
            "required" in schema and (schema / "required").read_value() or []
        )
        properties = (
            "properties" in schema and (schema / "properties").keys() or []
        )
        if "allOf" in schema:
            extra_properties = list(
                set(required) - set(properties) - set(nested_properties)
            )
        else:
            extra_properties = []

        if extra_properties and require_properties:
            yield ExtraParametersError(
                f"Required list has not defined properties: {extra_properties}"
            )

        if "default" in schema:
            default_value = (schema / "default").read_value()
            nullable_value = False
            if "nullable" in schema:
                nullable_value = (schema / "nullable").read_value()
            if default_value is not None or nullable_value is not True:
                yield from self.default_validator(schema, default_value)


class OpenAPIV30SchemaValidator(SchemaValidator):
    schema_validator_cls = OAS30Validator

    def _get_schema_checker(
        self, schema: SchemaPath, schema_value: Any
    ) -> Callable[[Any], None]:
        return cast(
            Callable[[Any], None],
            self.schema_validator_cls.check_schema,
        )


class OpenAPIV31SchemaValidator(SchemaValidator):
    default_jsonschema_dialect_id = OAS31_BASE_DIALECT_ID
    schema_validator_format_checker = oas31_format_checker

    def __init__(self, registry: "KeywordValidatorRegistry"):
        super().__init__(registry)
        self._default_jsonschema_dialect_id: str | None = None
        self._validator_classes_by_dialect: dict[
            str, type[Validator] | None
        ] = {}

    def _get_schema_checker(
        self, schema: SchemaPath, schema_value: Any
    ) -> Callable[[Any], None]:
        dialect_id = self._get_schema_dialect_id(
            schema,
            schema_value,
        )

        validator_cls = self._get_validator_class_for_dialect(dialect_id)
        if validator_cls is None:
            raise ValueError(f"Unknown JSON Schema dialect: {dialect_id!r}")

        return partial(
            validator_cls.check_schema,
            format_checker=self.schema_validator_format_checker,
        )

    def _get_schema_dialect_id(
        self, schema: SchemaPath, schema_value: Any
    ) -> str:
        if isinstance(schema_value, Mapping):
            schema_to_check = dict(schema_value)
            if "$schema" in schema_to_check:
                dialect_value = schema_to_check["$schema"]
                if not isinstance(dialect_value, str):
                    raise ValueError(
                        "Unknown JSON Schema dialect: " f"{dialect_value!r}"
                    )
                dialect_id = dialect_value
            else:
                jsonschema_dialect_id = (
                    self._get_default_jsonschema_dialect_id(schema)
                )
                schema_to_check = {
                    **schema_to_check,
                    "$schema": jsonschema_dialect_id,
                }
                dialect_id = jsonschema_dialect_id
        else:
            jsonschema_dialect_id = self._get_default_jsonschema_dialect_id(
                schema
            )
            schema_to_check = schema_value
            dialect_id = jsonschema_dialect_id

        return dialect_id

    def _get_validator_class_for_dialect(
        self, dialect_id: str
    ) -> type[Validator] | None:
        if dialect_id in self._validator_classes_by_dialect:
            return self._validator_classes_by_dialect[dialect_id]

        validator_cls = cast(
            type[Validator] | None,
            validator_for(
                {"$schema": dialect_id},
                default=cast(Any, None),
            ),
        )
        self._validator_classes_by_dialect[dialect_id] = validator_cls
        return validator_cls

    def _get_default_jsonschema_dialect_id(self, schema: SchemaPath) -> str:
        if self._default_jsonschema_dialect_id is not None:
            return self._default_jsonschema_dialect_id

        spec_root = self._get_spec_root(schema)
        dialect_id = (spec_root / "jsonSchemaDialect").read_str(
            default=self.default_jsonschema_dialect_id
        )

        self._default_jsonschema_dialect_id = dialect_id
        return dialect_id

    def _get_spec_root(self, schema: SchemaPath) -> SchemaPath:
        # jsonschema-path currently has no public API for root traversal.
        return schema._clone_with_parts(())


class OpenAPIV32SchemaValidator(OpenAPIV31SchemaValidator):
    default_jsonschema_dialect_id = OAS32_BASE_DIALECT_ID
    schema_validator_format_checker = oas32_format_checker


class SchemasValidator(KeywordValidator):
    @property
    def schema_validator(self) -> SchemaValidator:
        return cast(SchemaValidator, self.registry["schema"])

    def __call__(self, schemas: SchemaPath) -> Iterator[ValidationError]:
        for _, schema in schemas.items():
            yield from self.schema_validator(schema)


class ParameterValidator(KeywordValidator):
    @property
    def schema_validator(self) -> SchemaValidator:
        return cast(SchemaValidator, self.registry["schema"])

    def __call__(self, parameter: SchemaPath) -> Iterator[ValidationError]:
        if "schema" in parameter:
            schema = parameter / "schema"
            yield from self.schema_validator(schema)


class OpenAPIV2ParameterValidator(ParameterValidator):
    @property
    def default_validator(self) -> ValueValidator:
        return cast(ValueValidator, self.registry["default"])

    def __call__(self, parameter: SchemaPath) -> Iterator[ValidationError]:
        yield from super().__call__(parameter)

        if "default" in parameter:
            # only possible in swagger 2.0
            if "default" in parameter:
                default_value = (parameter / "default").read_value()
                yield from self.default_validator(parameter, default_value)


class ParametersValidator(KeywordValidator):
    @property
    def parameter_validator(self) -> ParameterValidator:
        return cast(ParameterValidator, self.registry["parameter"])

    def __call__(self, parameters: SchemaPath) -> Iterator[ValidationError]:
        seen = set()
        for parameter in parameters:
            yield from self.parameter_validator(parameter)

            key = (parameter["name"], parameter["in"])
            if key in seen:
                yield ParameterDuplicateError(
                    f"Duplicate parameter '{parameter['name']}'"
                )
            seen.add(key)


class MediaTypeValidator(KeywordValidator):
    @property
    def schema_validator(self) -> SchemaValidator:
        return cast(SchemaValidator, self.registry["schema"])

    def __call__(
        self, mimetype: str, media_type: SchemaPath
    ) -> Iterator[ValidationError]:
        if "schema" in media_type:
            schema = media_type / "schema"
            yield from self.schema_validator(schema)


class ContentValidator(KeywordValidator):
    @property
    def media_type_validator(self) -> MediaTypeValidator:
        return cast(MediaTypeValidator, self.registry["mediaType"])

    def __call__(self, content: SchemaPath) -> Iterator[ValidationError]:
        for mimetype, media_type in content.items():
            assert isinstance(mimetype, str)
            yield from self.media_type_validator(mimetype, media_type)


class ResponseValidator(KeywordValidator):
    def __call__(
        self, response_code: str, response: SchemaPath
    ) -> Iterator[ValidationError]:
        raise NotImplementedError


class OpenAPIV2ResponseValidator(ResponseValidator):
    @property
    def schema_validator(self) -> SchemaValidator:
        return cast(SchemaValidator, self.registry["schema"])

    def __call__(
        self, response_code: str, response: SchemaPath
    ) -> Iterator[ValidationError]:
        # openapi 2
        if "schema" in response:
            schema = response / "schema"
            yield from self.schema_validator(schema)


class OpenAPIV3ResponseValidator(ResponseValidator):
    @property
    def content_validator(self) -> ContentValidator:
        return cast(ContentValidator, self.registry["content"])

    def __call__(
        self, response_code: str, response: SchemaPath
    ) -> Iterator[ValidationError]:
        # openapi 3
        if "content" in response:
            content = response / "content"
            yield from self.content_validator(content)


class ResponsesValidator(KeywordValidator):
    @property
    def response_validator(self) -> ResponseValidator:
        return cast(ResponseValidator, self.registry["response"])

    def __call__(self, responses: SchemaPath) -> Iterator[ValidationError]:
        for response_code, response in responses.items():
            assert isinstance(response_code, str)
            yield from self.response_validator(response_code, response)


class OperationValidator(KeywordValidator):
    def __init__(self, registry: "KeywordValidatorRegistry"):
        super().__init__(registry)

        self.operation_ids_registry: list[str] | None = []

    @property
    def responses_validator(self) -> ResponsesValidator:
        return cast(ResponsesValidator, self.registry["responses"])

    @property
    def parameters_validator(self) -> ParametersValidator:
        return cast(ParametersValidator, self.registry["parameters"])

    def __call__(
        self,
        url: str,
        name: str,
        operation: SchemaPath,
        path_parameters: SchemaPath | None,
    ) -> Iterator[ValidationError]:
        assert self.operation_ids_registry is not None

        if "operationId" in operation:
            operation_id_value = (operation / "operationId").read_value()
            if (
                operation_id_value is not None
                and operation_id_value in self.operation_ids_registry
            ):
                yield DuplicateOperationIDError(
                    f"Operation ID '{operation_id_value}' for "
                    f"'{name}' in '{url}' is not unique"
                )
            self.operation_ids_registry.append(operation_id_value)

        if "responses" in operation:
            responses = operation / "responses"
            yield from self.responses_validator(responses)

        names = []

        parameters = None
        if "parameters" in operation:
            parameters = operation / "parameters"
            yield from self.parameters_validator(parameters)
            names += list(self._get_path_param_names(parameters))

        if path_parameters is not None:
            names += list(self._get_path_param_names(path_parameters))

        all_params = set(names)
        url_params = set(self._get_path_params_from_url(url))

        for path in sorted(url_params):
            if path not in all_params:
                yield UnresolvableParameterError(
                    f"Path parameter '{path}' for '{name}' operation in '{url}' was not resolved"
                )

        for path in sorted(all_params):
            if path not in url_params:
                yield UnresolvableParameterError(
                    f"Path parameter '{path}' for '{name}' operation in '{url}' was not resolved"
                )
        return

    def _get_path_param_names(self, params: SchemaPath) -> Iterator[str]:
        for param in params:
            if (param / "in").read_str() == "path":
                yield (param / "name").read_str()

    def _get_path_params_from_url(self, url: str) -> Iterator[str]:
        formatter = string.Formatter()
        path_params = [item[1] for item in formatter.parse(url)]
        return filter(None, path_params)


class PathValidator(KeywordValidator):
    OPERATIONS = [
        "get",
        "put",
        "post",
        "delete",
        "options",
        "head",
        "patch",
        "trace",
    ]

    @property
    def parameters_validator(self) -> ParametersValidator:
        return cast(ParametersValidator, self.registry["parameters"])

    @property
    def operation_validator(self) -> OperationValidator:
        return cast(OperationValidator, self.registry["operation"])

    def __call__(
        self, url: str, path_item: SchemaPath
    ) -> Iterator[ValidationError]:
        parameters = None
        if "parameters" in path_item:
            parameters = path_item / "parameters"
            yield from self.parameters_validator(parameters)

        for field_name, operation in path_item.items():
            assert isinstance(field_name, str)
            if field_name not in self.OPERATIONS:
                continue

            yield from self.operation_validator(
                url, field_name, operation, parameters
            )


class OpenAPIV32PathValidator(PathValidator):
    OPERATIONS = [*PathValidator.OPERATIONS, "query"]

    def __call__(
        self, url: str, path_item: SchemaPath
    ) -> Iterator[ValidationError]:
        parameters = None
        if "parameters" in path_item:
            parameters = path_item / "parameters"
            yield from self.parameters_validator(parameters)

        for field_name, operation in path_item.items():
            assert isinstance(field_name, str)
            if field_name in self.OPERATIONS:
                yield from self.operation_validator(
                    url, field_name, operation, parameters
                )
                continue

            if field_name == "additionalOperations":
                for operation_name, additional_operation in operation.items():
                    assert isinstance(operation_name, str)
                    yield from self.operation_validator(
                        url,
                        operation_name,
                        additional_operation,
                        parameters,
                    )


class PathsValidator(KeywordValidator):
    @property
    def path_validator(self) -> PathValidator:
        return cast(PathValidator, self.registry["path"])

    def __call__(self, paths: SchemaPath) -> Iterator[ValidationError]:
        for url, path_item in paths.items():
            assert isinstance(url, str)
            yield from self.path_validator(url, path_item)


class ComponentsValidator(KeywordValidator):
    @property
    def schemas_validator(self) -> SchemasValidator:
        return cast(SchemasValidator, self.registry["schemas"])

    def __call__(self, components: SchemaPath) -> Iterator[ValidationError]:
        if "schemas" in components:
            schemas = components / "schemas"
            yield from self.schemas_validator(schemas)


class TagsValidator(KeywordValidator):
    def __call__(self, tags: SchemaPath) -> Iterator[ValidationError]:
        seen: set[str] = set()
        for tag in tags:
            tag_name = (tag / "name").read_str()
            if tag_name in seen:
                yield OpenAPIValidationError(
                    f"Duplicate tag name '{tag_name}'"
                )
            seen.add(tag_name)


class OpenAPIV32TagsValidator(TagsValidator):
    def __call__(self, tags: SchemaPath) -> Iterator[ValidationError]:
        yield from super().__call__(tags)

        seen: set[str] = set()
        parent_by_tag_name: dict[str, str | None] = {}
        for tag in tags:
            tag_name = (tag / "name").read_str()
            seen.add(tag_name)

            if "parent" in tag:
                parent_by_tag_name[tag_name] = (tag / "parent").read_str()
            else:
                parent_by_tag_name[tag_name] = None

        for tag_name, parent in parent_by_tag_name.items():
            if parent is not None and parent not in seen:
                yield OpenAPIValidationError(
                    f"Tag '{tag_name}' references unknown parent tag '{parent}'"
                )

        reported_cycles: set[str] = set()
        for start_tag_name in parent_by_tag_name:
            tag_name = start_tag_name
            trail: list[str] = []
            trail_pos: dict[str, int] = {}

            while True:
                if tag_name in trail_pos:
                    cycle = trail[trail_pos[tag_name] :] + [tag_name]
                    cycle_str = " -> ".join(cycle)
                    if cycle_str not in reported_cycles:
                        reported_cycles.add(cycle_str)
                        yield OpenAPIValidationError(
                            f"Circular tag hierarchy detected: {cycle_str}"
                        )
                    break

                trail_pos[tag_name] = len(trail)
                trail.append(tag_name)

                parent = parent_by_tag_name.get(tag_name)
                if parent is None or parent not in seen:
                    break
                tag_name = parent


class RootValidator(KeywordValidator):
    @property
    def paths_validator(self) -> PathsValidator:
        return cast(PathsValidator, self.registry["paths"])

    @property
    def components_validator(self) -> ComponentsValidator:
        return cast(ComponentsValidator, self.registry["components"])

    def __call__(self, spec: SchemaPath) -> Iterator[ValidationError]:
        if "tags" in spec and "tags" in self.registry.keyword_validators:
            tags = spec / "tags"
            tags_validator = cast(Any, self.registry["tags"])
            yield from tags_validator(tags)

        if "paths" in spec:
            paths = spec / "paths"
            yield from self.paths_validator(paths)
        if "components" in spec:
            components = spec / "components"
            yield from self.components_validator(components)
