Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion learning_assistant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,34 @@ def render_prompt_template(request, user_id, course_run_id, unit_usage_key, cour
# buffer. This limit also prevents an error from occurring wherein unusually long prompt templates cause an
# error due to using too many tokens.
UNIT_CONTENT_MAX_CHAR_LENGTH = getattr(settings, 'CHAT_COMPLETION_UNIT_CONTENT_MAX_CHAR_LENGTH', 11750)
unit_content = unit_content[0:UNIT_CONTENT_MAX_CHAR_LENGTH]

# --- Proportional trimming logic ---
if isinstance(unit_content, list):
# Create a new list of dictionaries to hold trimmed content
trimmed_unit_content = []

total_chars = sum(len(str(item.get("content_text", "")).strip()) for item in unit_content) or 1
current_length = 0

for item in unit_content:
ctype = item.get("content_type", "")
text = str(item.get("content_text", "")).strip()

if not text:
trimmed_unit_content.append({"content_type": ctype, "content_text": ""})
continue

allowed_chars = max(1, int((len(text) / total_chars) * UNIT_CONTENT_MAX_CHAR_LENGTH))
trimmed_text = text[:allowed_chars]
trimmed_unit_content.append({"content_type": ctype, "content_text": trimmed_text})
current_length += len(trimmed_text)

# Keep the trimmed content as a list of dictionaries
unit_content = trimmed_unit_content

else:
# For non-list content, keep as string trimmed
unit_content = unit_content[0:UNIT_CONTENT_MAX_CHAR_LENGTH]

course_data = get_cache_course_data(course_id, ['skill_names', 'title'])
skill_names = course_data['skill_names']
Expand Down
57 changes: 49 additions & 8 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,26 @@ def test_get_block_content(self, mock_get_children_contents, mock_get_single_blo
self.assertEqual(items, content_items)

@ddt.data(
'This is content.',
''
'This is content.', # Short string case
'', # Empty string case
'A' * 20000, # Long string case to test trimming
[ # VIDEO content case
{'content_type': 'VIDEO', 'content_text': f"Video transcript {i} " + ("A" * 2000)} for i in range(10)
],
[ # TEXT content case
{'content_type': 'TEXT', 'content_text': f"Paragraph {i} " + ("B" * 1000)} for i in range(20)
],
[ # Mixed VIDEO + TEXT case
{'content_type': 'VIDEO', 'content_text': "Video intro " + ("C" * 1000)},
{'content_type': 'TEXT', 'content_text': "Some explanation " + ("D" * 1000)},
],
[ # Explicitly test empty content in a list
{'content_type': 'TEXT', 'content_text': ''},
],
)
@patch('learning_assistant.api.get_cache_course_data')
@patch('learning_assistant.api.get_block_content')
def test_render_prompt_template(self, unit_content, mock_get_content, mock_cache):
mock_get_content.return_value = (len(unit_content), unit_content)
skills_content = ['skills']
title = 'title'
mock_cache.return_value = {'skill_names': skills_content, 'title': title}
Expand All @@ -216,17 +229,45 @@ def test_render_prompt_template(self, unit_content, mock_get_content, mock_cache
unit_usage_key = 'block-v1:edX+A+B+type@vertical+block@verticalD'
course_id = 'edx+test'
template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '')
unit_content_max_length = settings.CHAT_COMPLETION_UNIT_CONTENT_MAX_CHAR_LENGTH

# Determine total content length for mock
if isinstance(unit_content, list):
total_length = sum(len(c['content_text']) for c in unit_content)
else:
total_length = len(unit_content)

mock_get_content.return_value = (total_length, unit_content)

prompt_text = render_prompt_template(
request, user_id, course_run_id, unit_usage_key, course_id, template_string
)

if unit_content:
self.assertIn(unit_content, prompt_text)
else:
self.assertNotIn('The following text is useful.', prompt_text)
self.assertIn(str(skills_content), prompt_text)
self.assertIn(title, prompt_text)
self.assertIn(str(skills_content), prompt_text)

if isinstance(unit_content, list):
with patch('learning_assistant.api.Environment') as mock_env_cls:
mock_template = mock_env_cls.return_value.from_string.return_value
mock_template.render.return_value = "rendered prompt"

prompt_text = render_prompt_template(
request, user_id, course_run_id, unit_usage_key, course_id, template_string
)

args, kwargs = mock_template.render.call_args
trimmed_unit_content = kwargs['unit_content']
total_trimmed_chars = sum(len(item['content_text']) for item in trimmed_unit_content)
self.assertLessEqual(total_trimmed_chars, unit_content_max_length)
self.assertEqual(prompt_text, "rendered prompt")
elif isinstance(unit_content, str):
if unit_content:
trimmed = unit_content[0:unit_content_max_length]
self.assertIn(trimmed, prompt_text)
if len(unit_content) > unit_content_max_length:
self.assertNotIn(unit_content, prompt_text)
else:
self.assertNotIn('The following text is useful.', prompt_text)

@patch('learning_assistant.api.get_cache_course_data', MagicMock())
@patch('learning_assistant.api.get_block_content')
Expand Down