Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions openedx_tagging/core/tagging/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""
Utilities for tagging and taxonomy models
"""
from django.db.models import Aggregate, CharField
from django.db.models.expressions import Func
from django.db import connection as db_connection
from django.db.models import Aggregate, CharField, TextField
from django.db.models.expressions import Combinable, Func

RESERVED_TAG_CHARS = [
'\t', # Used in the database to separate tag levels in the "lineage" field
Expand Down Expand Up @@ -34,21 +35,48 @@ def as_sqlite(self, compiler, connection, **extra_context):
)


class StringAgg(Aggregate): # pylint: disable=abstract-method
class StringAgg(Aggregate, Combinable):
"""
Aggregate function that collects the values of some column across all rows,
and creates a string by concatenating those values, with "," as a separator.
and creates a string by concatenating those values, with a specified separator.

This is the same as Django's django.contrib.postgres.aggregates.StringAgg,
but this version works with MySQL and SQLite.
This version supports PostgreSQL (STRING_AGG), MySQL (GROUP_CONCAT), and SQLite.
"""
# Default function is for MySQL (GROUP_CONCAT)
function = 'GROUP_CONCAT'
template = '%(function)s(%(distinct)s%(expressions)s)'

def __init__(self, expression, distinct=False, **extra):
def __init__(self, expression, distinct=False, delimiter=',', **extra):
self.delimiter = delimiter
# Handle the distinct option and output type
distinct_str = 'DISTINCT ' if distinct else ''

extra.update({
'distinct': distinct_str,
'output_field': CharField(),
})

# Check the database backend (PostgreSQL, MySQL, or SQLite)
if 'postgresql' in db_connection.vendor.lower():
self.function = 'STRING_AGG'
self.template = '%(function)s(%(distinct)s%(expressions)s, %(delimiter)s)'
extra.update({
"delimiter": self.delimiter,
"output_field": TextField(),
})

# Initialize the parent class with the necessary parameters
super().__init__(
expression,
distinct='DISTINCT ' if distinct else '',
output_field=CharField(),
**extra,
)

# Implementing abstract methods from Combinable
def __rand__(self, other):
return self._combine(other, 'AND', False)

def __ror__(self, other):
return self._combine(other, 'OR', False)

def __rxor__(self, other):
return self._combine(other, 'XOR', False)