From 67265dd73d26a347cb983292fe6dfb7d5c010ba1 Mon Sep 17 00:00:00 2001 From: naincy128 Date: Sat, 6 Sep 2025 10:01:46 +0000 Subject: [PATCH] fix: large system message in XpertAssistant --- learning_assistant/api.py | 29 +++++++++++++++++++- tests/test_api.py | 57 +++++++++++++++++++++++++++++++++------ 2 files changed, 77 insertions(+), 9 deletions(-) diff --git a/learning_assistant/api.py b/learning_assistant/api.py index 4867b70..2526f98 100644 --- a/learning_assistant/api.py +++ b/learning_assistant/api.py @@ -133,7 +133,34 @@ def render_prompt_template(request, user_id, course_run_id, unit_usage_key, cour # buffer. This limit also prevents an error from occurring wherein unusually long prompt templates cause an # error due to using too many tokens. UNIT_CONTENT_MAX_CHAR_LENGTH = getattr(settings, 'CHAT_COMPLETION_UNIT_CONTENT_MAX_CHAR_LENGTH', 11750) - unit_content = unit_content[0:UNIT_CONTENT_MAX_CHAR_LENGTH] + + # --- Proportional trimming logic --- + if isinstance(unit_content, list): + # Create a new list of dictionaries to hold trimmed content + trimmed_unit_content = [] + + total_chars = sum(len(str(item.get("content_text", "")).strip()) for item in unit_content) or 1 + current_length = 0 + + for item in unit_content: + ctype = item.get("content_type", "") + text = str(item.get("content_text", "")).strip() + + if not text: + trimmed_unit_content.append({"content_type": ctype, "content_text": ""}) + continue + + allowed_chars = max(1, int((len(text) / total_chars) * UNIT_CONTENT_MAX_CHAR_LENGTH)) + trimmed_text = text[:allowed_chars] + trimmed_unit_content.append({"content_type": ctype, "content_text": trimmed_text}) + current_length += len(trimmed_text) + + # Keep the trimmed content as a list of dictionaries + unit_content = trimmed_unit_content + + else: + # For non-list content, keep as string trimmed + unit_content = unit_content[0:UNIT_CONTENT_MAX_CHAR_LENGTH] course_data = get_cache_course_data(course_id, ['skill_names', 'title']) skill_names = course_data['skill_names'] diff --git a/tests/test_api.py b/tests/test_api.py index 2ea98c6..cc16cba 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -197,13 +197,26 @@ def test_get_block_content(self, mock_get_children_contents, mock_get_single_blo self.assertEqual(items, content_items) @ddt.data( - 'This is content.', - '' + 'This is content.', # Short string case + '', # Empty string case + 'A' * 20000, # Long string case to test trimming + [ # VIDEO content case + {'content_type': 'VIDEO', 'content_text': f"Video transcript {i} " + ("A" * 2000)} for i in range(10) + ], + [ # TEXT content case + {'content_type': 'TEXT', 'content_text': f"Paragraph {i} " + ("B" * 1000)} for i in range(20) + ], + [ # Mixed VIDEO + TEXT case + {'content_type': 'VIDEO', 'content_text': "Video intro " + ("C" * 1000)}, + {'content_type': 'TEXT', 'content_text': "Some explanation " + ("D" * 1000)}, + ], + [ # Explicitly test empty content in a list + {'content_type': 'TEXT', 'content_text': ''}, + ], ) @patch('learning_assistant.api.get_cache_course_data') @patch('learning_assistant.api.get_block_content') def test_render_prompt_template(self, unit_content, mock_get_content, mock_cache): - mock_get_content.return_value = (len(unit_content), unit_content) skills_content = ['skills'] title = 'title' mock_cache.return_value = {'skill_names': skills_content, 'title': title} @@ -216,17 +229,45 @@ def test_render_prompt_template(self, unit_content, mock_get_content, mock_cache unit_usage_key = 'block-v1:edX+A+B+type@vertical+block@verticalD' course_id = 'edx+test' template_string = getattr(settings, 'LEARNING_ASSISTANT_PROMPT_TEMPLATE', '') + unit_content_max_length = settings.CHAT_COMPLETION_UNIT_CONTENT_MAX_CHAR_LENGTH + + # Determine total content length for mock + if isinstance(unit_content, list): + total_length = sum(len(c['content_text']) for c in unit_content) + else: + total_length = len(unit_content) + + mock_get_content.return_value = (total_length, unit_content) prompt_text = render_prompt_template( request, user_id, course_run_id, unit_usage_key, course_id, template_string ) - if unit_content: - self.assertIn(unit_content, prompt_text) - else: - self.assertNotIn('The following text is useful.', prompt_text) - self.assertIn(str(skills_content), prompt_text) self.assertIn(title, prompt_text) + self.assertIn(str(skills_content), prompt_text) + + if isinstance(unit_content, list): + with patch('learning_assistant.api.Environment') as mock_env_cls: + mock_template = mock_env_cls.return_value.from_string.return_value + mock_template.render.return_value = "rendered prompt" + + prompt_text = render_prompt_template( + request, user_id, course_run_id, unit_usage_key, course_id, template_string + ) + + args, kwargs = mock_template.render.call_args + trimmed_unit_content = kwargs['unit_content'] + total_trimmed_chars = sum(len(item['content_text']) for item in trimmed_unit_content) + self.assertLessEqual(total_trimmed_chars, unit_content_max_length) + self.assertEqual(prompt_text, "rendered prompt") + elif isinstance(unit_content, str): + if unit_content: + trimmed = unit_content[0:unit_content_max_length] + self.assertIn(trimmed, prompt_text) + if len(unit_content) > unit_content_max_length: + self.assertNotIn(unit_content, prompt_text) + else: + self.assertNotIn('The following text is useful.', prompt_text) @patch('learning_assistant.api.get_cache_course_data', MagicMock()) @patch('learning_assistant.api.get_block_content')