diff --git a/upath/extensions.py b/upath/extensions.py index a60e7ca8..f4b7f442 100644 --- a/upath/extensions.py +++ b/upath/extensions.py @@ -38,6 +38,9 @@ else: from typing_extensions import Self + from pydantic import GetCoreSchemaHandler + from pydantic_core.core_schema import CoreSchema + __all__ = [ "ProxyUPath", ] @@ -576,5 +579,12 @@ def full_match( ) -> bool: return self.__wrapped__.full_match(pattern, case_sensitive=case_sensitive) + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + cs = UPath.__get_pydantic_core_schema__.__func__ # type: ignore[attr-defined] + return cs(cls, source_type, handler) + UPath.register(ProxyUPath) diff --git a/upath/implementations/local.py b/upath/implementations/local.py index 62f9d472..884199d9 100644 --- a/upath/implementations/local.py +++ b/upath/implementations/local.py @@ -48,6 +48,9 @@ from typing_extensions import Self from typing_extensions import Unpack + from pydantic import GetCoreSchemaHandler + from pydantic_core.core_schema import CoreSchema + from upath.types.storage_options import FileStorageOptions _WT = TypeVar("_WT", bound="WritablePath") @@ -725,6 +728,13 @@ def chmod( ) return super().chmod(mode) + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> CoreSchema: + cs = UPath.__get_pydantic_core_schema__.__func__ # type: ignore[attr-defined] + return cs(cls, source_type, handler) + UPath.register(LocalPath) diff --git a/upath/tests/test_pydantic.py b/upath/tests/test_pydantic.py index 00ac78db..9f318841 100644 --- a/upath/tests/test_pydantic.py +++ b/upath/tests/test_pydantic.py @@ -7,6 +7,12 @@ from fsspec.implementations.http import get_client from upath import UPath +from upath.implementations.local import FilePath +from upath.implementations.local import PosixUPath +from upath.implementations.local import WindowsUPath + +from .utils import only_on_windows +from .utils import skip_on_windows @pytest.mark.parametrize( @@ -113,6 +119,43 @@ def test_dump_non_serializable_json(): ) +def test_proxyupath_serialization(): + from upath.extensions import ProxyUPath + + u = ProxyUPath("memory://my/path", some_option=True) + + ta = pydantic.TypeAdapter(ProxyUPath) + dumped = ta.dump_python(u, mode="python") + loaded = ta.validate_python(dumped) + + assert isinstance(loaded, ProxyUPath) + assert loaded.path == u.path + assert loaded.protocol == u.protocol + assert loaded.storage_options == u.storage_options + + +@pytest.mark.parametrize( + "path,cls", + [ + pytest.param("/my/path", PosixUPath, marks=skip_on_windows(None)), + pytest.param("C:\\my\\path", WindowsUPath, marks=only_on_windows(None)), + ("file:///my/path", FilePath), + ], +) +def test_localpath_serialization(path, cls): + u = UPath(path) + assert type(u) is cls + + ta = pydantic.TypeAdapter(cls) + dumped = ta.dump_python(u, mode="python") + loaded = ta.validate_python(dumped) + + assert isinstance(loaded, cls) + assert loaded.path == u.path + assert loaded.protocol == u.protocol + assert loaded.storage_options == u.storage_options + + def test_json_schema(): ta = pydantic.TypeAdapter(UPath) ta.json_schema()