diff --git a/openedx_tagging/core/tagging/models/utils.py b/openedx_tagging/core/tagging/models/utils.py index 86a5f128f..8945e9d5d 100644 --- a/openedx_tagging/core/tagging/models/utils.py +++ b/openedx_tagging/core/tagging/models/utils.py @@ -1,8 +1,9 @@ """ Utilities for tagging and taxonomy models """ -from django.db.models import Aggregate, CharField -from django.db.models.expressions import Func +from django.db import connection as db_connection +from django.db.models import Aggregate, CharField, TextField +from django.db.models.expressions import Combinable, Func RESERVED_TAG_CHARS = [ '\t', # Used in the database to separate tag levels in the "lineage" field @@ -34,21 +35,48 @@ def as_sqlite(self, compiler, connection, **extra_context): ) -class StringAgg(Aggregate): # pylint: disable=abstract-method +class StringAgg(Aggregate, Combinable): """ Aggregate function that collects the values of some column across all rows, - and creates a string by concatenating those values, with "," as a separator. + and creates a string by concatenating those values, with a specified separator. - This is the same as Django's django.contrib.postgres.aggregates.StringAgg, - but this version works with MySQL and SQLite. + This version supports PostgreSQL (STRING_AGG), MySQL (GROUP_CONCAT), and SQLite. """ + # Default function is for MySQL (GROUP_CONCAT) function = 'GROUP_CONCAT' template = '%(function)s(%(distinct)s%(expressions)s)' - def __init__(self, expression, distinct=False, **extra): + def __init__(self, expression, distinct=False, delimiter=',', **extra): + self.delimiter = delimiter + # Handle the distinct option and output type + distinct_str = 'DISTINCT ' if distinct else '' + + extra.update({ + 'distinct': distinct_str, + 'output_field': CharField(), + }) + + # Check the database backend (PostgreSQL, MySQL, or SQLite) + if 'postgresql' in db_connection.vendor.lower(): + self.function = 'STRING_AGG' + self.template = '%(function)s(%(distinct)s%(expressions)s, %(delimiter)s)' + extra.update({ + "delimiter": self.delimiter, + "output_field": TextField(), + }) + + # Initialize the parent class with the necessary parameters super().__init__( expression, - distinct='DISTINCT ' if distinct else '', - output_field=CharField(), **extra, ) + + # Implementing abstract methods from Combinable + def __rand__(self, other): + return self._combine(other, 'AND', False) + + def __ror__(self, other): + return self._combine(other, 'OR', False) + + def __rxor__(self, other): + return self._combine(other, 'XOR', False)