49 lines
1.1 KiB
Python
49 lines
1.1 KiB
Python
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
|
|
from sqlalchemy import __version__
|
|
|
|
sa_version = tuple(map(int, __version__.split('.')))
|
|
|
|
if sa_version >= (1, 4, 0):
|
|
from sqlalchemy.orm import declarative_base
|
|
else:
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
@dataclass
|
|
class DbContext:
|
|
"""
|
|
Context flags for the database session.
|
|
"""
|
|
|
|
override_definitions: bool = False
|
|
|
|
|
|
_ctx = DbContext()
|
|
|
|
|
|
@contextmanager
|
|
def override_definitions():
|
|
"""
|
|
Temporarily override the definitions of the entities in the entities
|
|
registry.
|
|
|
|
This is useful when the entities are being imported off-context, like
|
|
e.g. in the `inspect` or `alembic` modules.
|
|
"""
|
|
_ctx.override_definitions = True
|
|
yield
|
|
_ctx.override_definitions = False
|
|
|
|
|
|
def is_defined(table_name: str) -> bool:
|
|
"""
|
|
Check if the given entity class is defined in the entities registry.
|
|
|
|
:param table_name: Name of the table associated to the entity class.
|
|
"""
|
|
return not _ctx.override_definitions and table_name in Base.metadata
|