diff --git a/tests/unit/vertexai/model_garden/test_model_garden.py b/tests/unit/vertexai/model_garden/test_model_garden.py index 3d1db70276..931b7e8f2c 100644 --- a/tests/unit/vertexai/model_garden/test_model_garden.py +++ b/tests/unit/vertexai/model_garden/test_model_garden.py @@ -1406,6 +1406,30 @@ def test_batch_prediction_success(self, batch_prediction_mock): timeout=None, ) + def test_deploy_with_psc_success(self, deploy_mock): + """Tests deploying a model with Private Service Connect.""" + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + model = model_garden.OpenModel(model_name=_TEST_MODEL_FULL_RESOURCE_NAME) + model.deploy( + enable_private_service_connect=True, + psc_project_allow_list=["project-1", "project-2"], + ) + deploy_mock.assert_called_once_with( + types.DeployRequest( + publisher_model_name=_TEST_MODEL_FULL_RESOURCE_NAME, + destination=f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}", + endpoint_config=types.DeployRequest.EndpointConfig( + private_service_connect_config=types.PrivateServiceConnectConfig( + enable_private_service_connect=True, + project_allowlist=["project-1", "project-2"], + ) + ), + ) + ) + def test_check_license_agreement_status_success( self, check_license_agreement_status_mock ): diff --git a/vertexai/model_garden/_model_garden.py b/vertexai/model_garden/_model_garden.py index 69c931c4ff..14b2062c39 100644 --- a/vertexai/model_garden/_model_garden.py +++ b/vertexai/model_garden/_model_garden.py @@ -416,6 +416,8 @@ def deploy( serving_container_health_probe_exec: Optional[Sequence[str]] = None, serving_container_health_probe_period_seconds: Optional[int] = None, serving_container_health_probe_timeout_seconds: Optional[int] = None, + enable_private_service_connect: bool = False, + psc_project_allow_list: Optional[Sequence[str]] = None, ) -> aiplatform.Endpoint: """Deploys an Open Model to an endpoint. @@ -550,6 +552,10 @@ def deploy( serving_container_health_probe_timeout_seconds (int): Optional. Number of seconds after which the health probe times out. Defaults to 1 second. Minimum value is 1. + enable_private_service_connect (bool): Whether to enable private service + connect. + psc_project_allow_list (Sequence[str]): The list of projects that are + allowed to access the endpoint over private service connect. Returns: endpoint (aiplatform.Endpoint): @@ -618,6 +624,14 @@ def deploy( dedicated_endpoint_disabled ) + if enable_private_service_connect and psc_project_allow_list: + request.endpoint_config.private_service_connect_config = ( + types.PrivateServiceConnectConfig( + enable_private_service_connect=enable_private_service_connect, + project_allowlist=psc_project_allow_list, + ) + ) + if fast_tryout_enabled: request.deploy_config.fast_tryout_enabled = fast_tryout_enabled