From 83a45edf1179ac4c688f923f1c71673fd6c40d6c Mon Sep 17 00:00:00 2001 From: user1303836 Date: Sun, 8 Mar 2026 11:54:36 -0400 Subject: [PATCH] Accept arXiv categories and list URLs --- .../discord/cogs/source_management.py | 59 ++++++++-- tests/test_discord/test_source_management.py | 103 ++++++++++++++++++ 2 files changed, 152 insertions(+), 10 deletions(-) diff --git a/src/intelstream/discord/cogs/source_management.py b/src/intelstream/discord/cogs/source_management.py index 8aaac4c..048c280 100644 --- a/src/intelstream/discord/cogs/source_management.py +++ b/src/intelstream/discord/cogs/source_management.py @@ -26,6 +26,7 @@ _TWITTER_USERNAME_RE = re.compile(r"^[A-Za-z0-9_]{1,15}$") +_ARXIV_CATEGORY_RE = re.compile(r"^[A-Za-z0-9.-]+$") def _is_valid_twitter_username(username: str) -> bool: @@ -36,6 +37,45 @@ class InvalidSourceURLError(ValueError): pass +def _parse_arxiv_identifier(value: str) -> str: + candidate = value.strip() + if not candidate: + raise InvalidSourceURLError("Arxiv category cannot be empty.") + + if candidate.startswith(("arxiv.org/", "www.arxiv.org/")): + candidate = f"https://{candidate}" + + parsed = urlparse(candidate) + identifier = candidate + + if parsed.scheme or parsed.netloc: + host = parsed.netloc.lower() + if host not in ("arxiv.org", "www.arxiv.org"): + raise InvalidSourceURLError(f"Invalid Arxiv URL: {value}. Expected arxiv.org domain.") + + path = parsed.path.rstrip("/") + if path.startswith("/list/"): + parts = [part for part in path.split("/") if part] + if len(parts) < 2: + raise InvalidSourceURLError( + f"Invalid Arxiv URL: {value}. Could not extract category." + ) + identifier = parts[1] + elif path.startswith("/rss/"): + identifier = path.split("/rss/", 1)[1].strip("/") + else: + raise InvalidSourceURLError( + f"Invalid Arxiv URL: {value}. Expected an arxiv.org list or RSS URL." + ) + + if not identifier or not _ARXIV_CATEGORY_RE.fullmatch(identifier): + raise InvalidSourceURLError( + f"Invalid Arxiv category: {value}. Expected format like cs.AI or stat.ML." + ) + + return identifier + + def parse_source_identifier(source_type: SourceType, url: str) -> tuple[str, str | None]: parsed = urlparse(url) @@ -86,9 +126,7 @@ def parse_source_identifier(source_type: SourceType, url: str) -> tuple[str, str return identifier, url elif source_type == SourceType.ARXIV: - identifier = url.strip() - if not identifier: - raise InvalidSourceURLError("Arxiv category cannot be empty.") + identifier = _parse_arxiv_identifier(url) feed_url = f"https://arxiv.org/rss/{identifier}" return identifier, feed_url @@ -248,7 +286,14 @@ async def source_add( ) return - safe, error_msg = is_safe_url(url) + try: + identifier, feed_url = parse_source_identifier(stype, url) + except InvalidSourceURLError as e: + await interaction.followup.send(str(e), ephemeral=True) + return + + validation_url = feed_url if stype == SourceType.ARXIV and feed_url else url + safe, error_msg = is_safe_url(validation_url) if not safe: await interaction.followup.send(f"URL not allowed: {error_msg}", ephemeral=True) return @@ -298,12 +343,6 @@ async def source_add( ) return - try: - identifier, feed_url = parse_source_identifier(stype, url) - except InvalidSourceURLError as e: - await interaction.followup.send(str(e), ephemeral=True) - return - existing = await self.bot.repository.get_source_by_identifier(identifier) if existing: await interaction.followup.send( diff --git a/tests/test_discord/test_source_management.py b/tests/test_discord/test_source_management.py index 13bdd1f..966cc81 100644 --- a/tests/test_discord/test_source_management.py +++ b/tests/test_discord/test_source_management.py @@ -101,6 +101,39 @@ def test_parse_arxiv_empty(self): with pytest.raises(InvalidSourceURLError, match="cannot be empty"): parse_source_identifier(SourceType.ARXIV, " ") + def test_parse_arxiv_category(self): + identifier, feed_url = parse_source_identifier(SourceType.ARXIV, "cs.AI") + assert identifier == "cs.AI" + assert feed_url == "https://arxiv.org/rss/cs.AI" + + def test_parse_arxiv_list_url(self): + identifier, feed_url = parse_source_identifier( + SourceType.ARXIV, + "https://arxiv.org/list/cs.AI/", + ) + assert identifier == "cs.AI" + assert feed_url == "https://arxiv.org/rss/cs.AI" + + def test_parse_arxiv_recent_list_url(self): + identifier, feed_url = parse_source_identifier( + SourceType.ARXIV, + "https://arxiv.org/list/cs.AI/recent", + ) + assert identifier == "cs.AI" + assert feed_url == "https://arxiv.org/rss/cs.AI" + + def test_parse_arxiv_rss_url(self): + identifier, feed_url = parse_source_identifier( + SourceType.ARXIV, + "https://arxiv.org/rss/stat.ML", + ) + assert identifier == "stat.ML" + assert feed_url == "https://arxiv.org/rss/stat.ML" + + def test_parse_arxiv_wrong_domain(self): + with pytest.raises(InvalidSourceURLError, match=r"Expected arxiv\.org domain"): + parse_source_identifier(SourceType.ARXIV, "https://example.com/list/cs.AI/") + def test_parse_page_no_host(self): with pytest.raises(InvalidSourceURLError, match="No host found"): parse_source_identifier(SourceType.PAGE, "not-a-url") @@ -366,6 +399,76 @@ async def test_add_source_with_summarize_true(self, source_management, mock_bot) call_kwargs = mock_bot.repository.add_source.call_args.kwargs assert call_kwargs["skip_summary"] is False + async def test_add_arxiv_source_with_category_identifier(self, source_management, mock_bot): + interaction = MagicMock(spec=discord.Interaction) + interaction.response = MagicMock() + interaction.response.defer = AsyncMock() + interaction.followup = MagicMock() + interaction.followup.send = AsyncMock() + interaction.user = MagicMock() + interaction.user.id = 123 + interaction.guild_id = 456 + interaction.channel_id = 789 + + source_type_choice = MagicMock() + source_type_choice.value = "arxiv" + source_type_choice.name = "Arxiv" + + mock_bot.repository.get_source_by_identifier = AsyncMock(return_value=None) + mock_bot.repository.get_source_by_name = AsyncMock(return_value=None) + + mock_source = MagicMock() + mock_source.id = "new-source-id" + mock_bot.repository.add_source = AsyncMock(return_value=mock_source) + + await source_management.source_add.callback( + source_management, + interaction, + source_type=source_type_choice, + name="arxiv_cs.AI", + url="cs.AI", + ) + + mock_bot.repository.add_source.assert_called_once() + call_kwargs = mock_bot.repository.add_source.call_args.kwargs + assert call_kwargs["identifier"] == "cs.AI" + assert call_kwargs["feed_url"] == "https://arxiv.org/rss/cs.AI" + + async def test_add_arxiv_source_with_list_url(self, source_management, mock_bot): + interaction = MagicMock(spec=discord.Interaction) + interaction.response = MagicMock() + interaction.response.defer = AsyncMock() + interaction.followup = MagicMock() + interaction.followup.send = AsyncMock() + interaction.user = MagicMock() + interaction.user.id = 123 + interaction.guild_id = 456 + interaction.channel_id = 789 + + source_type_choice = MagicMock() + source_type_choice.value = "arxiv" + source_type_choice.name = "Arxiv" + + mock_bot.repository.get_source_by_identifier = AsyncMock(return_value=None) + mock_bot.repository.get_source_by_name = AsyncMock(return_value=None) + + mock_source = MagicMock() + mock_source.id = "new-source-id" + mock_bot.repository.add_source = AsyncMock(return_value=mock_source) + + await source_management.source_add.callback( + source_management, + interaction, + source_type=source_type_choice, + name="arxiv_cs.AI", + url="https://arxiv.org/list/cs.AI/", + ) + + mock_bot.repository.add_source.assert_called_once() + call_kwargs = mock_bot.repository.add_source.call_args.kwargs + assert call_kwargs["identifier"] == "cs.AI" + assert call_kwargs["feed_url"] == "https://arxiv.org/rss/cs.AI" + class TestSourceManagementList: async def test_list_sources_empty(self, source_management, mock_bot):