diff --git a/doc/code/targets/4_openai_video_target.ipynb b/doc/code/targets/4_openai_video_target.ipynb index bad89e0d5..c27bf91e8 100644 --- a/doc/code/targets/4_openai_video_target.ipynb +++ b/doc/code/targets/4_openai_video_target.ipynb @@ -7,11 +7,24 @@ "source": [ "# 4. OpenAI Video Target\n", "\n", - "This example shows how to use the video target to create a video from a text prompt.\n", + "`OpenAIVideoTarget` supports three modes:\n", + "- **Text-to-video**: Generate a video from a text prompt.\n", + "- **Remix**: Create a variation of an existing video (using `video_id` from a prior generation).\n", + "- **Image-to-video**: Use an image as the first frame of the generated video.\n", "\n", "Note that the video scorer requires `opencv`, which is not a default PyRIT dependency. You need to install it manually or using `pip install pyrit[opencv]`." ] }, + { + "cell_type": "markdown", + "id": "0ebc1dc5", + "metadata": {}, + "source": [ + "## Text-to-Video\n", + "\n", + "This example shows the simplest mode: generating video from text prompts, with scoring." + ] + }, { "cell_type": "code", "execution_count": null, @@ -762,6 +775,104 @@ "for result in results:\n", " await ConsoleAttackResultPrinter().print_result_async(result=result, include_auxiliary_scores=True) # type: ignore" ] + }, + { + "cell_type": "markdown", + "id": "e21b0718", + "metadata": {}, + "source": [ + "## Remix (Video Variation)\n", + "\n", + "Remix creates a variation of an existing video. After any successful generation, the response\n", + "includes a `video_id` in `prompt_metadata`. Pass this back via `prompt_metadata={\"video_id\": \"\"}` to remix." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0a29f796", + "metadata": {}, + "outputs": [], + "source": [ + "from pyrit.models import Message, MessagePiece\n", + "\n", + "# Use the same target from above, or create a new one\n", + "remix_target = OpenAIVideoTarget()\n", + "\n", + "# Step 1: Generate a video\n", + "text_piece = MessagePiece(\n", + " role=\"user\",\n", + " original_value=\"A bird flying over a lake at sunset\",\n", + ")\n", + "result = await remix_target.send_prompt_async(message=Message([text_piece])) # type: ignore\n", + "response = result[0].message_pieces[0]\n", + "print(f\"Generated video: {response.converted_value}\")\n", + "video_id = response.prompt_metadata[\"video_id\"]\n", + "print(f\"Video ID for remix: {video_id}\")\n", + "\n", + "# Step 2: Remix using the video_id\n", + "remix_piece = MessagePiece(\n", + " role=\"user\",\n", + " original_value=\"Make it a watercolor painting style\",\n", + " prompt_metadata={\"video_id\": video_id},\n", + ")\n", + "remix_result = await remix_target.send_prompt_async(message=Message([remix_piece])) # type: ignore\n", + "print(f\"Remixed video: {remix_result[0].message_pieces[0].converted_value}\")" + ] + }, + { + "cell_type": "markdown", + "id": "a7f0708b", + "metadata": {}, + "source": [ + "## Image-to-Video\n", + "\n", + "Use an image as the first frame of the generated video. The input image dimensions must match\n", + "the video resolution (e.g. 1280x720). Pass both a text piece and an `image_path` piece in the same message." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b417ec67", + "metadata": {}, + "outputs": [], + "source": [ + "import uuid\n", + "\n", + "# Create a simple test image matching the video resolution (1280x720)\n", + "from PIL import Image\n", + "\n", + "from pyrit.common.path import HOME_PATH\n", + "\n", + "sample_image = HOME_PATH / \"assets\" / \"pyrit_architecture.png\"\n", + "resized = Image.open(sample_image).resize((1280, 720)).convert(\"RGB\")\n", + "\n", + "import tempfile\n", + "\n", + "tmp = tempfile.NamedTemporaryFile(suffix=\".jpg\", delete=False)\n", + "resized.save(tmp, format=\"JPEG\")\n", + "tmp.close()\n", + "image_path = tmp.name\n", + "\n", + "# Send text + image to the video target\n", + "i2v_target = OpenAIVideoTarget()\n", + "conversation_id = str(uuid.uuid4())\n", + "\n", + "text_piece = MessagePiece(\n", + " role=\"user\",\n", + " original_value=\"Animate this image with gentle camera motion\",\n", + " conversation_id=conversation_id,\n", + ")\n", + "image_piece = MessagePiece(\n", + " role=\"user\",\n", + " original_value=image_path,\n", + " converted_value_data_type=\"image_path\",\n", + " conversation_id=conversation_id,\n", + ")\n", + "result = await i2v_target.send_prompt_async(message=Message([text_piece, image_piece])) # type: ignore\n", + "print(f\"Image-to-video result: {result[0].message_pieces[0].converted_value}\")" + ] } ], "metadata": { diff --git a/doc/code/targets/4_openai_video_target.py b/doc/code/targets/4_openai_video_target.py index fb1b4ae70..0182c3a1a 100644 --- a/doc/code/targets/4_openai_video_target.py +++ b/doc/code/targets/4_openai_video_target.py @@ -11,10 +11,18 @@ # %% [markdown] # # 4. OpenAI Video Target # -# This example shows how to use the video target to create a video from a text prompt. +# `OpenAIVideoTarget` supports three modes: +# - **Text-to-video**: Generate a video from a text prompt. +# - **Remix**: Create a variation of an existing video (using `video_id` from a prior generation). +# - **Image-to-video**: Use an image as the first frame of the generated video. # # Note that the video scorer requires `opencv`, which is not a default PyRIT dependency. You need to install it manually or using `pip install pyrit[opencv]`. +# %% [markdown] +# ## Text-to-Video +# +# This example shows the simplest mode: generating video from text prompts, with scoring. + # %% from pyrit.executor.attack import ( AttackExecutor, @@ -65,3 +73,77 @@ for result in results: await ConsoleAttackResultPrinter().print_result_async(result=result, include_auxiliary_scores=True) # type: ignore + +# %% [markdown] +# ## Remix (Video Variation) +# +# Remix creates a variation of an existing video. After any successful generation, the response +# includes a `video_id` in `prompt_metadata`. Pass this back via `prompt_metadata={"video_id": ""}` to remix. + +# %% +from pyrit.models import Message, MessagePiece + +# Use the same target from above, or create a new one +remix_target = OpenAIVideoTarget() + +# Step 1: Generate a video +text_piece = MessagePiece( + role="user", + original_value="A bird flying over a lake at sunset", +) +result = await remix_target.send_prompt_async(message=Message([text_piece])) # type: ignore +response = result[0].message_pieces[0] +print(f"Generated video: {response.converted_value}") +video_id = response.prompt_metadata["video_id"] +print(f"Video ID for remix: {video_id}") + +# Step 2: Remix using the video_id +remix_piece = MessagePiece( + role="user", + original_value="Make it a watercolor painting style", + prompt_metadata={"video_id": video_id}, +) +remix_result = await remix_target.send_prompt_async(message=Message([remix_piece])) # type: ignore +print(f"Remixed video: {remix_result[0].message_pieces[0].converted_value}") + +# %% [markdown] +# ## Image-to-Video +# +# Use an image as the first frame of the generated video. The input image dimensions must match +# the video resolution (e.g. 1280x720). Pass both a text piece and an `image_path` piece in the same message. + +# %% +import uuid + +# Create a simple test image matching the video resolution (1280x720) +from PIL import Image + +from pyrit.common.path import HOME_PATH + +sample_image = HOME_PATH / "assets" / "pyrit_architecture.png" +resized = Image.open(sample_image).resize((1280, 720)).convert("RGB") + +import tempfile + +tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) +resized.save(tmp, format="JPEG") +tmp.close() +image_path = tmp.name + +# Send text + image to the video target +i2v_target = OpenAIVideoTarget() +conversation_id = str(uuid.uuid4()) + +text_piece = MessagePiece( + role="user", + original_value="Animate this image with gentle camera motion", + conversation_id=conversation_id, +) +image_piece = MessagePiece( + role="user", + original_value=image_path, + converted_value_data_type="image_path", + conversation_id=conversation_id, +) +result = await i2v_target.send_prompt_async(message=Message([text_piece, image_piece])) # type: ignore +print(f"Image-to-video result: {result[0].message_pieces[0].converted_value}") diff --git a/pyrit/models/message.py b/pyrit/models/message.py index 4c8c6e334..07ccc8b59 100644 --- a/pyrit/models/message.py +++ b/pyrit/models/message.py @@ -51,6 +51,30 @@ def get_piece(self, n: int = 0) -> MessagePiece: return self.message_pieces[n] + def get_pieces_by_type(self, *, data_type: PromptDataType) -> list[MessagePiece]: + """ + Return all message pieces matching the given data type. + + Args: + data_type: The converted_value_data_type to filter by. + + Returns: + A list of matching MessagePiece objects (may be empty). + """ + return [p for p in self.message_pieces if p.converted_value_data_type == data_type] + + def get_piece_by_type(self, *, data_type: PromptDataType) -> Optional[MessagePiece]: + """ + Return the first message piece matching the given data type, or None. + + Args: + data_type: The converted_value_data_type to filter by. + + Returns: + The first matching MessagePiece, or None if no match is found. + """ + return next((p for p in self.message_pieces if p.converted_value_data_type == data_type), None) + @property def api_role(self) -> ChatMessageRole: """ diff --git a/pyrit/prompt_target/openai/openai_video_target.py b/pyrit/prompt_target/openai/openai_video_target.py index f6915c027..3e4ebfad0 100644 --- a/pyrit/prompt_target/openai/openai_video_target.py +++ b/pyrit/prompt_target/openai/openai_video_target.py @@ -2,13 +2,16 @@ # Licensed under the MIT license. import logging -from typing import Any +import os +from typing import Any, Optional + +from openai.types import VideoSeconds, VideoSize from pyrit.exceptions import ( pyrit_target_retry, ) -from pyrit.identifiers import TargetIdentifier from pyrit.models import ( + DataTypeSerializer, Message, MessagePiece, construct_response_from_request, @@ -27,6 +30,11 @@ class OpenAIVideoTarget(OpenAITarget): Supports Sora-2 and Sora-2-Pro models via the OpenAI videos API. + Supports three modes: + - Text-to-video: Generate video from a text prompt + - Image-to-video: Generate video using an image as the first frame (include image_path piece) + - Remix: Create variation of existing video (include video_id in prompt_metadata) + Supported resolutions: - Sora-2: 720x1280, 1280x720 - Sora-2-Pro: 720x1280, 1280x720, 1024x1792, 1792x1024 @@ -34,16 +42,18 @@ class OpenAIVideoTarget(OpenAITarget): Supported durations: 4, 8, or 12 seconds Default: resolution="1280x720", duration=4 seconds + + Supported image formats for image-to-video: JPEG, PNG, WEBP """ - SUPPORTED_RESOLUTIONS = ["720x1280", "1280x720", "1024x1792", "1792x1024"] - SUPPORTED_DURATIONS = [4, 8, 12] + SUPPORTED_RESOLUTIONS: list[VideoSize] = ["720x1280", "1280x720", "1024x1792", "1792x1024"] + SUPPORTED_DURATIONS: list[VideoSeconds] = ["4", "8", "12"] def __init__( self, *, - resolution_dimensions: str = "1280x720", - n_seconds: int = 4, + resolution_dimensions: VideoSize = "1280x720", + n_seconds: int | VideoSeconds = 4, **kwargs: Any, ) -> None: """ @@ -61,22 +71,28 @@ def __init__( headers (str, Optional): Extra headers of the endpoint (JSON). max_requests_per_minute (int, Optional): Number of requests the target can handle per minute before hitting a rate limit. - resolution_dimensions (str, Optional): Resolution dimensions for the video in WIDTHxHEIGHT format. + resolution_dimensions (VideoSize, Optional): Resolution dimensions for the video. Defaults to "1280x720". Supported resolutions: - Sora-2: "720x1280", "1280x720" - Sora-2-Pro: "720x1280", "1280x720", "1024x1792", "1792x1024" - n_seconds (int, Optional): The duration of the generated video (in seconds). - Defaults to 4. Supported values: 4, 8, or 12 seconds. + n_seconds (int | VideoSeconds, Optional): The duration of the generated video. + Accepts an int (4, 8, 12) or a VideoSeconds string ("4", "8", "12"). + Defaults to 4. **kwargs: Additional keyword arguments passed to the parent OpenAITarget class. httpx_client_kwargs (dict, Optional): Additional kwargs to be passed to the ``httpx.AsyncClient()`` constructor. For example, to specify a 3 minute timeout: ``httpx_client_kwargs={"timeout": 180}`` + + Remix workflow: + To remix an existing video, set ``prompt_metadata={"video_id": ""}`` on the text + MessagePiece. The video_id is returned in the response metadata after any successful + generation (``response.message_pieces[0].prompt_metadata["video_id"]``). """ super().__init__(**kwargs) - self._n_seconds = n_seconds + self._n_seconds: VideoSeconds = str(n_seconds) if isinstance(n_seconds, int) else n_seconds self._validate_duration() - self._size = self._validate_resolution(resolution_dimensions=resolution_dimensions) + self._size: VideoSize = self._validate_resolution(resolution_dimensions=resolution_dimensions) def _set_openai_env_configuration_vars(self) -> None: """Set environment variable names.""" @@ -96,21 +112,7 @@ def _get_provider_examples(self) -> dict[str, str]: "api.openai.com": "https://api.openai.com/v1", } - def _build_identifier(self) -> TargetIdentifier: - """ - Build the identifier with video generation-specific parameters. - - Returns: - TargetIdentifier: The identifier for this target instance. - """ - return self._create_identifier( - target_specific_params={ - "resolution": self._size, - "n_seconds": self._n_seconds, - }, - ) - - def _validate_resolution(self, *, resolution_dimensions: str) -> str: + def _validate_resolution(self, *, resolution_dimensions: VideoSize) -> VideoSize: """ Validate resolution dimensions. @@ -139,8 +141,8 @@ def _validate_duration(self) -> None: """ if self._n_seconds not in self.SUPPORTED_DURATIONS: raise ValueError( - f"Invalid duration {self._n_seconds}s. " - f"Supported durations: {', '.join(map(str, self.SUPPORTED_DURATIONS))} seconds" + f"Invalid duration '{self._n_seconds}'. " + f"Supported durations: {', '.join(self.SUPPORTED_DURATIONS)} seconds" ) @limit_requests_per_minute @@ -149,33 +151,151 @@ async def send_prompt_async(self, *, message: Message) -> list[Message]: """ Asynchronously sends a message and generates a video using the OpenAI SDK. + Supports three modes: + - Text-to-video: Single text piece + - Image-to-video: Text piece + image_path piece (image becomes first frame) + - Remix: Text piece with prompt_metadata["video_id"] set to an existing video ID + Args: - message (Message): The message object containing the prompt. + message: The message object containing the prompt. Returns: - list[Message]: A list containing the response with the generated video path. + A list containing the response with the generated video path. Raises: RateLimitException: If the rate limit is exceeded. ValueError: If the request is invalid. """ self._validate_request(message=message) - message_piece = message.message_pieces[0] - prompt = message_piece.converted_value + + text_piece = message.get_piece_by_type(data_type="text") + image_piece = message.get_piece_by_type(data_type="image_path") + prompt = text_piece.converted_value + + # Check for remix mode via prompt_metadata + remix_video_id = text_piece.prompt_metadata.get("video_id") if text_piece.prompt_metadata else None logger.info(f"Sending video generation prompt: {prompt}") - # Use unified error handler - automatically detects Video and validates - response = await self._handle_openai_request( + if remix_video_id: + response = await self._send_remix_async(video_id=remix_video_id, prompt=prompt, request=message) + elif image_piece: + response = await self._send_image_to_video_async(image_piece=image_piece, prompt=prompt, request=message) + else: + response = await self._send_text_to_video_async(prompt=prompt, request=message) + + return [response] + + async def _send_remix_async(self, *, video_id: str, prompt: str, request: Message) -> Message: + """ + Send a remix request for an existing video. + + Args: + video_id: The ID of the completed video to remix. + prompt: The text prompt directing the remix. + request: The original request message. + + Returns: + The response Message with the generated video path. + """ + logger.info(f"Remix mode: Creating variation of video {video_id}") + return await self._handle_openai_request( + api_call=lambda: self._remix_and_poll_async(video_id=video_id, prompt=prompt), + request=request, + ) + + async def _send_image_to_video_async(self, *, image_piece: MessagePiece, prompt: str, request: Message) -> Message: + """ + Send an image-to-video request using an image as the first frame. + + Args: + image_piece: The MessagePiece containing the image path. + prompt: The text prompt describing the desired video. + request: The original request message. + + Returns: + The response Message with the generated video path. + """ + logger.info("Image-to-video mode: Using image as first frame") + input_file = await self._prepare_image_input_async(image_piece=image_piece) + return await self._handle_openai_request( api_call=lambda: self._async_client.videos.create_and_poll( model=self._model_name, prompt=prompt, - size=self._size, # type: ignore[arg-type] - seconds=str(self._n_seconds), # type: ignore[arg-type] + size=self._size, + seconds=self._n_seconds, + input_reference=input_file, ), - request=message, + request=request, ) - return [response] + + async def _send_text_to_video_async(self, *, prompt: str, request: Message) -> Message: + """ + Send a text-to-video generation request. + + Args: + prompt: The text prompt describing the desired video. + request: The original request message. + + Returns: + The response Message with the generated video path. + """ + return await self._handle_openai_request( + api_call=lambda: self._async_client.videos.create_and_poll( + model=self._model_name, + prompt=prompt, + size=self._size, + seconds=self._n_seconds, + ), + request=request, + ) + + async def _prepare_image_input_async(self, *, image_piece: MessagePiece) -> tuple[str, bytes, str]: + """ + Prepare image data for the OpenAI video API input_reference parameter. + + Reads the image bytes from storage and determines the MIME type. + + Args: + image_piece: The MessagePiece containing the image path. + + Returns: + A tuple of (filename, image_bytes, mime_type) for the SDK. + """ + image_path = image_piece.converted_value + image_serializer = data_serializer_factory( + value=image_path, data_type="image_path", category="prompt-memory-entries" + ) + image_bytes = await image_serializer.read_data() + + mime_type = DataTypeSerializer.get_mime_type(image_path) + if not mime_type: + mime_type = "image/png" + + filename = os.path.basename(image_path) + return (filename, image_bytes, mime_type) + + async def _remix_and_poll_async(self, *, video_id: str, prompt: str) -> Any: + """ + Create a remix of an existing video and poll until complete. + + The OpenAI SDK's remix() method returns immediately with a job status. + This method polls until the job completes or fails. + + Args: + video_id: The ID of the completed video to remix. + prompt: The text prompt directing the remix. + + Returns: + The completed Video object from the OpenAI SDK. + """ + video = await self._async_client.videos.remix(video_id, prompt=prompt) + + # Poll until completion if not already done + if video.status not in ["completed", "failed"]: + video = await self._async_client.videos.poll(video.id) + + return video def _check_content_filter(self, response: Any) -> bool: """ @@ -218,13 +338,17 @@ async def _construct_message_from_response(self, response: Any, request: Any) -> if video.status == "completed": logger.info(f"Video generation completed successfully: {video.id}") + # Log remix metadata if available + if hasattr(video, "remixed_from_video_id") and video.remixed_from_video_id: + logger.info(f"Video was remixed from: {video.remixed_from_video_id}") + # Download video content using SDK video_response = await self._async_client.videos.download_content(video.id) # Extract bytes from HttpxBinaryResponseContent video_content = video_response.content - # Save the video to storage - return await self._save_video_response(request=request, video_data=video_content) + # Save the video to storage (include video.id for chaining remixes) + return await self._save_video_response(request=request, video_data=video_content, video_id=video.id) elif video.status == "failed": # Handle failed video generation (non-content-filter) @@ -249,13 +373,16 @@ async def _construct_message_from_response(self, response: Any, request: Any) -> error="unknown", ) - async def _save_video_response(self, *, request: MessagePiece, video_data: bytes) -> Message: + async def _save_video_response( + self, *, request: MessagePiece, video_data: bytes, video_id: Optional[str] = None + ) -> Message: """ Save video data to storage and construct response. Args: request: The original request message piece. video_data: The video content as bytes. + video_id: The video ID from the API (stored in metadata for chaining remixes). Returns: Message: The response with the video file path. @@ -267,11 +394,15 @@ async def _save_video_response(self, *, request: MessagePiece, video_data: bytes logger.info(f"Video saved to: {video_path}") + # Include video_id in metadata for chaining (e.g., remix the generated video later) + prompt_metadata = {"video_id": video_id} if video_id else None + # Construct response response_entry = construct_response_from_request( request=request, response_text_pieces=[video_path], response_type="video_path", + prompt_metadata=prompt_metadata, ) return response_entry @@ -280,19 +411,42 @@ def _validate_request(self, *, message: Message) -> None: """ Validate the request message. + Accepts: + - Single text piece (text-to-video or remix mode) + - Text piece + image_path piece (image-to-video mode) + Args: message: The message to validate. Raises: ValueError: If the request is invalid. """ - n_pieces = len(message.message_pieces) - if n_pieces != 1: - raise ValueError(f"This target only supports a single message piece. Received: {n_pieces} pieces.") - - piece_type = message.message_pieces[0].converted_value_data_type - if piece_type != "text": - raise ValueError(f"This target only supports text prompt input. Received: {piece_type}.") + text_pieces = message.get_pieces_by_type(data_type="text") + image_pieces = message.get_pieces_by_type(data_type="image_path") + + # Check for unsupported types + supported_count = len(text_pieces) + len(image_pieces) + if supported_count != len(message.message_pieces): + other_types = [ + p.converted_value_data_type + for p in message.message_pieces + if p.converted_value_data_type not in ("text", "image_path") + ] + raise ValueError(f"Unsupported piece types: {other_types}. Only 'text' and 'image_path' are supported.") + + # Must have exactly one text piece + if len(text_pieces) != 1: + raise ValueError(f"Expected exactly 1 text piece, got {len(text_pieces)}.") + + # At most one image piece + if len(image_pieces) > 1: + raise ValueError(f"Expected at most 1 image piece, got {len(image_pieces)}.") + + # Check for conflicting modes: remix + image + text_piece = text_pieces[0] + remix_video_id = text_piece.prompt_metadata.get("video_id") if text_piece.prompt_metadata else None + if remix_video_id and image_pieces: + raise ValueError("Cannot use image input in remix mode. Remix uses existing video as reference.") def is_json_response_supported(self) -> bool: """ diff --git a/tests/integration/targets/test_entra_auth_targets.py b/tests/integration/targets/test_entra_auth_targets.py index 82dd17793..19ba564aa 100644 --- a/tests/integration/targets/test_entra_auth_targets.py +++ b/tests/integration/targets/test_entra_auth_targets.py @@ -275,6 +275,40 @@ async def test_video_target_entra_auth(sqlite_instance): assert result.last_response is not None +@pytest.mark.asyncio +async def test_video_target_remix_entra_auth(sqlite_instance): + """Test video remix mode with Entra authentication.""" + endpoint = os.environ["OPENAI_VIDEO2_ENDPOINT"] + target = OpenAIVideoTarget( + endpoint=endpoint, + model_name=os.environ["OPENAI_VIDEO2_MODEL"], + api_key=get_azure_openai_auth(endpoint), + n_seconds=4, + ) + + # Generate initial video + text_piece = MessagePiece( + role="user", + original_value="A bird flying over a lake", + converted_value="A bird flying over a lake", + ) + result = await target.send_prompt_async(message=Message([text_piece])) + response_piece = result[0].message_pieces[0] + assert response_piece.response_error == "none" + video_id = response_piece.prompt_metadata.get("video_id") + assert video_id + + # Remix + remix_piece = MessagePiece( + role="user", + original_value="Add a sunset", + converted_value="Add a sunset", + prompt_metadata={"video_id": video_id}, + ) + remix_result = await target.send_prompt_async(message=Message([remix_piece])) + assert remix_result[0].message_pieces[0].response_error == "none" + + @pytest.mark.asyncio async def test_prompt_shield_target_entra_auth(sqlite_instance): # Make sure to assign the Cognitive Services User or Contributor role diff --git a/tests/integration/targets/test_targets_and_secrets.py b/tests/integration/targets/test_targets_and_secrets.py index 31a3a9851..cb9f55978 100644 --- a/tests/integration/targets/test_targets_and_secrets.py +++ b/tests/integration/targets/test_targets_and_secrets.py @@ -7,7 +7,6 @@ import pytest -from pyrit.common.path import HOME_PATH from pyrit.executor.attack import AttackExecutor, PromptSendingAttack from pyrit.models import Message, MessagePiece from pyrit.prompt_target import ( @@ -329,111 +328,6 @@ async def test_connect_image(sqlite_instance, endpoint, api_key, model_name): assert image_path.is_file(), f"Path exists but is not a file: {image_path}" -# Path to sample image file for image editing tests -SAMPLE_IMAGE_FILE = HOME_PATH / "assets" / "pyrit_architecture.png" - - -@pytest.mark.asyncio -async def test_image_editing_single_image_api_key(sqlite_instance): - """ - Test image editing with a single image input using API key authentication. - Uses gpt-image-1 which supports image editing/remix. - - Verifies that: - 1. A text prompt + single image generates a modified image - 2. The edit endpoint is correctly called - 3. The output image file is created - """ - endpoint_value = _get_required_env_var("OPENAI_IMAGE_ENDPOINT2") - api_key_value = _get_required_env_var("OPENAI_IMAGE_API_KEY2") - model_name_value = os.getenv("OPENAI_IMAGE_MODEL2") or "gpt-image-1" - - target = OpenAIImageTarget( - endpoint=endpoint_value, - api_key=api_key_value, - model_name=model_name_value, - ) - - conv_id = str(uuid.uuid4()) - text_piece = MessagePiece( - role="user", - original_value="Add a red border around this image", - original_value_data_type="text", - conversation_id=conv_id, - ) - image_piece = MessagePiece( - role="user", - original_value=str(SAMPLE_IMAGE_FILE), - original_value_data_type="image_path", - conversation_id=conv_id, - ) - - message = Message(message_pieces=[text_piece, image_piece]) - result = await target.send_prompt_async(message=message) - - assert result is not None - assert len(result) >= 1 - assert result[0].message_pieces[0].response_error == "none" - - # Validate we got a valid image file path - output_path = Path(result[0].message_pieces[0].converted_value) - assert output_path.exists(), f"Output image file not found at path: {output_path}" - assert output_path.is_file(), f"Path exists but is not a file: {output_path}" - - -@pytest.mark.asyncio -async def test_image_editing_multiple_images_api_key(sqlite_instance): - """ - Test image editing with multiple image inputs using API key authentication. - Uses gpt-image-1 which supports 1-16 image inputs. - - Verifies that: - 1. Multiple images can be passed to the edit endpoint - 2. The model processes multiple image inputs correctly - """ - endpoint_value = _get_required_env_var("OPENAI_IMAGE_ENDPOINT2") - api_key_value = _get_required_env_var("OPENAI_IMAGE_API_KEY2") - model_name_value = os.getenv("OPENAI_IMAGE_MODEL2") or "gpt-image-1" - - target = OpenAIImageTarget( - endpoint=endpoint_value, - api_key=api_key_value, - model_name=model_name_value, - ) - - conv_id = str(uuid.uuid4()) - text_piece = MessagePiece( - role="user", - original_value="Combine these images into one", - original_value_data_type="text", - conversation_id=conv_id, - ) - image_piece1 = MessagePiece( - role="user", - original_value=str(SAMPLE_IMAGE_FILE), - original_value_data_type="image_path", - conversation_id=conv_id, - ) - image_piece2 = MessagePiece( - role="user", - original_value=str(SAMPLE_IMAGE_FILE), - original_value_data_type="image_path", - conversation_id=conv_id, - ) - - message = Message(message_pieces=[text_piece, image_piece1, image_piece2]) - result = await target.send_prompt_async(message=message) - - assert result is not None - assert len(result) >= 1 - assert result[0].message_pieces[0].response_error == "none" - - # Validate we got a valid image file path - output_path = Path(result[0].message_pieces[0].converted_value) - assert output_path.exists(), f"Output image file not found at path: {output_path}" - assert output_path.is_file(), f"Path exists but is not a file: {output_path}" - - @pytest.mark.asyncio @pytest.mark.parametrize( ("endpoint", "api_key", "model_name"), @@ -551,6 +445,107 @@ async def test_video_multiple_prompts_create_separate_files(sqlite_instance): ) +@pytest.mark.asyncio +async def test_video_remix_chain(sqlite_instance): + """Test text-to-video followed by remix using the returned video_id.""" + endpoint_value = _get_required_env_var("OPENAI_VIDEO2_ENDPOINT") + api_key_value = _get_required_env_var("OPENAI_VIDEO2_KEY") + model_name_value = _get_required_env_var("OPENAI_VIDEO2_MODEL") + + target = OpenAIVideoTarget( + endpoint=endpoint_value, + api_key=api_key_value, + model_name=model_name_value, + resolution_dimensions="1280x720", + n_seconds=4, + ) + + # Step 1: Generate initial video + text_piece = MessagePiece( + role="user", + original_value="A cat sitting on a windowsill", + converted_value="A cat sitting on a windowsill", + ) + result = await target.send_prompt_async(message=Message([text_piece])) + assert len(result) == 1 + response_piece = result[0].message_pieces[0] + assert response_piece.response_error == "none" + assert response_piece.prompt_metadata is not None + video_id = response_piece.prompt_metadata.get("video_id") + assert video_id, "Response must include video_id in prompt_metadata for chaining" + + # Step 2: Remix using the returned video_id + remix_piece = MessagePiece( + role="user", + original_value="Make it a watercolor painting style", + converted_value="Make it a watercolor painting style", + prompt_metadata={"video_id": video_id}, + ) + remix_result = await target.send_prompt_async(message=Message([remix_piece])) + assert len(remix_result) == 1 + remix_response = remix_result[0].message_pieces[0] + assert remix_response.response_error == "none" + + remix_path = Path(remix_response.converted_value) + assert remix_path.exists(), f"Remixed video file not found: {remix_path}" + assert remix_path.is_file() + + +@pytest.mark.asyncio +async def test_video_image_to_video(sqlite_instance): + """Test image-to-video mode using an image as the first frame.""" + endpoint_value = _get_required_env_var("OPENAI_VIDEO2_ENDPOINT") + api_key_value = _get_required_env_var("OPENAI_VIDEO2_KEY") + model_name_value = _get_required_env_var("OPENAI_VIDEO2_MODEL") + + target = OpenAIVideoTarget( + endpoint=endpoint_value, + api_key=api_key_value, + model_name=model_name_value, + resolution_dimensions="1280x720", + n_seconds=4, + ) + + # Prepare an image matching the video resolution (API requires exact match). + # Resize a sample image to 1280x720 and save as a temporary JPEG. + from PIL import Image + + from pyrit.common.path import HOME_PATH + + sample_image = HOME_PATH / "assets" / "pyrit_architecture.png" + resized = Image.open(sample_image).resize((1280, 720)).convert("RGB") + import tempfile + + tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) + resized.save(tmp, format="JPEG") + tmp.close() + image_path = tmp.name + + # Use the image for image-to-video + conversation_id = str(uuid.uuid4()) + text_piece = MessagePiece( + role="user", + original_value="Animate this image with gentle motion", + converted_value="Animate this image with gentle motion", + conversation_id=conversation_id, + ) + image_piece = MessagePiece( + role="user", + original_value=image_path, + converted_value=image_path, + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + result = await target.send_prompt_async(message=Message([text_piece, image_piece])) + assert len(result) == 1 + response_piece = result[0].message_pieces[0] + assert response_piece.response_error == "none", f"Image-to-video failed: {response_piece.converted_value}" + + video_path = Path(response_piece.converted_value) + assert video_path.exists(), f"Video file not found: {video_path}" + assert video_path.is_file() + + ################################################## # Optional tests - not run in pipeline, only locally # Need RUN_ALL_TESTS=true environment variable to run diff --git a/tests/unit/models/test_message.py b/tests/unit/models/test_message.py index 01bbf4fe6..c94a733ab 100644 --- a/tests/unit/models/test_message.py +++ b/tests/unit/models/test_message.py @@ -61,6 +61,49 @@ def test_get_piece_raises_value_error_for_empty_request() -> None: Message(message_pieces=[]) +def test_get_pieces_by_type_returns_matching_pieces() -> None: + conversation_id = "test-conv" + text_piece = MessagePiece( + role="user", original_value="hello", converted_value="hello", conversation_id=conversation_id + ) + image_piece = MessagePiece( + role="user", + original_value="/img.png", + converted_value="/img.png", + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + msg = Message([text_piece, image_piece]) + + result = msg.get_pieces_by_type(data_type="text") + assert len(result) == 1 + assert result[0] is text_piece + + result = msg.get_pieces_by_type(data_type="image_path") + assert len(result) == 1 + assert result[0] is image_piece + + +def test_get_pieces_by_type_returns_empty_for_no_match() -> None: + piece = MessagePiece(role="user", original_value="hello", converted_value="hello") + msg = Message([piece]) + assert msg.get_pieces_by_type(data_type="image_path") == [] + + +def test_get_piece_by_type_returns_first_match() -> None: + conversation_id = "test-conv" + text1 = MessagePiece(role="user", original_value="a", converted_value="a", conversation_id=conversation_id) + text2 = MessagePiece(role="user", original_value="b", converted_value="b", conversation_id=conversation_id) + msg = Message([text1, text2]) + assert msg.get_piece_by_type(data_type="text") is text1 + + +def test_get_piece_by_type_returns_none_for_no_match() -> None: + piece = MessagePiece(role="user", original_value="hello", converted_value="hello") + msg = Message([piece]) + assert msg.get_piece_by_type(data_type="image_path") is None + + def test_get_all_values_returns_all_converted_strings(message_pieces: list[MessagePiece]) -> None: response_one = Message(message_pieces=message_pieces[:2]) response_two = Message(message_pieces=message_pieces[2:]) diff --git a/tests/unit/target/test_video_target.py b/tests/unit/target/test_video_target.py index dbf16e6bc..a17835f57 100644 --- a/tests/unit/target/test_video_target.py +++ b/tests/unit/target/test_video_target.py @@ -54,8 +54,9 @@ def test_video_initialization_invalid_duration(patch_central_database): ) -def test_video_validate_request_length(video_target: OpenAIVideoTarget): - with pytest.raises(ValueError, match="single message piece"): +def test_video_validate_request_multiple_text_pieces(video_target: OpenAIVideoTarget): + """Test validation rejects multiple text pieces.""" + with pytest.raises(ValueError, match="Expected exactly 1 text piece"): conversation_id = str(uuid.uuid4()) msg1 = MessagePiece( role="user", original_value="test1", converted_value="test1", conversation_id=conversation_id @@ -66,8 +67,9 @@ def test_video_validate_request_length(video_target: OpenAIVideoTarget): video_target._validate_request(message=Message([msg1, msg2])) -def test_video_validate_prompt_type(video_target: OpenAIVideoTarget): - with pytest.raises(ValueError, match="text prompt input"): +def test_video_validate_prompt_type_image_only(video_target: OpenAIVideoTarget): + """Test validation rejects image-only input (must have text).""" + with pytest.raises(ValueError, match="Expected exactly 1 text piece"): msg = MessagePiece( role="user", original_value="test", converted_value="test", converted_value_data_type="image_path" ) @@ -348,3 +350,528 @@ def test_check_content_filter_no_error_object(video_target: OpenAIVideoTarget): mock_video.error = None assert video_target._check_content_filter(mock_video) is False + + +# Tests for image-to-video and remix features + + +class TestVideoTargetValidation: + """Tests for video target validation with new features.""" + + def test_validate_accepts_text_only(self, video_target: OpenAIVideoTarget): + """Test validation accepts single text piece (text-to-video mode).""" + msg = MessagePiece(role="user", original_value="test prompt", converted_value="test prompt") + # Should not raise + video_target._validate_request(message=Message([msg])) + + def test_validate_accepts_text_and_image(self, video_target: OpenAIVideoTarget): + """Test validation accepts text + image (image-to-video mode).""" + conversation_id = str(uuid.uuid4()) + msg_text = MessagePiece( + role="user", + original_value="animate this", + converted_value="animate this", + conversation_id=conversation_id, + ) + msg_image = MessagePiece( + role="user", + original_value="/path/image.png", + converted_value="/path/image.png", + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + # Should not raise + video_target._validate_request(message=Message([msg_text, msg_image])) + + def test_validate_rejects_multiple_images(self, video_target: OpenAIVideoTarget): + """Test validation rejects multiple image pieces.""" + conversation_id = str(uuid.uuid4()) + msg_text = MessagePiece( + role="user", + original_value="animate", + converted_value="animate", + conversation_id=conversation_id, + ) + msg_img1 = MessagePiece( + role="user", + original_value="/path/img1.png", + converted_value="/path/img1.png", + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + msg_img2 = MessagePiece( + role="user", + original_value="/path/img2.png", + converted_value="/path/img2.png", + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + with pytest.raises(ValueError, match="at most 1 image piece"): + video_target._validate_request(message=Message([msg_text, msg_img1, msg_img2])) + + def test_validate_rejects_unsupported_types(self, video_target: OpenAIVideoTarget): + """Test validation rejects unsupported data types.""" + conversation_id = str(uuid.uuid4()) + msg_text = MessagePiece( + role="user", + original_value="test", + converted_value="test", + conversation_id=conversation_id, + ) + msg_audio = MessagePiece( + role="user", + original_value="/path/audio.wav", + converted_value="/path/audio.wav", + converted_value_data_type="audio_path", + conversation_id=conversation_id, + ) + with pytest.raises(ValueError, match="Unsupported piece types"): + video_target._validate_request(message=Message([msg_text, msg_audio])) + + def test_validate_rejects_remix_with_image(self, video_target: OpenAIVideoTarget): + """Test validation rejects remix mode combined with image input.""" + conversation_id = str(uuid.uuid4()) + msg_text = MessagePiece( + role="user", + original_value="remix prompt", + converted_value="remix prompt", + prompt_metadata={"video_id": "existing_video_123"}, + conversation_id=conversation_id, + ) + msg_image = MessagePiece( + role="user", + original_value="/path/image.png", + converted_value="/path/image.png", + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + with pytest.raises(ValueError, match="Cannot use image input in remix mode"): + video_target._validate_request(message=Message([msg_text, msg_image])) + + +@pytest.mark.usefixtures("patch_central_database") +class TestVideoTargetImageToVideo: + """Tests for image-to-video functionality.""" + + @pytest.fixture + def video_target(self) -> OpenAIVideoTarget: + return OpenAIVideoTarget( + endpoint="https://api.openai.com/v1", + api_key="test", + model_name="sora-2", + ) + + @pytest.mark.asyncio + async def test_image_to_video_calls_create_with_input_reference(self, video_target: OpenAIVideoTarget): + """Test that image-to-video mode passes input_reference to create_and_poll.""" + conversation_id = str(uuid.uuid4()) + msg_text = MessagePiece( + role="user", + original_value="animate this image", + converted_value="animate this image", + conversation_id=conversation_id, + ) + msg_image = MessagePiece( + role="user", + original_value="/path/image.png", + converted_value="/path/image.png", + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + + mock_video = MagicMock() + mock_video.id = "video_img2vid" + mock_video.status = "completed" + mock_video.error = None + mock_video.remixed_from_video_id = None + + mock_video_response = MagicMock() + mock_video_response.content = b"video data" + + mock_serializer = MagicMock() + mock_serializer.value = "/path/to/output.mp4" + mock_serializer.save_data = AsyncMock() + + mock_image_serializer = MagicMock() + mock_image_serializer.read_data = AsyncMock(return_value=b"image bytes") + + with ( + patch.object(video_target._async_client.videos, "create_and_poll", new_callable=AsyncMock) as mock_create, + patch.object( + video_target._async_client.videos, "download_content", new_callable=AsyncMock + ) as mock_download, + patch("pyrit.prompt_target.openai.openai_video_target.data_serializer_factory") as mock_factory, + patch("pyrit.prompt_target.openai.openai_video_target.DataTypeSerializer.get_mime_type") as mock_mime, + ): + # First call returns image serializer, second call returns video serializer + mock_factory.side_effect = [mock_image_serializer, mock_serializer] + mock_create.return_value = mock_video + mock_download.return_value = mock_video_response + mock_mime.return_value = "image/png" + + response = await video_target.send_prompt_async(message=Message([msg_text, msg_image])) + + # Verify create_and_poll was called with input_reference as tuple with MIME type + mock_create.assert_called_once() + call_kwargs = mock_create.call_args.kwargs + # input_reference should be (filename, bytes, content_type) tuple + input_ref = call_kwargs["input_reference"] + assert isinstance(input_ref, tuple) + assert input_ref[0] == "image.png" # filename + assert input_ref[1] == b"image bytes" # content + assert input_ref[2] == "image/png" # MIME type + assert call_kwargs["prompt"] == "animate this image" + + # Verify response + assert len(response) == 1 + assert response[0].message_pieces[0].converted_value_data_type == "video_path" + + +@pytest.mark.usefixtures("patch_central_database") +class TestVideoTargetRemix: + """Tests for video remix functionality.""" + + @pytest.fixture + def video_target(self) -> OpenAIVideoTarget: + return OpenAIVideoTarget( + endpoint="https://api.openai.com/v1", + api_key="test", + model_name="sora-2", + ) + + @pytest.mark.asyncio + async def test_remix_calls_remix_and_poll(self, video_target: OpenAIVideoTarget): + """Test that remix mode calls remix() and poll().""" + msg = MessagePiece( + role="user", + original_value="make it more dramatic", + converted_value="make it more dramatic", + prompt_metadata={"video_id": "existing_video_123"}, + conversation_id=str(uuid.uuid4()), + ) + + mock_remix_video = MagicMock() + mock_remix_video.id = "remixed_video_456" + mock_remix_video.status = "in_progress" + + mock_polled_video = MagicMock() + mock_polled_video.id = "remixed_video_456" + mock_polled_video.status = "completed" + mock_polled_video.error = None + mock_polled_video.remixed_from_video_id = "existing_video_123" + + mock_video_response = MagicMock() + mock_video_response.content = b"remixed video data" + + mock_serializer = MagicMock() + mock_serializer.value = "/path/to/remixed.mp4" + mock_serializer.save_data = AsyncMock() + + with ( + patch.object(video_target._async_client.videos, "remix", new_callable=AsyncMock) as mock_remix, + patch.object(video_target._async_client.videos, "poll", new_callable=AsyncMock) as mock_poll, + patch.object( + video_target._async_client.videos, "download_content", new_callable=AsyncMock + ) as mock_download, + patch("pyrit.prompt_target.openai.openai_video_target.data_serializer_factory") as mock_factory, + ): + mock_remix.return_value = mock_remix_video + mock_poll.return_value = mock_polled_video + mock_download.return_value = mock_video_response + mock_factory.return_value = mock_serializer + + response = await video_target.send_prompt_async(message=Message([msg])) + + # Verify remix was called with correct params + mock_remix.assert_called_once_with("existing_video_123", prompt="make it more dramatic") + # Verify poll was called (since status was in_progress) + mock_poll.assert_called_once_with("remixed_video_456") + + # Verify response + assert len(response) == 1 + assert response[0].message_pieces[0].converted_value_data_type == "video_path" + + @pytest.mark.asyncio + async def test_remix_skips_poll_if_completed(self, video_target: OpenAIVideoTarget): + """Test that remix mode skips poll() if already completed.""" + msg = MessagePiece( + role="user", + original_value="remix prompt", + converted_value="remix prompt", + prompt_metadata={"video_id": "existing_video_123"}, + conversation_id=str(uuid.uuid4()), + ) + + mock_video = MagicMock() + mock_video.id = "remixed_video" + mock_video.status = "completed" + mock_video.error = None + mock_video.remixed_from_video_id = "existing_video_123" + + mock_video_response = MagicMock() + mock_video_response.content = b"remixed video data" + + mock_serializer = MagicMock() + mock_serializer.value = "/path/to/remixed.mp4" + mock_serializer.save_data = AsyncMock() + + with ( + patch.object(video_target._async_client.videos, "remix", new_callable=AsyncMock) as mock_remix, + patch.object(video_target._async_client.videos, "poll", new_callable=AsyncMock) as mock_poll, + patch.object( + video_target._async_client.videos, "download_content", new_callable=AsyncMock + ) as mock_download, + patch("pyrit.prompt_target.openai.openai_video_target.data_serializer_factory") as mock_factory, + ): + mock_remix.return_value = mock_video + mock_download.return_value = mock_video_response + mock_factory.return_value = mock_serializer + + await video_target.send_prompt_async(message=Message([msg])) + + # Verify poll was NOT called since status was already completed + mock_poll.assert_not_called() + + +@pytest.mark.usefixtures("patch_central_database") +class TestVideoTargetMetadata: + """Tests for video_id metadata storage in responses.""" + + @pytest.fixture + def video_target(self) -> OpenAIVideoTarget: + return OpenAIVideoTarget( + endpoint="https://api.openai.com/v1", + api_key="test", + model_name="sora-2", + ) + + @pytest.mark.asyncio + async def test_response_includes_video_id_metadata(self, video_target: OpenAIVideoTarget): + """Test that response includes video_id in prompt_metadata for chaining.""" + msg = MessagePiece( + role="user", + original_value="test prompt", + converted_value="test prompt", + conversation_id=str(uuid.uuid4()), + ) + + mock_video = MagicMock() + mock_video.id = "new_video_789" + mock_video.status = "completed" + mock_video.error = None + mock_video.remixed_from_video_id = None + + mock_video_response = MagicMock() + mock_video_response.content = b"video data" + + mock_serializer = MagicMock() + mock_serializer.value = "/path/to/video.mp4" + mock_serializer.save_data = AsyncMock() + + with ( + patch.object(video_target._async_client.videos, "create_and_poll", new_callable=AsyncMock) as mock_create, + patch.object( + video_target._async_client.videos, "download_content", new_callable=AsyncMock + ) as mock_download, + patch("pyrit.prompt_target.openai.openai_video_target.data_serializer_factory") as mock_factory, + ): + mock_create.return_value = mock_video + mock_download.return_value = mock_video_response + mock_factory.return_value = mock_serializer + + response = await video_target.send_prompt_async(message=Message([msg])) + + # Verify response contains video_id in metadata for chaining + response_piece = response[0].message_pieces[0] + assert response_piece.prompt_metadata is not None + assert response_piece.prompt_metadata.get("video_id") == "new_video_789" + + +@pytest.mark.usefixtures("patch_central_database") +class TestVideoTargetEdgeCases: + """Tests for edge cases and error scenarios.""" + + @pytest.fixture + def video_target(self) -> OpenAIVideoTarget: + return OpenAIVideoTarget( + endpoint="https://api.openai.com/v1", + api_key="test", + model_name="sora-2", + ) + + def test_validate_rejects_empty_message(self, video_target: OpenAIVideoTarget): + """Test that empty messages are rejected (by Message constructor).""" + with pytest.raises(ValueError, match="at least one message piece"): + Message([]) + + def test_validate_rejects_no_text_piece(self, video_target: OpenAIVideoTarget): + """Test validation rejects message without text piece.""" + msg = MessagePiece( + role="user", + original_value="/path/image.png", + converted_value="/path/image.png", + converted_value_data_type="image_path", + ) + with pytest.raises(ValueError, match="Expected exactly 1 text piece"): + video_target._validate_request(message=Message([msg])) + + @pytest.mark.asyncio + async def test_image_to_video_with_jpeg(self, video_target: OpenAIVideoTarget): + """Test image-to-video with JPEG image format.""" + conversation_id = str(uuid.uuid4()) + msg_text = MessagePiece( + role="user", + original_value="animate", + converted_value="animate", + conversation_id=conversation_id, + ) + msg_image = MessagePiece( + role="user", + original_value="/path/image.jpg", + converted_value="/path/image.jpg", + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + + mock_video = MagicMock() + mock_video.id = "video_jpeg" + mock_video.status = "completed" + mock_video.error = None + mock_video.remixed_from_video_id = None + + mock_video_response = MagicMock() + mock_video_response.content = b"video data" + + mock_serializer = MagicMock() + mock_serializer.value = "/path/to/output.mp4" + mock_serializer.save_data = AsyncMock() + + mock_image_serializer = MagicMock() + mock_image_serializer.read_data = AsyncMock(return_value=b"jpeg bytes") + + with ( + patch.object(video_target._async_client.videos, "create_and_poll", new_callable=AsyncMock) as mock_create, + patch.object( + video_target._async_client.videos, "download_content", new_callable=AsyncMock + ) as mock_download, + patch("pyrit.prompt_target.openai.openai_video_target.data_serializer_factory") as mock_factory, + patch("pyrit.prompt_target.openai.openai_video_target.DataTypeSerializer.get_mime_type") as mock_mime, + ): + mock_factory.side_effect = [mock_image_serializer, mock_serializer] + mock_create.return_value = mock_video + mock_download.return_value = mock_video_response + mock_mime.return_value = "image/jpeg" + + response = await video_target.send_prompt_async(message=Message([msg_text, msg_image])) + + # Verify JPEG MIME type is used + call_kwargs = mock_create.call_args.kwargs + input_ref = call_kwargs["input_reference"] + assert input_ref[2] == "image/jpeg" + + @pytest.mark.asyncio + async def test_image_to_video_with_unknown_mime_defaults_to_png(self, video_target: OpenAIVideoTarget): + """Test image-to-video defaults to PNG when MIME type cannot be determined.""" + conversation_id = str(uuid.uuid4()) + msg_text = MessagePiece( + role="user", + original_value="animate", + converted_value="animate", + conversation_id=conversation_id, + ) + msg_image = MessagePiece( + role="user", + original_value="/path/image.unknown", + converted_value="/path/image.unknown", + converted_value_data_type="image_path", + conversation_id=conversation_id, + ) + + mock_video = MagicMock() + mock_video.id = "video_unknown" + mock_video.status = "completed" + mock_video.error = None + mock_video.remixed_from_video_id = None + + mock_video_response = MagicMock() + mock_video_response.content = b"video data" + + mock_serializer = MagicMock() + mock_serializer.value = "/path/to/output.mp4" + mock_serializer.save_data = AsyncMock() + + mock_image_serializer = MagicMock() + mock_image_serializer.read_data = AsyncMock(return_value=b"unknown bytes") + + with ( + patch.object(video_target._async_client.videos, "create_and_poll", new_callable=AsyncMock) as mock_create, + patch.object( + video_target._async_client.videos, "download_content", new_callable=AsyncMock + ) as mock_download, + patch("pyrit.prompt_target.openai.openai_video_target.data_serializer_factory") as mock_factory, + patch("pyrit.prompt_target.openai.openai_video_target.DataTypeSerializer.get_mime_type") as mock_mime, + ): + mock_factory.side_effect = [mock_image_serializer, mock_serializer] + mock_create.return_value = mock_video + mock_download.return_value = mock_video_response + mock_mime.return_value = None # MIME type cannot be determined + + response = await video_target.send_prompt_async(message=Message([msg_text, msg_image])) + + # Verify default PNG MIME type is used + call_kwargs = mock_create.call_args.kwargs + input_ref = call_kwargs["input_reference"] + assert input_ref[2] == "image/png" # Default + + @pytest.mark.asyncio + async def test_remix_with_failed_status(self, video_target: OpenAIVideoTarget): + """Test remix mode handles failed video generation.""" + msg = MessagePiece( + role="user", + original_value="remix this", + converted_value="remix this", + prompt_metadata={"video_id": "existing_video"}, + conversation_id=str(uuid.uuid4()), + ) + + mock_video = MagicMock() + mock_video.id = "failed_remix" + mock_video.status = "failed" + mock_error = MagicMock() + mock_error.code = "internal_error" + mock_video.error = mock_error + + with ( + patch.object(video_target._async_client.videos, "remix", new_callable=AsyncMock) as mock_remix, + patch.object(video_target._async_client.videos, "poll", new_callable=AsyncMock) as mock_poll, + ): + mock_remix.return_value = mock_video + # Don't need poll since status is already "failed" + + response = await video_target.send_prompt_async(message=Message([msg])) + + # Verify response is processing error + response_piece = response[0].message_pieces[0] + assert response_piece.response_error == "processing" + + def test_supported_resolutions(self, video_target: OpenAIVideoTarget): + """Test that all supported resolutions are valid.""" + for resolution in OpenAIVideoTarget.SUPPORTED_RESOLUTIONS: + target = OpenAIVideoTarget( + endpoint="https://api.openai.com/v1", + api_key="test", + model_name="sora-2", + resolution_dimensions=resolution, + ) + assert target._size == resolution + + def test_supported_durations(self, video_target: OpenAIVideoTarget): + """Test that all supported durations are valid.""" + for duration in OpenAIVideoTarget.SUPPORTED_DURATIONS: + target = OpenAIVideoTarget( + endpoint="https://api.openai.com/v1", + api_key="test", + model_name="sora-2", + n_seconds=duration, + ) + assert target._n_seconds == duration