Pass the database engine to the Alembic process as an extra argument.

If the path of the default database engine is overridden via `--workdir`
option then it won't be visible to the new `python` subprocess spawned
for Alembic.
This commit is contained in:
Fabio Manganiello 2023-08-19 13:02:05 +02:00
parent c2b3ec8ce3
commit 181da63c89
Signed by untrusted user: blacklight
GPG key ID: D90FBA7F76362774
2 changed files with 29 additions and 11 deletions

View file

@ -24,6 +24,7 @@ from sqlalchemy import (
UniqueConstraint, UniqueConstraint,
inspect as schema_inspect, inspect as schema_inspect,
) )
from sqlalchemy.engine import Engine
from sqlalchemy.orm import ColumnProperty, backref, relationship from sqlalchemy.orm import ColumnProperty, backref, relationship
from sqlalchemy.orm.exc import ObjectDeletedError from sqlalchemy.orm.exc import ObjectDeletedError
@ -303,6 +304,24 @@ def _discover_entity_types():
entities_registry[obj] = {} # type: ignore entities_registry[obj] = {} # type: ignore
def _get_db():
"""
Utility method to get the db plugin.
"""
from platypush.context import get_plugin
db = get_plugin('db')
assert db
return db
def _get_db_engine() -> Engine:
"""
Utility method to get the db engine.
"""
return _get_db().get_engine()
def get_entities_registry() -> EntityRegistryType: def get_entities_registry() -> EntityRegistryType:
""" """
:returns: A copy of the entities registry. :returns: A copy of the entities registry.
@ -314,13 +333,9 @@ def init_entities_db():
""" """
Initializes the entities database. Initializes the entities database.
""" """
from platypush.context import get_plugin
run_db_migrations() run_db_migrations()
_discover_entity_types() _discover_entity_types()
db = get_plugin('db') _get_db().create_all(_get_db_engine(), Base)
assert db
db.create_all(db.get_engine(), Base)
def run_db_migrations(): def run_db_migrations():
@ -339,6 +354,8 @@ def run_db_migrations():
'alembic', 'alembic',
'-c', '-c',
alembic_ini, alembic_ini,
'-x',
f'DBNAME={_get_db_engine().url}',
'upgrade', 'upgrade',
'head', 'head',
], ],

View file

@ -74,14 +74,15 @@ def run_migrations_online() -> None:
def set_db_engine(): def set_db_engine():
engine_url = context.get_x_argument(as_dictionary=True).get('DBNAME')
if not engine_url:
db_conf = Config.get('db') db_conf = Config.get('db')
assert db_conf, 'Could not retrieve the database configuration' assert db_conf, 'Could not retrieve the database configuration'
engine = db_conf['engine'] engine_url = db_conf['engine']
assert engine, 'No database engine configured' assert engine_url, 'No database engine configured'
config = context.config
section = config.config_ini_section section = config.config_ini_section
config.set_section_option(section, 'DB_ENGINE', engine) config.set_section_option(section, 'DB_ENGINE', engine_url)
set_db_engine() set_db_engine()