Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions am_bot/cogs/quarantine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from am_bot.constants import (
QUARANTINE_HONEYPOT_CHANNEL_ID,
QUARANTINE_LOG_CHANNEL_ID,
QUARANTINE_ROLE_ID,
STAFF_ROLE_ID,
)
Expand Down Expand Up @@ -260,6 +261,47 @@ async def _purge_channel(
logger.warning(f"HTTP error purging in {channel.name}: {e}")
return 0

async def _log_quarantine(
self, message: discord.Message, reason: str
) -> None:
member = message.author
guild = message.guild

try:
log_channel = guild.get_channel(QUARANTINE_LOG_CHANNEL_ID)
if log_channel is None:
logger.error(
"Unable to retrieve quarantine log channel: "
f"{QUARANTINE_LOG_CHANNEL_ID}"
)
return

embed = discord.Embed(
title=" Quarantine Log",
description=(
f"{member.mention} has been quarantined "
"for violating rules."
),
color=discord.Color.red(),
timestamp=datetime.now(timezone.utc),
)

embed.add_field(
name="Member",
value=f"{member.name}#{member.discriminator}",
inline=False,
)
embed.add_field(
name="Message Content",
value=f"`{discord.utils.escape_markdown(message.content)}`",
inline=False,
)
embed.add_field(name="Reason", value=reason, inline=False)

await log_channel.send(embed=embed)
except Exception as e:
logger.error(f"Failed to send quarantine log: {e}")

async def _handle_quarantine(
self, message: discord.Message, reason: str
) -> None:
Expand All @@ -271,6 +313,8 @@ async def _handle_quarantine(
f"Quarantine triggered for {member} ({member.id}): {reason}"
)

await self._log_quarantine(message, reason)

# Delete the triggering message
try:
await message.delete()
Expand Down
1 change: 1 addition & 0 deletions am_bot/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@
STARBOARD_TEXT_CHANNEL_ID = 863887933089906718
INVITE_HELP_TEXT_CHANNEL_ID = 372253844953890817
QUARANTINE_HONEYPOT_CHANNEL_ID = 1453143051332616223
QUARANTINE_LOG_CHANNEL_ID = 1458280538518323303
96 changes: 96 additions & 0 deletions tests/test_quarantine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime, timedelta, timezone
from unittest.mock import MagicMock, patch

import discord
import pytest

from tests.conftest import (
Expand Down Expand Up @@ -563,6 +564,101 @@ async def test_on_message_spam_detection_triggers_quarantine(self, cog):
# Message should be deleted (quarantine triggered)
message.delete.assert_called_once()

@pytest.mark.asyncio
async def test_log_quarantine(self, cog):
"""Test quarantine logging functionality."""
# Create mock objects
member = make_mock_member(name="test_user", discriminator=1234)
message = make_mock_message(
author=member,
content="Test message",
)

# Mock log channel and embed creation
with (
patch("am_bot.cogs.quarantine.QUARANTINE_LOG_CHANNEL_ID", 12345),
):
log_channel = make_mock_channel(channel_id=12345)
message.guild.get_channel.return_value = log_channel

await cog._log_quarantine(message, "Test reason")

# Verify send was called with the correct embed data
log_channel.send.assert_called_once()

# Extract the embed from the call arguments
call_args = log_channel.send.call_args
embed = call_args.kwargs[
"embed"
] # Get the embed from keyword arguments

assert embed.title == " Quarantine Log"
assert (
f"{member.mention} has been quarantined for violating rules."
in embed.description
)
assert embed.color == discord.Color.red()

# Verify fields
assert len(embed.fields) == 3

member_field = embed.fields[0]
assert (
member_field.name == "Member"
and member_field.value
== f"{member.name}#{member.discriminator}"
)

content_field = embed.fields[1]
assert content_field.name == "Message Content"
assert f"`{message.content}`" in content_field.value

reason_field = embed.fields[2]
assert (
reason_field.name == "Reason"
and reason_field.value == "Test reason"
)

@pytest.mark.asyncio
async def test_log_quarantine_channel_not_found(self, cog):
"""Test log handling when channel doesn't exist."""

message = make_mock_message()
with (
patch("am_bot.cogs.quarantine.QUARANTINE_LOG_CHANNEL_ID", 54321),
patch("am_bot.cogs.quarantine.logger.error") as mock_logging_error,
):
message.guild.get_channel.return_value = None

await cog._log_quarantine(message, "Test reason")

# Verify error was logged
mock_logging_error.assert_called_once_with(
f"Unable to retrieve quarantine log channel: {54321}"
)

@pytest.mark.asyncio
async def test_log_quarantine_error_handling(self, cog):
"""Test error handling in _log_quarantine."""

message = make_mock_message()
log_channel = make_mock_channel(channel_id=12345)
with (
patch("am_bot.cogs.quarantine.QUARANTINE_LOG_CHANNEL_ID", 12345),
patch.object(discord.Embed, "__init__") as mock_embed,
patch("am_bot.cogs.quarantine.logger.error") as mock_logging_error,
):
message.guild.get_channel.return_value = log_channel
# Simulate an exception during embed creation
mock_embed.side_effect = Exception("Embed error")

await cog._log_quarantine(message, "Test reason")

# Verify error was logged
mock_logging_error.assert_called_once_with(
"Failed to send quarantine log: Embed error"
)


class TestQuarantineSlashCommand:
"""Tests for the /quarantine slash command."""
Expand Down