Conversation
|
|
||
| # Start backend with uvicorn | ||
| # Default to no initializers | ||
| if initializers is None: |
There was a problem hiding this comment.
Chicken/egg, but pretty quickly it'd be nice to update this to use the config, and we can just get rid of initializers here: #1343
Maybe we should check in with Victor and use it to begin with?
There was a problem hiding this comment.
Right. Don't want to block on that but that was my feeling, too, while reading that PR.
| """ | ||
| Attack-related request and response models. | ||
|
|
||
| All interactions in the UI are modeled as "attacks" - including manual conversations. |
There was a problem hiding this comment.
One thing to think about is that we likely want the same types of separations we have in workflows. I could also see us wanting to benchmark things or to setup prompt generation.
But for now this is probably good to get started.
There was a problem hiding this comment.
That will probably be different since there are no attack results, though.
| """ | ||
|
|
||
| piece_id: str = Field(..., description="Unique piece identifier") | ||
| data_type: str = Field(default="text", description="Data type: 'text', 'image', 'audio', 'video', etc.") |
There was a problem hiding this comment.
wouldn't we need both original_value_datatype and converted_value_data_type?
There was a problem hiding this comment.
Correct! I was so focused on the mime type I forgot that there's just one 🙂
| original_value: Optional[str] = Field(default=None, description="Original value before conversion") | ||
| original_value_mime_type: Optional[str] = Field(default=None, description="MIME type of original value") | ||
| converted_value: str = Field(..., description="Converted value (text or base64 for media)") | ||
| converted_value_mime_type: Optional[str] = Field(default=None, description="MIME type of converted value") |
There was a problem hiding this comment.
I'm guessing we need mime type to upload/download content. If we need it here I could see us needing it elsewhere. Should we add it to MessagePiece? Should we make this a thinner wrapper of that?
There was a problem hiding this comment.
It's for adding files (imagine any modality other than text).
There was a problem hiding this comment.
Pull request overview
This PR adds comprehensive backend REST APIs to support upcoming frontend development, implementing an attack-centric design where all user interactions (including manual conversations) are modeled as "attacks". The implementation includes service layers, API routes, Pydantic models, error handling middleware, registries for managing target/converter instances, a CLI tool, and comprehensive test coverage.
Changes:
- Adds three main services (AttackService, ConverterService, TargetService) for managing attacks, converters, and targets
- Implements REST API routes for attacks, targets, converters, labels, health, and version endpoints
- Creates Pydantic models for request/response validation with RFC 7807 error handling
- Adds instance registries for targets and converters with singleton pattern
- Implements
pyrit_backendCLI command for starting the server with initializer support - Includes 900+ lines of comprehensive unit tests for all services and routes
Reviewed changes
Copilot reviewed 35 out of 35 changed files in this pull request and generated 21 comments.
Show a summary per file
| File | Description |
|---|---|
| pyrit/backend/services/attack_service.py | Service layer for managing attack lifecycle, messages, and scoring (586 lines) |
| pyrit/backend/services/converter_service.py | Service for managing converter instances and previewing conversions (303 lines) |
| pyrit/backend/services/target_service.py | Service for managing target instance creation and retrieval (187 lines) |
| pyrit/backend/routes/attacks.py | REST API endpoints for attack CRUD operations and messaging (249 lines) |
| pyrit/backend/routes/targets.py | REST API endpoints for target instance management (101 lines) |
| pyrit/backend/routes/converters.py | REST API endpoints for converter instances and preview (134 lines) |
| pyrit/backend/routes/labels.py | REST API endpoint for retrieving filter label options (88 lines) |
| pyrit/backend/models/attacks.py | Pydantic models for attack requests/responses (201 lines) |
| pyrit/backend/models/targets.py | Pydantic models for target instances (52 lines) |
| pyrit/backend/models/converters.py | Pydantic models for converter instances and preview (98 lines) |
| pyrit/backend/models/common.py | Common models including RFC 7807 error responses and sensitive field filtering (93 lines) |
| pyrit/backend/middleware/error_handlers.py | RFC 7807 compliant error handler middleware (182 lines) |
| pyrit/registry/instance_registries/target_registry.py | Registry for managing target instances (88 lines) |
| pyrit/registry/instance_registries/converter_registry.py | Registry for managing converter instances (108 lines) |
| pyrit/cli/pyrit_backend.py | CLI command for starting the backend server with initialization support (217 lines) |
| tests/unit/backend/*.py | Comprehensive unit tests for all services, routes, models, and error handlers (2000+ lines) |
| pyrit/backend/main.py | Updated to register new routes and error handlers |
| pyproject.toml | Adds pyrit_backend CLI entry point and documentation exemption |
| frontend/dev.py | Updated to use pyrit_backend CLI instead of direct uvicorn |
| async def list_converters(self) -> ConverterInstanceListResponse: | ||
| """ | ||
| List all converter instances. | ||
|
|
||
| Returns: | ||
| ConverterInstanceListResponse containing all registered converters. | ||
| """ | ||
| items = [ | ||
| self._build_instance_from_object(name, obj) for name, obj in self._registry.get_all_instances().items() | ||
| ] | ||
| return ConverterInstanceListResponse(items=items) | ||
|
|
||
| async def get_converter(self, converter_id: str) -> Optional[ConverterInstance]: | ||
| """ | ||
| Get a converter instance by ID. | ||
|
|
||
| Returns: | ||
| ConverterInstance if found, None otherwise. | ||
| """ | ||
| obj = self._registry.get_instance_by_name(converter_id) | ||
| if obj is None: | ||
| return None | ||
| return self._build_instance_from_object(converter_id, obj) | ||
|
|
||
| def get_converter_object(self, converter_id: str) -> Optional[Any]: | ||
| """ | ||
| Get the actual converter object. | ||
|
|
||
| Returns: | ||
| The PromptConverter object if found, None otherwise. | ||
| """ | ||
| return self._registry.get_instance_by_name(converter_id) | ||
|
|
||
| async def create_converter(self, request: CreateConverterRequest) -> CreateConverterResponse: | ||
| """ | ||
| Create a new converter instance from API request. | ||
|
|
||
| Instantiates the converter with the given type and params, | ||
| then registers it in the registry. | ||
|
|
||
| Args: | ||
| request: The create converter request with type and params. | ||
|
|
||
| Returns: | ||
| CreateConverterResponse with the new converter's details. | ||
|
|
||
| Raises: | ||
| ValueError: If the converter type is not found. | ||
| """ | ||
| converter_id = str(uuid.uuid4()) | ||
|
|
||
| # Resolve any converter references in params and instantiate | ||
| params = self._resolve_converter_params(request.params) | ||
| converter_class = self._get_converter_class(request.type) | ||
| converter_obj = converter_class(**params) | ||
| self._registry.register_instance(converter_obj, name=converter_id) | ||
|
|
||
| return CreateConverterResponse( | ||
| converter_id=converter_id, | ||
| type=request.type, | ||
| display_name=request.display_name, | ||
| params=request.params, | ||
| ) | ||
|
|
||
| async def preview_conversion(self, request: ConverterPreviewRequest) -> ConverterPreviewResponse: |
There was a problem hiding this comment.
All async functions and methods MUST end with _async suffix according to PyRIT coding guidelines. The following async functions in this file are missing the _async suffix:
list_convertersget_convertercreate_converterpreview_conversion
These should be renamed to:
list_converters_asyncget_converter_asynccreate_converter_asyncpreview_conversion_async
| async def list_targets(self) -> TargetListResponse: | ||
| """ | ||
| List all target instances. | ||
|
|
||
| Returns: | ||
| TargetListResponse containing all registered targets. | ||
| """ | ||
| items = [ | ||
| self._build_instance_from_object(name, obj) for name, obj in self._registry.get_all_instances().items() | ||
| ] | ||
| return TargetListResponse(items=items) | ||
|
|
||
| async def get_target(self, target_id: str) -> Optional[TargetInstance]: | ||
| """ | ||
| Get a target instance by ID. | ||
|
|
||
| Returns: | ||
| TargetInstance if found, None otherwise. | ||
| """ | ||
| obj = self._registry.get_instance_by_name(target_id) | ||
| if obj is None: | ||
| return None | ||
| return self._build_instance_from_object(target_id, obj) | ||
|
|
||
| def get_target_object(self, target_id: str) -> Optional[Any]: | ||
| """ | ||
| Get the actual target object for use in attacks. | ||
|
|
||
| Returns: | ||
| The PromptTarget object if found, None otherwise. | ||
| """ | ||
| return self._registry.get_instance_by_name(target_id) | ||
|
|
||
| async def create_target(self, request: CreateTargetRequest) -> CreateTargetResponse: |
There was a problem hiding this comment.
All async functions and methods MUST end with _async suffix according to PyRIT coding guidelines. The following async functions in this file are missing the _async suffix:
list_targetsget_targetcreate_target
These should be renamed to:
list_targets_asyncget_target_asynccreate_target_async
| pagination=PaginationInfo(limit=limit, has_more=has_more, next_cursor=next_cursor, prev_cursor=cursor), | ||
| ) | ||
|
|
||
| async def get_attack(self, attack_id: str) -> Optional[AttackSummary]: |
There was a problem hiding this comment.
Functions with more than 1 parameter MUST use * after self to enforce keyword-only arguments according to PyRIT coding guidelines. The function get_attack has 2 parameters (self, attack_id) but doesn't enforce keyword-only arguments.
Change to:
async def get_attack(self, *, attack_id: str) -> Optional[AttackSummary]:| def test_get_attack_service_returns_attack_service(self) -> None: | ||
| """Test that get_attack_service returns an AttackService instance.""" | ||
| # Reset singleton for clean test | ||
| import pyrit.backend.services.attack_service as module |
There was a problem hiding this comment.
Module 'pyrit.backend.services.attack_service' is imported with both 'import' and 'import from'.
| def test_get_attack_service_returns_same_instance(self) -> None: | ||
| """Test that get_attack_service returns the same instance.""" | ||
| # Reset singleton for clean test | ||
| import pyrit.backend.services.attack_service as module |
There was a problem hiding this comment.
Module 'pyrit.backend.services.attack_service' is imported with both 'import' and 'import from'.
|
|
||
| import pytest | ||
|
|
||
| import pyrit.backend.services.converter_service as converter_service_module |
There was a problem hiding this comment.
Module 'pyrit.backend.services.converter_service' is imported with both 'import' and 'import from'.
|
|
||
| def test_get_target_service_returns_target_service(self) -> None: | ||
| """Test that get_target_service returns a TargetService instance.""" | ||
| import pyrit.backend.services.target_service as module |
There was a problem hiding this comment.
Module 'pyrit.backend.services.target_service' is imported with both 'import' and 'import from'.
|
|
||
| def test_get_target_service_returns_same_instance(self) -> None: | ||
| """Test that get_target_service returns the same instance.""" | ||
| import pyrit.backend.services.target_service as module |
There was a problem hiding this comment.
Module 'pyrit.backend.services.target_service' is imported with both 'import' and 'import from'.
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT license. | ||
|
|
||
| """ |
There was a problem hiding this comment.
Concern: Model Duplication/Drift and Confusion
A lot of these are really similar to our core models (MessagePiece, Message, Score, AttackSummary). Additionally, converters and targets here are really similar to the corresponding Identifiers.
Right now, the translation logic (in attack_service.py) is fragile. A name change breaks things, and things like type could all be different (e.g. score_value is float in the backend but a string in the model)
I also think this would be easier to use if all the same fields are available. It could be confusing to program here and not have access to converted_data_type. As a new user I'd be asking myself "what is converted_type and how does it map?". I think it would be really useful to be able to access all the model pieces in the same ways.
Proposal
Could we define the Pydantic models with the same field names and drift detection? E.g. use from_attributes=True
# pyrit/backend/models/message_piece.py
class MessagePieceSchema(BaseModel):
model_config = ConfigDict(from_attributes=True)
# Same names as domain MessagePiece
id: Optional[str] = None # serialized from UUID
role: Literal["system", "user", "assistant", "simulated_assistant"]
original_value: str
converted_value: Optional[str] = None
original_value_data_type: str = "text"
converted_value_data_type: Optional[str] = None
response_error: str = "none"
sequence: int = -1
conversation_id: Optional[str] = None
# ... all 25 fields with same names ...
# API-only extras
friendly_name: Optional[str] = None
_EXTRA_FIELDS = {"friendly_name"}
# Potentially add this as an abstract method in a base class to ensure we add it in all backend models
@classmethod
def from_model(cls, obj, **extras):
instance = cls.model_validate(obj)
for k, v in extras.items():
setattr(instance, k, v)
return instance
Apply the same paradigm to identifiers:
# pyrit/backend/models/converters.py
class ConverterIdentifierSchema(BaseModel):
model_config = ConfigDict(from_attributes=True)
# Same fields as domain ConverterIdentifier
__type__: str
__module__: str
id: Optional[str] = None
supported_input_types: Optional[tuple] = None
supported_output_types: Optional[tuple] = None
sub_converter_identifiers: Optional[List["ConverterIdentifierSchema"]] = None
# ... etc
# API extras
display_name: Optional[str] = None
Add some kind of drift detection test. We could at runtime or even tests I'd be happy with
# tests/unit/backend/test_model_mirrors.py
class TestMessagePieceMirror:
def test_all_domain_fields_present(self):
domain_fields = get_init_param_names(MessagePiece)
schema_fields = set(MessagePieceSchema.model_fields.keys())
extra_fields = MessagePieceSchema._EXTRA_FIELDS
missing = domain_fields - (schema_fields - extra_fields)
assert not missing, f"Schema missing domain fields: {missing}"
Then instead of attack_service, translate with from_model
# Before (in attack_service.py)
pieces = [MessagePiece(piece_id=str(p.id), data_type=p.converted_value_data_type, ...)]
# After
pieces = [MessagePieceSchema.from_model(p) for p in msg.message_pieces]
There was a problem hiding this comment.
I want to add some thoughts here since this is going to set precedent for the rest of the backend module.
I think keeping separate API models is the right call. MessagePiece is huge (we have around 25 fields, converter_identifiers, scorer_identifier, attack_identifier, original_value_sha256, prompt_metadata, originator, etc) and most of that is internal plumbing that an API consumer shouldn't ever see. I think exposing these 8 fields in the API isn't really duplication. It's the whole point of DTOs to keep the public contract clean and let the domain evolve without dragging the API around with it (here is a great article on this).
The from_attributes you suggested to mirror everything pushes us in the opposite direction. It basically couples the API schema to domain internals. Any refactor of MessagePiece becomes an API change. Also, we end up with a bunch of Optional fields that are never actually populated, which is misleading for the users.
Even with the current setup we have today (backend and frontend live together and run locally), this is open source. People will deploy this differently (host backend separately, build their own frontend, etc.). Keeping the API contract independent now would save us a lot of painful untangling later.
I know you mentioned the current mapping is fragile, but I'm not too worried because our mapping code is explicit attribute access. If a domain field gets renamed, lint/type checking should catch it (or even mypy). That said, I do agree the translation code could be organized better, e.g pulling it into dedicated mapper functions (e.g. pyrit/backend/mappers/) would centralize the changes and make it easy to add a few mapping tests.
I think the approach in this PR is correct, let's just clean up the translation layer a bit and I think we'd be in great shape.
bashirpartovi
left a comment
There was a problem hiding this comment.
Great job Roman, this is my first pass at your PR, I do have a few more comments but need to think about them more
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT license. | ||
|
|
||
| """ |
There was a problem hiding this comment.
I want to add some thoughts here since this is going to set precedent for the rest of the backend module.
I think keeping separate API models is the right call. MessagePiece is huge (we have around 25 fields, converter_identifiers, scorer_identifier, attack_identifier, original_value_sha256, prompt_metadata, originator, etc) and most of that is internal plumbing that an API consumer shouldn't ever see. I think exposing these 8 fields in the API isn't really duplication. It's the whole point of DTOs to keep the public contract clean and let the domain evolve without dragging the API around with it (here is a great article on this).
The from_attributes you suggested to mirror everything pushes us in the opposite direction. It basically couples the API schema to domain internals. Any refactor of MessagePiece becomes an API change. Also, we end up with a bunch of Optional fields that are never actually populated, which is misleading for the users.
Even with the current setup we have today (backend and frontend live together and run locally), this is open source. People will deploy this differently (host backend separately, build their own frontend, etc.). Keeping the API contract independent now would save us a lot of painful untangling later.
I know you mentioned the current mapping is fragile, but I'm not too worried because our mapping code is explicit attribute access. If a domain field gets renamed, lint/type checking should catch it (or even mypy). That said, I do agree the translation code could be organized better, e.g pulling it into dedicated mapper functions (e.g. pyrit/backend/mappers/) would centralize the changes and make it easy to add a few mapping tests.
I think the approach in this PR is correct, let's just clean up the translation layer a bit and I think we'd be in great shape.
| @app.exception_handler(Exception) | ||
| async def global_exception_handler_async(request: object, exc: Exception) -> JSONResponse: | ||
| """ | ||
| Handle all unhandled exceptions globally. | ||
|
|
||
| Note: This is a fallback handler. Most exceptions are handled by | ||
| the RFC 7807 error handlers in middleware/error_handlers.py. | ||
|
|
There was a problem hiding this comment.
Just looking at this for the first time. It looks like you have an issue here (not really related to the PR but...).
In main.py, you are registering a generic Exception handler that doesn't follow RFC 7807. But in error_handlers.py, you already have a compliant handler:
@app.exception_handler(Exception)
async def generic_exception_handler(...) -> JSONResponse:
...
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=problem.model_dump(exclude_none=True), # RFC 7807 format
)Since register_error_handlers(app) is called before the decorator in main.py runs, the main.py handler overwrites the RFC 7807 one. This means all unhandled exceptions return the non-standard format.
Can you remove the handler from main.py and just use the one in error_handlers.py?
|
|
||
|
|
||
| # Initialize PyRIT on startup to load .env and .env.local files | ||
| @app.on_event("startup") |
There was a problem hiding this comment.
As far as I know, this is deprecated, no?
I think the new standard is spans, e.g.:
@asynccontextmanager
async def lifespan(app: FastAPI):
await initialize_pyrit_async(memory_db_type="SQLite")
yield
app = FastAPI(..., lifespan=lifespan)| return AttackSummary( | ||
| attack_id=ar.conversation_id, | ||
| name=ar.attack_identifier.get("name"), | ||
| target_id=ar.attack_identifier.get("target_id", ""), | ||
| target_type=ar.attack_identifier.get("target_type", ""), | ||
| outcome=self._map_outcome(ar.outcome), | ||
| last_message_preview=last_preview, | ||
| message_count=message_count, | ||
| created_at=created_at, | ||
| updated_at=updated_at, | ||
| ) |
There was a problem hiding this comment.
I think you are missing labels
| return AttackSummary( | |
| attack_id=ar.conversation_id, | |
| name=ar.attack_identifier.get("name"), | |
| target_id=ar.attack_identifier.get("target_id", ""), | |
| target_type=ar.attack_identifier.get("target_type", ""), | |
| outcome=self._map_outcome(ar.outcome), | |
| last_message_preview=last_preview, | |
| message_count=message_count, | |
| created_at=created_at, | |
| updated_at=updated_at, | |
| ) | |
| return AttackSummary( | |
| attack_id=ar.conversation_id, | |
| name=ar.attack_identifier.get("name"), | |
| target_id=ar.attack_identifier.get("target_id", ""), | |
| target_type=ar.attack_identifier.get("target_type", ""), | |
| outcome=self._map_outcome(ar.outcome), | |
| last_message_preview=last_preview, | |
| message_count=message_count, | |
| created_at=created_at, | |
| updated_at=updated_at, | |
| labels=ar.metadata.get("labels", {}) | |
| ) |
| metadata={ | ||
| "created_at": now.isoformat(), | ||
| "updated_at": now.isoformat(), | ||
| **(request.labels or {}), |
There was a problem hiding this comment.
I think this is wrong. In get_attack, you are doing this:
labels=ar.metadata.get("labels", {}),which means, you are expecting to have a key in metadata called labels. But here, you are spreading the labels as individual keys in the metadata. I think what you meant was this:
metadata={
"created_at": now.isoformat(),
"updated_at": now.isoformat(),
"labels": request.labels or {},
}| _attack_service: Optional[AttackService] = None | ||
|
|
||
|
|
||
| def get_attack_service() -> AttackService: | ||
| """ | ||
| Get the global attack service instance. | ||
|
|
||
| Returns: | ||
| The singleton AttackService instance. | ||
| """ | ||
| global _attack_service | ||
| if _attack_service is None: | ||
| _attack_service = AttackService() | ||
| return _attack_service |
There was a problem hiding this comment.
All services use this pattern but this is not thread-safe. When you have async workers, you could get multiple instances.
You should try to use locks or caching:
from functools import lru_cache
@lru_cache(maxsize=1)
def get_attack_service() -> AttackService:
return AttackService()| # Configure CORS | ||
| app.add_middleware( | ||
| CORSMiddleware, | ||
| allow_origins=["http://localhost:3000", "http://localhost:5173"], # Vite default ports |
There was a problem hiding this comment.
This should not be hardcoded. If it gets deployed anywhere other than the local env, it will break. You should try using env vars to read these, something like this:
CORS_ORIGINS = os.getenv("PYRIT_CORS_ORIGINS", "http://localhost:3000,http://localhost:5173").split(",")
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
...
)| class TargetListResponse(BaseModel): | ||
| """Response for listing target instances.""" | ||
|
|
||
| items: List[TargetInstance] = Field(..., description="List of target instances") |
There was a problem hiding this comment.
I am assuming the reason we don't have pagination here is because we are not expecting someone creating many targets? Could this grow over time, especially if we have a DB locally ?
| @router.get( | ||
| "", | ||
| response_model=TargetListResponse, | ||
| ) |
There was a problem hiding this comment.
Do we need error responses here?
| from pyrit.backend.models.common import ProblemDetail | ||
| from pyrit.backend.services.attack_service import get_attack_service | ||
|
|
||
| router = APIRouter(prefix="/attacks", tags=["attacks"]) |
There was a problem hiding this comment.
Should we add DELETE operation for attacks? I am assuming this can grow over time
Description
Adding backend APIs to support upcoming frontend development. This is based on an initial proposal and review.
Tests and Documentation
Includes tests for all APIs.