diff --git a/platypush/plugins/openai/__init__.py b/platypush/plugins/openai/__init__.py index 70b5538b09..0002796fa7 100644 --- a/platypush/plugins/openai/__init__.py +++ b/platypush/plugins/openai/__init__.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from datetime import datetime as dt from enum import Enum from threading import RLock -from typing import Iterable, List, Optional +from typing import IO, Iterable, List, Optional import requests @@ -265,6 +265,84 @@ class OpenaiPlugin(Plugin): self._update_context(msg) return msg["content"] + def _process_transcribe_response(self, resp: requests.Response) -> str: + rs_json = None + + try: + rs_json = resp.json() + except Exception: + pass + + self.logger.debug("OpenAI response: %s", rs_json) + resp.raise_for_status() + return (rs_json or {}).get("text", "") + + def transcribe_file( + self, + f: IO, + model: Optional[str] = 'whisper-1', + timeout: Optional[float] = None, + ) -> str: + resp = requests.post( + "https://api.openai.com/v1/audio/transcriptions", + timeout=timeout or self.timeout, + headers={ + "Authorization": f"Bearer {self._api_key}", + }, + files={ + "file": f, + }, + data={ + "model": model or self.model, + }, + ) + + return self._process_transcribe_response(resp) + + def transcribe_raw( + self, + audio: bytes, + extension: str, + model: Optional[str] = 'whisper-1', + timeout: Optional[float] = None, + ) -> str: + resp = requests.post( + "https://api.openai.com/v1/audio/transcriptions", + timeout=timeout or self.timeout, + headers={ + "Authorization": f"Bearer {self._api_key}", + }, + files={ + "file": (f"audio.{extension}", audio), + }, + data={ + "model": model or self.model, + }, + ) + + return self._process_transcribe_response(resp) + + @action + def transcribe( + self, + audio: str, + model: Optional[str] = 'whisper-1', + timeout: Optional[float] = None, + ) -> str: + """ + Perform speech-to-text on an audio file. + + :param audio: The audio file to transcribe. + :param model: The model to use for speech-to-text. Default: + ``whisper-1``. If not set, the configured default model will be + used. + :param timeout: Timeout for the API request. If not set, the default + timeout will be used. + :return: The transcribed text. + """ + with open(os.path.expanduser(audio), "rb") as f: + return self.transcribe_file(f, model=model, timeout=timeout) + def _update_context(self, *entries: dict): """ Update the context with a new entry.