[#348] Added openai.transcribe action.

This API is the foundation for the `assistant.openai` plugin.
This commit is contained in:
Fabio Manganiello 2024-06-02 01:00:07 +02:00
parent f356fcd844
commit 9cca928d4b

View file

@ -3,7 +3,7 @@ from dataclasses import dataclass
from datetime import datetime as dt from datetime import datetime as dt
from enum import Enum from enum import Enum
from threading import RLock from threading import RLock
from typing import Iterable, List, Optional from typing import IO, Iterable, List, Optional
import requests import requests
@ -265,6 +265,84 @@ class OpenaiPlugin(Plugin):
self._update_context(msg) self._update_context(msg)
return msg["content"] return msg["content"]
def _process_transcribe_response(self, resp: requests.Response) -> str:
rs_json = None
try:
rs_json = resp.json()
except Exception:
pass
self.logger.debug("OpenAI response: %s", rs_json)
resp.raise_for_status()
return (rs_json or {}).get("text", "")
def transcribe_file(
self,
f: IO,
model: Optional[str] = 'whisper-1',
timeout: Optional[float] = None,
) -> str:
resp = requests.post(
"https://api.openai.com/v1/audio/transcriptions",
timeout=timeout or self.timeout,
headers={
"Authorization": f"Bearer {self._api_key}",
},
files={
"file": f,
},
data={
"model": model or self.model,
},
)
return self._process_transcribe_response(resp)
def transcribe_raw(
self,
audio: bytes,
extension: str,
model: Optional[str] = 'whisper-1',
timeout: Optional[float] = None,
) -> str:
resp = requests.post(
"https://api.openai.com/v1/audio/transcriptions",
timeout=timeout or self.timeout,
headers={
"Authorization": f"Bearer {self._api_key}",
},
files={
"file": (f"audio.{extension}", audio),
},
data={
"model": model or self.model,
},
)
return self._process_transcribe_response(resp)
@action
def transcribe(
self,
audio: str,
model: Optional[str] = 'whisper-1',
timeout: Optional[float] = None,
) -> str:
"""
Perform speech-to-text on an audio file.
:param audio: The audio file to transcribe.
:param model: The model to use for speech-to-text. Default:
``whisper-1``. If not set, the configured default model will be
used.
:param timeout: Timeout for the API request. If not set, the default
timeout will be used.
:return: The transcribed text.
"""
with open(os.path.expanduser(audio), "rb") as f:
return self.transcribe_file(f, model=model, timeout=timeout)
def _update_context(self, *entries: dict): def _update_context(self, *entries: dict):
""" """
Update the context with a new entry. Update the context with a new entry.