Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions examples/gather_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@

from lacquer import parser
from lacquer.tree import AliasedRelation
from lacquer.tree import Join
from lacquer.tree import DefaultTraversalVisitor
from lacquer.tree import FunctionCall
from lacquer.tree import Join
from lacquer.tree import QualifiedNameReference
from lacquer.tree import SingleColumn
from lacquer.tree import Table
Expand Down Expand Up @@ -48,13 +49,25 @@ def visit_query_specification(self, node, context):
self.tables.append(node.from_)
self.tables.reverse()

def get_all_qualified(expression, qualified=None):
if qualified is None:
qualified = []

if isinstance(expression, QualifiedNameReference):
qualified.append(expression)
elif isinstance(expression, FunctionCall):
for argument in expression.arguments:
get_all_qualified(argument, qualified)

return qualified

def print_column_resolution_order(columns, tables):
table_columns = []
tables_and_aliases = OrderedDict()
for i in range(len(columns)):
column = columns[i]
if isinstance(column.expression, QualifiedNameReference):
table_columns.append((column, i))
for qualified in get_all_qualified(column.expression):
table_columns.append((qualified, i))

for table in tables:
if isinstance(table, AliasedRelation):
Expand All @@ -67,8 +80,7 @@ def print_column_resolution_order(columns, tables):

print("\nTable Column Resolution:")
for (column, position) in table_columns:
names = column.expression.name.parts
column_name = names[-1]
names = column.name.parts
resolution = []
if len(names) > 1:
qualified_table_name = ".".join(names[:-1])
Expand Down Expand Up @@ -114,7 +126,8 @@ def visit_subquery_expression(self, node, context):
check_extracted_columns("select (select 1 from foo), a "
"from c join d using(foo) join e using (bar)")
check_extracted_columns("select 1, 20+a from c join d using(foo) join e using (bar)")

check_extracted_columns("select sum(foo) from a", True)
check_extracted_columns("select concat(concat(foo)) from a", True)
print("Running subquery checkers\n\n")
check_has_subquery("select a from b")
check_has_subquery("select a from (select a from b)")
Expand Down
2 changes: 1 addition & 1 deletion test/test_presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,4 +332,4 @@ def select_list_with_items(*args):


def simple_query(select, from_=None):
return Query(query_body=QuerySpecification(select=select, from_=from_))
return Query(query_body=QuerySpecification(select=select, from_=from_))