Skip to content

Commit

Permalink
strip trailing whitespace in function name prior to parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
olirice committed May 21, 2024
1 parent cf42ef8 commit a349e40
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/alembic_utils/pg_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def literal_signature(self) -> str:
"""
# May already be quoted if loading from database or SQL file
name, remainder = self.signature.split("(", 1)
return '"' + name + '"(' + remainder
return '"' + name.strip() + '"(' + remainder

def to_sql_statement_create(self):
"""Generates a SQL "create function" statement for PGFunction"""
Expand All @@ -79,12 +79,12 @@ def to_sql_statement_drop(self, cascade=False):
template = "{function_name}({parameters})"
result = parse(template, self.signature, case_sensitive=False)
try:
function_name = result["function_name"]
function_name = result["function_name"].strip()
parameters_str = result["parameters"].strip()
except TypeError:
# Did not match, NoneType is not scriptable
result = parse("{function_name}()", self.signature, case_sensitive=False)
function_name = result["function_name"]
function_name = result["function_name"].strip()
parameters_str = ""

# NOTE: Will fail if a text field has a default and that deafult contains a comma...
Expand Down
17 changes: 16 additions & 1 deletion src/test/test_pg_function.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

from sqlalchemy import text

from alembic_utils.pg_function import PGFunction
Expand All @@ -6,7 +8,7 @@

TO_UPPER = PGFunction(
schema="public",
signature="toUpper(some_text text default 'my text!')",
signature="toUpper (some_text text default 'my text!')",
definition="""
returns text
as
Expand All @@ -15,6 +17,19 @@
)


def test_trailing_whitespace_stripped():
sql_statements: List[str] = [
str(TO_UPPER.to_sql_statement_create()),
str(next(iter(TO_UPPER.to_sql_statement_create_or_replace()))),
str(TO_UPPER.to_sql_statement_drop()),
]

for statement in sql_statements:
print(statement)
assert '"toUpper"' in statement
assert not '"toUpper "' in statement


def test_create_revision(engine) -> None:
register_entities([TO_UPPER], entity_types=[PGFunction])

Expand Down

0 comments on commit a349e40

Please sign in to comment.