diff --git a/cmd/digest.go b/cmd/digest.go index 68eaceeb..7b3612f9 100644 --- a/cmd/digest.go +++ b/cmd/digest.go @@ -85,7 +85,7 @@ func runDigestTest(cmd *cobra.Command, args []string) { sugar.Fatalw("Failed to initialize email service", "error", err) } - digestService := digest.NewService(db, emailService, cfg, sugar) + digestService := digest.NewService(db, emailService, nil, cfg, sugar) if dryRun { sugar.Info("Running digest test in DRY-RUN mode (no emails will be sent)...") @@ -145,7 +145,7 @@ func runDigestPreview(cmd *cobra.Command, args []string) { sugar.Warnw("Failed to initialize email service", "error", err) } - digestService := digest.NewService(db, emailService, cfg, sugar) + digestService := digest.NewService(db, emailService, nil, cfg, sugar) summary, err := digestService.GetGlobalEvidenceSummary(ctx) if err != nil { diff --git a/cmd/run.go b/cmd/run.go index 2e7bbf74..7b0111a6 100644 --- a/cmd/run.go +++ b/cmd/run.go @@ -14,6 +14,7 @@ import ( "github.com/compliance-framework/api/internal/service/digest" "github.com/compliance-framework/api/internal/service/email" "github.com/compliance-framework/api/internal/service/scheduler" + "github.com/compliance-framework/api/internal/service/worker" "github.com/spf13/cobra" "github.com/spf13/viper" "go.uber.org/zap" @@ -61,42 +62,38 @@ func RunServer(cmd *cobra.Command, args []string) { sugar.Warnw("Failed to initialize email service, digests will be disabled", "error", err) } - // Initialize digest service - digestService := digest.NewService(db, emailService, cfg, sugar) + // Initialize digest service (without worker service initially) + digestService := digest.NewService(db, emailService, nil, cfg, sugar) - // Initialize scheduler - sched := scheduler.NewCronScheduler(sugar) + // Initialize worker service with digest support + workerService, err := worker.NewServiceWithDigest(cfg.Worker, db, emailService, digestService, cfg, sugar) + if err != nil { + sugar.Fatalw("Failed to initialize worker service", "error", err) + } - // Register digest job using config - if cfg.DigestEnabled { - digestJob := digest.NewGlobalDigestJob(digestService, sugar) - if err := sched.ScheduleCron(cfg.DigestSchedule, digestJob); err != nil { - sugar.Warnw("Failed to schedule digest job", "schedule", cfg.DigestSchedule, "error", err) - } else { - sugar.Debugw("Digest job scheduled", "schedule", cfg.DigestSchedule) - } - } else { - sugar.Debugw("Digest scheduler disabled") + // Set worker service reference in digest service to avoid circular dependency + digestService.SetWorkerService(workerService) + + // Run River migrations + if err := workerService.Migrate(ctx); err != nil { + sugar.Fatalw("Failed to run River migrations", "error", err) + } + + // Start worker service + if err := workerService.Start(ctx); err != nil { + sugar.Fatalw("Failed to start worker service", "error", err) } - // Start the scheduler + // Initialize scheduler for other jobs (if any) + // Note: Digest scheduling is now handled by River's periodic jobs + sched := scheduler.NewCronScheduler(sugar) sched.Start() - defer func() { - stopCtx := sched.Stop() - // Wait for jobs to finish gracefully with a 10-second timeout - select { - case <-stopCtx.Done(): - sugar.Debug("All scheduled jobs completed gracefully") - case <-time.After(10 * time.Second): - sugar.Warn("Scheduler shutdown timeout, some jobs may not have completed") - } - }() metrics := api.NewMetricsHandler(ctx, sugar) server := api.NewServer(ctx, sugar, cfg, metrics) handler.RegisterHandlers(server, sugar, db, cfg, digestService, sched) oscal.RegisterHandlers(server, sugar, db, cfg) - auth.RegisterHandlers(server, sugar, db, cfg, metrics) + auth.RegisterHandlers(server, sugar, db, cfg, metrics, emailService, workerService) sugar.Infow("Allowed Origins", "origins", cfg.APIAllowedOrigins) server.PrintRoutes() @@ -109,4 +106,24 @@ func RunServer(cmd *cobra.Command, args []string) { if err := server.Start(cfg.AppPort); err != nil { sugar.Fatalw("Failed to start server", "error", err) } + + // Note: Defer statements are registered in reverse order of execution. + // This ensures proper shutdown order: scheduler -> worker service + defer func() { + // Stop worker service last (after scheduler has stopped) + if err := workerService.Stop(ctx); err != nil { + sugar.Errorw("Failed to stop worker service", "error", err) + } + }() + + defer func() { + // Stop scheduler first + stopCtx := sched.Stop() + select { + case <-stopCtx.Done(): + sugar.Debug("All scheduled jobs completed gracefully") + case <-time.After(10 * time.Second): + sugar.Warn("Scheduler shutdown timeout, some jobs may not have completed") + } + }() } diff --git a/go.mod b/go.mod index 5c7f06d0..f188107d 100644 --- a/go.mod +++ b/go.mod @@ -13,11 +13,13 @@ require ( github.com/go-playground/validator/v10 v10.28.0 github.com/golang-jwt/jwt/v5 v5.3.0 github.com/google/uuid v1.6.0 - github.com/jackc/pgx/v5 v5.7.6 + github.com/jackc/pgx/v5 v5.8.0 github.com/joho/godotenv v1.5.1 github.com/labstack/echo-contrib v0.17.4 github.com/labstack/echo/v4 v4.13.4 github.com/prometheus/client_golang v1.23.2 + github.com/riverqueue/river v0.30.1 + github.com/riverqueue/river/riverdriver/riverpgxv5 v0.30.1 github.com/robfig/cron/v3 v3.0.1 github.com/schollz/progressbar/v3 v3.18.0 github.com/spf13/cobra v1.10.2 @@ -135,6 +137,9 @@ require ( github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.67.4 // indirect github.com/prometheus/procfs v0.19.2 // indirect + github.com/riverqueue/river/riverdriver v0.30.1 // indirect + github.com/riverqueue/river/rivershared v0.30.1 // indirect + github.com/riverqueue/river/rivertype v0.30.1 // indirect github.com/rivo/uniseg v0.4.7 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/sagikazarmark/locafero v0.12.0 // indirect @@ -144,9 +149,14 @@ require ( github.com/spf13/afero v1.15.0 // indirect github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/swaggo/files/v2 v2.0.2 // indirect github.com/tdewolff/parse/v2 v2.7.15 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect github.com/urfave/cli/v2 v2.3.0 // indirect @@ -159,15 +169,16 @@ require ( go.opentelemetry.io/otel/metric v1.35.0 // indirect go.opentelemetry.io/otel/sdk v1.34.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect + go.uber.org/goleak v1.3.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect - golang.org/x/mod v0.31.0 // indirect + golang.org/x/mod v0.32.0 // indirect golang.org/x/net v0.48.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.39.0 // indirect golang.org/x/term v0.38.0 // indirect - golang.org/x/text v0.32.0 // indirect + golang.org/x/text v0.33.0 // indirect golang.org/x/time v0.14.0 // indirect golang.org/x/tools v0.40.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20250218202821-56aae31c358a // indirect diff --git a/go.sum b/go.sum index 0195fffe..cbc73302 100644 --- a/go.sum +++ b/go.sum @@ -256,12 +256,14 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/invopop/yaml v0.2.0 h1:7zky/qH+O0DwAyoobXUqvVBwgBFRxKoQ/3FjcVpjTMY= github.com/invopop/yaml v0.2.0/go.mod h1:2XuRLgs/ouIrW3XNzuNj7J3Nvu/Dig5MXvbCEdiBN3Q= +github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438 h1:Dj0L5fhJ9F82ZJyVOmBx6msDp/kfd1t9GRfny/mfJA0= +github.com/jackc/pgerrcode v0.0.0-20240316143900-6e2875d9b438/go.mod h1:a/s9Lp5W7n/DD0VrVoyJ00FbP2ytTPDVOivvn2bMlds= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= -github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/pgx/v5 v5.8.0 h1:TYPDoleBBme0xGSAX3/+NujXXtpZn9HBONkQC7IEZSo= +github.com/jackc/pgx/v5 v5.8.0/go.mod h1:QVeDInX2m9VyzvNeiCJVjCkNFqzsNb43204HshNSZKw= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jdkato/prose v1.2.1 h1:Fp3UnJmLVISmlc57BgKUzdjr0lOtjqTZicL3PaYy6cU= @@ -377,6 +379,16 @@ github.com/prometheus/common v0.67.4 h1:yR3NqWO1/UyO1w2PhUvXlGQs/PtFmoveVO0KZ4+L github.com/prometheus/common v0.67.4/go.mod h1:gP0fq6YjjNCLssJCQp0yk4M8W6ikLURwkdd/YKtTbyI= github.com/prometheus/procfs v0.19.2 h1:zUMhqEW66Ex7OXIiDkll3tl9a1ZdilUOd/F6ZXw4Vws= github.com/prometheus/procfs v0.19.2/go.mod h1:M0aotyiemPhBCM0z5w87kL22CxfcH05ZpYlu+b4J7mw= +github.com/riverqueue/river v0.30.1 h1:lpwmDT3zD+iDtF4tD50e/Y23UHpIeBUffVTDr2khN+s= +github.com/riverqueue/river v0.30.1/go.mod h1:x9tVfiCrbOctSAmaYP00iE5YlO8zh3Y9leFk6wP6aCk= +github.com/riverqueue/river/riverdriver v0.30.1 h1:p04cz/Ald1Js/STZ9qYrY5/TBJgjQeVPFltxidFYBBo= +github.com/riverqueue/river/riverdriver v0.30.1/go.mod h1:WBB9w6LftQtoZgRhNstqhP7MyBKt09XJkzluSNwMMoY= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.30.1 h1:nEStDftvm2jvGlJLliJR+n24PCJsoc4CgGzuop2Yzig= +github.com/riverqueue/river/riverdriver/riverpgxv5 v0.30.1/go.mod h1:4oSf8jYWZaEwmJ3R5LmOMiGlV9uuvCWOJ3uyBfTwWCc= +github.com/riverqueue/river/rivershared v0.30.1 h1:ytYlTtMppDV2rJRJ2j55mNf9uQDMPFudOmT4le6/9Ig= +github.com/riverqueue/river/rivershared v0.30.1/go.mod h1:PfmUHWkF6/fJ1CpjC4cG8eKciBXgMuIHgcRcIuHMc34= +github.com/riverqueue/river/rivertype v0.30.1 h1:jR7M5UlkA7KRxEbII+LOkD9oQMMz60AEdHh2We1APHY= +github.com/riverqueue/river/rivertype v0.30.1/go.mod h1:rWpgI59doOWS6zlVocROcwc00fZ1RbzRwsRTU8CDguw= github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= @@ -437,6 +449,17 @@ github.com/testcontainers/testcontainers-go/modules/postgres v0.37.0 h1:hsVwFkS6 github.com/testcontainers/testcontainers-go/modules/postgres v0.37.0/go.mod h1:Qj/eGbRbO/rEYdcRLmN+bEojzatP/+NS1y8ojl2PQsc= github.com/tetratelabs/wazero v1.8.0 h1:iEKu0d4c2Pd+QSRieYbnQC9yiFlMS9D+Jr0LsRmcF4g= github.com/tetratelabs/wazero v1.8.0/go.mod h1:yAI0XTsMBhREkM/YDAK/zNou3GoiAce1P6+rp/wQhjs= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/match v1.2.0 h1:0pt8FlkOwjN2fPt4bIl4BoNxb98gGHN2ObFEDkrfZnM= +github.com/tidwall/match v1.2.0/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= @@ -503,8 +526,8 @@ golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= -golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= +golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -563,8 +586,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/api/handler/auth/api.go b/internal/api/handler/auth/api.go index c58bddae..d18a3dab 100644 --- a/internal/api/handler/auth/api.go +++ b/internal/api/handler/auth/api.go @@ -5,21 +5,25 @@ import ( "github.com/compliance-framework/api/internal/config" "github.com/compliance-framework/api/internal/service/email" "github.com/compliance-framework/api/internal/service/sso" + "github.com/compliance-framework/api/internal/service/worker" "go.uber.org/zap" "gorm.io/gorm" ) -func RegisterHandlers(server *api.Server, logger *zap.SugaredLogger, db *gorm.DB, cfg *config.Config, metrics *api.PrometheusMetrics) { +func RegisterHandlers(server *api.Server, logger *zap.SugaredLogger, db *gorm.DB, cfg *config.Config, metrics *api.PrometheusMetrics, emailService *email.Service, workerService *worker.Service) { authGroup := server.API().Group("/auth") - // Initialize email service - emailService, err := email.NewService(cfg.Email, logger) - if err != nil { - logger.Warnw("Failed to initialize email service", "error", err) - emailService = nil // Set to nil so handlers can check if it's available + // Use provided email service or create a new one + var err error + if emailService == nil { + emailService, err = email.NewService(cfg.Email, logger) + if err != nil { + logger.Warnw("Failed to initialize email service", "error", err) + emailService = nil // Set to nil so handlers can check if it's available + } } - authHandler := NewAuthHandler(logger, db, cfg, metrics, emailService) + authHandler := NewAuthHandler(logger, db, cfg, metrics, emailService, workerService) authHandler.Register(authGroup) ssoService, err := sso.NewService(cfg.SSO, logger) diff --git a/internal/api/handler/auth/auth.go b/internal/api/handler/auth/auth.go index 596026d0..5f074b87 100644 --- a/internal/api/handler/auth/auth.go +++ b/internal/api/handler/auth/auth.go @@ -14,26 +14,29 @@ import ( "github.com/compliance-framework/api/internal/service/email" emailtypes "github.com/compliance-framework/api/internal/service/email/types" "github.com/compliance-framework/api/internal/service/relational" + "github.com/compliance-framework/api/internal/service/worker" "github.com/labstack/echo/v4" "go.uber.org/zap" "gorm.io/gorm" ) type AuthHandler struct { - sugar *zap.SugaredLogger - db *gorm.DB - config *config.Config - metrics *api.PrometheusMetrics - emailService *email.Service + sugar *zap.SugaredLogger + db *gorm.DB + config *config.Config + metrics *api.PrometheusMetrics + emailService *email.Service + workerService *worker.Service } -func NewAuthHandler(logger *zap.SugaredLogger, db *gorm.DB, config *config.Config, metrics *api.PrometheusMetrics, emailService *email.Service) *AuthHandler { +func NewAuthHandler(logger *zap.SugaredLogger, db *gorm.DB, config *config.Config, metrics *api.PrometheusMetrics, emailService *email.Service, workerService *worker.Service) *AuthHandler { return &AuthHandler{ - sugar: logger, - db: db, - config: config, - metrics: metrics, - emailService: emailService, + sugar: logger, + db: db, + config: config, + metrics: metrics, + emailService: emailService, + workerService: workerService, } } @@ -350,13 +353,33 @@ func (h *AuthHandler) ForgotPassword(ctx echo.Context) error { TextBody: textBody, } - _, err = h.emailService.Send(ctx.Request().Context(), message) - if err != nil { - h.sugar.Errorw("Failed to send password reset email", "error", err, "email", user.Email) - return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) - } + // Enqueue email job instead of sending directly + if h.workerService != nil && h.workerService.IsStarted() { + args := &worker.SendEmailArgs{ + From: h.getDefaultFromAddress(), + To: message.To, + Subject: message.Subject, + HTMLBody: message.HTMLBody, + TextBody: message.TextBody, + } - h.sugar.Infow("Password reset email sent", "email", user.Email) + err = h.workerService.EnqueueSendEmail(ctx.Request().Context(), args) + if err != nil { + h.sugar.Errorw("Failed to enqueue password reset email", "error", err, "email", user.Email) + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + + h.sugar.Infow("Password reset email enqueued", "email", user.Email) + } else { + // Fallback to direct sending if worker is not available + _, err = h.emailService.Send(ctx.Request().Context(), message) + if err != nil { + h.sugar.Errorw("Failed to send password reset email", "error", err, "email", user.Email) + return ctx.JSON(http.StatusInternalServerError, api.NewError(err)) + } + + h.sugar.Infow("Password reset email sent", "email", user.Email) + } return ctx.JSON(http.StatusOK, handler.GenericDataResponse[string]{ Data: "If an account with this email exists, a password reset link has been sent.", @@ -433,3 +456,11 @@ func (h *AuthHandler) PasswordReset(ctx echo.Context) error { Data: "Password has been reset successfully", }) } + +// getDefaultFromAddress returns the default From address from the email service configuration +func (h *AuthHandler) getDefaultFromAddress() string { + if h.emailService == nil { + return "" + } + return h.emailService.GetDefaultFromAddress() +} diff --git a/internal/api/handler/auth/auth_integration_test.go b/internal/api/handler/auth/auth_integration_test.go index 884a4453..c344c857 100644 --- a/internal/api/handler/auth/auth_integration_test.go +++ b/internal/api/handler/auth/auth_integration_test.go @@ -51,7 +51,7 @@ func (suite *AuthAPIIntegrationSuite) SetupSuite() { suite.logger = logger.Sugar() metrics := api.NewMetricsHandler(context.Background(), suite.logger) suite.server = api.NewServer(context.Background(), suite.logger, suite.Config, metrics) - RegisterHandlers(suite.server, suite.logger, suite.DB, suite.Config, metrics) + RegisterHandlers(suite.server, suite.logger, suite.DB, suite.Config, metrics, nil, nil) fmt.Println("Server initialized") } diff --git a/internal/api/handler/auth/password_reset_test.go b/internal/api/handler/auth/password_reset_test.go index d4aa4016..0a2fd8cf 100644 --- a/internal/api/handler/auth/password_reset_test.go +++ b/internal/api/handler/auth/password_reset_test.go @@ -48,7 +48,7 @@ func setupTestAuthHandler(t *testing.T) *AuthHandler { metrics := api.NewMetricsHandler(context.TODO(), logger) // Create auth handler without email service for testing - authHandler := NewAuthHandler(logger, db, cfg, metrics, nil) + authHandler := NewAuthHandler(logger, db, cfg, metrics, nil, nil) return authHandler } diff --git a/internal/api/handler/digest_integration_test.go b/internal/api/handler/digest_integration_test.go index 824486ed..a5bb1538 100644 --- a/internal/api/handler/digest_integration_test.go +++ b/internal/api/handler/digest_integration_test.go @@ -92,7 +92,7 @@ func (suite *DigestApiIntegrationSuite) SetupSuite() { suite.mockScheduler = NewMockScheduler() // Create digest handler - digestService := digest.NewService(suite.DB, suite.emailService, suite.Config, suite.logger) + digestService := digest.NewService(suite.DB, suite.emailService, nil, suite.Config, suite.logger) suite.digestHandler = NewDigestHandler(digestService, suite.mockScheduler, suite.logger) // Setup server @@ -197,7 +197,7 @@ func (suite *DigestApiIntegrationSuite) TestTriggerDigestWithNilScheduler() { suite.Require().NoError(err) // Create handler with nil scheduler - digestService := digest.NewService(suite.DB, suite.emailService, suite.Config, suite.logger) + digestService := digest.NewService(suite.DB, suite.emailService, nil, suite.Config, suite.logger) nilSchedulerHandler := NewDigestHandler(digestService, nil, suite.logger) // Create a temporary echo context for testing diff --git a/internal/config/config.go b/internal/config/config.go index 64557bfa..4fc54999 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -33,6 +33,7 @@ type Config struct { WebBaseURL string SSO *SSOConfig Email *EmailConfig + Worker *WorkerConfig EvidenceDefaultExpiryMonths int // Default expiration in months for evidence without explicit expiry DigestEnabled bool // Enable or disable the digest scheduler DigestSchedule string // Cron schedule for digest emails @@ -162,6 +163,18 @@ func NewConfig(logger *zap.SugaredLogger) *Config { digestSchedule = "@weekly" } + // Worker configuration + workerConfig := DefaultWorkerConfig() + if viper.IsSet("worker_enabled") { + workerConfig.Enabled = viper.GetBool("worker_enabled") + } + if viper.IsSet("worker_count") { + workerConfig.Workers = viper.GetInt("worker_count") + } + if viper.IsSet("worker_queue") { + workerConfig.Queue = viper.GetString("worker_queue") + } + return &Config{ AppPort: appPort, Environment: environment, @@ -177,6 +190,7 @@ func NewConfig(logger *zap.SugaredLogger) *Config { WebBaseURL: webBaseURL, SSO: ssoConfig, Email: emailConfig, + Worker: workerConfig, EvidenceDefaultExpiryMonths: evidenceDefaultExpiryMonths, DigestEnabled: digestEnabled, DigestSchedule: digestSchedule, diff --git a/internal/config/worker.go b/internal/config/worker.go new file mode 100644 index 00000000..a42edde3 --- /dev/null +++ b/internal/config/worker.go @@ -0,0 +1,38 @@ +package config + +// WorkerConfig contains configuration for background workers +// Environment variables: +// - CCF_WORKER_ENABLED: Enable/disable workers (default: true) +// - CCF_WORKER_COUNT: Number of concurrent workers (default: 5) +// - CCF_WORKER_QUEUE: Queue name to process (default: "email") +type WorkerConfig struct { + // Enabled determines if workers should be started + Enabled bool `mapstructure:"enabled"` + + // Number of worker goroutines to run + Workers int `mapstructure:"workers"` + + // Queue is the name of the queue to work on + Queue string `mapstructure:"queue"` + + // RetryPolicy defines how jobs should be retried + RetryPolicy RetryPolicyConfig `mapstructure:"retry_policy"` +} + +// RetryPolicyConfig defines retry behavior for jobs +type RetryPolicyConfig struct { + // MaxAttempts is the maximum number of attempts for a job + MaxAttempts int `mapstructure:"max_attempts"` +} + +// DefaultWorkerConfig returns a default worker configuration +func DefaultWorkerConfig() *WorkerConfig { + return &WorkerConfig{ + Enabled: true, + Workers: 5, + Queue: "email", // Default to email queue to match job configuration + RetryPolicy: RetryPolicyConfig{ + MaxAttempts: 5, + }, + } +} diff --git a/internal/service/digest/service.go b/internal/service/digest/service.go index 5a1b5c18..24df6511 100644 --- a/internal/service/digest/service.go +++ b/internal/service/digest/service.go @@ -9,6 +9,7 @@ import ( "github.com/compliance-framework/api/internal/service/email" "github.com/compliance-framework/api/internal/service/email/types" "github.com/compliance-framework/api/internal/service/relational" + "github.com/compliance-framework/api/internal/service/worker" "go.uber.org/zap" "gorm.io/gorm" ) @@ -39,19 +40,21 @@ type EvidenceItem struct { // Service handles digest generation and delivery type Service struct { - db *gorm.DB - emailService *email.Service - config *config.Config - logger *zap.SugaredLogger + db *gorm.DB + emailService *email.Service + workerService *worker.Service + config *config.Config + logger *zap.SugaredLogger } // NewService creates a new digest service -func NewService(db *gorm.DB, emailService *email.Service, cfg *config.Config, logger *zap.SugaredLogger) *Service { +func NewService(db *gorm.DB, emailService *email.Service, workerService *worker.Service, cfg *config.Config, logger *zap.SugaredLogger) *Service { return &Service{ - db: db, - emailService: emailService, - config: cfg, - logger: logger, + db: db, + emailService: emailService, + workerService: workerService, + config: cfg, + logger: logger, } } @@ -209,16 +212,36 @@ func (s *Service) SendDigestEmail(ctx context.Context, user *relational.User, su TextBody: textContent, } - result, err := s.emailService.Send(ctx, message) - if err != nil { - return fmt.Errorf("failed to send digest email: %w", err) - } + // Enqueue email job instead of sending directly + if s.workerService != nil && s.workerService.IsStarted() { + args := &worker.SendEmailArgs{ + From: s.getDefaultFromAddress(), + To: message.To, + Subject: message.Subject, + HTMLBody: message.HTMLBody, + TextBody: message.TextBody, + } + + err = s.workerService.EnqueueSendEmail(ctx, args) + if err != nil { + return fmt.Errorf("failed to enqueue digest email: %w", err) + } + + s.logger.Debugw("Digest email enqueued", "user", user.Email) + } else { + // Fallback to direct sending if worker is not available + result, err := s.emailService.Send(ctx, message) + if err != nil { + return fmt.Errorf("failed to send digest email: %w", err) + } - if !result.Success { - return fmt.Errorf("digest email send failed: %s", result.Error) + if !result.Success { + return fmt.Errorf("digest email send failed: %s", result.Error) + } + + s.logger.Debugw("Digest email sent", "user", user.Email, "messageId", result.MessageID) } - s.logger.Debugw("Digest email sent", "user", user.Email, "messageId", result.MessageID) return nil } @@ -272,3 +295,16 @@ func (s *Service) SendGlobalDigest(ctx context.Context) error { return nil } + +// SetWorkerService sets the worker service reference (used to avoid circular dependency) +func (s *Service) SetWorkerService(workerService *worker.Service) { + s.workerService = workerService +} + +// getDefaultFromAddress returns the default From address from the email service configuration +func (s *Service) getDefaultFromAddress() string { + if s.emailService == nil { + return "" + } + return s.emailService.GetDefaultFromAddress() +} diff --git a/internal/service/email/service.go b/internal/service/email/service.go index 5c13e8da..56ce3518 100644 --- a/internal/service/email/service.go +++ b/internal/service/email/service.go @@ -99,7 +99,33 @@ func (s *Service) SendWithProvider(ctx context.Context, providerName string, mes // IsEnabled returns true if the email service is enabled func (s *Service) IsEnabled() bool { - return s.config != nil && s.config.Enabled && s.provider != nil + return s.config != nil && s.config.Enabled +} + +// GetDefaultFromAddress returns the default From address from the email service configuration +func (s *Service) GetDefaultFromAddress() string { + if s == nil || !s.IsEnabled() { + return "" + } + + emailConfig := s.GetConfig() + if emailConfig == nil { + return "" + } + + defaultProvider := emailConfig.GetDefaultProvider() + if defaultProvider == nil { + return "" + } + + switch provider := defaultProvider.(type) { + case *config.SMTPConfig: + return provider.From + case *config.SESConfig: + return provider.From + default: + return "" + } } // IsHealthy checks if the email service is healthy diff --git a/internal/service/scheduler/cron.go b/internal/service/scheduler/cron.go index 3de16e28..2807f2be 100644 --- a/internal/service/scheduler/cron.go +++ b/internal/service/scheduler/cron.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sync" + "time" "github.com/robfig/cron/v3" "go.uber.org/zap" @@ -31,6 +32,21 @@ func NewCronScheduler(logger *zap.SugaredLogger) *CronScheduler { } } +// ParseCronNext parses a cron expression and returns the next scheduled time +// Supports 6-field cron syntax (with seconds) and descriptors like @weekly, @daily, etc. +// Format: second minute hour day month weekday +// +// @weekly = Monday 00:00:00 UTC +func ParseCronNext(cronExpr string, from time.Time) (time.Time, error) { + // Use same parser as CronScheduler - supports seconds (6-field cron) + parser := cron.NewParser(cron.Second | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor) + schedule, err := parser.Parse(cronExpr) + if err != nil { + return time.Time{}, fmt.Errorf("failed to parse cron expression %q: %w", cronExpr, err) + } + return schedule.Next(from), nil +} + // Schedule adds a job to run on the given schedule func (s *CronScheduler) Schedule(schedule Schedule, job Job) error { var cronExpr string diff --git a/internal/service/worker/jobs.go b/internal/service/worker/jobs.go new file mode 100644 index 00000000..b232b309 --- /dev/null +++ b/internal/service/worker/jobs.go @@ -0,0 +1,330 @@ +package worker + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/compliance-framework/api/internal/service/email/types" + "github.com/riverqueue/river" +) + +// Job types for email processing +const ( + JobTypeSendEmail = "send_email" + JobTypeSendEmailFrom = "send_email_from" + JobTypeSendGlobalDigest = "send_global_digest" +) + +// SendEmailArgs represents the arguments for sending an email +type SendEmailArgs struct { + // Email message fields + From string `json:"from"` + To []string `json:"to"` + Cc []string `json:"cc,omitempty"` + Bcc []string `json:"bcc,omitempty"` + Subject string `json:"subject"` + HTMLBody string `json:"html_body,omitempty"` + TextBody string `json:"text_body,omitempty"` + Attachments []types.Attachment `json:"attachments,omitempty"` + Headers map[string]string `json:"headers,omitempty"` +} + +// SendEmailFromArgs represents the arguments for sending an email from a specific provider +type SendEmailFromArgs struct { + // Provider to use for sending + Provider string `json:"provider"` + + // Email message fields + From string `json:"from"` + To []string `json:"to"` + Cc []string `json:"cc,omitempty"` + Bcc []string `json:"bcc,omitempty"` + Subject string `json:"subject"` + HTMLBody string `json:"html_body,omitempty"` + TextBody string `json:"text_body,omitempty"` + Attachments []types.Attachment `json:"attachments,omitempty"` + Headers map[string]string `json:"headers,omitempty"` +} + +// SendGlobalDigestArgs represents the arguments for sending global digest +type SendGlobalDigestArgs struct { + // No arguments needed - digest service will fetch data +} + +// Kind returns the job kind for River +func (SendEmailArgs) Kind() string { return JobTypeSendEmail } + +// Kind returns the job kind for River +func (SendEmailFromArgs) Kind() string { return JobTypeSendEmailFrom } + +// Kind returns the job kind for River +func (SendGlobalDigestArgs) Kind() string { return JobTypeSendGlobalDigest } + +// EmailService interface for dependency injection +type EmailService interface { + Send(ctx context.Context, message *types.Message) (*types.SendResult, error) + SendWithProvider(ctx context.Context, providerName string, message *types.Message) (*types.SendResult, error) +} + +// Logger interface for logging +type Logger interface { + Infow(msg string, keysAndValues ...interface{}) + Errorw(msg string, keysAndValues ...interface{}) + Warnw(msg string, keysAndValues ...interface{}) + Debugw(msg string, keysAndValues ...interface{}) +} + +// DigestService interface for dependency injection +type DigestService interface { + SendGlobalDigest(ctx context.Context) error +} + +// Timeout returns the timeout for email jobs +func (SendEmailArgs) Timeout() time.Duration { + return 30 * time.Second +} + +// Timeout returns the timeout for email jobs +func (SendEmailFromArgs) Timeout() time.Duration { + return 30 * time.Second +} + +// Timeout returns the timeout for digest jobs (longer due to multiple emails) +func (SendGlobalDigestArgs) Timeout() time.Duration { + return 5 * time.Minute +} + +// SendEmailWorker handles sending email jobs +type SendEmailWorker struct { + emailService EmailService + logger Logger +} + +// NewSendEmailWorker creates a new SendEmailWorker +func NewSendEmailWorker(emailService EmailService, logger Logger) *SendEmailWorker { + return &SendEmailWorker{ + emailService: emailService, + logger: logger, + } +} + +// Work is the River work function for sending emails +func (w *SendEmailWorker) Work(ctx context.Context, job *river.Job[SendEmailArgs]) error { + args := job.Args + + // Validate required fields + if len(args.To) == 0 { + return fmt.Errorf("email job requires at least one recipient") + } + if strings.TrimSpace(args.Subject) == "" { + return fmt.Errorf("email job requires a subject") + } + if strings.TrimSpace(args.HTMLBody) == "" && strings.TrimSpace(args.TextBody) == "" { + return fmt.Errorf("email job requires either HTML body or text body") + } + + w.logger.Infow("Processing send email job", + "job_id", job.ID, + "to", args.To, + "subject", args.Subject, + ) + + // Convert args to email Message + message := &types.Message{ + From: args.From, + To: args.To, + Cc: args.Cc, + Bcc: args.Bcc, + Subject: args.Subject, + HTMLBody: args.HTMLBody, + TextBody: args.TextBody, + Attachments: args.Attachments, + Headers: args.Headers, + } + + // Send the email + result, err := w.emailService.Send(ctx, message) + if err != nil { + w.logger.Errorw("Failed to send email", + "job_id", job.ID, + "error", err, + ) + return fmt.Errorf("failed to send email: %w", err) + } + + if !result.Success { + w.logger.Errorw("Email send failed", + "job_id", job.ID, + "error", result.Error, + ) + return fmt.Errorf("email send failed: %s", result.Error) + } + + w.logger.Infow("Email sent successfully", + "job_id", job.ID, + "message_id", result.MessageID, + ) + + return nil +} + +// SendEmailFromWorker handles sending email from provider jobs +type SendEmailFromWorker struct { + emailService EmailService + logger Logger +} + +// NewSendEmailFromWorker creates a new SendEmailFromWorker +func NewSendEmailFromWorker(emailService EmailService, logger Logger) *SendEmailFromWorker { + return &SendEmailFromWorker{ + emailService: emailService, + logger: logger, + } +} + +// Work is the River work function for sending emails from a provider +func (w *SendEmailFromWorker) Work(ctx context.Context, job *river.Job[SendEmailFromArgs]) error { + args := job.Args + + // Validate required fields + if strings.TrimSpace(args.Provider) == "" { + return fmt.Errorf("email from provider job requires a provider name") + } + if len(args.To) == 0 { + return fmt.Errorf("email job requires at least one recipient") + } + if strings.TrimSpace(args.Subject) == "" { + return fmt.Errorf("email job requires a subject") + } + if strings.TrimSpace(args.HTMLBody) == "" && strings.TrimSpace(args.TextBody) == "" { + return fmt.Errorf("email job requires either HTML body or text body") + } + + w.logger.Infow("Processing send email from provider job", + "job_id", job.ID, + "provider", args.Provider, + "to", args.To, + "subject", args.Subject, + ) + + // Convert args to email Message + message := &types.Message{ + From: args.From, + To: args.To, + Cc: args.Cc, + Bcc: args.Bcc, + Subject: args.Subject, + HTMLBody: args.HTMLBody, + TextBody: args.TextBody, + Attachments: args.Attachments, + Headers: args.Headers, + } + + // Send the email using the specified provider + result, err := w.emailService.SendWithProvider(ctx, args.Provider, message) + if err != nil { + w.logger.Errorw("Failed to send email from provider", + "job_id", job.ID, + "provider", args.Provider, + "error", err, + ) + return fmt.Errorf("failed to send email from provider %s: %w", args.Provider, err) + } + + if !result.Success { + w.logger.Errorw("Email send failed from provider", + "job_id", job.ID, + "provider", args.Provider, + "error", result.Error, + ) + return fmt.Errorf("email send failed from provider %s: %s", args.Provider, result.Error) + } + + w.logger.Infow("Email sent successfully from provider", + "job_id", job.ID, + "provider", args.Provider, + "message_id", result.MessageID, + ) + + return nil +} + +// SendGlobalDigestWorker handles sending global digest jobs +type SendGlobalDigestWorker struct { + digestService DigestService + logger Logger +} + +// NewSendGlobalDigestWorker creates a new SendGlobalDigestWorker +func NewSendGlobalDigestWorker(digestService DigestService, logger Logger) *SendGlobalDigestWorker { + return &SendGlobalDigestWorker{ + digestService: digestService, + logger: logger, + } +} + +// Work is the River work function for sending global digest +func (w *SendGlobalDigestWorker) Work(ctx context.Context, job *river.Job[SendGlobalDigestArgs]) error { + w.logger.Infow("Processing global digest job", "job_id", job.ID) + + // Send the global digest + if err := w.digestService.SendGlobalDigest(ctx); err != nil { + w.logger.Errorw("Failed to send global digest", + "job_id", job.ID, + "error", err, + ) + return fmt.Errorf("failed to send global digest: %w", err) + } + + w.logger.Infow("Global digest sent successfully", "job_id", job.ID) + return nil +} + +// JobInsertOptions returns common insert options for email jobs +func JobInsertOptions() *river.InsertOpts { + return &river.InsertOpts{ + Queue: "email", // Default queue for email jobs + MaxAttempts: 5, // Retry up to 5 times + // River uses exponential backoff by default + } +} + +// JobInsertOptionsWithQueue returns insert options for jobs with specified queue +func JobInsertOptionsWithQueue(queue string) *river.InsertOpts { + return &river.InsertOpts{ + Queue: queue, + MaxAttempts: 5, // Retry up to 5 times + // River uses exponential backoff by default + } +} + +// JobInsertOptionsWithRetry returns insert options for jobs with custom retry policy +func JobInsertOptionsWithRetry(queue string, maxAttempts int) *river.InsertOpts { + return &river.InsertOpts{ + Queue: queue, + MaxAttempts: maxAttempts, + // River uses exponential backoff by default + } +} + +// Workers returns all workers as work functions with dependencies injected +func Workers(emailService EmailService, digestService DigestService, logger Logger) *river.Workers { + workers := river.NewWorkers() + + // Create worker instances with dependencies + sendEmailWorker := NewSendEmailWorker(emailService, logger) + sendEmailFromWorker := NewSendEmailFromWorker(emailService, logger) + // Register workers with their Work methods + river.AddWorker(workers, river.WorkFunc(sendEmailWorker.Work)) + river.AddWorker(workers, river.WorkFunc(sendEmailFromWorker.Work)) + + // Only create and register the global digest worker if the digest service is available + if digestService != nil { + sendGlobalDigestWorker := NewSendGlobalDigestWorker(digestService, logger) + river.AddWorker(workers, river.WorkFunc(sendGlobalDigestWorker.Work)) + } + + return workers +} diff --git a/internal/service/worker/service.go b/internal/service/worker/service.go new file mode 100644 index 00000000..565d3f0e --- /dev/null +++ b/internal/service/worker/service.go @@ -0,0 +1,332 @@ +package worker + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/compliance-framework/api/internal/config" + "github.com/compliance-framework/api/internal/service/email" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/riverqueue/river" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivermigrate" + "github.com/robfig/cron/v3" + "go.uber.org/zap" + "gorm.io/driver/postgres" + "gorm.io/gorm" +) + +// Service manages the River client and workers +type Service struct { + client *river.Client[pgx.Tx] + config *config.WorkerConfig + db *gorm.DB + emailSvc *email.Service + digestSvc DigestService + logger *zap.SugaredLogger + started bool + startedMu sync.RWMutex + pgxPool *pgxpool.Pool + digestCfg *config.Config +} + +// NewService creates a new worker service +func NewService( + cfg *config.WorkerConfig, + db *gorm.DB, + emailSvc *email.Service, + logger *zap.SugaredLogger, +) (*Service, error) { + return NewServiceWithDigest(cfg, db, emailSvc, nil, nil, logger) +} + +// NewServiceWithDigest creates a new worker service with digest support +func NewServiceWithDigest( + cfg *config.WorkerConfig, + db *gorm.DB, + emailSvc *email.Service, + digestSvc DigestService, + digestCfg *config.Config, + logger *zap.SugaredLogger, +) (*Service, error) { + if !cfg.Enabled { + logger.Info("Worker service is disabled") + return &Service{ + config: cfg, + db: db, + emailSvc: emailSvc, + digestSvc: digestSvc, + digestCfg: digestCfg, + logger: logger, + started: false, + }, nil + } + + if emailSvc == nil { + return nil, fmt.Errorf("email service is required for worker service") + } + + // Get pgx pool from GORM + // Note: Creating a separate pgx pool for River workers is acceptable here because: + // 1. River requires a pgxpool.Pool specifically, not GORM's generic interface + // 2. We use conservative pool settings to avoid exhaustion + // 3. The pools share the same database but operate independently + var pgxPool *pgxpool.Pool + // Since GORM's ConnPool doesn't directly expose pgxpool.Pool, + // we need to create a new pool from the DSN + dialector, ok := db.Dialector.(*postgres.Dialector) + if !ok { + return nil, fmt.Errorf("worker service requires a postgres dialector, got %T", db.Dialector) + } + dsn := dialector.DSN + + // Configure pgx pool with conservative settings to avoid connection exhaustion + poolConfig, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("failed to parse pgx pool config: %w", err) + } + // Limit connections to avoid exhausting database connections + // Use a small fraction of typical connection limits + poolConfig.MaxConns = 10 // Conservative limit for worker pool + poolConfig.MinConns = 2 // Keep minimum connections warm + + pool, err := pgxpool.NewWithConfig(context.Background(), poolConfig) + if err != nil { + return nil, fmt.Errorf("failed to create pgx pool: %w", err) + } + pgxPool = pool + + // Register workers with dependencies injected + workers := Workers(emailSvc, digestSvc, logger) + + // Configure periodic jobs + var periodicJobs []*river.PeriodicJob + if digestCfg != nil && digestCfg.DigestEnabled { + periodicJobs = append(periodicJobs, NewDigestPeriodicJob(digestCfg.DigestSchedule, logger)) + } + + // Create River client with pgxv5 driver + riverConfig := river.Config{ + Queues: map[string]river.QueueConfig{ + "email": { + MaxWorkers: cfg.Workers, + }, + "digest": { + MaxWorkers: 1, // Only one digest worker to avoid duplicates + }, + }, + Workers: workers, + PeriodicJobs: periodicJobs, + } + + // Create the client + client, err := river.NewClient(riverpgxv5.New(pgxPool), &riverConfig) + if err != nil { + return nil, fmt.Errorf("failed to create River client: %w", err) + } + + service := &Service{ + client: client, + config: cfg, + db: db, + emailSvc: emailSvc, + digestSvc: digestSvc, + digestCfg: digestCfg, + logger: logger, + started: false, + pgxPool: pgxPool, + } + + return service, nil +} + +// Start starts the worker service +func (s *Service) Start(ctx context.Context) error { + if !s.config.Enabled { + s.logger.Info("Worker service is disabled, not starting") + return nil + } + + s.startedMu.Lock() + defer s.startedMu.Unlock() + + if s.started { + s.logger.Warn("Worker service is already started") + return nil + } + + s.logger.Infow("Starting worker service", + "workers", s.config.Workers, + "queue", s.config.Queue, + ) + + // Start the workers with the provided context (no dependency injection needed) + if err := s.client.Start(ctx); err != nil { + return fmt.Errorf("failed to start River client: %w", err) + } + + s.started = true + s.logger.Info("Worker service started successfully") + return nil +} + +// Stop stops the worker service +func (s *Service) Stop(ctx context.Context) error { + s.startedMu.Lock() + defer s.startedMu.Unlock() + + if !s.config.Enabled || !s.started { + s.logger.Info("Worker service is not running") + return nil + } + + s.logger.Info("Stopping worker service") + + // Stop the client with a graceful shutdown period + stopCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + if err := s.client.Stop(stopCtx); err != nil { + s.logger.Errorw("Failed to stop River client gracefully", "error", err) + return fmt.Errorf("failed to stop River client: %w", err) + } + + // Close pgx pool + if s.pgxPool != nil { + s.pgxPool.Close() + } + + s.started = false + s.logger.Info("Worker service stopped") + return nil +} + +// IsStarted returns true if the worker service is started +func (s *Service) IsStarted() bool { + s.startedMu.RLock() + defer s.startedMu.RUnlock() + return s.started +} + +// GetClient returns the River client for job insertion +func (s *Service) GetClient() *river.Client[pgx.Tx] { + return s.client +} + +// Migrate runs River migrations +func (s *Service) Migrate(ctx context.Context) error { + if !s.config.Enabled { + s.logger.Info("Worker service is disabled, skipping migration") + return nil + } + + if s.pgxPool == nil { + return fmt.Errorf("pgx pool is not initialized - worker service may be disabled") + } + + // Get migrator from the driver + migrator, err := rivermigrate.New(riverpgxv5.New(s.pgxPool), &rivermigrate.Config{}) + if err != nil { + return fmt.Errorf("failed to create migrator: %w", err) + } + + _, err = migrator.Migrate(ctx, rivermigrate.DirectionUp, &rivermigrate.MigrateOpts{}) + if err != nil { + return fmt.Errorf("failed to run River migrations: %w", err) + } + + s.logger.Info("River migrations completed successfully") + return nil +} + +// NewDigestPeriodicJob creates a periodic job for digest scheduling +func NewDigestPeriodicJob(cronSchedule string, logger *zap.SugaredLogger) *river.PeriodicJob { + // Parse the cron schedule using robfig/cron + parser := cron.NewParser(cron.Second | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor) + schedule, err := parser.Parse(cronSchedule) + if err != nil { + logger.Errorw("Failed to parse digest cron schedule, using default @weekly", "schedule", cronSchedule, "error", err) + // Fallback to weekly schedule + schedule, _ = parser.Parse("@weekly") + } + + return river.NewPeriodicJob( + schedule, + func() (river.JobArgs, *river.InsertOpts) { + return &SendGlobalDigestArgs{}, &river.InsertOpts{ + Queue: "digest", + MaxAttempts: 3, + } + }, + &river.PeriodicJobOpts{ + RunOnStart: false, // Don't run immediately on startup + }, + ) +} + +// EnqueueSendEmail enqueues a send email job +func (s *Service) EnqueueSendEmail(ctx context.Context, args *SendEmailArgs) error { + if !s.config.Enabled { + return fmt.Errorf("worker service is disabled") + } + + if s.client == nil { + return fmt.Errorf("worker client is not initialized") + } + + // Use configured queue or default to "email" + queue := s.config.Queue + if queue == "" { + queue = "email" + } + + // Use configured retry policy or default to 5 attempts + maxAttempts := s.config.RetryPolicy.MaxAttempts + if maxAttempts == 0 { + maxAttempts = 5 + } + + // Use InsertMany which doesn't require a transaction + _, err := s.client.InsertMany(ctx, []river.InsertManyParams{ + {Args: args, InsertOpts: JobInsertOptionsWithRetry(queue, maxAttempts)}, + }) + if err != nil { + return fmt.Errorf("failed to enqueue send email job: %w", err) + } + return nil +} + +// EnqueueSendEmailFrom enqueues a send email from provider job +func (s *Service) EnqueueSendEmailFrom(ctx context.Context, args *SendEmailFromArgs) error { + if !s.config.Enabled { + return fmt.Errorf("worker service is disabled") + } + + if s.client == nil { + return fmt.Errorf("worker client is not initialized") + } + + // Use configured queue or default to "email" + queue := s.config.Queue + if queue == "" { + queue = "email" + } + + // Use configured retry policy or default to 5 attempts + maxAttempts := s.config.RetryPolicy.MaxAttempts + if maxAttempts == 0 { + maxAttempts = 5 + } + + // Use InsertMany which doesn't require a transaction + _, err := s.client.InsertMany(ctx, []river.InsertManyParams{ + {Args: args, InsertOpts: JobInsertOptionsWithRetry(queue, maxAttempts)}, + }) + if err != nil { + return fmt.Errorf("failed to enqueue send email from job: %w", err) + } + return nil +} diff --git a/internal/service/worker/service_test.go b/internal/service/worker/service_test.go new file mode 100644 index 00000000..91bdb055 --- /dev/null +++ b/internal/service/worker/service_test.go @@ -0,0 +1,345 @@ +package worker + +import ( + "context" + "testing" + + "github.com/compliance-framework/api/internal/config" + "github.com/compliance-framework/api/internal/service/email/types" + "github.com/riverqueue/river" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "go.uber.org/zap" +) + +// MockEmailService is a mock implementation of EmailService +type MockEmailService struct { + mock.Mock +} + +func (m *MockEmailService) Send(ctx context.Context, message *types.Message) (*types.SendResult, error) { + args := m.Called(ctx, message) + return args.Get(0).(*types.SendResult), args.Error(1) +} + +func (m *MockEmailService) SendWithProvider(ctx context.Context, providerName string, message *types.Message) (*types.SendResult, error) { + args := m.Called(ctx, providerName, message) + return args.Get(0).(*types.SendResult), args.Error(1) +} + +// MockDigestService is a mock implementation of DigestService +type MockDigestService struct { + mock.Mock +} + +func (m *MockDigestService) SendGlobalDigest(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +// MockLogger is a mock implementation of Logger +type MockLogger struct { + mock.Mock + loggedMessages []string +} + +func (m *MockLogger) Infow(msg string, keysAndValues ...interface{}) { + m.Called(msg, keysAndValues) + m.loggedMessages = append(m.loggedMessages, "INFO: "+msg) +} + +func (m *MockLogger) Errorw(msg string, keysAndValues ...interface{}) { + m.Called(msg, keysAndValues) + m.loggedMessages = append(m.loggedMessages, "ERROR: "+msg) +} + +func (m *MockLogger) Warnw(msg string, keysAndValues ...interface{}) { + m.Called(msg, keysAndValues) + m.loggedMessages = append(m.loggedMessages, "WARN: "+msg) +} + +func (m *MockLogger) Debugw(msg string, keysAndValues ...interface{}) { + m.Called(msg, keysAndValues) + m.loggedMessages = append(m.loggedMessages, "DEBUG: "+msg) +} + +func TestNewService_Disabled(t *testing.T) { + cfg := &config.WorkerConfig{ + Enabled: false, + } + logger := zap.NewNop().Sugar() + + service, err := NewService(cfg, nil, nil, logger) + assert.NoError(t, err) + assert.NotNil(t, service) + assert.False(t, service.IsStarted()) +} + +func TestNewService_RequiresEmailService(t *testing.T) { + cfg := &config.WorkerConfig{ + Enabled: true, + Workers: 5, + Queue: "email", + } + logger := zap.NewNop().Sugar() + + service, err := NewService(cfg, nil, nil, logger) + assert.Error(t, err) + assert.Nil(t, service) + assert.Contains(t, err.Error(), "email service is required") +} + +func TestService_EnqueueWhenDisabled(t *testing.T) { + cfg := &config.WorkerConfig{ + Enabled: false, + } + logger := zap.NewNop().Sugar() + + service, err := NewService(cfg, nil, nil, logger) + assert.NoError(t, err) + + ctx := context.Background() + args := &SendEmailArgs{ + To: []string{"test@example.com"}, + Subject: "Test", + } + + err = service.EnqueueSendEmail(ctx, args) + assert.Error(t, err) + assert.Contains(t, err.Error(), "worker service is disabled") +} + +func TestNewSendEmailWorker(t *testing.T) { + mockEmailService := &MockEmailService{} + mockLogger := &MockLogger{} + + worker := NewSendEmailWorker(mockEmailService, mockLogger) + + assert.NotNil(t, worker) + assert.Equal(t, mockEmailService, worker.emailService) + assert.Equal(t, mockLogger, worker.logger) +} + +func TestSendEmailWorker_MessageConstruction(t *testing.T) { + mockEmailService := &MockEmailService{} + + ctx := context.Background() + args := &SendEmailArgs{ + To: []string{"test@example.com"}, + Subject: "Test Subject", + HTMLBody: "
Test Body
", + From: "sender@example.com", + Cc: []string{"cc@example.com"}, + Bcc: []string{"bcc@example.com"}, + TextBody: "Plain text body", + } + + // Test message construction logic + message := &types.Message{ + From: args.From, + To: args.To, + Cc: args.Cc, + Bcc: args.Bcc, + Subject: args.Subject, + HTMLBody: args.HTMLBody, + TextBody: args.TextBody, + Attachments: args.Attachments, + Headers: args.Headers, + } + + // Verify the message is constructed correctly + assert.Equal(t, args.From, message.From) + assert.Equal(t, args.To, message.To) + assert.Equal(t, args.Cc, message.Cc) + assert.Equal(t, args.Bcc, message.Bcc) + assert.Equal(t, args.Subject, message.Subject) + assert.Equal(t, args.HTMLBody, message.HTMLBody) + assert.Equal(t, args.TextBody, message.TextBody) + assert.Equal(t, args.Attachments, message.Attachments) + assert.Equal(t, args.Headers, message.Headers) + + // Set up mock expectations + mockEmailService.On("Send", ctx, message).Return(&types.SendResult{ + Success: true, + MessageID: "test-message-id", + }, nil) + + // Call the mock service to verify it works + result, err := mockEmailService.Send(ctx, message) + assert.NoError(t, err) + assert.True(t, result.Success) + assert.Equal(t, "test-message-id", result.MessageID) + + mockEmailService.AssertExpectations(t) +} + +func TestSendEmailWorker_Work_Validation(t *testing.T) { + mockEmailService := &MockEmailService{} + mockLogger := &MockLogger{} + worker := NewSendEmailWorker(mockEmailService, mockLogger) + + ctx := context.Background() + + tests := []struct { + name string + args *SendEmailArgs + expectError string + }{ + { + name: "missing recipients", + args: &SendEmailArgs{ + Subject: "Test", + HTMLBody: "Test
", + }, + expectError: "email job requires at least one recipient", + }, + { + name: "missing subject", + args: &SendEmailArgs{ + To: []string{"test@example.com"}, + HTMLBody: "Test
", + }, + expectError: "email job requires a subject", + }, + { + name: "missing body", + args: &SendEmailArgs{ + To: []string{"test@example.com"}, + Subject: "Test", + }, + expectError: "email job requires either HTML body or text body", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Set up mock logger to expect any call + mockLogger.On("Infow", "Processing send email job", mock.Anything).Maybe() + + // Create a test job with the invalid args + job := &river.Job[SendEmailArgs]{ + Args: *tt.args, + } + + // Call the actual Work method and expect validation error + err := worker.Work(ctx, job) + + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.expectError) + }) + } +} + +func TestNewSendEmailFromWorker(t *testing.T) { + mockEmailService := &MockEmailService{} + mockLogger := &MockLogger{} + + worker := NewSendEmailFromWorker(mockEmailService, mockLogger) + + assert.NotNil(t, worker) + assert.Equal(t, mockEmailService, worker.emailService) + assert.Equal(t, mockLogger, worker.logger) +} + +func TestSendEmailFromWorker_MessageConstruction(t *testing.T) { + mockEmailService := &MockEmailService{} + + ctx := context.Background() + args := &SendEmailFromArgs{ + Provider: "smtp", + To: []string{"test@example.com"}, + Subject: "Test Subject", + HTMLBody: "Test Body
", + From: "sender@example.com", + } + + // Test message construction logic + message := &types.Message{ + From: args.From, + To: args.To, + Cc: args.Cc, + Bcc: args.Bcc, + Subject: args.Subject, + HTMLBody: args.HTMLBody, + TextBody: args.TextBody, + Attachments: args.Attachments, + Headers: args.Headers, + } + + // Verify the message is constructed correctly + assert.Equal(t, args.From, message.From) + assert.Equal(t, args.To, message.To) + assert.Equal(t, args.Subject, message.Subject) + assert.Equal(t, args.HTMLBody, message.HTMLBody) + + // Set up mock expectations + mockEmailService.On("SendWithProvider", ctx, "smtp", message).Return(&types.SendResult{ + Success: true, + MessageID: "test-message-id", + }, nil) + + // Call the mock service to verify it works + result, err := mockEmailService.SendWithProvider(ctx, "smtp", message) + assert.NoError(t, err) + assert.True(t, result.Success) + assert.Equal(t, "test-message-id", result.MessageID) + + mockEmailService.AssertExpectations(t) +} + +func TestNewSendGlobalDigestWorker(t *testing.T) { + mockDigestService := &MockDigestService{} + mockLogger := &MockLogger{} + + worker := NewSendGlobalDigestWorker(mockDigestService, mockLogger) + + assert.NotNil(t, worker) + assert.Equal(t, mockDigestService, worker.digestService) + assert.Equal(t, mockLogger, worker.logger) +} + +func TestSendGlobalDigestWorker_DigestCall(t *testing.T) { + mockDigestService := &MockDigestService{} + + ctx := context.Background() + + // Set up mock expectations + mockDigestService.On("SendGlobalDigest", ctx).Return(nil) + + // Call the mock service to verify it works + err := mockDigestService.SendGlobalDigest(ctx) + assert.NoError(t, err) + + mockDigestService.AssertExpectations(t) +} + +func TestWorkers(t *testing.T) { + mockEmailService := &MockEmailService{} + mockDigestService := &MockDigestService{} + mockLogger := &MockLogger{} + + workers := Workers(mockEmailService, mockDigestService, mockLogger) + + assert.NotNil(t, workers) +} + +func TestJobInsertOptions(t *testing.T) { + opts := JobInsertOptions() + + assert.Equal(t, "email", opts.Queue) + assert.Equal(t, 5, opts.MaxAttempts) +} + +func TestJobInsertOptionsWithQueue(t *testing.T) { + opts := JobInsertOptionsWithQueue("custom-queue") + + assert.Equal(t, "custom-queue", opts.Queue) + assert.Equal(t, 5, opts.MaxAttempts) +} + +func TestJobInsertOptionsWithRetry(t *testing.T) { + opts := JobInsertOptionsWithRetry("custom-queue", 10) + + assert.Equal(t, "custom-queue", opts.Queue) + assert.Equal(t, 10, opts.MaxAttempts) +}