diff --git a/apollo/formsframework/utils.py b/apollo/formsframework/utils.py index a944ee913..a3bf854b3 100644 --- a/apollo/formsframework/utils.py +++ b/apollo/formsframework/utils.py @@ -8,6 +8,7 @@ from xlwt import Workbook from apollo.formsframework.models import Form +from apollo.submissions.qa.query_builder import build_expression from apollo.utils import generate_identifier gt_constraint_regex = re.compile(r'(?:.*\.\s*\>={0,1}\s*)(\d+)') @@ -182,6 +183,10 @@ def _process_qa_worksheet(qa_data): if 'name' in qa_dict: if current_name != qa_dict['name']: if current_check is not None: + if 'expression' not in current_check: + current_check.update( + expression=build_expression(current_check)) + current_check.pop('criteria', None) quality_checks.append(current_check) current_name = qa_dict['name'] current_check = { @@ -210,9 +215,15 @@ def _process_qa_worksheet(qa_data): 'comparator': qa_dict['relation'], 'rvalue': qa_dict['right'] } + qa_check.update(expression=build_expression(qa_check)) + qa_check.pop('comparator') + qa_check.pop('lvalue') + qa_check.pop('rvalue') quality_checks.append(qa_check) if current_check is not None: + current_check.update(expression=build_expression(current_check)) + current_check.pop('criteria', None) quality_checks.append(current_check) return quality_checks @@ -280,8 +291,7 @@ def export_form(form): 'accredited_voters_tag', 'invalid_votes_tag', 'registered_voters_tag', 'blank_votes_tag', 'quality_checks_enabled', 'vote_shares'] - qa_header = ['name', 'description', 'left', 'relation', 'right', - 'conjunction'] + qa_header = ['name', 'description', 'expression'] # output headers for col, value in enumerate(survey_header): @@ -415,22 +425,10 @@ def export_form(form): if quality_checks and qa_sheet: row = 1 for check in quality_checks: - if 'criteria' in check: - for term in check['criteria']: - qa_sheet.write(row, 0, check['name']) - qa_sheet.write(row, 1, check['description']) - qa_sheet.write(row, 2, term['lvalue']) - qa_sheet.write(row, 3, term['comparator']) - qa_sheet.write(row, 4, term['rvalue']) - qa_sheet.write(row, 5, term['conjunction']) - row += 1 - else: + if 'expression' in check: qa_sheet.write(row, 0, check['name']) qa_sheet.write(row, 1, check['description']) - qa_sheet.write(row, 2, check['lvalue']) - qa_sheet.write(row, 3, check['comparator']) - qa_sheet.write(row, 4, check['rvalue']) - qa_sheet.write(row, 5, '&&') + qa_sheet.write(row, 2, check['expression']) row += 1 return book diff --git a/apollo/formsframework/views_forms.py b/apollo/formsframework/views_forms.py index 954476650..e51bc844a 100644 --- a/apollo/formsframework/views_forms.py +++ b/apollo/formsframework/views_forms.py @@ -245,25 +245,9 @@ def quality_controls(view, form_id): quality_control['name'] = quality_check['name'] quality_control['description'] = quality_check['description'] - quality_control['criteria'] = [] - if 'criteria' in quality_check: - for index, criterion in enumerate(quality_check['criteria']): - quality_control['criteria'].append({ - 'lvalue': criterion['lvalue'], - 'comparator': criterion['comparator'], - 'rvalue': criterion['rvalue'], - 'conjunction': criterion['conjunction'], - 'id': str(index) - }) - else: - quality_control['criteria'].append({ - 'lvalue': quality_check['lvalue'], - 'comparator': quality_check['comparator'], - 'rvalue': quality_check['rvalue'], - 'conjunction': '&&', - 'id': '0' - }) + if 'expression' in quality_check: + quality_control['expression'] = quality_check['expression'] quality_controls.append(quality_control) @@ -421,8 +405,7 @@ def export_form(id): workbook.save(memory_file) memory_file.seek(0) current_timestamp = datetime.utcnow() - filename = slugify( - f'{form.name}-{current_timestamp:%Y %m %d %H%M%S}') + '.xls' + filename = slugify(f'{form.name}-{current_timestamp:%Y %m %d %H%M%S}.xls') return send_file( memory_file, attachment_filename=filename, diff --git a/apollo/submissions/qa/query_builder.py b/apollo/submissions/qa/query_builder.py index f6f253bce..61f016fb5 100644 --- a/apollo/submissions/qa/query_builder.py +++ b/apollo/submissions/qa/query_builder.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- '''Query builder module for checklist QA''' +import enum import operator as op from arpeggio import PTNodeVisitor, visit_parse_tree @@ -63,6 +64,13 @@ } +class OperandType(enum.IntEnum): + BOOLEAN = enum.auto() + NULL = enum.auto() + NUMERIC = enum.auto() + TEXT = enum.auto() + + class BaseVisitor(PTNodeVisitor): def __init__(self, defaults=True, **kwargs): @@ -70,15 +78,23 @@ def __init__(self, defaults=True, **kwargs): def visit_number(self, node, children): if node.value.isdigit(): - return int(node.value) + value = int(node.value) else: - return float(node.value) + value = float(node.value) + + return value, OperandType.NUMERIC + + def visit_null(self, node, children): + return node.value, OperandType.NULL def visit_factor(self, node, children): if len(children) == 1: return children[0] + multiplier = -1 if children[0] == '-' else 1 - return multiplier * children[-1] + value, op_type = children[-1] + + return multiplier * value, op_type def visit_value(self, node, children): return children[-1] @@ -87,26 +103,45 @@ def visit_exponent(self, node, children): if len(children) == 1: return children[0] - exponent = children[0] - for i in children[1:]: + exponent, l_op_type = children[0] + if l_op_type != OperandType.NUMERIC: + raise ValueError('Only numeric operands supported for *') + for item, r_op_type in children[1:]: # exponent **= i - exponent = func.pow(exponent, i) + if r_op_type != OperandType.NUMERIC: + raise ValueError('Only numeric operands supported for *') + exponent = func.pow(exponent, item) - return exponent + return exponent, l_op_type def visit_product(self, node, children): - product = children[0] + if len(children) == 1: + return children[0] + + product, l_op_type = children[0] for i in range(2, len(children), 2): + item, r_op_type = children[i] sign = children[i - 1] - product = OPERATIONS[sign](product, children[i]) - return product + if r_op_type != OperandType.NUMERIC: + raise ValueError(f'Only numeric operands supported for {sign}') + product = OPERATIONS[sign](product, item) + + return product, l_op_type def visit_sum(self, node, children): - total = children[0] + if len(children) == 1: + return children[0] + + total, l_op_type = children[0] for i in range(2, len(children), 2): + item, r_op_type = children[i] sign = children[i - 1] - total = OPERATIONS[sign](total, children[i]) + + if r_op_type != OperandType.NUMERIC: + raise ValueError(f'Only numeric operands supported for {sign}') + + total = OPERATIONS[sign](total, item) return total @@ -116,36 +151,76 @@ def visit_concat(self, node, children): self.uses_concat = True - operand = func.cast(children[0], String) + term, _ = children[0] + operand = func.cast(term, String) for i in children[1:]: operand = concat_op(operand, func.cast(i, String)) - return operand + return operand, OperandType.TEXT def visit_comparison(self, node, children): - if getattr(self, 'uses_concat', False): - comparison = func.cast(children[0], String) \ - if children[0] != 'NULL' else None + if len(children) == 1: + return children[0] + + uses_concat = getattr(self, 'uses_concat', False) + first_term, l_op_type = children[0] + if uses_concat: + comparison = func.cast(first_term, String) \ + if first_term != 'NULL' else None + l_op_type = OperandType.NULL if first_term == 'NULL' else l_op_type else: - comparison = children[0] if children[0] != 'NULL' else None + comparison = first_term if first_term != 'NULL' else None + l_op_type = OperandType.NULL if first_term == 'NULL' else l_op_type + for i in range(2, len(children), 2): sign = children[i - 1] - if getattr(self, 'uses_concat', False): - item = func.cast(children[i], String) \ - if children[i] != 'NULL' else None + term, r_op_type = children[i] + + if uses_concat: + item = func.cast(term, String) \ + if term != 'NULL' else None + r_op_type = OperandType.NULL if term == 'NULL' else r_op_type else: - item = children[i] if children[i] != 'NULL' else None + item = term if term != 'NULL' else None + r_op_type = OperandType.NULL if term == 'NULL' else r_op_type + + if l_op_type == OperandType.NULL or r_op_type == OperandType.NULL: + if sign not in ('!=', '='): + raise ValueError('Invalid comparison for null operand') + + if l_op_type != r_op_type: + if ( + l_op_type != OperandType.NULL and + r_op_type != OperandType.NULL + ): + raise ValueError('Cannot compare different types') + comparison = OPERATIONS[sign](comparison, item) - return comparison + return comparison, OperandType.BOOLEAN def visit_expression(self, node, children): - expression = children[0] if children[0] != 'NULL' else None + if len(children) == 1: + return children[0] + + first_term, l_op_type = children[0] + expression = first_term if first_term != 'NULL' else None + l_op_type = OperandType.NULL if first_term == 'NULL' else l_op_type + for i in range(2, len(children), 2): sign = children[i - 1] - expression = OPERATIONS[sign](expression, children[i]) + term, r_op_type = children[i] - return expression + if ( + l_op_type != OperandType.BOOLEAN or + r_op_type != OperandType.BOOLEAN + ): + raise ValueError( + 'Invalid operation for non-boolean expression') + + expression = OPERATIONS[sign](expression, term) + + return expression, OperandType.BOOLEAN class InlineQATreeVisitor(BaseVisitor): @@ -163,29 +238,41 @@ def visit_variable(self, node, children): if field['type'] == 'multiselect': return 'NULL' - return self.submission.data.get(var_name, 'NULL') + field_value = self.submission.data.get(var_name, 'NULL') + if field_value == 'NULL': + op_type = OperandType.NULL + else: + if FIELD_TYPE_CASTS.get(field['type']) == Integer: + op_type = OperandType.NUMERIC + elif FIELD_TYPE_CASTS.get(field['type']) == String: + op_type = OperandType.TEXT + else: + raise ValueError(f'Unknown data type for field {var_name}') + + return field_value, op_type def visit_lookup(self, node, children): top_level_attr, symbol, name = children + op_type = OperandType.NUMERIC if top_level_attr in ['location', 'participant']: attribute = getattr(self.submission, top_level_attr) if symbol == '.': - return getattr(attribute, name) + return getattr(attribute, name), op_type else: - return attribute.extra_data.get(name) + return attribute.extra_data.get(name), op_type else: - return getattr(self.submission, name) + return getattr(self.submission, name), op_type def visit_comparison(self, node, children): if len(children) > 1: # left and right are indices 0 and 2 respectively - left = children[0] - right = children[2] + left = children[0][0] + right = children[2][0] if isinstance(left, str) and isinstance(right, str): # both sides are NULL - return 'NULL' + return 'NULL', OperandType.NULL return super().visit_comparison(node, children) @@ -199,19 +286,21 @@ def __init__(self, defaults=True, **kwargs): def visit_lookup(self, node, children): top_level_attr, symbol, name = children + op_type = OperandType.NUMERIC if top_level_attr == 'location': if symbol == '.': - return getattr(Location, name).cast(Integer) + return getattr(Location, name).cast(Integer), op_type else: - return Location.extra_data[name].astext.cast(Integer) + return Location.extra_data[name].astext.cast(Integer), op_type elif top_level_attr == 'participant': if symbol == '.': - return getattr(Participant, name).cast(Integer) + return getattr(Participant, name).cast(Integer), op_type else: - return Participant.extra_data[name].astext.cast(Integer) + return ( + Participant.extra_data[name].astext.cast(Integer), op_type) else: - return getattr(Submission, name).cast(Integer) + return getattr(Submission, name).cast(Integer), op_type def visit_variable(self, node, children): var_name = node.value @@ -220,6 +309,10 @@ def visit_variable(self, node, children): self.lock_null = True return null() + field = self.form.get_field_by_tag(var_name) + if field['type'] == 'multiselect': + raise ValueError('QA not supported for multi-value fields') + # casting is necessary because PostgreSQL will throw # a fit if you attempt some operations that mix JSONB # with other types @@ -249,8 +342,15 @@ def visit_variable(self, node, children): else: self.prev_cast_type = cast_type if cast_type is not None: - return Submission.data[var_name].astext.cast(cast_type) - return Submission.data[var_name] + if cast_type == Integer: + op_type = OperandType.NUMERIC + elif cast_type == String: + op_type = OperandType.TEXT + return ( + Submission.data[var_name].astext.cast(cast_type), op_type) + + # this is an error + raise ValueError('Unknown value type') def generate_qa_query(expression, form): @@ -263,7 +363,7 @@ def generate_qa_query(expression, form): visitor = QATreeVisitor(form=form) - return visit_parse_tree(tree, visitor), visitor.variables + return visit_parse_tree(tree, visitor)[0], visitor.variables def generate_qa_queries(form): @@ -316,7 +416,7 @@ def get_logical_check_stats(query, form, condition): query = query.join( Participant, Participant.id == Submission.participant_id) - if 'null' in complete_expression.lower(): + if 'null' not in complete_expression.lower(): null_query = or_(*[ Submission.data[tag] == None # noqa for tag in question_codes @@ -329,7 +429,7 @@ def get_logical_check_stats(query, form, condition): (and_(null_query == False, qa_query == True, ~Submission.verified_fields.has_all(array(question_codes))), 'Flagged'), # noqa (and_(null_query == False, qa_query == True, Submission.verified_fields.has_all(array(question_codes))), 'Verified'), # noqa (and_(null_query == False, qa_query == False), 'OK'), # noqa - (or_(null_query == False, qa_query == None), 'Missing') # noqa + (or_(null_query == True, qa_query == None), 'Missing') # noqa ]) else: qa_case_query = case([ @@ -375,11 +475,13 @@ def get_inline_qa_status(submission, condition): # most likely return None, set() - return result, used_tags + return result[0], used_tags def build_expression(logical_check): - if 'criteria' in logical_check: + if 'expression' in logical_check: + control_expression = logical_check.get('expression') + elif 'criteria' in logical_check: control_expression = '' for index, cond in enumerate(logical_check['criteria']): diff --git a/apollo/templates/admin/quality_assurance.html b/apollo/templates/admin/quality_assurance.html index ab0f0df74..3d722a89f 100644 --- a/apollo/templates/admin/quality_assurance.html +++ b/apollo/templates/admin/quality_assurance.html @@ -38,9 +38,9 @@