diff --git a/chatkit/server.py b/chatkit/server.py index 19e3563..8f8c004 100644 --- a/chatkit/server.py +++ b/chatkit/server.py @@ -1,4 +1,5 @@ import asyncio +import base64 from abc import ABC, abstractmethod from collections.abc import AsyncIterator from contextlib import contextmanager @@ -41,6 +42,7 @@ ErrorEvent, FeedbackKind, HiddenContextItem, + InputTranscribeReq, ItemsFeedbackReq, ItemsListReq, NonStreamingReq, @@ -69,6 +71,7 @@ ThreadStreamEvent, ThreadsUpdateReq, ThreadUpdatedEvent, + TranscriptionResult, UserMessageInput, UserMessageItem, WidgetComponentUpdated, @@ -319,6 +322,14 @@ async def add_feedback( # noqa: B027 """Persist user feedback for one or more thread items.""" pass + async def transcribe( # noqa: B027 + self, audio_bytes: bytes, mime_type: str, context: TContext + ) -> TranscriptionResult: + """Transcribe speech audio to text. Override this method to support dictation.""" + raise NotImplementedError( + "transcribe() must be overridden to support the input.transcribe request." + ) + def action( self, thread: ThreadMetadata, @@ -446,6 +457,12 @@ async def _process_non_streaming( request.params.attachment_id, context=context ) return b"{}" + case InputTranscribeReq(): + audio_bytes = base64.b64decode(request.params.audio_base64) + transcription_result = await self.transcribe( + audio_bytes, request.params.mime_type, context=context + ) + return self._serialize(transcription_result) case ItemsListReq(): items_list_params = request.params items = await self.store.load_thread_items( diff --git a/chatkit/types.py b/chatkit/types.py index 3280b13..9c9656c 100644 --- a/chatkit/types.py +++ b/chatkit/types.py @@ -174,6 +174,29 @@ class AttachmentCreateParams(BaseModel): mime_type: str +class InputTranscribeReq(BaseReq): + """Request to transcribe an audio payload into text.""" + + type: Literal["input.transcribe"] = "input.transcribe" + params: InputTranscribeParams + + +class InputTranscribeParams(BaseModel): + """Parameters for speech transcription.""" + + audio_base64: str + """Base64-encoded audio bytes.""" + + mime_type: str + """MIME type for the audio payload (e.g. 'audio/webm', 'audio/wav').""" + + +class TranscriptionResult(BaseModel): + """Input speech transcription result.""" + + text: str + + class ItemsListReq(BaseReq): """Request to list items inside a thread.""" @@ -236,6 +259,7 @@ class ThreadDeleteParams(BaseModel): | AttachmentsDeleteReq | ThreadsUpdateReq | ThreadsDeleteReq + | InputTranscribeReq ) """Union of request types that yield immediate responses.""" diff --git a/docs/guides/accept-rich-user-input.md b/docs/guides/accept-rich-user-input.md index 052e7eb..8bd8a4e 100644 --- a/docs/guides/accept-rich-user-input.md +++ b/docs/guides/accept-rich-user-input.md @@ -172,6 +172,49 @@ Set `ImageAttachment.preview_url` to allow the client to render thumbnails. - If your preview URLs are **permanent/public**, set `preview_url` once when creating the attachment and persist it. - If your storage uses **expiring URLs**, generate a fresh `preview_url` when returning attachment metadata (for example, in `Store.load_thread_items` and `Store.load_attachment`) rather than persisting a long-lived URL. In this case, returning a short-lived signed URL directly is the simplest approach. Alternatively, you may return a redirect that resolves to a temporary signed URL, as long as the final URL serves image bytes with appropriate CORS headers. +## Dictation: speech-to-text input + +Enable dictation so users can record audio and have it transcribed into text before sending. + +### Enable dictation in the client + +Turn on dictation in the composer: + +```ts +const chatkit = useChatKit({ + // ... + composer: { + dictation: { + enabled: true, + }, + }, +}); +``` + +This maps to `ChatKitOptions.composer.dictation`. + +### Implement `ChatKitServer.transcribe` + +When dictation is enabled, the client records audio and sends it to your backend for transcription. Implement `ChatKitServer.transcribe` to accept audio bytes and return a transcription result. + +Example transcription method using the OpenAI Audio API: + +```python +async def transcribe(self, audio_bytes: bytes, mime_type: str, context: RequestContext) -> str: + ext = "m4a" if mime_type.startswith("audio/mp4") else "webm" + audio_file = io.BytesIO(audio_bytes) + audio_file.name = f"audio.{ext}" + + client = OpenAI() + transcription = client.audio.transcriptions.create( + model="gpt-4o-transcribe", + file=audio_file + ) + return TranscriptionResult(text=transcription.text) +``` + +Return a `TranscriptionResult` that includes the final `text` that should appear in the composer. + ## @-mentions: tag entities in user messages Enable @-mentions so users can tag entities (like documents, tickets, or users) instead of pasting raw identifiers. Mentions travel through ChatKit as structured tags so the model can resolve entities instead of guessing from free text. diff --git a/tests/test_chatkit_server.py b/tests/test_chatkit_server.py index 750f7c2..066ffae 100644 --- a/tests/test_chatkit_server.py +++ b/tests/test_chatkit_server.py @@ -1,4 +1,5 @@ import asyncio +import base64 import sqlite3 from contextlib import contextmanager from datetime import datetime @@ -38,6 +39,8 @@ FileAttachment, ImageAttachment, InferenceOptions, + InputTranscribeParams, + InputTranscribeReq, ItemFeedbackParams, ItemsFeedbackReq, ItemsListParams, @@ -75,6 +78,7 @@ ThreadUpdatedEvent, ThreadUpdateParams, ToolChoice, + TranscriptionResult, UserMessageInput, UserMessageItem, UserMessageTextContent, @@ -159,6 +163,7 @@ def make_server( ] | None = None, file_store: AttachmentStore | None = None, + transcribe_callback: Callable[[bytes, str, Any], TranscriptionResult] | None = None, ): global server_id db_path = f"file:{server_id}?mode=memory&cache=shared" @@ -206,6 +211,13 @@ async def add_feedback( return handle_feedback(thread_id, item_ids, feedback, context) + async def transcribe( + self, audio_bytes: bytes, mime_type: str, context: Any + ) -> TranscriptionResult: + if transcribe_callback is None: + return await super().transcribe(audio_bytes, mime_type, context) + return transcribe_callback(audio_bytes, mime_type, context) + async def process_streaming( self, request_obj, context: Any | None = None ) -> list[ThreadStreamEvent]: @@ -1887,6 +1899,36 @@ async def responder( assert any(e.type == "thread.item.done" for e in events) +async def test_input_transcribe_decodes_base64_and_passes_mime_type(): + audio_bytes = b"hello audio" + audio_b64 = base64.b64encode(audio_bytes).decode("ascii") + seen: dict[str, Any] = {} + + def transcribe_callback( + audio: bytes, mime: str, context: Any + ) -> TranscriptionResult: + seen["audio"] = audio + seen["mime"] = mime + seen["context"] = context + return TranscriptionResult(text="ok") + + with make_server(transcribe_callback=transcribe_callback) as server: + result = await server.process_non_streaming( + InputTranscribeReq( + params=InputTranscribeParams( + audio_base64=audio_b64, + mime_type="audio/wav", + ) + ) + ) + parsed = TypeAdapter(TranscriptionResult).validate_json(result.json) + assert parsed.text == "ok" + + assert seen["audio"] == audio_bytes + assert seen["mime"] == "audio/wav" + assert seen["context"] == DEFAULT_CONTEXT + + async def test_retry_after_item_passes_tools_to_responder(): pass